{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE CPP #-}
module Shapes.Linear.Template where
import Test.QuickCheck.Arbitrary
import Control.Monad
import Language.Haskell.TH
data ValueInfo = ValueInfo { _valueN :: Name
, _valueWrap :: Name
, _valueBoxed :: Name
, _valueAdd :: Name
, _valueSub :: Name
, _valueMul :: Name
, _valueDiv :: Name
, _valueNeg :: Name
, _valueEq :: Name
, _valueNeq :: Name
, _valueLeq :: Name
, _valueGeq :: Name
, _valueGt :: Name
, _valueLt :: Name
}
makeInlineD :: Name -> DecQ
makeInlineD n = pragInlD n Inline FunLike AllPhases
makeVectorN :: Int -> Name
makeVectorN dim = mkName $ "V" ++ show dim
makeVectorType :: ValueInfo -> Int -> DecsQ
makeVectorType vi@ValueInfo{..} dim = do
#if MIN_VERSION_template_haskell(2,11,0)
notStrict_ <- bang noSourceUnpackedness noSourceStrictness
#else
notStrict_ <- notStrict
#endif
let vectorN = makeVectorN dim
constrArg = (notStrict_, ConT _valueN)
definers = [ defineLift
, defineLift2
, defineDot
, defineFromList
, defineToList
, deriveShow
, deriveArbitrary
]
impls <- concat <$> mapM (\f -> f vectorN vi dim) definers
#if MIN_VERSION_template_haskell(2,11,0)
let decs = DataD [] vectorN [] Nothing [NormalC vectorN (replicate dim constrArg)] [] : impls
#else
let decs = DataD [] vectorN [] [NormalC vectorN (replicate dim constrArg)] [] : impls
#endif
return decs
deriveShow :: Name -> ValueInfo -> Int -> DecsQ
deriveShow vectorN ValueInfo{..} dim = do
(pat, vars) <- conPE vectorN "a" dim
let f [] = [| "" |]
f (v:vs) = [| " " ++ show $(appE (conE _valueWrap) v) ++ $(f vs) |]
constructorShown = nameBase vectorN
showClause = clause [pat] (normalB [| constructorShown ++ $(f vars) |]) []
return <$> instanceD (cxt []) (appT (conT ''Show) (conT vectorN)) [funD 'show [showClause]]
dimE :: Int -> ExpQ
dimE = litE . integerL . fromIntegral
deriveArbitrary :: Name -> ValueInfo -> Int -> DecsQ
deriveArbitrary vectorN ValueInfo{..} dim = do
let arbClause = clause [] (normalB $ infixApp (fromListE vectorN) (varE '(<$>)) arbList) []
arbList = [| replicateM $(dimE dim) arbitrary |]
return <$> instanceD (cxt []) (appT (conT ''Arbitrary) (conT vectorN)) [funD 'arbitrary [arbClause]]
defineLift :: Name -> ValueInfo -> Int -> DecsQ
defineLift vectorN ValueInfo{..} dim = do
(funcP, funcV) <- newPE "f"
(vecP, elemVars) <- conPE vectorN "a" dim
let liftClause = clause [funcP, vecP] liftBody []
f = appE funcV
liftBody = normalB $ appsE (conE vectorN : fmap f elemVars)
liftName = mkName $ "lift" ++ nameBase vectorN
valueT = conT _valueN
vectorT = conT vectorN
liftType = arrowsT [arrowsT [valueT, valueT], vectorT, vectorT]
inlSigDef liftName liftType [liftClause]
defineLift2 :: Name -> ValueInfo -> Int -> DecsQ
defineLift2 vectorN ValueInfo{..} dim = do
(funcP, funcV) <- newPE "f"
(vecP, elemVars) <- conPE vectorN "a" dim
(vecP', elemVars') <- conPE vectorN "b" dim
let pairVars = zip elemVars elemVars'
liftClause = clause [funcP, vecP, vecP'] liftBody []
f (x, y) = appsE [funcV, x, y]
liftBody = normalB $ appsE (conE vectorN : fmap f pairVars)
liftName = mkName $ "lift2" ++ nameBase vectorN
valueT = conT _valueN
vectorT = conT vectorN
liftType = arrowsT [arrowsT [valueT, valueT, valueT], vectorT, vectorT, vectorT]
inlSigDef liftName liftType [liftClause]
dotE :: ValueInfo -> [ExpQ] -> [ExpQ] -> ExpQ
dotE ValueInfo{..} row col = foldl1 (infixApp' $ varE _valueAdd) products
where products = uncurry (infixApp' $ varE _valueMul) <$> zip row col
defineDot :: Name -> ValueInfo -> Int -> DecsQ
defineDot vectorN vi@ValueInfo{..} dim = do
(vecP, elemVars) <- conPE vectorN "a" dim
(vecP', elemVars') <- conPE vectorN "b" dim
let dotClause = clause [vecP, vecP'] (normalB $ dotE vi elemVars elemVars') []
dotName = mkName $ "dot" ++ nameBase vectorN
valueT = conT _valueN
vectorT = conT vectorN
dotType = arrowsT [vectorT, vectorT, valueT]
inlSigDef dotName dotType [dotClause]
defineJoinSplit :: ValueInfo -> (Int, Int) -> DecsQ
defineJoinSplit ValueInfo{..} (left, right) = do
let vecN = makeVectorN left
vecN' = makeVectorN right
vecN'' = makeVectorN (left + right)
(vecP, elemVs) <- conPE vecN "a" left
(vecP', elemVs') <- conPE vecN' "b" right
(vecP'', elemVs'') <- conPE vecN'' "c" (left + right)
let joinE = appsE (conE vecN'' : elemVs ++ elemVs')
joinC = simpleClause [vecP, vecP'] joinE
joinN = mkName $ "join" ++ show left ++ "v" ++ show right
joinT = arrowsT [vecT, vecT', vecT'']
(leftVs, rightVs) = splitAt left elemVs''
splitE = tupE [ appsE $ conE vecN : leftVs
, appsE $ conE vecN' : rightVs
]
splitC = simpleClause [vecP''] splitE
splitN = mkName $ "split" ++ show left ++ "v" ++ show right
splitT = arrowsT [vecT'', tupT [vecT, vecT']]
vecT = conT vecN
vecT' = conT vecN'
vecT'' = conT vecN''
joinI <- inlSigDef joinN joinT [joinC]
splitI <- inlSigDef splitN splitT [splitC]
return $ joinI ++ splitI
fromListN :: Name -> Name
fromListN = mkName . ("fromList" ++) . nameBase
fromListE :: Name -> ExpQ
fromListE = varE . fromListN
defineFromList :: Name -> ValueInfo -> Int -> DecsQ
defineFromList vectorN ValueInfo{..} dim = do
(pats, vars) <- genPEWith "x" dim (conP _valueWrap . return . varP) varE
let listPat = listP pats
vecE = appsE (conE vectorN : vars)
fromListClause0 = clause [listPat] (normalB vecE) []
fromListClause1 = clause [wildP] (normalB [| error "wrong number of elements" |]) []
vectorT = conT vectorN
argT = appT listT (conT _valueBoxed)
fromListType = arrowsT [argT, vectorT]
inlSigDef (fromListN vectorN) fromListType [fromListClause0, fromListClause1]
defineToList :: Name -> ValueInfo -> Int -> DecsQ
defineToList vectorN ValueInfo{..} dim = do
(vecP, elemVars) <- conPE vectorN "a" dim
let boxedElemVars = fmap (appE $ conE _valueWrap) elemVars
toListClause = clause [vecP] (normalB $ listE boxedElemVars) []
toListName = mkName $ "toList" ++ nameBase vectorN
vectorT = conT vectorN
resultT = appT listT (conT _valueBoxed)
toListType = arrowsT [vectorT, resultT]
inlSigDef toListName toListType [toListClause]
infixApp' :: ExpQ -> ExpQ -> ExpQ -> ExpQ
infixApp' = flip infixApp
inlSigDef :: Name -> TypeQ -> [ClauseQ] -> DecsQ
inlSigDef funN funT funCs = do
sigdef <- funSigDef funN funT funCs
inl <- makeInlineD funN
return $ sigdef ++ [inl]
funSigDef :: Name -> TypeQ -> [ClauseQ] -> DecsQ
funSigDef funN funT funCs = do
funSig <- sigD funN funT
funDef <- funD funN funCs
return [funSig, funDef]
tupT :: [TypeQ] -> TypeQ
tupT ts = foldl appT (tupleT $ length ts) ts
arrowsT :: [TypeQ] -> TypeQ
arrowsT [] = error "can't have no type"
arrowsT [t] = t
arrowsT (t:ts) = appT (appT arrowT t) $ arrowsT ts
newPE :: String -> Q (PatQ, ExpQ)
newPE x = do
x' <- newName x
return (varP x', varE x')
conPE :: Name -> String -> Int -> Q (PatQ, [ExpQ])
conPE conN x dim = do
(pats, vars) <- genPE x dim
return (conP conN pats, vars)
genPEWith :: String -> Int -> (Name -> PatQ) -> (Name -> ExpQ) -> Q ([PatQ], [ExpQ])
genPEWith x n mkP mkE = do
ids <- replicateM n (newName x)
return (fmap mkP ids, fmap mkE ids)
genPE :: String -> Int -> Q ([PatQ], [ExpQ])
genPE x n = genPEWith x n varP varE
simpleClause :: [PatQ] -> ExpQ -> ClauseQ
simpleClause ps e = clause ps (normalB e) []