module Language.SMTLib2.Internals.Monad where

import Language.SMTLib2.Internals.Backend as B
import Language.SMTLib2.Internals.Type
import Language.SMTLib2.Internals.Type.List (List(..))
import qualified Language.SMTLib2.Internals.Type.List as List

import Control.Monad.State.Strict
import Data.Typeable
import Data.GADT.Compare
import Data.GADT.Show
import Data.Dependent.Map (DMap)
import qualified Data.Dependent.Map as Map
import Control.Exception (onException)
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative
#endif

newtype Backend b => SMT b a = SMT { runSMT :: StateT (SMTState b) (SMTMonad b) a }

data SMTState b = SMTState { backend :: !b
                           , datatypes :: !(DatatypeInfo (B.Constr b) (B.Field b)) }

instance Backend b => Functor (SMT b) where
  fmap f (SMT act) = SMT (fmap f act)

instance Backend b => Applicative (SMT b) where
  pure x = SMT (pure x)
  (<*>) (SMT fun) (SMT arg) = SMT (fun <*> arg)

instance Backend b => Monad (SMT b) where
  (>>=) (SMT act) app = SMT (act >>= (\res -> case app res of
                                                  SMT p -> p))

instance Backend b => MonadState (SMTState b) (SMT b) where
  get = SMT get
  put x = SMT (put x)
  state act = SMT (state act)

instance (Backend b,MonadIO (SMTMonad b)) => MonadIO (SMT b) where
  liftIO act = SMT (liftIO act)

withBackend :: Backend b => SMTMonad b b
            -> SMT b a
            -> SMTMonad b a
withBackend constr act = do
  b <- constr
  (res,nb) <- runStateT (runSMT act) (SMTState b emptyDatatypeInfo)
  exit (backend nb)
  return res

withBackendExitCleanly :: (Backend b,SMTMonad b ~ IO) => IO b -> SMT b a -> IO a
withBackendExitCleanly constr (SMT act) = do
  b <- constr
  (do
      (res,nb) <- runStateT act (SMTState b emptyDatatypeInfo)
      exit (backend nb)
      return res) `onException` (exit b)

liftSMT :: Backend b => SMTMonad b a -> SMT b a
liftSMT act = SMT (lift act)

embedSMT :: Backend b => (b -> SMTMonad b (a,b)) -> SMT b a
embedSMT act = SMT $ do
  b <- get
  (res,nb) <- lift $ act (backend b)
  put (b { backend = nb })
  return res

embedSMT' :: Backend b => (b -> SMTMonad b b) -> SMT b ()
embedSMT' act = SMT $ do
  b <- get
  nb <- lift $ act (backend b)
  put (b { backend = nb })

data DTProxy dt where
  DTProxy :: IsDatatype dt => DTProxy dt

instance GEq DTProxy where
  geq DTProxy DTProxy = eqT

instance GCompare DTProxy where
  gcompare x@(DTProxy::DTProxy a) y@(DTProxy::DTProxy b) = case (eqT :: Maybe (a :~: b)) of
    Just Refl -> GEQ
    Nothing -> case compare (typeRep x) (typeRep y) of
      LT -> GLT
      GT -> GGT

instance GShow DTProxy where
  gshowsPrec p pr@DTProxy = showsPrec p (typeRep pr)

instance Show (DTProxy dt) where
  showsPrec = gshowsPrec

type DatatypeInfo con field = DMap DTProxy (RegisteredDT con field)

newtype RegisteredDT con field dt
  = RegisteredDT (B.BackendDatatype con field '(DatatypeSig dt,dt))
  deriving (Typeable)

emptyDatatypeInfo :: DatatypeInfo con field
emptyDatatypeInfo = Map.empty

reproxyDT :: IsDatatype dt => Proxy dt -> DTProxy dt
reproxyDT _ = DTProxy

registerDatatype :: (Backend b,IsDatatype dt) => Proxy dt -> SMT b ()
registerDatatype pr = do
  st <- get
  if Map.member (reproxyDT pr) (datatypes st)
    then return ()
    else do
      (dts,nb) <- liftSMT $ B.declareDatatypes (getTypeCollection pr) (backend st)
      put $ st { backend = nb
               , datatypes = insertTypes dts (datatypes st) }
  where
    insertTypes :: B.BackendTypeCollection con field sigs -> DatatypeInfo con field -> DatatypeInfo con field
    insertTypes NoDts mp = mp
    insertTypes (ConsDts (dt::B.BackendDatatype con field '(DatatypeSig dt,dt)) dts) mp
      = let nmp =  Map.insert (DTProxy::DTProxy dt) (RegisteredDT dt) mp
         in insertTypes dts nmp

lookupDatatype :: DTProxy dt -> DatatypeInfo con field
               -> B.BackendDatatype con field '(DatatypeSig dt,dt)
lookupDatatype pr dts = case Map.lookup pr dts of
  Just (RegisteredDT dt) -> dt
  Nothing -> error $ "smtlib2: Datatype "++show pr++" is not registered."

lookupConstructor :: String -> B.BackendDatatype con field '(DatatypeSig dt,dt)
                  -> (forall arg. B.BackendConstr con field '(arg,dt) -> a)
                  -> a
lookupConstructor name dt f = lookup (bconstructors dt) f
  where
    lookup :: Constrs (B.BackendConstr con field) sigs dt
           -> (forall arg. B.BackendConstr con field '(arg,dt) -> a)
           -> a
    lookup NoCon _ = error $ "smtlib2: "++name++" is not a constructor."
    lookup (ConsCon con cons) f = if bconName con==name
                                  then f con
                                  else lookup cons f

constructDatatype :: GEq con => con '(arg,ret)
                  -> List ConcreteValue arg
                  -> B.BackendDatatype con field '(cons,ret)
                  -> ret
constructDatatype con args dt = get con args (bconstructors dt)
  where
    get :: GEq con => con '(arg,ret) -> List ConcreteValue arg
        -> Constrs (BackendConstr con field) sigs ret -> ret
    get con args (ConsCon x xs)
      = case geq con (bconRepr x) of
      Just Refl -> bconstruct x args
      Nothing -> get con args xs

lookupField :: String -> B.BackendConstr con field '(arg,dt)
            -> (forall tp. B.BackendField field dt tp -> a)
            -> a
lookupField name con f = lookup (bconFields con) f
  where
    lookup :: List (B.BackendField field dt) arg
           -> (forall tp. B.BackendField field dt tp -> a)
           -> a
    lookup Nil _ = error $ "smtlib2: "++name++" is not a field."
    lookup (x ::: xs) f = if bfieldName x==name
                          then f x
                          else lookup xs f

lookupDatatypeCon :: (IsDatatype dt,Typeable con,Typeable field)
                  => DTProxy dt -> String -> DatatypeInfo con field
                  -> (forall arg. B.BackendConstr con field '(arg,dt) -> a)
                  -> a
lookupDatatypeCon pr name info f
  = lookupConstructor name (lookupDatatype pr info) f

lookupDatatypeField :: (IsDatatype dt,Typeable con,Typeable field)
                  => DTProxy dt -> String -> String -> DatatypeInfo con field
                  -> (forall tp. B.BackendField field dt tp -> a)
                  -> a
lookupDatatypeField pr con field info f
  = lookupDatatypeCon pr con info $
    \con' -> lookupField field con' f

mkConcr :: B.Backend b => Value (B.Constr b) t -> SMT b (ConcreteValue t)
mkConcr (BoolValue v) = return (BoolValueC v)
mkConcr (IntValue v) = return (IntValueC v)
mkConcr (RealValue v) = return (RealValueC v)
mkConcr (BitVecValue v bw) = return (BitVecValueC v bw)
mkConcr (ConstrValue con args) = do
  args' <- List.mapM mkConcr args
  st <- get
  return $ ConstrValueC $
    constructDatatype con args' $
    lookupDatatype DTProxy (datatypes st)

mkAbstr :: (B.Backend b) => ConcreteValue t -> SMT b (Value (B.Constr b) t)
mkAbstr (BoolValueC v) = return (BoolValue v)
mkAbstr (IntValueC v) = return (IntValue v)
mkAbstr (RealValueC v) = return (RealValue v)
mkAbstr (BitVecValueC v bw) = return (BitVecValue v bw)
mkAbstr (ConstrValueC v) = do
  st <- get
  getConstructor v (bconstructors $ lookupDatatype DTProxy (datatypes st)) $
    \con args -> do
      rargs <- List.mapM mkAbstr args
      return $ ConstrValue (bconRepr con) rargs

defineVar' :: (B.Backend b) => B.Expr b t -> SMT b (B.Var b t)
defineVar' e = embedSMT $ B.defineVar Nothing e

defineVarNamed' :: (B.Backend b) => String -> B.Expr b t -> SMT b (B.Var b t)
defineVarNamed' name e = embedSMT $ B.defineVar (Just name) e

declareVar' :: B.Backend b => Repr t -> SMT b (B.Var b t)
declareVar' tp = embedSMT $ B.declareVar tp Nothing

declareVarNamed' :: B.Backend b => Repr t -> String -> SMT b (B.Var b t)
declareVarNamed' tp name = embedSMT $ B.declareVar tp (Just name)