module Database.Beam.Migrate.Actions
(
DatabaseStateSource(..)
, DatabaseState(..)
, PotentialAction(..)
, ActionProvider(..)
, ActionProviderFn
, ensuringNot_
, justOne_
, createTableActionProvider
, dropTableActionProvider
, addColumnProvider
, addColumnNullProvider
, dropColumnNullProvider
, defaultActionProvider
, Solver(..), FinalSolution(..)
, finalSolution
, heuristicSolver
) where
import Database.Beam.Migrate.Types
import Database.Beam.Migrate.Checks
import Database.Beam.Migrate.SQL
import Control.Applicative
import Control.DeepSeq
import Control.Monad
import Control.Parallel.Strategies
import Data.Foldable
import qualified Data.HashMap.Strict as HM
import qualified Data.HashSet as HS
import Data.Monoid
import qualified Data.PQueue.Min as PQ
import qualified Data.Sequence as Seq
import Data.Text (Text)
import qualified Data.Text as T
import Data.Typeable
import GHC.Generics
data DatabaseStateSource
= DatabaseStateSourceOriginal
| DatabaseStateSourceDerived
deriving (Show, Eq, Ord, Enum, Bounded, Generic)
instance NFData DatabaseStateSource
data DatabaseState cmd
= DatabaseState
{ dbStateCurrentState :: !(HM.HashMap SomeDatabasePredicate DatabaseStateSource)
, dbStateKey :: !(HS.HashSet SomeDatabasePredicate)
, dbStateCmdSequence :: !(Seq.Seq cmd)
} deriving Show
instance NFData (DatabaseState cmd) where
rnf d@DatabaseState {..} = d `seq` ()
data MeasuredDatabaseState cmd
= MeasuredDatabaseState !Int !Int (DatabaseState cmd)
deriving (Show, Generic)
instance NFData (MeasuredDatabaseState cmd)
instance Eq (MeasuredDatabaseState cmd) where
a == b = measure a == measure b
instance Ord (MeasuredDatabaseState cmd) where
compare a b = compare (measure a) (measure b)
measure :: MeasuredDatabaseState cmd -> Int
measure (MeasuredDatabaseState cmdLength estGoalDistance _) = cmdLength + 100 * estGoalDistance
measuredDbState :: MeasuredDatabaseState cmd -> DatabaseState cmd
measuredDbState (MeasuredDatabaseState _ _ s) = s
measureDb' :: HS.HashSet SomeDatabasePredicate
-> HS.HashSet SomeDatabasePredicate
-> Int
-> DatabaseState cmd
-> MeasuredDatabaseState cmd
measureDb' _ post cmdLength st@(DatabaseState _ repr _) =
MeasuredDatabaseState cmdLength distToGoal st
where
distToGoal = HS.size ((repr `HS.difference` post) `HS.union`
(post `HS.difference` repr))
data PotentialAction cmd
= PotentialAction
{ actionPreConditions :: !(HS.HashSet SomeDatabasePredicate)
, actionPostConditions :: !(HS.HashSet SomeDatabasePredicate)
, actionCommands :: !(Seq.Seq cmd)
, actionEnglish :: !Text
, actionScore :: !Int
}
instance Monoid (PotentialAction cmd) where
mempty = PotentialAction mempty mempty mempty "" 0
mappend a b =
PotentialAction (actionPreConditions a <> actionPreConditions b)
(actionPostConditions a <> actionPostConditions b)
(actionCommands a <> actionCommands b)
(if T.null (actionEnglish a) then actionEnglish b
else if T.null (actionEnglish b) then actionEnglish a
else actionEnglish a <> "; " <> actionEnglish b)
(actionScore a + actionScore b)
type ActionProviderFn cmd =
(forall preCondition. Typeable preCondition => [ preCondition ])
-> (forall postCondition. Typeable postCondition => [ postCondition ])
-> [ PotentialAction cmd ]
newtype ActionProvider cmd
= ActionProvider { getPotentialActions :: ActionProviderFn cmd }
instance Monoid (ActionProvider cmd) where
mempty = ActionProvider (\_ _ -> [])
mappend (ActionProvider a) (ActionProvider b) =
ActionProvider $ \pre post ->
let aRes = a pre post
bRes = b pre post
in withStrategy (rparWith (parList rseq)) aRes `seq`
withStrategy (rparWith (parList rseq)) bRes `seq`
aRes ++ bRes
createTableWeight, dropTableWeight, addColumnWeight, dropColumnWeight :: Int
createTableWeight = 500
dropTableWeight = 100
addColumnWeight = 1
dropColumnWeight = 1
ensuringNot_ :: Alternative m => [ a ] -> m ()
ensuringNot_ [] = pure ()
ensuringNot_ _ = empty
justOne_ :: [ a ] -> [ a ]
justOne_ [x] = [x]
justOne_ _ = []
createTableActionProvider :: forall cmd
. ( Sql92SaneDdlCommandSyntax cmd
, Sql92SerializableDataTypeSyntax (Sql92DdlCommandDataTypeSyntax cmd) )
=> ActionProvider cmd
createTableActionProvider =
ActionProvider provider
where
provider :: ActionProviderFn cmd
provider findPreConditions findPostConditions =
do tblP@(TableExistsPredicate postTblNm) <- findPostConditions
ensuringNot_ $
do TableExistsPredicate preTblNm <- findPreConditions
guard (preTblNm == postTblNm)
(columnsP, columns) <- pure . unzip $
do columnP@
(TableHasColumn tblNm colNm schema
:: TableHasColumn (Sql92DdlCommandColumnSchemaSyntax cmd)) <-
findPostConditions
guard (tblNm == postTblNm)
(constraintsP, constraints) <-
pure . unzip $ do
constraintP@
(TableColumnHasConstraint tblNm' colNm' c
:: TableColumnHasConstraint (Sql92DdlCommandColumnSchemaSyntax cmd)) <-
findPostConditions
guard (postTblNm == tblNm')
guard (colNm == colNm')
pure (p constraintP, c)
pure (p columnP:constraintsP, (colNm, schema, constraints))
(primaryKeyP, primaryKey) <- justOne_ $ do
primaryKeyP@(TableHasPrimaryKey tblNm primaryKey) <-
findPostConditions
guard (tblNm == postTblNm)
pure (primaryKeyP, primaryKey)
let postConditions = [ p tblP, p primaryKeyP ] ++ concat columnsP
cmd = createTableCmd (createTableSyntax Nothing postTblNm colsSyntax tblConstraints)
tblConstraints = [ primaryKeyConstraintSyntax primaryKey ]
colsSyntax = map (\(colNm, type_, cs) -> (colNm, columnSchemaSyntax type_ Nothing cs Nothing)) columns
pure (PotentialAction mempty (HS.fromList postConditions) (Seq.singleton cmd) ("Create the table " <> postTblNm) createTableWeight)
dropTableActionProvider :: forall cmd
. ( Sql92SaneDdlCommandSyntax cmd
, Sql92SerializableDataTypeSyntax (Sql92DdlCommandDataTypeSyntax cmd) )
=> ActionProvider cmd
dropTableActionProvider =
ActionProvider provider
where
provider :: ActionProviderFn cmd
provider findPreConditions findPostConditions =
do tblP@(TableExistsPredicate preTblNm) <- findPreConditions
ensuringNot_ $
do TableExistsPredicate postTblNm <- findPostConditions
guard (preTblNm == postTblNm)
relatedPreds <-
pure $ do p'@(SomeDatabasePredicate pred') <- findPreConditions
guard (pred' `predicateCascadesDropOn` tblP)
pure p'
let cmd = dropTableCmd (dropTableSyntax preTblNm)
pure (
PotentialAction (HS.fromList (SomeDatabasePredicate tblP:relatedPreds)) mempty (Seq.singleton cmd) ("Drop table " <> preTblNm) dropTableWeight)
addColumnProvider :: forall cmd
. ( Sql92SaneDdlCommandSyntax cmd
, Sql92SerializableDataTypeSyntax (Sql92DdlCommandDataTypeSyntax cmd) )
=> ActionProvider cmd
addColumnProvider =
ActionProvider provider
where
provider :: ActionProviderFn cmd
provider findPreConditions findPostConditions =
do colP@(TableHasColumn tblNm colNm colType :: TableHasColumn (Sql92DdlCommandColumnSchemaSyntax cmd))
<- findPostConditions
TableExistsPredicate tblNm' <- findPreConditions
guard (tblNm' == tblNm)
ensuringNot_ $ do
TableHasColumn tblNm'' colNm' _ :: TableHasColumn (Sql92DdlCommandColumnSchemaSyntax cmd) <-
findPreConditions
guard (tblNm'' == tblNm && colNm == colNm')
let cmd = alterTableCmd (alterTableSyntax tblNm (addColumnSyntax colNm schema))
schema = columnSchemaSyntax colType Nothing [] Nothing
pure (PotentialAction mempty (HS.fromList [SomeDatabasePredicate colP])
(Seq.singleton cmd)
("Add column " <> colNm <> " to " <> tblNm)
(addColumnWeight + fromIntegral (T.length tblNm + T.length colNm)))
dropColumnProvider :: forall cmd
. ( Sql92SaneDdlCommandSyntax cmd
, Sql92SerializableDataTypeSyntax (Sql92DdlCommandDataTypeSyntax cmd) )
=> ActionProvider cmd
dropColumnProvider = ActionProvider provider
where
provider :: ActionProviderFn cmd
provider findPreConditions _ =
do colP@(TableHasColumn tblNm colNm _ :: TableHasColumn (Sql92DdlCommandColumnSchemaSyntax cmd))
<- findPreConditions
relatedPreds <-
pure $ do p'@(SomeDatabasePredicate pred') <- findPreConditions
guard (pred' `predicateCascadesDropOn` colP)
pure p'
let cmd = alterTableCmd (alterTableSyntax tblNm (dropColumnSyntax colNm))
pure (PotentialAction (HS.fromList (SomeDatabasePredicate colP:relatedPreds)) mempty
(Seq.singleton cmd)
("Drop column " <> colNm <> " from " <> tblNm)
(dropColumnWeight + fromIntegral (T.length tblNm + T.length colNm)))
addColumnNullProvider :: forall cmd
. Sql92SaneDdlCommandSyntax cmd
=> ActionProvider cmd
addColumnNullProvider = ActionProvider provider
where
provider :: ActionProviderFn cmd
provider findPreConditions findPostConditions =
do colP@(TableColumnHasConstraint tblNm colNm _ :: TableColumnHasConstraint (Sql92DdlCommandColumnSchemaSyntax cmd))
<- findPostConditions
TableExistsPredicate tblNm' <- findPreConditions
guard (tblNm == tblNm')
TableHasColumn tblNm'' colNm' _ :: TableHasColumn (Sql92DdlCommandColumnSchemaSyntax cmd) <- findPreConditions
guard (tblNm == tblNm'' && colNm == colNm')
let cmd = alterTableCmd (alterTableSyntax tblNm (alterColumnSyntax colNm setNotNullSyntax))
pure (PotentialAction mempty (HS.fromList [SomeDatabasePredicate colP]) (Seq.singleton cmd)
("Add not null constraint to " <> colNm <> " on " <> tblNm) 100)
dropColumnNullProvider :: forall cmd
. Sql92SaneDdlCommandSyntax cmd
=> ActionProvider cmd
dropColumnNullProvider = ActionProvider provider
where
provider :: ActionProviderFn cmd
provider findPreConditions _ =
do colP@(TableColumnHasConstraint tblNm colNm _ :: TableColumnHasConstraint (Sql92DdlCommandColumnSchemaSyntax cmd))
<- findPreConditions
TableExistsPredicate tblNm' <- findPreConditions
guard (tblNm == tblNm')
TableHasColumn tblNm'' colNm' _ :: TableHasColumn (Sql92DdlCommandColumnSchemaSyntax cmd) <- findPreConditions
guard (tblNm == tblNm'' && colNm == colNm')
let cmd = alterTableCmd (alterTableSyntax tblNm (alterColumnSyntax colNm setNullSyntax))
pure (PotentialAction (HS.fromList [SomeDatabasePredicate colP]) mempty (Seq.singleton cmd)
("Drop not null constraint for " <> colNm <> " on " <> tblNm) 100)
defaultActionProvider :: ( Sql92SaneDdlCommandSyntax cmd
, Sql92SerializableDataTypeSyntax (Sql92DdlCommandDataTypeSyntax cmd) )
=> ActionProvider cmd
defaultActionProvider =
mconcat
[ createTableActionProvider
, dropTableActionProvider
, addColumnProvider
, dropColumnProvider
, addColumnNullProvider
, dropColumnNullProvider ]
data Solver cmd where
ProvideSolution :: [cmd] -> Solver cmd
SearchFailed :: [ DatabaseState cmd ] -> Solver cmd
ChooseActions :: { choosingActionsAtState :: !(DatabaseState cmd)
, getPotentialActionChoice :: f -> PotentialAction cmd
, potentialActionChoices :: [ f ]
, continueSearch :: [ f ] -> Solver cmd
} -> Solver cmd
data FinalSolution cmd
= Solved [ cmd ]
| Candidates [ DatabaseState cmd ]
deriving Show
solvedState :: HS.HashSet SomeDatabasePredicate -> DatabaseState cmd -> Bool
solvedState goal (DatabaseState _ cur _) = goal == cur
finalSolution :: Solver cmd -> FinalSolution cmd
finalSolution (SearchFailed sts) = Candidates sts
finalSolution (ProvideSolution cmds) = Solved cmds
finalSolution (ChooseActions _ _ actions next) =
finalSolution (next actions)
heuristicSolver :: ActionProvider cmd
-> [ SomeDatabasePredicate ]
-> [ SomeDatabasePredicate ]
-> Solver cmd
heuristicSolver provider preConditionsL postConditionsL =
heuristicSolver' initQueue mempty PQ.empty
where
rejectedCount = 10
postConditions = HS.fromList postConditionsL
preConditions = HS.fromList preConditionsL
allToFalsify = preConditions `HS.difference` postConditions
measureDb = measureDb' allToFalsify postConditions
initQueue = PQ.singleton (measureDb 0 initDbState)
initDbState = DatabaseState (DatabaseStateSourceOriginal <$ HS.toMap preConditions)
preConditions
mempty
findPredicate :: forall predicate. Typeable predicate
=> SomeDatabasePredicate
-> [ predicate ] -> [ predicate ]
findPredicate
| Just (Refl :: predicate :~: SomeDatabasePredicate) <- eqT =
(:)
| otherwise =
\(SomeDatabasePredicate pred') ps ->
maybe ps (:ps) (cast pred')
findPredicates :: forall predicate f. (Typeable predicate, Foldable f)
=> f SomeDatabasePredicate -> [ predicate ]
findPredicates = foldr findPredicate []
heuristicSolver' !q !visited !bestRejected =
case PQ.minView q of
Nothing -> SearchFailed (measuredDbState <$> PQ.toList bestRejected)
Just (mdbState@(MeasuredDatabaseState _ _ dbState), q')
| dbStateKey dbState `HS.member` visited -> heuristicSolver' q' visited bestRejected
| solvedState postConditions (measuredDbState mdbState) ->
ProvideSolution (toList (dbStateCmdSequence dbState))
| otherwise ->
let steps = getPotentialActions
provider
(findPredicates (dbStateKey dbState))
(findPredicates postConditionsL)
steps' = filter (not . (`HS.member` visited) . dbStateKey . measuredDbState . snd) $
withStrategy (parList rseq) $
map (\step -> let dbState' = applyStep step mdbState
in dbState' `seq` (step, dbState')) steps
applyStep step (MeasuredDatabaseState score _ dbState') =
let dbState'' = dbStateAfterAction dbState' step
in measureDb (score + 1) dbState''
in case steps' of
[] -> heuristicSolver' q' visited (reject mdbState bestRejected)
_ -> ChooseActions dbState fst steps' $ \chosenSteps ->
let q'' = foldr (\(_, dbState') -> PQ.insert dbState') q' chosenSteps
visited' = HS.insert (dbStateKey dbState) visited
in withStrategy (rparWith rseq) q'' `seq` heuristicSolver' q'' visited' bestRejected
reject :: MeasuredDatabaseState cmd -> PQ.MinQueue (MeasuredDatabaseState cmd)
-> PQ.MinQueue (MeasuredDatabaseState cmd)
reject mdbState q =
let q' = PQ.insert mdbState q
in PQ.fromAscList (PQ.take rejectedCount q')
dbStateAfterAction (DatabaseState curState _ cmds) action =
let curState' = ((curState `HM.difference` HS.toMap (actionPreConditions action))
`HM.union` (DatabaseStateSourceDerived <$ HS.toMap (actionPostConditions action)))
in DatabaseState curState' (HS.fromMap (() <$ curState'))
(cmds <> actionCommands action)