{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Tensor.Static.TH (
genTensorInstance
) where
import Data.List (foldl')
import Data.Tensor.Static (tensor, Tensor, unsafeFromList, toList, IsTensor)
import Language.Haskell.TH
import qualified Data.List.NonEmpty as N
genTensorInstance :: N.NonEmpty Int
-> Name
-> Q [Dec]
genTensorInstance (N.toList -> dimensions) elemTypeName = do
conName <- newName ("Tensor'" ++ concatMap (\x -> show x ++ "'") dimensions ++ nameBase elemTypeName)
let fieldCount = product dimensions
fieldNames <- mapM (newName . ('x' :) . show) [0 .. fieldCount - 1]
let fields = replicate (fromIntegral fieldCount) (Bang SourceUnpack SourceStrict, elemType)
dims = natListT dimensions
elemType = ConT elemTypeName
let dataInstDec = DataInstD
[]
''Tensor
[dims, elemType]
(Just StarT)
[NormalC conName fields]
[]
let fromListPat = foldr (\name pat -> InfixP (VarP name) '(:) pat) WildP fieldNames
constructTensor = foldl' (\acc name -> acc `AppE` VarE name) (ConE conName) fieldNames
tensorPat = ConP conName (map VarP fieldNames)
toListBody = ListE (map VarE fieldNames)
failBody = VarE 'error `AppE` LitE (StringL ("Not enough elements to build a Tensor of shape "
++ show dimensions))
let tensorDec = ValD (VarP 'tensor) (NormalB $ ConE conName ) []
unsafeFromListDec = FunD 'unsafeFromList [ Clause [fromListPat] (NormalB constructTensor ) []
, Clause [WildP ] (NormalB failBody ) []]
toListDec = FunD 'toList [ Clause [tensorPat ] (NormalB toListBody ) []]
tensorCInstPragmas =
[ PragmaD (InlineP 'tensor Inline FunLike AllPhases)
, PragmaD (InlineP 'unsafeFromList Inline FunLike AllPhases)
, PragmaD (InlineP 'toList Inline FunLike AllPhases) ]
let tensorCInstDec = InstanceD
Nothing
[]
(ConT ''IsTensor `AppT` dims `AppT` elemType)
([dataInstDec, tensorDec, unsafeFromListDec, toListDec] ++ tensorCInstPragmas)
pure [tensorCInstDec]
natListT :: [Int] -> Type
natListT = foldr (\d acc -> PromotedConsT `AppT` LitT (NumTyLit $ fromIntegral d) `AppT` acc) PromotedNilT