{-# LANGUAGE TemplateHaskell #-}

-- |
-- Module      : Streamly.Internal.Data.Unbox.TH
-- Copyright   : (c) 2023 Composewell Technologies
-- License     : BSD3-3-Clause
-- Maintainer  : streamly@composewell.com
-- Stability   : experimental
-- Portability : GHC
--
module Streamly.Internal.Data.Unbox.TH
    ( deriveUnbox

    -- th-helpers
    , DataCon(..)
    , DataType(..)
    , reifyDataType
    ) where

--------------------------------------------------------------------------------
-- Imports
--------------------------------------------------------------------------------

import Data.Word (Word16, Word32, Word64, Word8)
import Data.Proxy (Proxy(..))
import Data.List (elemIndex)

import Language.Haskell.TH
import Language.Haskell.TH.Syntax
import Streamly.Internal.Data.Unbox

--------------------------------------------------------------------------------
-- th-utilities
--------------------------------------------------------------------------------

-- Note: We don't support template-haskell < 2.14 (GHC < 8.6)

-- The following are copied to remove the dependency on th-utilities.
-- The code has been copied from th-abstraction and th-utilities.

-- Some CPP macros in the following code are not required but are kept
-- anyway. They can be removed if deemed as a maintainance burden.

#if MIN_VERSION_template_haskell(2,17,0)
type TyVarBndr_ flag = TyVarBndr flag
#else
type TyVarBndr_ flag = TyVarBndr
#endif

-- | Case analysis for a 'TyVarBndr'. If the value is a @'PlainTV' n _@, apply
-- the first function to @n@; if it is @'KindedTV' n _ k@, apply the second
-- function to @n@ and @k@.
elimTV :: (Name -> r) -> (Name -> Kind -> r) -> TyVarBndr_ flag -> r
#if MIN_VERSION_template_haskell(2,17,0)
elimTV :: forall r flag.
(Name -> r) -> (Name -> Type -> r) -> TyVarBndr_ flag -> r
elimTV Name -> r
ptv Name -> Type -> r
_ktv (PlainTV Name
n flag
_)    = Name -> r
ptv Name
n
elimTV Name -> r
_ptv Name -> Type -> r
ktv (KindedTV Name
n flag
_ Type
k) = Name -> Type -> r
ktv Name
n Type
k
#else
elimTV ptv _ktv (PlainTV n)    = ptv n
elimTV _ptv ktv (KindedTV n k) = ktv n k
#endif

-- | Extract the type variable name from a 'TyVarBndr', ignoring the
-- kind signature if one exists.
tvName :: TyVarBndr_ flag -> Name
tvName :: forall flag. TyVarBndr_ flag -> Name
tvName = forall r flag.
(Name -> r) -> (Name -> Type -> r) -> TyVarBndr_ flag -> r
elimTV forall a. a -> a
id (\Name
n Type
_ -> Name
n)

-- | Get the 'Name' of a 'TyVarBndr'
tyVarBndrName :: TyVarBndr_ flag -> Name
tyVarBndrName :: forall flag. TyVarBndr_ flag -> Name
tyVarBndrName = forall flag. TyVarBndr_ flag -> Name
tvName

-- | Simplified info about a 'DataD'. Omits deriving, strictness,
-- kind info, and whether it's @data@ or @newtype@.
data DataType = DataType
    { DataType -> Name
dtName :: Name
    , DataType -> [Name]
dtTvs :: [Name]
    , DataType -> Cxt
dtCxt :: Cxt
    , DataType -> [DataCon]
dtCons :: [DataCon]
    } deriving (DataType -> DataType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DataType -> DataType -> Bool
$c/= :: DataType -> DataType -> Bool
== :: DataType -> DataType -> Bool
$c== :: DataType -> DataType -> Bool
Eq, Int -> DataType -> ShowS
[DataType] -> ShowS
DataType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DataType] -> ShowS
$cshowList :: [DataType] -> ShowS
show :: DataType -> String
$cshow :: DataType -> String
showsPrec :: Int -> DataType -> ShowS
$cshowsPrec :: Int -> DataType -> ShowS
Show, Eq DataType
DataType -> DataType -> Bool
DataType -> DataType -> Ordering
DataType -> DataType -> DataType
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: DataType -> DataType -> DataType
$cmin :: DataType -> DataType -> DataType
max :: DataType -> DataType -> DataType
$cmax :: DataType -> DataType -> DataType
>= :: DataType -> DataType -> Bool
$c>= :: DataType -> DataType -> Bool
> :: DataType -> DataType -> Bool
$c> :: DataType -> DataType -> Bool
<= :: DataType -> DataType -> Bool
$c<= :: DataType -> DataType -> Bool
< :: DataType -> DataType -> Bool
$c< :: DataType -> DataType -> Bool
compare :: DataType -> DataType -> Ordering
$ccompare :: DataType -> DataType -> Ordering
Ord) --, Data, Typeable, Generic)

-- | Simplified info about a 'Con'. Omits deriving, strictness, and kind
-- info. This is much nicer than consuming 'Con' directly, because it
-- unifies all the constructors into one.
data DataCon = DataCon
    { DataCon -> Name
dcName :: Name
    , DataCon -> [Name]
dcTvs :: [Name]
    , DataCon -> Cxt
dcCxt :: Cxt
    , DataCon -> [(Maybe Name, Type)]
dcFields :: [(Maybe Name, Type)]
    } deriving (DataCon -> DataCon -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DataCon -> DataCon -> Bool
$c/= :: DataCon -> DataCon -> Bool
== :: DataCon -> DataCon -> Bool
$c== :: DataCon -> DataCon -> Bool
Eq, Int -> DataCon -> ShowS
[DataCon] -> ShowS
DataCon -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DataCon] -> ShowS
$cshowList :: [DataCon] -> ShowS
show :: DataCon -> String
$cshow :: DataCon -> String
showsPrec :: Int -> DataCon -> ShowS
$cshowsPrec :: Int -> DataCon -> ShowS
Show, Eq DataCon
DataCon -> DataCon -> Bool
DataCon -> DataCon -> Ordering
DataCon -> DataCon -> DataCon
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: DataCon -> DataCon -> DataCon
$cmin :: DataCon -> DataCon -> DataCon
max :: DataCon -> DataCon -> DataCon
$cmax :: DataCon -> DataCon -> DataCon
>= :: DataCon -> DataCon -> Bool
$c>= :: DataCon -> DataCon -> Bool
> :: DataCon -> DataCon -> Bool
$c> :: DataCon -> DataCon -> Bool
<= :: DataCon -> DataCon -> Bool
$c<= :: DataCon -> DataCon -> Bool
< :: DataCon -> DataCon -> Bool
$c< :: DataCon -> DataCon -> Bool
compare :: DataCon -> DataCon -> Ordering
$ccompare :: DataCon -> DataCon -> Ordering
Ord) --, Data, Typeable, Generic)


-- | Convert a 'Con' to a list of 'DataCon'. The result is a list
-- because 'GadtC' and 'RecGadtC' can define multiple constructors.
conToDataCons :: Con -> [DataCon]
conToDataCons :: Con -> [DataCon]
conToDataCons = \case
    NormalC Name
name [BangType]
slots ->
        [Name -> [Name] -> Cxt -> [(Maybe Name, Type)] -> DataCon
DataCon Name
name [] [] (forall a b. (a -> b) -> [a] -> [b]
map (\(Bang
_, Type
ty) -> (forall a. Maybe a
Nothing, Type
ty)) [BangType]
slots)]
    RecC Name
name [VarBangType]
fields ->
        [Name -> [Name] -> Cxt -> [(Maybe Name, Type)] -> DataCon
DataCon Name
name [] [] (forall a b. (a -> b) -> [a] -> [b]
map (\(Name
n, Bang
_, Type
ty) -> (forall a. a -> Maybe a
Just Name
n, Type
ty)) [VarBangType]
fields)]
    InfixC (Bang
_, Type
ty1) Name
name (Bang
_, Type
ty2) ->
        [Name -> [Name] -> Cxt -> [(Maybe Name, Type)] -> DataCon
DataCon Name
name [] [] [(forall a. Maybe a
Nothing, Type
ty1), (forall a. Maybe a
Nothing, Type
ty2)]]
    ForallC [TyVarBndr Specificity]
tvs Cxt
preds Con
con ->
        forall a b. (a -> b) -> [a] -> [b]
map (\(DataCon Name
name [Name]
tvs0 Cxt
preds0 [(Maybe Name, Type)]
fields) ->
            Name -> [Name] -> Cxt -> [(Maybe Name, Type)] -> DataCon
DataCon Name
name ([Name]
tvs0 forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map forall flag. TyVarBndr_ flag -> Name
tyVarBndrName [TyVarBndr Specificity]
tvs) (Cxt
preds0 forall a. [a] -> [a] -> [a]
++ Cxt
preds) [(Maybe Name, Type)]
fields) (Con -> [DataCon]
conToDataCons Con
con)
#if MIN_VERSION_template_haskell(2,11,0)
    GadtC [Name]
ns [BangType]
slots Type
_ ->
        forall a b. (a -> b) -> [a] -> [b]
map (\Name
dn -> Name -> [Name] -> Cxt -> [(Maybe Name, Type)] -> DataCon
DataCon Name
dn [] [] (forall a b. (a -> b) -> [a] -> [b]
map (\(Bang
_, Type
ty) -> (forall a. Maybe a
Nothing, Type
ty)) [BangType]
slots)) [Name]
ns
    RecGadtC [Name]
ns [VarBangType]
fields Type
_ ->
        forall a b. (a -> b) -> [a] -> [b]
map (\Name
dn -> Name -> [Name] -> Cxt -> [(Maybe Name, Type)] -> DataCon
DataCon Name
dn [] [] (forall a b. (a -> b) -> [a] -> [b]
map (\(Name
fn, Bang
_, Type
ty) -> (forall a. a -> Maybe a
Just Name
fn, Type
ty)) [VarBangType]
fields)) [Name]
ns
#endif

-- | Reify the given data or newtype declaration, and yields its
-- 'DataType' representation.
reifyDataType :: Name -> Q DataType
reifyDataType :: Name -> Q DataType
reifyDataType Name
name = do
    Info
info <- Name -> Q Info
reify Name
name
    case Info -> Maybe DataType
infoToDataType Info
info of
        Maybe DataType
Nothing -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Expected to reify a datatype. Instead got:\n" forall a. [a] -> [a] -> [a]
++ forall a. Ppr a => a -> String
pprint Info
info
        Just DataType
x -> forall (m :: * -> *) a. Monad m => a -> m a
return DataType
x

infoToDataType :: Info -> Maybe DataType
infoToDataType :: Info -> Maybe DataType
infoToDataType Info
info = case Info
info of
#if MIN_VERSION_template_haskell(2,11,0)
    TyConI (DataD Cxt
preds Name
name [TyVarBndr ()]
tvs Maybe Type
_kind [Con]
cons [DerivClause]
_deriving) ->
#else
    TyConI (DataD preds name tvs cons _deriving) ->
#endif
        forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Name -> [Name] -> Cxt -> [DataCon] -> DataType
DataType Name
name (forall a b. (a -> b) -> [a] -> [b]
map forall flag. TyVarBndr_ flag -> Name
tyVarBndrName [TyVarBndr ()]
tvs) Cxt
preds (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Con -> [DataCon]
conToDataCons [Con]
cons)
#if MIN_VERSION_template_haskell(2,11,0)
    TyConI (NewtypeD Cxt
preds Name
name [TyVarBndr ()]
tvs Maybe Type
_kind Con
con [DerivClause]
_deriving) ->
#else
    TyConI (NewtypeD preds name tvs con _deriving) ->
#endif
        forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Name -> [Name] -> Cxt -> [DataCon] -> DataType
DataType Name
name (forall a b. (a -> b) -> [a] -> [b]
map forall flag. TyVarBndr_ flag -> Name
tyVarBndrName [TyVarBndr ()]
tvs) Cxt
preds (Con -> [DataCon]
conToDataCons Con
con)
    Info
_ -> forall a. Maybe a
Nothing

--------------------------------------------------------------------------------
-- Helpers
--------------------------------------------------------------------------------

type Field = (Maybe Name, Type)

_arr :: Name
_arr :: Name
_arr = String -> Name
mkName String
"arr"

_tag :: Name
_tag :: Name
_tag = String -> Name
mkName String
"tag"

_initialOffset :: Name
_initialOffset :: Name
_initialOffset = String -> Name
mkName String
"initialOffset"

_val :: Name
_val :: Name
_val = String -> Name
mkName String
"val"

mkOffsetName :: Int -> Name
mkOffsetName :: Int -> Name
mkOffsetName Int
i = String -> Name
mkName (String
"offset" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
i)

mkFieldName :: Int -> Name
mkFieldName :: Int -> Name
mkFieldName Int
i = String -> Name
mkName (String
"field" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
i)

--------------------------------------------------------------------------------
-- Domain specific helpers
--------------------------------------------------------------------------------

exprGetSize :: Type -> Q Exp
exprGetSize :: Type -> Q Exp
exprGetSize Type
ty = forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE (forall (m :: * -> *). Quote m => Name -> m Exp
varE 'sizeOf) [|Proxy :: Proxy $(pure ty)|]

getTagSize :: Int -> Int
getTagSize :: Int -> Int
getTagSize Int
numConstructors
    | Int
numConstructors forall a. Eq a => a -> a -> Bool
== Int
1 = Int
0
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Word8) forall a. Ord a => a -> a -> Bool
>= Int
numConstructors = Int
1
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Word16) forall a. Ord a => a -> a -> Bool
>= Int
numConstructors = Int
2
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Word32) forall a. Ord a => a -> a -> Bool
>= Int
numConstructors = Int
4
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Word64) forall a. Ord a => a -> a -> Bool
>= Int
numConstructors = Int
8
    | Bool
otherwise = forall a. HasCallStack => String -> a
error String
"Too many constructors"

getTagType :: Int -> Name
getTagType :: Int -> Name
getTagType Int
numConstructors
    | Int
numConstructors forall a. Eq a => a -> a -> Bool
== Int
1 = forall a. HasCallStack => String -> a
error String
"No tag for 1 constructor"
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Word8) forall a. Ord a => a -> a -> Bool
>= Int
numConstructors = ''Word8
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Word16) forall a. Ord a => a -> a -> Bool
>= Int
numConstructors = ''Word16
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Word32) forall a. Ord a => a -> a -> Bool
>= Int
numConstructors = ''Word32
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Word64) forall a. Ord a => a -> a -> Bool
>= Int
numConstructors = ''Word64
    | Bool
otherwise = forall a. HasCallStack => String -> a
error String
"Too many constructors"

mkOffsetDecls :: Int -> [Field] -> [Q Dec]
mkOffsetDecls :: Int -> [(Maybe Name, Type)] -> [Q Dec]
mkOffsetDecls Int
tagSize [(Maybe Name, Type)]
fields =
    forall a. [a] -> [a]
init
        ((:) (forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Dec
valD
                  (forall (m :: * -> *). Quote m => Name -> m Pat
varP (Int -> Name
mkOffsetName Int
0))
                  (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB
                       [|$(litE (IntegerL (fromIntegral tagSize))) +
                         $(varE _initialOffset)|])
                  [])
             (forall a b. (a -> b) -> [a] -> [b]
map forall {a}. (Int, (a, Type)) -> Q Dec
mkOffsetExpr (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1 ..] [(Maybe Name, Type)]
fields)))

    where

    mkOffsetExpr :: (Int, (a, Type)) -> Q Dec
mkOffsetExpr (Int
i, (a
_, Type
ty)) =
        forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Dec
valD
            (forall (m :: * -> *). Quote m => Name -> m Pat
varP (Int -> Name
mkOffsetName Int
i))
            (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [|$(varE (mkOffsetName (i - 1))) + $(exprGetSize ty)|])
            []

--------------------------------------------------------------------------------
-- Size
--------------------------------------------------------------------------------

isUnitType :: [DataCon] -> Bool
isUnitType :: [DataCon] -> Bool
isUnitType [DataCon Name
_ [Name]
_ Cxt
_ []] = Bool
True
isUnitType [DataCon]
_ = Bool
False

mkSizeOfExpr :: Type -> [DataCon] -> Q Exp
mkSizeOfExpr :: Type -> [DataCon] -> Q Exp
mkSizeOfExpr Type
headTy [DataCon]
constructors =
    case [DataCon]
constructors of
        [] ->
            [|error
                  ("Attempting to get size with no constructors (" ++
                   $(lift (pprint headTy)) ++ ")")|]
        -- One constructor with no fields is a unit type. Size of a unit type is
        -- 1.
        [con :: DataCon
con@(DataCon Name
_ [Name]
_ Cxt
_ [(Maybe Name, Type)]
fields)] ->
            case [(Maybe Name, Type)]
fields of
                [] -> forall (m :: * -> *). Quote m => Lit -> m Exp
litE (Integer -> Lit
IntegerL Integer
1)
                [(Maybe Name, Type)]
_ -> [|$(sizeOfConstructor con)|]
        [DataCon]
_ -> [|$(litE (IntegerL (fromIntegral tagSize))) + $(sizeOfHeadDt)|]

    where

    tagSize :: Int
tagSize = Int -> Int
getTagSize (forall (t :: * -> *) a. Foldable t => t a -> Int
length [DataCon]
constructors)
    sizeOfField :: (a, Type) -> Q Exp
sizeOfField (a
_, Type
ty) = Type -> Q Exp
exprGetSize Type
ty
    sizeOfConstructor :: DataCon -> Q Exp
sizeOfConstructor (DataCon Name
_ [Name]
_ Cxt
_ [(Maybe Name, Type)]
fields) =
        forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE (forall (m :: * -> *). Quote m => Name -> m Exp
varE 'sum) (forall (m :: * -> *). Quote m => [m Exp] -> m Exp
listE (forall a b. (a -> b) -> [a] -> [b]
map forall {a}. (a, Type) -> Q Exp
sizeOfField [(Maybe Name, Type)]
fields))
    -- The size of any Unbox type is atleast 1
    sizeOfHeadDt :: Q Exp
sizeOfHeadDt =
        forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE (forall (m :: * -> *). Quote m => Name -> m Exp
varE 'maximum) (forall (m :: * -> *). Quote m => [m Exp] -> m Exp
listE (forall a b. (a -> b) -> [a] -> [b]
map DataCon -> Q Exp
sizeOfConstructor [DataCon]
constructors))

--------------------------------------------------------------------------------
-- Peek
--------------------------------------------------------------------------------

mkPeekExprOne :: Int -> DataCon -> Q Exp
mkPeekExprOne :: Int -> DataCon -> Q Exp
mkPeekExprOne Int
tagSize (DataCon Name
cname [Name]
_ Cxt
_ [(Maybe Name, Type)]
fields) =
    case [(Maybe Name, Type)]
fields of
        [] -> [|pure $(conE cname)|]
        [(Maybe Name, Type)]
_ ->
            forall (m :: * -> *). Quote m => [m Dec] -> m Exp -> m Exp
letE
                (Int -> [(Maybe Name, Type)] -> [Q Dec]
mkOffsetDecls Int
tagSize [(Maybe Name, Type)]
fields)
                (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
                     (\Q Exp
acc Int
i -> [|$(acc) <*> $(peekField i)|])
                     [|$(conE cname) <$> $(peekField 0)|]
                     [Int
1 .. (forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Maybe Name, Type)]
fields forall a. Num a => a -> a -> a
- Int
1)])

    where

    peekField :: Int -> m Exp
peekField Int
i = [|peekAt $(varE (mkOffsetName i)) $(varE _arr)|]

mkPeekExpr :: Type -> [DataCon] -> Q Exp
mkPeekExpr :: Type -> [DataCon] -> Q Exp
mkPeekExpr Type
headTy [DataCon]
cons =
    case [DataCon]
cons of
        [] ->
            [|error
                  ("Attempting to peek type with no constructors (" ++
                   $(lift (pprint headTy)) ++ ")")|]
        [DataCon
con] -> Int -> DataCon -> Q Exp
mkPeekExprOne Int
0 DataCon
con
        [DataCon]
_ ->
            forall (m :: * -> *). Quote m => [m Stmt] -> m Exp
doE
                [ forall (m :: * -> *). Quote m => m Pat -> m Exp -> m Stmt
bindS
                      (forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
_tag)
                      [|peekAt $(varE _initialOffset) $(varE _arr)|]
                , forall (m :: * -> *). Quote m => m Exp -> m Stmt
noBindS
                      (forall (m :: * -> *). Quote m => m Exp -> [m Match] -> m Exp
caseE
                           (forall (m :: * -> *). Quote m => m Exp -> m Type -> m Exp
sigE (forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
_tag) (forall (m :: * -> *). Quote m => Name -> m Type
conT Name
tagType))
                           (forall a b. (a -> b) -> [a] -> [b]
map (Integer, DataCon) -> Q Match
peekMatch (forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0 ..] [DataCon]
cons) forall a. [a] -> [a] -> [a]
++ [Q Match
peekErr]))
                ]

    where

    lenCons :: Int
lenCons = forall (t :: * -> *) a. Foldable t => t a -> Int
length [DataCon]
cons
    tagType :: Name
tagType = Int -> Name
getTagType Int
lenCons
    tagSize :: Int
tagSize = Int -> Int
getTagSize Int
lenCons
    peekMatch :: (Integer, DataCon) -> Q Match
peekMatch (Integer
i, DataCon
con) =
        forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Match
match (forall (m :: * -> *). Quote m => Lit -> m Pat
litP (Integer -> Lit
IntegerL Integer
i)) (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB (Int -> DataCon -> Q Exp
mkPeekExprOne Int
tagSize DataCon
con)) []
    peekErr :: Q Match
peekErr =
        forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Match
match
            forall (m :: * -> *). Quote m => m Pat
wildP
            (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB
                 [|error
                       ("Found invalid tag while peeking (" ++
                        $(lift (pprint headTy)) ++ ")")|])
            []

--------------------------------------------------------------------------------
-- Poke
--------------------------------------------------------------------------------

mkPokeExprTag :: Name -> Int -> Q Exp
mkPokeExprTag :: Name -> Int -> Q Exp
mkPokeExprTag Name
tagType Int
tagVal = Q Exp
pokeTag

    where

    pokeTag :: Q Exp
pokeTag =
        [|pokeAt
              $(varE _initialOffset)
              $(varE _arr)
              $((sigE (litE (IntegerL (fromIntegral tagVal))) (conT tagType)))|]

mkPokeExprFields :: Int -> [Field] -> Q Exp
mkPokeExprFields :: Int -> [(Maybe Name, Type)] -> Q Exp
mkPokeExprFields Int
tagSize [(Maybe Name, Type)]
fields = do
    case [(Maybe Name, Type)]
fields of
        [] -> [|pure ()|]
        [(Maybe Name, Type)]
_ ->
            forall (m :: * -> *). Quote m => [m Dec] -> m Exp -> m Exp
letE
                (Int -> [(Maybe Name, Type)] -> [Q Dec]
mkOffsetDecls Int
tagSize [(Maybe Name, Type)]
fields)
                (forall (m :: * -> *). Quote m => [m Stmt] -> m Exp
doE forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). Quote m => m Exp -> m Stmt
noBindS forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {m :: * -> *}. Quote m => Int -> m Exp
pokeField) [Int
0 .. (Int
numFields forall a. Num a => a -> a -> a
- Int
1)])

    where

    numFields :: Int
numFields = forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Maybe Name, Type)]
fields
    pokeField :: Int -> m Exp
pokeField Int
i =
        [|pokeAt
              $(varE (mkOffsetName i))
              $(varE _arr)
              $(varE (mkFieldName i))|]

mkPokeMatch :: Name -> Int -> Q Exp -> Q Match
mkPokeMatch :: Name -> Int -> Q Exp -> Q Match
mkPokeMatch Name
cname Int
numFields Q Exp
exp0 =
    forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Match
match
        (forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP Name
cname (forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *). Quote m => Name -> m Pat
varP (forall a b. (a -> b) -> [a] -> [b]
map Int -> Name
mkFieldName [Int
0 .. (Int
numFields forall a. Num a => a -> a -> a
- Int
1)])))
        (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB Q Exp
exp0)
        []

mkPokeExpr :: Type -> [DataCon] -> Q Exp
mkPokeExpr :: Type -> [DataCon] -> Q Exp
mkPokeExpr Type
headTy [DataCon]
cons =
    case [DataCon]
cons of
        [] ->
            [|error
                  ("Attempting to poke type with no constructors (" ++
                   $(lift (pprint headTy)) ++ ")")|]
        [(DataCon Name
_ [Name]
_ Cxt
_ [])] -> [|pure ()|]
        [(DataCon Name
cname [Name]
_ Cxt
_ [(Maybe Name, Type)]
fields)] ->
            forall (m :: * -> *). Quote m => m Exp -> [m Match] -> m Exp
caseE
                (forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
_val)
                [Name -> Int -> Q Exp -> Q Match
mkPokeMatch Name
cname (forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Maybe Name, Type)]
fields) (Int -> [(Maybe Name, Type)] -> Q Exp
mkPokeExprFields Int
0 [(Maybe Name, Type)]
fields)]
        [DataCon]
_ ->
            forall (m :: * -> *). Quote m => m Exp -> [m Match] -> m Exp
caseE
                (forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
_val)
                (forall a b. (a -> b) -> [a] -> [b]
map (\(Int
tagVal, (DataCon Name
cname [Name]
_ Cxt
_ [(Maybe Name, Type)]
fields)) ->
                          Name -> Int -> Q Exp -> Q Match
mkPokeMatch
                              Name
cname
                              (forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Maybe Name, Type)]
fields)
                              (forall (m :: * -> *). Quote m => [m Stmt] -> m Exp
doE [ forall (m :: * -> *). Quote m => m Exp -> m Stmt
noBindS forall a b. (a -> b) -> a -> b
$ Name -> Int -> Q Exp
mkPokeExprTag Name
tagType Int
tagVal
                                   , forall (m :: * -> *). Quote m => m Exp -> m Stmt
noBindS forall a b. (a -> b) -> a -> b
$ Int -> [(Maybe Name, Type)] -> Q Exp
mkPokeExprFields Int
tagSize [(Maybe Name, Type)]
fields
                                   ]))
                     (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [DataCon]
cons))

    where

    lenCons :: Int
lenCons = forall (t :: * -> *) a. Foldable t => t a -> Int
length [DataCon]
cons
    tagType :: Name
tagType = Int -> Name
getTagType Int
lenCons
    tagSize :: Int
tagSize = Int -> Int
getTagSize Int
lenCons

--------------------------------------------------------------------------------
-- Main
--------------------------------------------------------------------------------

-- | A general function to derive Unbox instances where you can control which
-- Constructors of the datatype to consider and what the Context for the 'Unbox'
-- instance would be.
--
-- Consider the datatype:
-- @
-- data CustomDataType a b
--     = CDTConstructor1
--     | CDTConstructor2 Bool
--     | CDTConstructor3 Bool b
--     deriving (Show, Eq)
-- @
--
-- Usage:
-- @
-- $(deriveUnboxInternal
--       [AppT (ConT ''Unbox) (VarT (mkName "b"))]
--       (AppT
--            (AppT (ConT ''CustomDataType) (VarT (mkName "a")))
--            (VarT (mkName "b")))
--       [ DataCon 'CDTConstructor1 [] [] []
--       , DataCon 'CDTConstructor2 [] [] [(Nothing, (ConT ''Bool))]
--       , DataCon
--             'CDTConstructor3
--             []
--             []
--             [(Nothing, (ConT ''Bool)), (Nothing, (VarT (mkName "b")))]
--       ])
-- @
deriveUnboxInternal :: Type -> [DataCon] -> ([Dec] -> Q [Dec]) -> Q [Dec]
deriveUnboxInternal :: Type -> [DataCon] -> ([Dec] -> Q [Dec]) -> Q [Dec]
deriveUnboxInternal Type
headTy [DataCon]
cons [Dec] -> Q [Dec]
mkDec = do
    Exp
sizeOfMethod <- Type -> [DataCon] -> Q Exp
mkSizeOfExpr Type
headTy [DataCon]
cons
    Exp
peekMethod <- Type -> [DataCon] -> Q Exp
mkPeekExpr Type
headTy [DataCon]
cons
    Exp
pokeMethod <- Type -> [DataCon] -> Q Exp
mkPokeExpr Type
headTy [DataCon]
cons
    let methods :: [Dec]
methods =
            -- INLINE on sizeOf actually worsens some benchmarks, and improves
            -- none
            [ -- PragmaD (InlineP 'sizeOf Inline FunLike AllPhases)
              Name -> [Clause] -> Dec
FunD 'sizeOf [[Pat] -> Body -> [Dec] -> Clause
Clause [Pat
WildP] (Exp -> Body
NormalB Exp
sizeOfMethod) []]
            , Pragma -> Dec
PragmaD (Name -> Inline -> RuleMatch -> Phases -> Pragma
InlineP 'peekAt Inline
Inline RuleMatch
FunLike Phases
AllPhases)
            , Name -> [Clause] -> Dec
FunD
                  'peekAt
                  [ [Pat] -> Body -> [Dec] -> Clause
Clause
                        (if [DataCon] -> Bool
isUnitType [DataCon]
cons
                             then [Pat
WildP, Pat
WildP]
                             else [Name -> Pat
VarP Name
_initialOffset, Name -> Pat
VarP Name
_arr])
                        (Exp -> Body
NormalB Exp
peekMethod)
                        []
                  ]
            , Pragma -> Dec
PragmaD (Name -> Inline -> RuleMatch -> Phases -> Pragma
InlineP 'pokeAt Inline
Inline RuleMatch
FunLike Phases
AllPhases)
            , Name -> [Clause] -> Dec
FunD
                  'pokeAt
                  [ [Pat] -> Body -> [Dec] -> Clause
Clause
                        (if [DataCon] -> Bool
isUnitType [DataCon]
cons
                             then [Pat
WildP, Pat
WildP, Pat
WildP]
                             else [Name -> Pat
VarP Name
_initialOffset, Name -> Pat
VarP Name
_arr, Name -> Pat
VarP Name
_val])
                        (Exp -> Body
NormalB Exp
pokeMethod)
                        []
                  ]
            ]
    [Dec] -> Q [Dec]
mkDec [Dec]
methods

-- | Given an 'Unbox' instance declaration splice without the methods (e.g.
-- @[d|instance Unbox a => Unbox (Maybe a)|]@), generate an instance
-- declaration including all the type class method implementations.
--
-- Usage:
--
-- @
-- \$(deriveUnbox [d|instance Unbox a => Unbox (Maybe a)|])
-- @
deriveUnbox :: Q [Dec] -> Q [Dec]
deriveUnbox :: Q [Dec] -> Q [Dec]
deriveUnbox Q [Dec]
mDecs = do
    [Dec]
dec <- Q [Dec]
mDecs
    case [Dec]
dec of
        [InstanceD Maybe Overlap
mo Cxt
preds Type
headTyWC []] -> do
            let headTy :: Type
headTy = forall {a}. Ppr a => a -> Type -> Type
unwrap [Dec]
dec Type
headTyWC
                (Name
mainTyName, Cxt
subs) = forall {p}. Ppr p => p -> Type -> (Name, Cxt)
getMainTypeName [Dec]
dec Type
headTy
            DataType
dt <- Name -> Q DataType
reifyDataType Name
mainTyName
            let tyVars :: [Name]
tyVars = DataType -> [Name]
dtTvs DataType
dt
                mapper :: Type -> Type
mapper = forall {a}. Eq a => [a] -> [a] -> a -> a
mapperWith (Name -> Type
VarT forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
tyVars) Cxt
subs
                cons :: [DataCon]
cons = forall a b. (a -> b) -> [a] -> [b]
map ((Type -> Type) -> DataCon -> DataCon
modifyConVariables Type -> Type
mapper) (DataType -> [DataCon]
dtCons DataType
dt)
            Type -> [DataCon] -> ([Dec] -> Q [Dec]) -> Q [Dec]
deriveUnboxInternal Type
headTy [DataCon]
cons (forall {f :: * -> *}.
Applicative f =>
Maybe Overlap -> Cxt -> Type -> [Dec] -> f [Dec]
mkInst Maybe Overlap
mo Cxt
preds Type
headTyWC)
        [Dec]
_ -> forall {a} {a}. Ppr a => a -> a
errorMessage [Dec]
dec

    where

    mapperWith :: [a] -> [a] -> a -> a
mapperWith [a]
l1 [a]
l2 a
a =
        case forall a. Eq a => a -> [a] -> Maybe Int
elemIndex a
a [a]
l1 of
            Maybe Int
Nothing -> a
a
            -- XXX Capture this case and give a relavant error.
            Just Int
i -> [a]
l2 forall a. [a] -> Int -> a
!! Int
i

    mapType :: (Type -> Type) -> Type -> Type
mapType Type -> Type
f (AppT Type
t1 Type
t2) = Type -> Type -> Type
AppT ((Type -> Type) -> Type -> Type
mapType Type -> Type
f Type
t1) ((Type -> Type) -> Type -> Type
mapType Type -> Type
f Type
t2)
    mapType Type -> Type
f (InfixT Type
t1 Name
n Type
t2) = Type -> Name -> Type -> Type
InfixT ((Type -> Type) -> Type -> Type
mapType Type -> Type
f Type
t1) Name
n ((Type -> Type) -> Type -> Type
mapType Type -> Type
f Type
t2)
    mapType Type -> Type
f (UInfixT Type
t1 Name
n Type
t2) = Type -> Name -> Type -> Type
UInfixT ((Type -> Type) -> Type -> Type
mapType Type -> Type
f Type
t1) Name
n ((Type -> Type) -> Type -> Type
mapType Type -> Type
f Type
t2)
    mapType Type -> Type
f (ParensT Type
t) = Type -> Type
ParensT ((Type -> Type) -> Type -> Type
mapType Type -> Type
f Type
t)
    mapType Type -> Type
f Type
v = Type -> Type
f Type
v

    modifyConVariables :: (Type -> Type) -> DataCon -> DataCon
modifyConVariables Type -> Type
f DataCon
con =
        DataCon
con { dcFields :: [(Maybe Name, Type)]
dcFields = forall a b. (a -> b) -> [a] -> [b]
map (\(Maybe Name
a, Type
b) -> (Maybe Name
a, (Type -> Type) -> Type -> Type
mapType Type -> Type
f Type
b)) (DataCon -> [(Maybe Name, Type)]
dcFields DataCon
con) }

    mkInst :: Maybe Overlap -> Cxt -> Type -> [Dec] -> f [Dec]
mkInst Maybe Overlap
mo Cxt
preds Type
headTyWC [Dec]
methods =
        forall (f :: * -> *) a. Applicative f => a -> f a
pure [Maybe Overlap -> Cxt -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
mo Cxt
preds Type
headTyWC [Dec]
methods]

    errorMessage :: a -> a
errorMessage a
dec =
        forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ [String] -> String
unlines
            [ String
"Error: deriveUnbox:"
            , String
""
            , String
">> " forall a. [a] -> [a] -> [a]
++ forall a. Ppr a => a -> String
pprint a
dec
            , String
""
            , String
"The supplied declaration not a valid instance declaration."
            , String
"Provide a valid Haskell instance declaration without a body."
            , String
""
            , String
"Examples:"
            , String
"instance Unbox (Proxy a)"
            , String
"instance Unbox a => Unbox (Identity a)"
            , String
"instance Unbox (TableT Identity)"
            ]

    unwrap :: a -> Type -> Type
unwrap a
_ (AppT (ConT Name
_) Type
r) = Type
r
    unwrap a
dec Type
_ = forall {a} {a}. Ppr a => a -> a
errorMessage a
dec

    getMainTypeName :: p -> Type -> (Name, Cxt)
getMainTypeName p
dec = Cxt -> Type -> (Name, Cxt)
go []

        where

        go :: Cxt -> Type -> (Name, Cxt)
go Cxt
xs (ConT Name
nm) = (Name
nm, Cxt
xs)
        go Cxt
xs (AppT Type
l Type
r) = Cxt -> Type -> (Name, Cxt)
go (Type
rforall a. a -> [a] -> [a]
:Cxt
xs) Type
l
        go Cxt
_ Type
_ = forall {a} {a}. Ppr a => a -> a
errorMessage p
dec