{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE ParallelListComp #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Store.TH.Internal
(
deriveManyStoreFromStorable
, deriveTupleStoreInstance
, deriveGenericInstance
, deriveManyStorePrimVector
, deriveManyStoreUnboxVector
, deriveStore
, makeStore
, getAllInstanceTypes1
, isMonoType
) where
import Control.Applicative
import Data.Complex ()
import Data.Generics.Aliases (extT, mkQ, extQ)
import Data.Generics.Schemes (listify, everywhere, something)
import Data.List (find)
import qualified Data.Map as M
import Data.Maybe (fromMaybe, isJust)
import Data.Primitive.ByteArray
import Data.Primitive.Types
import Data.Store.Core
import Data.Store.Impl
import qualified Data.Text as T
import Data.Traversable (forM)
import qualified Data.Vector.Primitive as PV
import qualified Data.Vector.Unboxed as UV
import Data.Word
import Foreign.Storable (Storable)
import GHC.Types (Int(..))
import Language.Haskell.TH
import Language.Haskell.TH.ReifyMany.Internal (TypeclassInstance(..), getInstances, unAppsT)
import Language.Haskell.TH.Syntax (lift)
import Prelude
import Safe (headMay)
import TH.Derive (Deriver(..))
import TH.ReifySimple
import TH.Utilities (expectTyCon1, dequalify, plainInstanceD, appsT)
instance Deriver (Store a) where
runDeriver _ preds ty = do
argTy <- expectTyCon1 ''Store ty
dt <- reifyDataTypeSubstituted argTy
(:[]) <$> deriveStore preds argTy (dtCons dt)
makeStore :: Name -> Q [Dec]
makeStore name = do
dt <- reifyDataType name
let preds = map (storePred . VarT) (dtTvs dt)
argTy = appsT (ConT name) (map VarT (dtTvs dt))
(:[]) <$> deriveStore preds argTy (dtCons dt)
deriveStore :: Cxt -> Type -> [DataCon] -> Q Dec
deriveStore preds headTy cons0 =
makeStoreInstance preds headTy
<$> sizeExpr
<*> peekExpr
<*> pokeExpr
where
cons :: [(Name, [(Name, Type)])]
cons =
[ ( dcName dc
, [ (mkName ("c" ++ show ixc ++ "f" ++ show ixf), ty)
| ixf <- ints
| (_, ty) <- dcFields dc
]
)
| ixc <- ints
| dc <- cons0
]
(tagType, _, tagSize) =
fromMaybe (error "Too many constructors") $
find (\(_, maxN, _) -> maxN >= length cons) tagTypes
tagTypes :: [(Name, Int, Int)]
tagTypes =
[ ('(), 1, 0)
, (''Word8, fromIntegral (maxBound :: Word8), 1)
, (''Word16, fromIntegral (maxBound :: Word16), 2)
, (''Word32, fromIntegral (maxBound :: Word32), 4)
, (''Word64, fromIntegral (maxBound :: Word64), 8)
]
fName ix = mkName ("f" ++ show ix)
ints = [0..] :: [Int]
fNames = map fName ints
sizeNames = zipWith (\_ -> mkName . ("sz" ++) . show) cons ints
tagName = mkName "tag"
valName = mkName "val"
sizeExpr
| length cons <= 62 =
caseE (tupE (concatMap (map sizeAtType . snd) cons))
(case cons of
[] -> [matchConstSize]
[c] | null (snd c) -> [matchConstSize]
_ -> [matchConstSize, matchVarSize])
| otherwise = varSizeExpr
where
sizeAtType :: (Name, Type) -> ExpQ
sizeAtType (_, ty) = [| size :: Size $(return ty) |]
matchConstSize :: MatchQ
matchConstSize = do
let sz0 = VarE (mkName "sz0")
sizeDecls =
if null sizeNames
then [valD (varP (mkName "sz0")) (normalB [| 0 |]) []]
else zipWith constSizeDec sizeNames cons
sameSizeExpr <-
case sizeNames of
(_ : tailSizeNames) ->
foldl (\l r -> [| $(l) && $(r) |]) [| True |] $
map (\szn -> [| $(return sz0) == $(varE szn) |]) tailSizeNames
[] -> [| True |]
result <- [| ConstSize (tagSize + $(return sz0)) |]
match (tupP (map (\(n, _) -> conP 'ConstSize [varP n])
(concatMap snd cons)))
(guardedB [return (NormalG sameSizeExpr, result)])
sizeDecls
constSizeDec :: Name -> (Name, [(Name, Type)]) -> DecQ
constSizeDec szn (_, []) =
valD (varP szn) (normalB [| 0 |]) []
constSizeDec szn (_, fields) =
valD (varP szn) body []
where
body = normalB $
foldl1 (\l r -> [| $(l) + $(r) |]) $
map (\(sizeName, _) -> varE sizeName) fields
matchVarSize :: MatchQ
matchVarSize = do
match (tupP (map (\(n, _) -> varP n) (concatMap snd cons)))
(normalB varSizeExpr)
[]
varSizeExpr :: ExpQ
varSizeExpr =
[| VarSize $ \x -> tagSize + $(caseE [| x |] (map matchVar cons)) |]
matchVar :: (Name, [(Name, Type)]) -> MatchQ
matchVar (cname, []) =
match (conP cname []) (normalB [| 0 |]) []
matchVar (cname, fields) =
match (conP cname (zipWith (\_ fn -> varP fn) fields fNames))
body
[]
where
body = normalB $
foldl1 (\l r -> [| $(l) + $(r) |])
(zipWith (\(sizeName, _) fn -> [| getSizeWith $(varE sizeName) $(varE fn) |])
fields
fNames)
peekExpr = case cons of
[] -> [| error ("Attempting to peek type with no constructors (" ++ $(lift (show headTy)) ++ ")") |]
[con] -> peekCon con
_ -> doE
[ bindS (varP tagName) [| peek |]
, noBindS (caseE (sigE (varE tagName) (conT tagType))
(map peekMatch (zip [0..] cons) ++ [peekErr]))
]
peekMatch (ix, con) = match (litP (IntegerL ix)) (normalB (peekCon con)) []
peekErr = match wildP (normalB
[| peekException $ T.pack $ "Found invalid tag while peeking (" ++ $(lift (show headTy)) ++ ")" |]) []
peekCon (cname, fields) =
case fields of
[] -> [| pure $(conE cname) |]
_ -> doE $
map (\(fn, _) -> bindS (varP fn) [| peek |]) fields ++
[noBindS $ appE (varE 'return) $ appsE $ conE cname : map (\(fn, _) -> varE fn) fields]
pokeExpr = lamE [varP valName] $ caseE (varE valName) $ zipWith pokeCon [0..] cons
pokeCon :: Int -> (Name, [(Name, Type)]) -> MatchQ
pokeCon ix (cname, fields) =
match (conP cname (map (\(fn, _) -> varP fn) fields)) body []
where
body = normalB $
case cons of
(_:_:_) -> doE (pokeTag ix : map pokeField fields)
_ -> doE (map pokeField fields)
pokeTag ix = noBindS [| poke (ix :: $(conT tagType)) |]
pokeField (fn, _) = noBindS [| poke $(varE fn) |]
deriveTupleStoreInstance :: Int -> Dec
deriveTupleStoreInstance n =
deriveGenericInstance (map storePred tvs)
(foldl1 AppT (TupleT n : tvs))
where
tvs = take n (map (VarT . mkName . (:[])) ['a'..'z'])
deriveGenericInstance :: Cxt -> Type -> Dec
deriveGenericInstance cs ty = plainInstanceD cs (AppT (ConT ''Store) ty) []
deriveManyStoreFromStorable :: (Type -> Bool) -> Q [Dec]
deriveManyStoreFromStorable p = do
storables <- postprocess . instancesMap <$> getInstances ''Storable
stores <- postprocess . instancesMap <$> getInstances ''Store
return $ M.elems $ flip M.mapMaybe (storables `M.difference` stores) $
\(TypeclassInstance cs ty _) ->
let argTy = head (tail (unAppsT ty))
tyNameLit = LitE (StringL (pprint ty)) in
if p argTy && not (superclassHasStorable cs)
then Just $ makeStoreInstance cs argTy
(AppE (VarE 'sizeStorableTy) tyNameLit)
(AppE (VarE 'peekStorableTy) tyNameLit)
(VarE 'pokeStorable)
else Nothing
superclassHasStorable :: Cxt -> Bool
superclassHasStorable = isJust . something (mkQ Nothing justStorable `extQ` ignoreStrings)
where
justStorable :: Type -> Maybe ()
justStorable (ConT n) | n == ''Storable = Just ()
justStorable _ = Nothing
ignoreStrings :: String -> Maybe ()
ignoreStrings _ = Nothing
deriveManyStorePrimVector :: Q [Dec]
deriveManyStorePrimVector = do
prims <- postprocess . instancesMap <$> getInstances ''PV.Prim
stores <- postprocess . instancesMap <$> getInstances ''Store
let primInsts =
M.mapKeys (map (AppT (ConT ''PV.Vector))) prims
`M.difference`
stores
forM (M.toList primInsts) $ \primInst -> case primInst of
([_], TypeclassInstance cs ty _) -> do
let argTy = head (tail (unAppsT ty))
sizeExpr <- [|
VarSize $ \x ->
I# $(primSizeOfExpr (ConT ''Int)) +
I# $(primSizeOfExpr argTy) * PV.length x
|]
peekExpr <- [| do
len <- peek
let sz = I# $(primSizeOfExpr argTy)
array <- peekToByteArray $(lift ("Primitive Vector (" ++ pprint argTy ++ ")"))
(len * sz)
return (PV.Vector 0 len array)
|]
pokeExpr <- [| \(PV.Vector offset len (ByteArray array)) -> do
let sz = I# $(primSizeOfExpr argTy)
poke len
pokeFromByteArray array (offset * sz) (len * sz)
|]
return $ makeStoreInstance cs (AppT (ConT ''PV.Vector) argTy) sizeExpr peekExpr pokeExpr
_ -> fail "Invariant violated in derivemanyStorePrimVector"
primSizeOfExpr :: Type -> ExpQ
primSizeOfExpr ty = [| $(varE 'sizeOf#) (error "sizeOf# evaluated its argument" :: $(return ty)) |]
deriveManyStoreUnboxVector :: Q [Dec]
deriveManyStoreUnboxVector = do
unboxes <- getUnboxInfo
stores <- postprocess . instancesMap <$> getInstances ''Store
unboxInstances <- postprocess . instancesMap <$> getInstances ''UV.Unbox
let dataFamilyDecls =
M.fromList (map (\(preds, ty, cons) -> ([AppT (ConT ''UV.Vector) ty], (preds, cons))) unboxes)
`M.difference`
stores
#if MIN_VERSION_template_haskell(2,10,0)
substituteConstraint (AppT (ConT n) arg)
| n == ''UV.Unbox = AppT (ConT ''Store) (AppT (ConT ''UV.Vector) arg)
#else
substituteConstraint (ClassP n [arg])
| n == ''UV.Unbox = ClassP ''Store [AppT (ConT ''UV.Vector) arg]
#endif
substituteConstraint x = x
forM (M.toList dataFamilyDecls) $ \case
([ty], (_, cons)) -> do
let headTy = getTyHead (unAppsT ty !! 1)
(preds, ty') <- case M.lookup [headTy] unboxInstances of
Nothing -> do
reportWarning $ "No Unbox instance found for " ++ pprint headTy
return ([], ty)
Just (TypeclassInstance cs (AppT _ ty') _) ->
return (map substituteConstraint cs, AppT (ConT ''UV.Vector) ty')
Just _ -> fail "Impossible case"
deriveStore preds ty' cons
_ -> fail "impossible case in deriveManyStoreUnboxVector"
getUnboxInfo :: Q [(Cxt, Type, [DataCon])]
getUnboxInfo = do
FamilyI _ insts <- reify ''UV.Vector
return (map (everywhere (id `extT` dequalVarT) . go) insts)
where
#if MIN_VERSION_template_haskell(2,15,0)
go (NewtypeInstD preds _ lhs _ con _)
| [_, ty] <- unAppsT lhs
= (preds, ty, conToDataCons con)
go (DataInstD preds _ lhs _ cons _)
| [_, ty] <- unAppsT lhs
= (preds, ty, concatMap conToDataCons cons)
#elif MIN_VERSION_template_haskell(2,11,0)
go (NewtypeInstD preds _ [ty] _ con _) = (preds, ty, conToDataCons con)
go (DataInstD preds _ [ty] _ cons _) = (preds, ty, concatMap conToDataCons cons)
#else
go (NewtypeInstD preds _ [ty] con _) = (preds, ty, conToDataCons con)
go (DataInstD preds _ [ty] cons _) = (preds, ty, concatMap conToDataCons cons)
#endif
go x = error ("Unexpected result from reifying Unboxed Vector instances: " ++ pprint x)
dequalVarT :: Type -> Type
dequalVarT (VarT n) = VarT (dequalify n)
dequalVarT ty = ty
postprocess :: M.Map [Type] [a] -> M.Map [Type] a
postprocess =
M.mapMaybeWithKey $ \tys xs ->
case (tys, xs) of
([_ty], [x]) -> Just x
_ -> Nothing
makeStoreInstance :: Cxt -> Type -> Exp -> Exp -> Exp -> Dec
makeStoreInstance cs ty sizeExpr peekExpr pokeExpr =
plainInstanceD
cs
(AppT (ConT ''Store) ty)
[ ValD (VarP 'size) (NormalB sizeExpr) []
, ValD (VarP 'peek) (NormalB peekExpr) []
, ValD (VarP 'poke) (NormalB pokeExpr) []
]
getAllInstanceTypes :: Name -> Q [[Type]]
getAllInstanceTypes n =
map (\(TypeclassInstance _ ty _) -> drop 1 (unAppsT ty)) <$>
getInstances n
getAllInstanceTypes1 :: Name -> Q [Type]
getAllInstanceTypes1 n =
fmap (fmap (fromMaybe (error "getAllMonoInstances1 expected only one type argument") . headMay))
(getAllInstanceTypes n)
isMonoType :: Type -> Bool
isMonoType = null . listify isVarT
isVarT :: Type -> Bool
isVarT VarT{} = True
isVarT _ = False
instancesMap :: [TypeclassInstance] -> M.Map [Type] [TypeclassInstance]
instancesMap =
M.fromListWith (++) .
map (\ti -> (map getTyHead (instanceArgTypes ti), [ti]))
instanceArgTypes :: TypeclassInstance -> [Type]
instanceArgTypes (TypeclassInstance _ ty _) = drop 1 (unAppsT ty)
getTyHead :: Type -> Type
getTyHead (SigT x _) = getTyHead x
getTyHead (ForallT _ _ x) = getTyHead x
getTyHead (AppT l _) = getTyHead l
getTyHead x = x
storePred :: Type -> Pred
storePred ty =
#if MIN_VERSION_template_haskell(2,10,0)
AppT (ConT ''Store) ty
#else
ClassP ''Store [ty]
#endif