{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}
module Database.Persist.Migration.Internal where
import Control.Monad (unless, when)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Reader (mapReaderT)
import Data.Data (Data)
import Data.List (nub)
import Data.Maybe (fromMaybe, isNothing)
import Data.Monoid ((<>))
import Data.Text (Text)
import qualified Data.Text as Text
import Data.Time.Clock (getCurrentTime)
import Database.Persist.Migration.Utils.Data (hasDuplicateConstrs)
import Database.Persist.Migration.Utils.Plan (getPath)
import Database.Persist.Sql
(PersistValue(..), Single(..), SqlPersistT, rawExecute, rawSql)
import Database.Persist.Types (SqlType(..))
type Version = Int
type OperationPath = (Version, Version)
(~>) :: Int -> Int -> OperationPath
(~>) = (,)
data Operation =
forall op. Migrateable op =>
Operation
{ opPath :: OperationPath
, opOp :: op
}
deriving instance Show Operation
type Migration = [Operation]
data MigrateBackend = MigrateBackend
{ createTable :: Bool -> CreateTable -> SqlPersistT IO [Text]
, dropTable :: DropTable -> SqlPersistT IO [Text]
, addColumn :: AddColumn -> SqlPersistT IO [Text]
, dropColumn :: DropColumn -> SqlPersistT IO [Text]
}
class Show op => Migrateable op where
validateOperation :: op -> Either String ()
validateOperation _ = Right ()
getMigrationText :: MigrateBackend -> op -> SqlPersistT IO [Text]
getCurrVersion :: MonadIO m => MigrateBackend -> SqlPersistT m (Maybe Version)
getCurrVersion backend = do
mapReaderT liftIO (createTable backend True migrationSchema) >>= rawExecute'
extractVersion <$> rawSql queryVersion []
where
migrationSchema = CreateTable
{ ctName = "persistent_migration"
, ctSchema =
[ Column "id" SqlInt32 [NotNull, AutoIncrement]
, Column "version" SqlInt32 [NotNull]
, Column "label" SqlString []
, Column "timestamp" SqlDayTime [NotNull]
]
, ctConstraints =
[ PrimaryKey ["id"]
]
}
queryVersion = "SELECT version FROM persistent_migration ORDER BY timestamp DESC LIMIT 1"
extractVersion = \case
[] -> Nothing
[Single v] -> Just v
_ -> error "Invalid response from the database."
getMigratePlan :: Migration -> Maybe Version -> Either (Version, Version) Migration
getMigratePlan migration mVersion = case getPath edges start end of
Just path -> Right path
Nothing -> Left (start, end)
where
edges = map (\op@Operation{opPath} -> (opPath, op)) migration
start = fromMaybe (getFirstVersion migration) mVersion
end = getLatestVersion migration
getFirstVersion :: Migration -> Version
getFirstVersion = minimum . map (fst . opPath)
getLatestVersion :: Migration -> Version
getLatestVersion = maximum . map (snd . opPath)
newtype MigrateSettings = MigrateSettings
{ versionToLabel :: Version -> Maybe String
}
defaultSettings :: MigrateSettings
defaultSettings = MigrateSettings
{ versionToLabel = const Nothing
}
validateMigration :: Migration -> Either String ()
validateMigration migration = do
unless (allIncreasing opVersions) $
Left "Operation versions must be monotonically increasing"
when (hasDuplicates opVersions) $
Left "There may only be one operation per pair of versions"
where
opVersions = map opPath migration
allIncreasing = all (uncurry (<))
hasDuplicates l = length (nub l) < length l
runMigration :: MonadIO m => MigrateBackend -> MigrateSettings -> Migration -> SqlPersistT m ()
runMigration backend settings@MigrateSettings{..} migration = do
getMigration backend settings migration >>= rawExecute'
now <- liftIO getCurrentTime
let version = getLatestVersion migration
rawExecute "INSERT INTO persistent_migration(version, label, timestamp) VALUES (?, ?, ?)"
[ PersistInt64 $ fromIntegral version
, PersistText $ Text.pack $ fromMaybe (show version) $ versionToLabel version
, PersistUTCTime now
]
getMigration :: MonadIO m => MigrateBackend -> MigrateSettings -> Migration -> SqlPersistT m [Text]
getMigration backend _ migration = do
either fail return $ validateMigration migration
either fail return $ mapM_ (\Operation{opOp} -> validateOperation opOp) migration
currVersion <- getCurrVersion backend
migratePlan <- either badPath return $ getMigratePlan migration currVersion
concatMapM getMigrationText' migratePlan
where
badPath (start, end) = fail $ "Could not find path: " ++ show start ++ " ~> " ++ show end
concatMapM f = fmap concat . mapM f
getMigrationText' Operation{opOp} = mapReaderT liftIO $ getMigrationText backend opOp
rawExecute' :: MonadIO m => [Text] -> SqlPersistT m ()
rawExecute' = mapM_ $ \s -> rawExecute s []
data CreateTable = CreateTable
{ ctName :: Text
, ctSchema :: [Column]
, ctConstraints :: [TableConstraint]
} deriving (Show)
instance Migrateable CreateTable where
validateOperation ct@CreateTable{..} = do
mapM_ validateColumn ctSchema
when (hasDuplicateConstrs ctConstraints) $
Left $ "Duplicate table constraints detected: " ++ show ct
let constraintCols = concatMap getConstraintColumns ctConstraints
schemaCols = map colName ctSchema
when (any (`notElem` schemaCols) constraintCols) $
Left $ "Table constraint references non-existent column: " ++ show ct
getMigrationText backend = createTable backend False
newtype DropTable = DropTable
{ dtName :: Text
}
deriving (Show)
instance Migrateable DropTable where
getMigrationText = dropTable
data AddColumn = AddColumn
{ acTable :: Text
, acColumn :: Column
, acDefault :: Maybe Text
} deriving (Show)
instance Migrateable AddColumn where
validateOperation ac@AddColumn{..} = do
validateColumn acColumn
when (NotNull `elem` colProps acColumn && isNothing acDefault) $
Left $ "Adding a non-nullable column requires a default: " ++ show ac
getMigrationText = addColumn
newtype DropColumn = DropColumn
{ dcColumn :: ColumnIdentifier
} deriving (Show)
instance Migrateable DropColumn where
getMigrationText = dropColumn
data RawOperation = RawOperation
{ message :: Text
, rawOp :: SqlPersistT IO [Text]
}
instance Show RawOperation where
show RawOperation{message} = "RawOperation: " ++ Text.unpack message
instance Migrateable RawOperation where
getMigrationText _ RawOperation{rawOp} = rawOp
data NoOp = NoOp
deriving (Show)
instance Migrateable NoOp where
getMigrationText _ _ = return []
type ColumnIdentifier = (Text, Text)
dotted :: ColumnIdentifier -> Text
dotted (tab, col) = tab <> "." <> col
data Column = Column
{ colName :: Text
, colType :: SqlType
, colProps :: [ColumnProp]
} deriving (Show)
validateColumn :: Column -> Either String ()
validateColumn col@Column{..} = when (hasDuplicateConstrs colProps) $
Left $ "Duplicate column properties detected: " ++ show col
data ColumnProp
= NotNull
| References ColumnIdentifier
| AutoIncrement
deriving (Show,Eq,Data)
data TableConstraint
= PrimaryKey [Text]
| Unique Text [Text]
deriving (Show,Data)
getConstraintColumns :: TableConstraint -> [Text]
getConstraintColumns = \case
PrimaryKey cols -> cols
Unique _ cols -> cols