{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
module Type.InstanceMap.TH (
mkMap,
mkMapWithOpts,
defaultOptions,
ClassName, InputTypeName, OutputWrapperName,
Some(..),
Options(..)) where
import Data.Typeable
import qualified Data.Set as S
import Data.List (nub, intersect)
import qualified Data.Map.Strict as M
import Control.Monad.State.Strict (StateT(..), put, get, evalStateT, modify, withStateT)
import Control.Monad.Trans (lift)
import Control.Monad.IO.Class (liftIO)
import Data.Traversable (for)
import qualified Debug.Trace as DBG
import Control.Monad (join, foldM, when)
import Language.Haskell.TH
import GHC.Exts
data family Some (c :: * -> Constraint)
data Options = Options {
maxDepth :: Int,
verbose :: Bool,
witnessGenerator :: ExpQ,
witnessTypeName :: Name
}
defaultOptions = Options 2 False [|typeRep|] ''TypeRep
type ConcreteType = Type
type InstanceHead = Type
type ClassName = Name
type InputTypeName = Name
type OutputWrapperName = Name
data InstanceTraversalState = InstanceTraversalState {
classesInProgress :: [Name],
classesDone :: M.Map Name [ConcreteType]
}
initialTraversalState = InstanceTraversalState [] M.empty
type StateQ a = StateT InstanceTraversalState Q a
getInstances :: Name -> Q [InstanceDec]
getInstances typ = do
ClassI _ instances <- reify typ
return instances
showInstances :: Name -> Q Exp
showInstances typ = do
ins <- getInstances typ
return . LitE . stringL $ show ins
mkMap :: ClassName -> InputTypeName -> OutputWrapperName -> ExpQ -> Q [Dec]
mkMap = mkMapWithOpts defaultOptions
mkMapWithOpts :: Options -> ClassName -> InputTypeName -> OutputWrapperName -> ExpQ -> Q [Dec]
mkMapWithOpts opts className inType outWrap fExp = do
typs <- knownConcreteInstances opts className
lst <- mapM mkExp typs
let witnesses = traverse mkWitness typs
typMap <- [|M.fromList $ zip ($(ListE <$> witnesses)) $(return $ ListE lst)|]
decoderE <- [| case M.lookup trep $(varE mapName) of
Nothing -> fail $ "No instance found for " ++ show trep
Just f -> f v |]
a <- newName "a"
let dataInstD = DataInstD []
''Some
[ConT className]
Nothing
[ForallC [PlainTV a]
[AppT (ConT ''Typeable) (VarT a),AppT (ConT className) (VarT a)]
(NormalC someName [(Bang NoSourceUnpackedness NoSourceStrictness,
(VarT a))])]
[]
mapDefinitionD = ValD (VarP mapName) (NormalB typMap) []
decoderSigD = SigD getterName (AppT (AppT ArrowT (ConT (witnessTypeName opts)))
(AppT (AppT ArrowT (ConT inType))
(AppT (ConT outWrap)
(AppT (ConT ''Some) (ConT className)))))
decoderD = FunD getterName [Clause [VarP (mkName "trep"), VarP (mkName "v")] (NormalB decoderE) []]
return $ [dataInstD, mapDefinitionD, decoderSigD, decoderD]
where
mkExp :: Type -> Q Exp
mkExp t = [| fmap $(conE someName) . ($(fExp) :: $(conT inType) -> $(appT (conT outWrap) (return t))) |]
someName = mkName ("Some" ++ nameBase className)
mapName = mkName ("mapOf" ++ nameBase className)
getterName = mkName ("getSome" ++ nameBase className)
mkWitness :: ConcreteType -> ExpQ
mkWitness t = [|$(witnessGenerator opts) (Proxy :: Proxy $(return t))|]
knownConcreteInstances :: Options -> Name -> Q [ConcreteType]
knownConcreteInstances opts className = evalStateT (knownConcreteInstances' opts className) initialTraversalState
knownConcreteInstances' :: Options -> Name -> StateQ [ConcreteType]
knownConcreteInstances' opts className = do
InstanceTraversalState {..} <- get
case M.lookup className classesDone of
Just types -> return types
Nothing ->
if length classesInProgress > maxDepth opts
then warn opts ("Cutting off recursion at " ++ show className) >> return []
else do
info opts $ "Looking for instances of " ++ show className
trInsts <- lift $ getInstances className :: StateQ [InstanceDec]
concreteInsts <- for trInsts $ \(InstanceD _ ctx (AppT _ head) _) -> do
s <- get
lift $ evalStateT (deepReplaceVars opts head ctx)
(s { classesInProgress = className : classesInProgress })
let retVal = join concreteInsts
info opts $ "Returning instances for " ++ show className ++ ": " ++ show retVal
modify $ \s -> s { classesDone = M.insert className retVal classesDone }
return retVal
deepReplaceVars :: Options -> InstanceHead -> Cxt -> StateQ [ConcreteType]
deepReplaceVars opts t constraints
| monomorphic t = return [t]
| otherwise = do
InstanceTraversalState {..} <- get
if (not (all univariate constraints))
then do
warn opts $ "Only simple univariate constraints (like 'C a') are supported (skipping "
++ show t ++ " due to " ++ show (filter univariate constraints) ++ ")"
return []
else do
constraintCandidates <- traverse getVarsAndCandidates constraints :: StateQ [(Name, [ConcreteType])]
let mc :: M.Map Name [[Type]]
mc = M.fromListWith (flip mappend) $ fmap (\(n, typs) -> (n, [typs])) constraintCandidates
possibleVals = foldl1 intersect <$> mc
return $ allReplacements possibleVals t
where
univariate (AppT (ConT cls) (VarT v)) = True
univariate _ = False
getVarsAndCandidates :: Type -> StateQ (Name, [ConcreteType])
getVarsAndCandidates (AppT (ConT cls) (VarT v)) = (v,) <$> knownConcreteInstances' opts cls
getVarsAndCandidates t = lift $ (,[]) <$> newName ""
warn opts s = if verbose opts
then lift $ reportWarning s
else return ()
info opts s = if verbose opts
then lift $ runIO $ putStrLn s
else return ()
allReplacements :: M.Map Name [ConcreteType] -> InstanceHead -> [ConcreteType]
allReplacements var2candidates instHead = foldM substitute instHead (assocs var2candidates)
where
assocs hm = zip (M.keys hm) (M.elems hm)
substitute :: Type -> (Name, [ConcreteType]) -> [Type]
substitute (VarT v1) (v2, substs)
| v1 == v2 = substs
| otherwise = [VarT v1]
substitute (AppT t1 t2) s = AppT <$> (substitute t1 s) <*> (substitute t2 s)
substitute (SigT t1 k) s = SigT <$> (substitute t1 s) <*> [k]
substitute (InfixT t1 n t2) s = InfixT <$> (substitute t1 s) <*> [n] <*> (substitute t2 s)
substitute (UInfixT t1 n t2) s = UInfixT <$> (substitute t1 s) <*> [n] <*> (substitute t2 s)
substitute (ParensT t1) s = ParensT <$> (substitute t1 s)
substitute t _ = [t]
monomorphic :: Type -> Bool
monomorphic (VarT v) = False
monomorphic (AppT t1 t2) = monomorphic t1 && monomorphic t2
monomorphic (SigT t1 _) = monomorphic t1
monomorphic (InfixT t1 n t2) = monomorphic t1 && monomorphic t2
monomorphic (UInfixT t1 n t2) = monomorphic t1 && monomorphic t2
monomorphic (ParensT t1) = monomorphic t1
monomorphic _ = True
intersectInstances :: [Name] -> Q [InstanceDec]
intersectInstances classNames = foldl1 intersect <$> (traverse getInstances classNames)