module Database.PostgreSQL.PQTypes.Checks (
checkDatabase
, checkDatabaseAllowUnknownTables
, createTable
, createDomain
, ExtrasOptions(..)
, migrateDatabase
) where
import Control.Applicative ((<$>))
import Control.Monad.Catch
import Control.Monad.Reader
import Data.Int
import Data.Function (on)
import Data.Maybe
import Data.Monoid
import Data.Monoid.Utils
import Data.Ord (comparing)
import qualified Data.String
import Data.Text (Text)
import Database.PostgreSQL.PQTypes hiding (def)
import Log
import Prelude
import TextShow
import qualified Data.List as L
import qualified Data.Text as T
import Database.PostgreSQL.PQTypes.ExtrasOptions
import Database.PostgreSQL.PQTypes.Checks.Util
import Database.PostgreSQL.PQTypes.Migrate
import Database.PostgreSQL.PQTypes.Model
import Database.PostgreSQL.PQTypes.SQL.Builder
import Database.PostgreSQL.PQTypes.Versions
headExc :: String -> [a] -> a
headExc s [] = error s
headExc _ (x:_) = x
migrateDatabase
:: (MonadDB m, MonadLog m, MonadThrow m)
=> ExtrasOptions -> [Extension] -> [Domain] -> [Table] -> [Migration m]
-> m ()
migrateDatabase options@ExtrasOptions{..} extensions domains tables migrations = do
setDBTimeZoneToUTC
mapM_ checkExtension extensions
checkDBConsistency options domains (tableVersions : tables) migrations
resultCheck =<< checkDomainsStructure domains
resultCheck =<< checkDBStructure options (tableVersions : tables)
resultCheck =<< checkTablesWereDropped migrations
resultCheck =<< checkUnknownTables tables
resultCheck =<< checkExistenceOfVersionsForTables (tableVersions : tables)
commit
checkDatabase
:: forall m . (MonadDB m, MonadLog m, MonadThrow m)
=> ExtrasOptions -> [Domain] -> [Table] -> m ()
checkDatabase options = checkDatabase_ options False
checkDatabaseAllowUnknownTables
:: forall m . (MonadDB m, MonadLog m, MonadThrow m)
=> ExtrasOptions -> [Domain] -> [Table] -> m ()
checkDatabaseAllowUnknownTables options = checkDatabase_ options True
checkDatabase_
:: forall m . (MonadDB m, MonadLog m, MonadThrow m)
=> ExtrasOptions -> Bool -> [Domain] -> [Table] -> m ()
checkDatabase_ options allowUnknownTables domains tables = do
tablesWithVersions <- getTableVersions tables
resultCheck $ checkVersions tablesWithVersions
resultCheck =<< checkDomainsStructure domains
resultCheck =<< checkDBStructure options (tableVersions : tables)
when (not $ allowUnknownTables) $ do
resultCheck =<< checkUnknownTables tables
resultCheck =<< checkExistenceOfVersionsForTables (tableVersions : tables)
resultCheck =<< checkInitialSetups tables
where
checkVersions :: [(Table, Int32)] -> ValidationResult
checkVersions vs = mconcat . map (ValidationResult . checkVersion) $ vs
checkVersion :: (Table, Int32) -> [Text]
checkVersion (t@Table{..}, v)
| tblVersion == v = []
| v == 0 = ["Table '" <> tblNameText t <> "' must be created"]
| otherwise = ["Table '" <> tblNameText t
<> "' must be migrated" <+> showt v <+> "->"
<+> showt tblVersion]
checkInitialSetups :: [Table] -> m ValidationResult
checkInitialSetups tbls =
liftM mconcat . mapM (liftM ValidationResult . checkInitialSetup') $ tbls
checkInitialSetup' :: Table -> m [Text]
checkInitialSetup' t@Table{..} = case tblInitialSetup of
Nothing -> return []
Just is -> checkInitialSetup is >>= \case
True -> return []
False -> return ["Initial setup for table '"
<> tblNameText t <> "' is not valid"]
currentCatalog :: (MonadDB m, MonadThrow m) => m (RawSQL ())
currentCatalog = do
runSQL_ "SELECT current_catalog::text"
dbname <- fetchOne runIdentity
return $ unsafeSQL $ "\"" ++ dbname ++ "\""
checkExtension :: (MonadDB m, MonadLog m, MonadThrow m) => Extension -> m ()
checkExtension (Extension extension) = do
logInfo_ $ "Checking for extension '" <> txtExtension <> "'"
extensionExists <- runQuery01 . sqlSelect "pg_extension" $ do
sqlResult "TRUE"
sqlWhereEq "extname" $ unRawSQL extension
if not extensionExists
then do
logInfo_ $ "Creating extension '" <> txtExtension <> "'"
runSQL_ $ "CREATE EXTENSION IF NOT EXISTS" <+> raw extension
else logInfo_ $ "Extension '" <> txtExtension <> "' exists"
where
txtExtension = unRawSQL extension
setDBTimeZoneToUTC :: (MonadDB m, MonadLog m, MonadThrow m) => m ()
setDBTimeZoneToUTC = do
runSQL_ "SHOW timezone"
timezone :: String <- fetchOne runIdentity
when (timezone /= "UTC") $ do
dbname <- currentCatalog
logInfo_ $ "Setting '" <> unRawSQL dbname
<> "' database to return timestamps in UTC"
runQuery_ $ "ALTER DATABASE" <+> dbname <+> "SET TIMEZONE = 'UTC'"
getDBTableNames :: (MonadDB m) => m [Text]
getDBTableNames = do
runQuery_ $ sqlSelect "information_schema.tables" $ do
sqlResult "table_name::text"
sqlWhere "table_name <> 'table_versions'"
sqlWhere "table_type = 'BASE TABLE'"
sqlWhereExists $ sqlSelect "unnest(current_schemas(false)) as cs" $ do
sqlResult "TRUE"
sqlWhere "cs = table_schema"
dbTableNames <- fetchMany runIdentity
return dbTableNames
checkUnknownTables :: (MonadDB m, MonadLog m) => [Table] -> m ValidationResult
checkUnknownTables tables = do
dbTableNames <- getDBTableNames
let tableNames = map (unRawSQL . tblName) tables
absent = dbTableNames L.\\ tableNames
notPresent = tableNames L.\\ dbTableNames
if (not . null $ absent) || (not . null $ notPresent)
then do
mapM_ (logInfo_ . (<+>) "Unknown table:") absent
mapM_ (logInfo_ . (<+>) "Table not present in the database:") notPresent
return . ValidationResult $
(joinedResult "Unknown tables:" absent) ++
(joinedResult "Tables not present in the database:" notPresent)
else return mempty
where
joinedResult :: Text -> [Text] -> [Text]
joinedResult _ [] = []
joinedResult t ts = [ t <+> T.intercalate ", " ts]
checkExistenceOfVersionsForTables :: (MonadDB m, MonadLog m) => [Table] -> m ValidationResult
checkExistenceOfVersionsForTables tables = do
runQuery_ $ sqlSelect "table_versions" $ do
sqlResult "name::text"
(existingTableNames :: [Text]) <- fetchMany runIdentity
let tableNames = map (unRawSQL . tblName) tables
absent = existingTableNames L.\\ tableNames
notPresent = tableNames L.\\ existingTableNames
if (not . null $ absent) || (not . null $ notPresent)
then do
mapM_ (logInfo_ . (<+>) "Unknown entry in 'table_versions':") absent
mapM_ (logInfo_ . (<+>) "Table not present in the 'table_versions':") notPresent
return . ValidationResult $
(joinedResult "Unknown entry in table_versions':" absent ) ++
(joinedResult "Tables not present in the 'table_versions':" notPresent)
else return mempty
where
joinedResult :: Text -> [Text] -> [Text]
joinedResult _ [] = []
joinedResult t ts = [ t <+> T.intercalate ", " ts]
checkDomainsStructure :: (MonadDB m, MonadThrow m)
=> [Domain] -> m ValidationResult
checkDomainsStructure defs = fmap mconcat . forM defs $ \def -> do
runQuery_ . sqlSelect "pg_catalog.pg_type t1" $ do
sqlResult "t1.typname::text"
sqlResult "(SELECT pg_catalog.format_type(t2.oid, t2.typtypmod) FROM pg_catalog.pg_type t2 WHERE t2.oid = t1.typbasetype)"
sqlResult "NOT t1.typnotnull"
sqlResult "t1.typdefault"
sqlResult "ARRAY(SELECT c.conname::text FROM pg_catalog.pg_constraint c WHERE c.contypid = t1.oid ORDER by c.oid)"
sqlResult "ARRAY(SELECT regexp_replace(pg_get_constraintdef(c.oid, true), 'CHECK \\((.*)\\)', '\\1') FROM pg_catalog.pg_constraint c WHERE c.contypid = t1.oid ORDER by c.oid)"
sqlWhereEq "t1.typname" $ unRawSQL $ domName def
mdom <- fetchMaybe $ \(dname, dtype, nullable, defval, cnames, conds) ->
Domain {
domName = unsafeSQL dname
, domType = dtype
, domNullable = nullable
, domDefault = unsafeSQL <$> defval
, domChecks = mkChecks $ zipWith (\cname cond -> Check {
chkName = unsafeSQL cname
, chkCondition = unsafeSQL cond
}) (unArray1 cnames) (unArray1 conds)
}
return $ case mdom of
Just dom
| dom /= def -> topMessage "domain" (unRawSQL $ domName dom) $ mconcat [
compareAttr dom def "name" domName
, compareAttr dom def "type" domType
, compareAttr dom def "nullable" domNullable
, compareAttr dom def "default" domDefault
, compareAttr dom def "checks" domChecks
]
| otherwise -> mempty
Nothing -> ValidationResult ["Domain '" <> unRawSQL (domName def)
<> "' doesn't exist in the database"]
where
compareAttr :: (Eq a, Show a)
=> Domain -> Domain -> Text -> (Domain -> a) -> ValidationResult
compareAttr dom def attrname attr
| attr dom == attr def = ValidationResult []
| otherwise = ValidationResult
[ "Attribute '" <> attrname
<> "' does not match (database:" <+> T.pack (show $ attr dom)
<> ", definition:" <+> T.pack (show $ attr def) <> ")" ]
checkTablesWereDropped :: (MonadDB m, MonadThrow m) =>
[Migration m] -> m ValidationResult
checkTablesWereDropped mgrs = do
let droppedTableNames = [ mgrTableName mgr
| mgr <- mgrs, isDropTableMigration mgr ]
fmap mconcat . forM droppedTableNames $
\tblName -> do
mver <- checkTableVersion (T.unpack . unRawSQL $ tblName)
return $ if isNothing mver
then mempty
else ValidationResult [ "The table '" <> unRawSQL tblName
<> "' that must have been dropped"
<> " is still present in the database." ]
checkDBStructure :: forall m. (MonadDB m, MonadThrow m)
=> ExtrasOptions -> [Table] -> m ValidationResult
checkDBStructure options tables = fmap mconcat . forM tables $ \table ->
topMessage "table" (tblNameText table) <$> checkTableStructure table
where
checkTableStructure :: Table -> m ValidationResult
checkTableStructure table@Table{..} = do
runQuery_ $ sqlSelect "pg_catalog.pg_attribute a" $ do
sqlResult "a.attname::text"
sqlResult "pg_catalog.format_type(a.atttypid, a.atttypmod)"
sqlResult "NOT a.attnotnull"
sqlResult . parenthesize . toSQLCommand $
sqlSelect "pg_catalog.pg_attrdef d" $ do
sqlResult "pg_catalog.pg_get_expr(d.adbin, d.adrelid)"
sqlWhere "d.adrelid = a.attrelid"
sqlWhere "d.adnum = a.attnum"
sqlWhere "a.atthasdef"
sqlWhere "a.attnum > 0"
sqlWhere "NOT a.attisdropped"
sqlWhereEqSql "a.attrelid" $ sqlGetTableID table
sqlOrderBy "a.attnum"
desc <- fetchMany fetchTableColumn
pk <- sqlGetPrimaryKey table
runQuery_ $ sqlGetChecks table
checks <- fetchMany fetchTableCheck
runQuery_ $ sqlGetIndexes table
indexes <- fetchMany fetchTableIndex
runQuery_ $ sqlGetForeignKeys table
fkeys <- fetchMany fetchForeignKey
return $ mconcat [
checkColumns 1 tblColumns desc
, checkPrimaryKey tblPrimaryKey pk
, checkChecks tblChecks checks
, checkIndexes tblIndexes indexes
, checkForeignKeys tblForeignKeys fkeys
]
where
fetchTableColumn :: (String, ColumnType, Bool, Maybe String) -> TableColumn
fetchTableColumn (name, ctype, nullable, mdefault) = TableColumn {
colName = unsafeSQL name
, colType = ctype
, colNullable = nullable
, colDefault = unsafeSQL `liftM` mdefault
}
checkColumns :: Int -> [TableColumn] -> [TableColumn] -> ValidationResult
checkColumns _ [] [] = mempty
checkColumns _ rest [] = ValidationResult [tableHasLess "columns" rest]
checkColumns _ [] rest = ValidationResult [tableHasMore "columns" rest]
checkColumns !n (d:defs) (c:cols) = mconcat [
validateNames $ colName d == colName c
, validateTypes $ colType d == colType c ||
(colType d == BigSerialT && colType c == BigIntT)
, validateDefaults $ colDefault d == colDefault c ||
(colDefault d == Nothing
&& ((T.isPrefixOf "nextval('" . unRawSQL) `liftM` colDefault c)
== Just True)
, validateNullables $ colNullable d == colNullable c
, checkColumns (n+1) defs cols
]
where
validateNames True = mempty
validateNames False = ValidationResult
[ errorMsg ("no. " <> showt n) "names" (unRawSQL . colName) ]
validateTypes True = mempty
validateTypes False = ValidationResult
[ errorMsg cname "types" (T.pack . show . colType)
<+> sqlHint ("TYPE" <+> columnTypeToSQL (colType d)) ]
validateNullables True = mempty
validateNullables False = ValidationResult
[ errorMsg cname "nullables" (showt . colNullable)
<+> sqlHint ((if colNullable d then "DROP" else "SET")
<+> "NOT NULL") ]
validateDefaults True = mempty
validateDefaults False = ValidationResult
[ (errorMsg cname "defaults" (showt . fmap unRawSQL . colDefault))
<+> sqlHint set_default ]
where
set_default = case colDefault d of
Just v -> "SET DEFAULT" <+> v
Nothing -> "DROP DEFAULT"
cname = unRawSQL $ colName d
errorMsg ident attr f =
"Column '" <> ident <> "' differs in"
<+> attr <+> "(table:" <+> f c <> ", definition:" <+> f d <> ")."
sqlHint sql =
"(HINT: SQL for making the change is: ALTER TABLE"
<+> tblNameText table <+> "ALTER COLUMN" <+> unRawSQL (colName d)
<+> unRawSQL sql <> ")"
checkPrimaryKey :: Maybe PrimaryKey -> Maybe (PrimaryKey, RawSQL ())
-> ValidationResult
checkPrimaryKey mdef mpk = mconcat [
checkEquality "PRIMARY KEY" def (map fst pk)
, checkNames (const (pkName tblName)) pk
, if (eoEnforcePKs options)
then checkPKPresence tblName mdef mpk
else mempty
]
where
def = maybeToList mdef
pk = maybeToList mpk
checkChecks :: [Check] -> [Check] -> ValidationResult
checkChecks defs checks = case checkEquality "CHECKs" defs checks of
ValidationResult [] -> ValidationResult []
ValidationResult errmsgs -> ValidationResult $
errmsgs ++ [" (HINT: If checks are equal modulo number of parentheses/whitespaces used in conditions, just copy and paste expected output into source code)"]
checkIndexes :: [TableIndex] -> [(TableIndex, RawSQL ())]
-> ValidationResult
checkIndexes defs indexes = mconcat [
checkEquality "INDEXes" defs (map fst indexes)
, checkNames (indexName tblName) indexes
]
checkForeignKeys :: [ForeignKey] -> [(ForeignKey, RawSQL ())]
-> ValidationResult
checkForeignKeys defs fkeys = mconcat [
checkEquality "FOREIGN KEYs" defs (map fst fkeys)
, checkNames (fkName tblName) fkeys
]
checkDBConsistency
:: forall m. (MonadDB m, MonadLog m, MonadThrow m)
=> ExtrasOptions -> [Domain] -> [Table] -> [Migration m]
-> m ()
checkDBConsistency options domains tables migrations = do
validateMigrations
validateDropTableMigrations
tablesWithVersions <- getTableVersions $ tables
dbTablesWithVersions <- getDBTableVersions
if all ((==) 0 . snd) tablesWithVersions
then do
createDBSchema
initializeDB
else do
validateMigrationsAgainstDB [ (tblName table, tblVersion table, actualVer)
| (table, actualVer) <- tablesWithVersions ]
validateDropTableMigrationsAgainstDB dbTablesWithVersions
runMigrations dbTablesWithVersions
where
errorInvalidMigrations :: [RawSQL ()] -> a
errorInvalidMigrations tblNames =
error $ "checkDBConsistency: invalid migrations for tables"
<+> (L.intercalate ", " $ map (T.unpack . unRawSQL) tblNames)
checkMigrationsListValidity :: Table -> [Int32] -> [Int32] -> m ()
checkMigrationsListValidity table presentMigrationVersions
expectedMigrationVersions = do
when (presentMigrationVersions /= expectedMigrationVersions) $ do
logAttention "Migrations are invalid" $ object [
"table" .= tblNameText table
, "migration_versions" .= presentMigrationVersions
, "expected_migration_versions" .= expectedMigrationVersions
]
errorInvalidMigrations [tblName $ table]
validateMigrations :: m ()
validateMigrations = forM_ tables $ \table -> do
let presentMigrationVersions
= [ mgrFrom | Migration{..} <- migrations
, mgrTableName == tblName table ]
expectedMigrationVersions
= reverse $ take (length presentMigrationVersions) $
reverse [0 .. tblVersion table 1]
checkMigrationsListValidity table presentMigrationVersions
expectedMigrationVersions
validateDropTableMigrations :: m ()
validateDropTableMigrations = do
let droppedTableNames =
[ mgrTableName $ mgr | mgr <- migrations
, isDropTableMigration mgr ]
tableNames =
[ tblName tbl | tbl <- tables ]
let intersection = L.intersect droppedTableNames tableNames
when (not . null $ intersection) $ do
logAttention ("The intersection between tables "
<> "and dropped tables is not empty")
$ object
[ "intersection" .= map unRawSQL intersection ]
errorInvalidMigrations [ tblName tbl | tbl <- tables
, tblName tbl `elem` intersection ]
let migrationsByTable = L.groupBy ((==) `on` mgrTableName)
migrations
dropMigrationLists = [ mgrs | mgrs <- migrationsByTable
, any isDropTableMigration mgrs ]
invalidMigrationLists =
[ mgrs | mgrs <- dropMigrationLists
, (not . isDropTableMigration . last $ mgrs) ||
(length . filter isDropTableMigration $ mgrs) > 1 ]
when (not . null $ invalidMigrationLists) $ do
let tablesWithInvalidMigrationLists =
[ mgrTableName mgr | mgrs <- invalidMigrationLists
, let mgr = head mgrs ]
logAttention ("Migration lists for some tables contain "
<> "either multiple drop table migrations or "
<> "a drop table migration in non-tail position.")
$ object [ "tables" .= [ unRawSQL tblName
| tblName <- tablesWithInvalidMigrationLists ] ]
errorInvalidMigrations tablesWithInvalidMigrationLists
createDBSchema :: m ()
createDBSchema = do
logInfo_ "Creating domains..."
mapM_ createDomain domains
logInfo_ "Creating tables..."
mapM_ (createTable False) tables
logInfo_ "Creating table constraints..."
mapM_ createTableConstraints tables
logInfo_ "Done."
initializeDB :: m ()
initializeDB = do
logInfo_ "Running initial setup for tables..."
forM_ tables $ \t -> case tblInitialSetup t of
Nothing -> return ()
Just tis -> do
logInfo_ $ "Initializing" <+> tblNameText t <> "..."
initialSetup tis
logInfo_ "Done."
validateMigrationsAgainstDB :: [(RawSQL (), Int32, Int32)] -> m ()
validateMigrationsAgainstDB tablesWithVersions
= forM_ tablesWithVersions $ \(tableName, expectedVer, actualVer) ->
when (expectedVer /= actualVer) $
case [ m | m@Migration{..} <- migrations
, mgrTableName == tableName ] of
[] ->
error $ "checkDBConsistency: no migrations found for table '"
++ (T.unpack . unRawSQL $ tableName) ++ "', cannot migrate "
++ show actualVer ++ " -> " ++ show expectedVer
(m:_) | mgrFrom m > actualVer ->
error $ "checkDBConsistency: earliest migration for table '"
++ (T.unpack . unRawSQL $ tableName) ++ "' is from version "
++ show (mgrFrom m) ++ ", cannot migrate "
++ show actualVer ++ " -> " ++ show expectedVer
| otherwise -> return ()
validateDropTableMigrationsAgainstDB :: [(Text, Int32)] -> m ()
validateDropTableMigrationsAgainstDB dbTablesWithVersions = do
let dbTablesToDropWithVersions =
[ (tblName, mgrFrom mgr, fromJust mver)
| mgr <- migrations
, isDropTableMigration mgr
, let tblName = mgrTableName mgr
, let mver = lookup (unRawSQL tblName) $ dbTablesWithVersions
, isJust mver ]
forM_ dbTablesToDropWithVersions $ \(tblName, fromVer, ver) ->
when (fromVer /= ver) $
validateMigrationsAgainstDB [(tblName, fromVer, ver)]
findMigrationsToRun :: [(Text, Int32)] -> [Migration m]
findMigrationsToRun dbTablesWithVersions =
let tableNamesToDrop = [ mgrTableName mgr | mgr <- migrations
, isDropTableMigration mgr ]
droppedEventually :: Migration m -> Bool
droppedEventually mgr = mgrTableName mgr `elem` tableNamesToDrop
lookupVer :: Migration m -> Maybe Int32
lookupVer mgr = lookup (unRawSQL $ mgrTableName mgr) dbTablesWithVersions
tableDoesNotExist = isNothing . lookupVer
migrationsToRun' = dropWhile
(\mgr ->
case lookupVer mgr of
Nothing -> not $
(mgrFrom mgr == 0) && (not . droppedEventually $ mgr)
Just ver -> not $
mgrFrom mgr >= ver)
migrations
l = length migrationsToRun'
initialMigrations = drop l $ reverse migrations
additionalMigrations = takeWhile
(\mgr -> droppedEventually mgr && tableDoesNotExist mgr)
initialMigrations
migrationsToRun = (reverse additionalMigrations) ++ migrationsToRun'
in migrationsToRun
runMigration :: (Migration m) -> m ()
runMigration Migration{..} = do
case mgrAction of
StandardMigration mgrDo -> do
logInfo_ $ arrListTable mgrTableName <> showt mgrFrom <+> "->"
<+> showt (succ mgrFrom)
mgrDo
runQuery_ $ sqlUpdate "table_versions" $ do
sqlSet "version" (succ mgrFrom)
sqlWhereEq "name" (T.unpack . unRawSQL $ mgrTableName)
DropTableMigration mgrDropTableMode -> do
logInfo_ $ arrListTable mgrTableName <> "drop table"
runQuery_ $ sqlDropTable mgrTableName
mgrDropTableMode
runQuery_ $ sqlDelete "table_versions" $ do
sqlWhereEq "name" (T.unpack . unRawSQL $ mgrTableName)
runMigrations :: [(Text, Int32)] -> m ()
runMigrations dbTablesWithVersions = do
let migrationsToRun = findMigrationsToRun dbTablesWithVersions
validateMigrationsToRun migrationsToRun dbTablesWithVersions
when (not . null $ migrationsToRun) $ do
logInfo_ "Running migrations..."
forM_ migrationsToRun $ \mgr -> do
runMigration mgr
when (eoForceCommit options) $ do
logInfo_ $ "Forcing commit after migraton"
<> " and starting new transaction..."
commit
begin
logInfo_ $ "Forcing commit after migraton"
<> " and starting new transaction... done."
logInfo_ "!IMPORTANT! Database has been permanently changed"
logInfo_ "Running migrations... done."
validateMigrationsToRun :: [Migration m] -> [(Text, Int32)] -> m ()
validateMigrationsToRun migrationsToRun dbTablesWithVersions = do
let migrationsToRunGrouped :: [[Migration m]]
migrationsToRunGrouped =
L.groupBy ((==) `on` mgrTableName) .
L.sortBy (comparing mgrTableName) $
migrationsToRun
loc_common = "Database.PostgreSQL.PQTypes.Checks."
++ "checkDBConsistency.validateMigrationsToRun"
lookupDBTableVer :: [Migration m] -> Maybe Int32
lookupDBTableVer mgrGroup =
lookup (unRawSQL . mgrTableName . headExc head_err
$ mgrGroup) dbTablesWithVersions
where
head_err = loc_common ++ ".lookupDBTableVer: broken invariant"
groupsWithWrongDBTableVersions :: [([Migration m], Int32)]
groupsWithWrongDBTableVersions =
[ (mgrGroup, dbTableVer)
| mgrGroup <- migrationsToRunGrouped
, let dbTableVer = fromMaybe 0 $ lookupDBTableVer mgrGroup
, dbTableVer /= (mgrFrom . headExc head_err $ mgrGroup)
]
where
head_err = loc_common
++ ".groupsWithWrongDBTableVersions: broken invariant"
mgrGroupsNotInDB :: [[Migration m]]
mgrGroupsNotInDB =
[ mgrGroup
| mgrGroup <- migrationsToRunGrouped
, isNothing $ lookupDBTableVer mgrGroup
]
groupsStartingWithDropTable :: [[Migration m]]
groupsStartingWithDropTable =
[ mgrGroup
| mgrGroup <- mgrGroupsNotInDB
, isDropTableMigration . headExc head_err $ mgrGroup
]
where
head_err = loc_common
++ ".groupsStartingWithDropTable: broken invariant"
groupsNotStartingWithCreateTable :: [[Migration m]]
groupsNotStartingWithCreateTable =
[ mgrGroup
| mgrGroup <- mgrGroupsNotInDB
, mgrFrom (headExc head_err mgrGroup) /= 0
]
where
head_err = loc_common
++ ".groupsNotStartingWithCreateTable: broken invariant"
tblNames :: [[Migration m]] -> [RawSQL ()]
tblNames grps =
[ mgrTableName . headExc head_err $ grp | grp <- grps ]
where
head_err = loc_common ++ ".tblNames: broken invariant"
when (not . null $ groupsWithWrongDBTableVersions) $ do
let tnms = tblNames . map fst $ groupsWithWrongDBTableVersions
logAttention
("There are migration chains selected for execution "
<> "that expect a different starting table version number "
<> "from the one in the database. "
<> "This likely means that the order of migrations is wrong.")
$ object [ "tables" .= map unRawSQL tnms ]
errorInvalidMigrations tnms
when (not . null $ groupsStartingWithDropTable) $ do
let tnms = tblNames groupsStartingWithDropTable
logAttention "There are drop table migrations for non-existing tables."
$ object [ "tables" .= map unRawSQL tnms ]
errorInvalidMigrations tnms
when (not . null $ groupsNotStartingWithCreateTable) $ do
let tnms = tblNames groupsNotStartingWithCreateTable
logAttention
("Some tables haven't been created yet, but" <>
"their migration lists don't start with a create table migration.")
$ object [ "tables" .= map unRawSQL tnms ]
errorInvalidMigrations tnms
getTableVersions :: (MonadDB m, MonadThrow m) => [Table] -> m [(Table, Int32)]
getTableVersions tbls =
sequence
[ (\mver -> (tbl, fromMaybe 0 mver)) <$> checkTableVersion (tblNameString tbl)
| tbl <- tbls ]
getDBTableVersions :: (MonadDB m, MonadThrow m) => m [(Text, Int32)]
getDBTableVersions = do
dbTableNames <- getDBTableNames
sequence
[ (\mver -> (name, fromMaybe 0 mver)) <$> checkTableVersion (T.unpack name)
| name <- dbTableNames ]
checkTableVersion :: (MonadDB m, MonadThrow m) => String -> m (Maybe Int32)
checkTableVersion tblName = do
doesExist <- runQuery01 . sqlSelect "pg_catalog.pg_class c" $ do
sqlResult "TRUE"
sqlLeftJoinOn "pg_catalog.pg_namespace n" "n.oid = c.relnamespace"
sqlWhereEq "c.relname" $ tblName
sqlWhere "pg_catalog.pg_table_is_visible(c.oid)"
if doesExist
then do
runQuery_ $ "SELECT version FROM table_versions WHERE name ="
<?> tblName
mver <- fetchMaybe runIdentity
case mver of
Just ver -> return $ Just ver
Nothing -> error $ "checkTableVersion: table '"
++ tblName
++ "' is present in the database, "
++ "but there is no corresponding version info in 'table_versions'."
else do
return Nothing
sqlGetTableID :: Table -> SQL
sqlGetTableID table = parenthesize . toSQLCommand $
sqlSelect "pg_catalog.pg_class c" $ do
sqlResult "c.oid"
sqlLeftJoinOn "pg_catalog.pg_namespace n" "n.oid = c.relnamespace"
sqlWhereEq "c.relname" $ tblNameString table
sqlWhere "pg_catalog.pg_table_is_visible(c.oid)"
sqlGetPrimaryKey :: (MonadDB m, MonadThrow m) => Table -> m (Maybe (PrimaryKey, RawSQL ()))
sqlGetPrimaryKey table = do
(mColumnNumbers :: Maybe [Int16]) <- do
runQuery_ . sqlSelect "pg_catalog.pg_constraint" $ do
sqlResult "conkey"
sqlWhereEqSql "conrelid" (sqlGetTableID table)
sqlWhereEq "contype" 'p'
fetchMaybe $ unArray1 . runIdentity
case mColumnNumbers of
Nothing -> do return Nothing
Just columnNumbers -> do
columnNames <- do
forM columnNumbers $ \k -> do
runQuery_ . sqlSelect "pk_columns" $ do
sqlWith "key_series" . sqlSelect "pg_constraint as c2" $ do
sqlResult "unnest(c2.conkey) as k"
sqlWhereEqSql "c2.conrelid" $ sqlGetTableID table
sqlWhereEq "c2.contype" 'p'
sqlWith "pk_columns" . sqlSelect "key_series" $ do
sqlJoinOn "pg_catalog.pg_attribute as a" "a.attnum = key_series.k"
sqlResult "a.attname::text as column_name"
sqlResult "key_series.k as column_order"
sqlWhereEqSql "a.attrelid" $ sqlGetTableID table
sqlResult "pk_columns.column_name"
sqlWhereEq "pk_columns.column_order" k
fetchOne (\(Identity t) -> t :: String)
runQuery_ . sqlSelect "pg_catalog.pg_constraint as c" $ do
sqlWhereEq "c.contype" 'p'
sqlWhereEqSql "c.conrelid" $ sqlGetTableID table
sqlResult "c.conname::text"
sqlResult $ Data.String.fromString ("array['" <> (mintercalate "', '" columnNames) <> "']::text[]")
join <$> fetchMaybe fetchPrimaryKey
fetchPrimaryKey :: (String, Array1 String) -> Maybe (PrimaryKey, RawSQL ())
fetchPrimaryKey (name, Array1 columns) = (, unsafeSQL name)
<$> (pkOnColumns $ map unsafeSQL columns)
sqlGetChecks :: Table -> SQL
sqlGetChecks table = toSQLCommand . sqlSelect "pg_catalog.pg_constraint c" $ do
sqlResult "c.conname::text"
sqlResult "regexp_replace(pg_get_constraintdef(c.oid, true), 'CHECK \\((.*)\\)', '\\1') AS body"
sqlWhereEq "c.contype" 'c'
sqlWhereEqSql "c.conrelid" $ sqlGetTableID table
fetchTableCheck :: (String, String) -> Check
fetchTableCheck (name, condition) = Check {
chkName = unsafeSQL name
, chkCondition = unsafeSQL condition
}
sqlGetIndexes :: Table -> SQL
sqlGetIndexes table = toSQLCommand . sqlSelect "pg_catalog.pg_class c" $ do
sqlResult "c.relname::text"
sqlResult $ "ARRAY(" <> selectCoordinates <> ")"
sqlResult "am.amname::text"
sqlResult "i.indisunique"
sqlResult "pg_catalog.pg_get_expr(i.indpred, i.indrelid, true)"
sqlJoinOn "pg_catalog.pg_index i" "c.oid = i.indexrelid"
sqlJoinOn "pg_catalog.pg_am am" "c.relam = am.oid"
sqlLeftJoinOn "pg_catalog.pg_constraint r"
"r.conrelid = i.indrelid AND r.conindid = i.indexrelid"
sqlWhereEqSql "i.indrelid" $ sqlGetTableID table
sqlWhereIsNULL "r.contype"
where
selectCoordinates = smconcat [
"WITH RECURSIVE coordinates(k, name) AS ("
, " VALUES (0, NULL)"
, " UNION ALL"
, " SELECT k+1, pg_catalog.pg_get_indexdef(i.indexrelid, k+1, true)"
, " FROM coordinates"
, " WHERE pg_catalog.pg_get_indexdef(i.indexrelid, k+1, true) != ''"
, ")"
, "SELECT name FROM coordinates WHERE k > 0"
]
fetchTableIndex :: (String, Array1 String, String, Bool, Maybe String)
-> (TableIndex, RawSQL ())
fetchTableIndex (name, Array1 columns, method, unique, mconstraint) = (TableIndex {
idxColumns = map unsafeSQL columns
, idxMethod = read method
, idxUnique = unique
, idxWhere = unsafeSQL `liftM` mconstraint
}, unsafeSQL name)
sqlGetForeignKeys :: Table -> SQL
sqlGetForeignKeys table = toSQLCommand
. sqlSelect "pg_catalog.pg_constraint r" $ do
sqlResult "r.conname::text"
sqlResult $
"ARRAY(SELECT a.attname::text FROM pg_catalog.pg_attribute a JOIN ("
<> unnestWithOrdinality "r.conkey"
<> ") conkeys ON (a.attnum = conkeys.item) WHERE a.attrelid = r.conrelid ORDER BY conkeys.n)"
sqlResult "c.relname::text"
sqlResult $ "ARRAY(SELECT a.attname::text FROM pg_catalog.pg_attribute a JOIN ("
<> unnestWithOrdinality "r.confkey"
<> ") confkeys ON (a.attnum = confkeys.item) WHERE a.attrelid = r.confrelid ORDER BY confkeys.n)"
sqlResult "r.confupdtype"
sqlResult "r.confdeltype"
sqlResult "r.condeferrable"
sqlResult "r.condeferred"
sqlJoinOn "pg_catalog.pg_class c" "c.oid = r.confrelid"
sqlWhereEqSql "r.conrelid" $ sqlGetTableID table
sqlWhereEq "r.contype" 'f'
where
unnestWithOrdinality :: RawSQL () -> SQL
unnestWithOrdinality arr =
"SELECT n, " <> raw arr
<> "[n] AS item FROM generate_subscripts(" <> raw arr <> ", 1) AS n"
fetchForeignKey ::
(String, Array1 String, String, Array1 String, Char, Char, Bool, Bool)
-> (ForeignKey, RawSQL ())
fetchForeignKey
( name, Array1 columns, reftable, Array1 refcolumns
, on_update, on_delete, deferrable, deferred ) = (ForeignKey {
fkColumns = map unsafeSQL columns
, fkRefTable = unsafeSQL reftable
, fkRefColumns = map unsafeSQL refcolumns
, fkOnUpdate = charToForeignKeyAction on_update
, fkOnDelete = charToForeignKeyAction on_delete
, fkDeferrable = deferrable
, fkDeferred = deferred
}, unsafeSQL name)
where
charToForeignKeyAction c = case c of
'a' -> ForeignKeyNoAction
'r' -> ForeignKeyRestrict
'c' -> ForeignKeyCascade
'n' -> ForeignKeySetNull
'd' -> ForeignKeySetDefault
_ -> error $ "fetchForeignKey: invalid foreign key action code: " ++ show c