{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}

-- | Implementation of a 'Storable' deriver for data types. This works for
-- any non-recursive datatype which has 'Storable' fields.
--
-- Most users won't need to import this module directly. Instead, use
-- 'derive' / 'Deriving' to create 'Storable' instances.
module TH.Derive.Storable
    ( makeStorableInst
    ) where

import           Control.Applicative
import           Control.Monad
import           Data.List (find)
import           Data.Maybe (fromMaybe)
import           Data.Word
import           Foreign.Ptr
import           Foreign.Storable
import           Language.Haskell.TH
import           Language.Haskell.TH.Syntax
import           Prelude
import           TH.Derive.Internal
import           TH.ReifySimple
import           TH.Utilities

instance Deriver (Storable a) where
    runDeriver :: Proxy (Storable a) -> Cxt -> Type -> Q [Dec]
runDeriver Proxy (Storable a)
_ = Cxt -> Type -> Q [Dec]
makeStorableInst

-- | Implementation used for 'runDeriver'.
makeStorableInst :: Cxt -> Type -> Q [Dec]
makeStorableInst :: Cxt -> Type -> Q [Dec]
makeStorableInst Cxt
preds Type
ty = do
    Type
argTy <- Name -> Type -> Q Type
expectTyCon1 ''Storable Type
ty
    DataType
dt <- Type -> Q DataType
reifyDataTypeSubstituted Type
argTy
    Cxt -> Type -> [DataCon] -> Q [Dec]
makeStorableImpl Cxt
preds Type
ty (DataType -> [DataCon]
dtCons DataType
dt)

-- TODO: recursion check? At least document that this could in some
-- cases work, but produce a bogus instance.

makeStorableImpl :: Cxt -> Type -> [DataCon] -> Q [Dec]
makeStorableImpl :: Cxt -> Type -> [DataCon] -> Q [Dec]
makeStorableImpl Cxt
preds Type
headTy [DataCon]
cons = do
    -- Since this instance doesn't pay attention to alignment, we
    -- just say alignment doesn't matter.
    Exp
alignmentMethod <- [| 1 |]
    Exp
sizeOfMethod <- Q Exp
sizeExpr
    Exp
peekMethod <- Q Exp
peekExpr
    Exp
pokeMethod <- Q Exp
pokeExpr
    let methods :: [Dec]
methods =
            [ Name -> [Clause] -> Dec
FunD (String -> Name
mkName String
"alignment") [[Pat] -> Body -> [Dec] -> Clause
Clause [Pat
WildP] (Exp -> Body
NormalB Exp
alignmentMethod) []]
            , Name -> [Clause] -> Dec
FunD (String -> Name
mkName String
"sizeOf") [[Pat] -> Body -> [Dec] -> Clause
Clause [] (Exp -> Body
NormalB Exp
sizeOfMethod) []]
            , Name -> [Clause] -> Dec
FunD (String -> Name
mkName String
"peek") [[Pat] -> Body -> [Dec] -> Clause
Clause [Name -> Pat
VarP Name
ptrName] (Exp -> Body
NormalB Exp
peekMethod) []]
            , Name -> [Clause] -> Dec
FunD (String -> Name
mkName String
"poke") [[Pat] -> Body -> [Dec] -> Clause
Clause [Name -> Pat
VarP Name
ptrName, Name -> Pat
VarP Name
valName] (Exp -> Body
NormalB Exp
pokeMethod) []]
            ]
    forall (m :: * -> *) a. Monad m => a -> m a
return [Cxt -> Type -> [Dec] -> Dec
plainInstanceD Cxt
preds Type
headTy [Dec]
methods]
  where
    -- NOTE: Much of the code here resembles code in store for deriving
    -- Store instances. Changes here may be relevant there as well.
    (Name
tagType, Int
_, Int
tagSize) =
        forall a. a -> Maybe a -> a
fromMaybe (forall a. HasCallStack => String -> a
error String
"Too many constructors") forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(Name
_, Int
maxN, Int
_) -> Int
maxN forall a. Ord a => a -> a -> Bool
>= forall (t :: * -> *) a. Foldable t => t a -> Int
length [DataCon]
cons) [(Name, Int, Int)]
tagTypes
    tagTypes :: [(Name, Int, Int)]
    tagTypes :: [(Name, Int, Int)]
tagTypes =
        [ ('(), Int
1, Int
0)
        , (''Word8, forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Word8), Int
1)
        , (''Word16, forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Word16), Int
2)
        , (''Word32, forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Word32), Int
4)
        , (''Word64, forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Word64), Int
8)
        ]
    valName :: Name
valName = String -> Name
mkName String
"val"
    tagName :: Name
tagName = String -> Name
mkName String
"tag"
    ptrName :: Name
ptrName = String -> Name
mkName String
"ptr"
    fName :: a -> Name
fName a
ix = String -> Name
mkName (String
"f" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show a
ix)
    ptrExpr :: Q Exp
ptrExpr = forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
ptrName
    -- [[Int]] expression, where the inner lists are the sizes of the
    -- fields. Each member of the outer list corresponds to a different
    -- constructor.
    sizeExpr :: Q Exp
sizeExpr = forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE (forall (m :: * -> *). Quote m => Name -> m Exp
varE 'const) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE (forall (m :: * -> *). Quote m => Name -> m Exp
varE 'maximum) forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *). Quote m => [m Exp] -> m Exp
listE [ forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE (forall (m :: * -> *). Quote m => Name -> m Exp
varE 'sum) (forall (m :: * -> *). Quote m => [m Exp] -> m Exp
listE [forall {m :: * -> *}. Quote m => Type -> m Exp
sizeOfExpr Type
ty | (Maybe Name
_, Type
ty) <- [(Maybe Name, Type)]
fields])
              | (DataCon Name
_ [Name]
_ Cxt
_ [(Maybe Name, Type)]
fields) <- [DataCon]
cons
              ]
    -- Choose a tag size large enough for this constructor count.
    -- Expression used for the definition of peek.
    peekExpr :: Q Exp
peekExpr = case [DataCon]
cons of
        [] -> [| error ("Attempting to peek type with no constructors (" ++ $(lift (pprint headTy)) ++ ")") |]
        [DataCon
con] -> DataCon -> Q Exp
peekCon DataCon
con
        [DataCon]
_ -> forall (m :: * -> *). Quote m => [m Stmt] -> m Exp
doE
            [ forall (m :: * -> *). Quote m => m Pat -> m Exp -> m Stmt
bindS (forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
tagName) [| peek (castPtr $(ptrExpr)) |]
            , forall (m :: * -> *). Quote m => m Exp -> m Stmt
noBindS (forall (m :: * -> *). Quote m => m Exp -> [m Match] -> m Exp
caseE (forall (m :: * -> *). Quote m => m Exp -> m Type -> m Exp
sigE (forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
tagName) (forall (m :: * -> *). Quote m => Name -> m Type
conT Name
tagType))
                             (forall a b. (a -> b) -> [a] -> [b]
map (Integer, DataCon) -> Q Match
peekMatch (forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0..] [DataCon]
cons) forall a. [a] -> [a] -> [a]
++ [Q Match
peekErr]))
            ]
    peekMatch :: (Integer, DataCon) -> Q Match
peekMatch (Integer
ix, DataCon
con) = forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Match
match (forall (m :: * -> *). Quote m => Lit -> m Pat
litP (Integer -> Lit
IntegerL Integer
ix)) (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB (DataCon -> Q Exp
peekCon DataCon
con)) []
    peekErr :: Q Match
peekErr = forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Match
match forall (m :: * -> *). Quote m => m Pat
wildP (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [| error ("Found invalid tag while peeking (" ++ $(lift (pprint headTy)) ++ ")") |]) []
    peekCon :: DataCon -> Q Exp
peekCon (DataCon Name
cname [Name]
_ Cxt
_ [(Maybe Name, Type)]
fields) =
        forall (m :: * -> *). Quote m => [m Dec] -> m Exp -> m Exp
letE (forall {m :: * -> *} {a}. Quote m => [(a, Type)] -> [m Dec]
offsetDecls [(Maybe Name, Type)]
fields) forall a b. (a -> b) -> a -> b
$
        case [(Maybe Name, Type)]
fields of
            [] -> [| pure $(conE cname) |]
            ((Maybe Name, Type)
_:[(Maybe Name, Type)]
fields') ->
                forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Q Exp
acc (Int
ix, (Maybe Name, Type)
_) -> [| $(acc) <*> $(peekOffset ix) |] )
                      [| $(conE cname) <$> $(peekOffset 0) |]
                      (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1..] [(Maybe Name, Type)]
fields')
    peekOffset :: Int -> Q Exp
peekOffset Int
ix = [| peek (castPtr (plusPtr $(ptrExpr) $(varE (offset ix)))) |]
    -- Expression used for the definition of poke.
    pokeExpr :: Q Exp
pokeExpr = forall (m :: * -> *). Quote m => m Exp -> [m Match] -> m Exp
caseE (forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
valName) (forall a b. (a -> b) -> [a] -> [b]
map (Int, DataCon) -> Q Match
pokeMatch (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] [DataCon]
cons))
    pokeMatch :: (Int, DataCon) -> MatchQ
    pokeMatch :: (Int, DataCon) -> Q Match
pokeMatch (Int
ixcon, DataCon Name
cname [Name]
_ Cxt
_ [(Maybe Name, Type)]
fields) =
        forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Match
match (forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP Name
cname (forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *). Quote m => Name -> m Pat
varP (forall a b. (a -> b) -> [a] -> [b]
map forall {a}. Show a => a -> Name
fName [Int]
ixs)))
              (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB (case [Q Stmt]
tagPokes forall a. [a] -> [a] -> [a]
++ [Q Stmt]
offsetLet forall a. [a] -> [a] -> [a]
++ [Q Stmt]
fieldPokes of
                           [] -> [|return ()|]
                           [Q Stmt]
stmts -> forall (m :: * -> *). Quote m => [m Stmt] -> m Exp
doE [Q Stmt]
stmts))
              []
      where
        tagPokes :: [Q Stmt]
tagPokes = case [DataCon]
cons of
            (DataCon
_:DataCon
_:[DataCon]
_) -> [forall (m :: * -> *). Quote m => m Exp -> m Stmt
noBindS [| poke (castPtr $(ptrExpr)) (ixcon :: $(conT tagType)) |]]
            [DataCon]
_ -> []
        offsetLet :: [Q Stmt]
offsetLet
            | forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
ixs = []
            | Bool
otherwise = [forall (m :: * -> *). Quote m => [m Dec] -> m Stmt
letS (forall {m :: * -> *} {a}. Quote m => [(a, Type)] -> [m Dec]
offsetDecls [(Maybe Name, Type)]
fields)]
        fieldPokes :: [Q Stmt]
fieldPokes = forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). Quote m => m Exp -> m Stmt
noBindS forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Q Exp
pokeField) [Int]
ixs
        ixs :: [Int]
ixs = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] [(Maybe Name, Type)]
fields)
    pokeField :: Int -> Q Exp
pokeField Int
ix = [| poke (castPtr (plusPtr $(ptrExpr)
                                             $(varE (offset ix))))
                           $(varE (fName ix)) |]
    -- Generate declarations which compute the field offsets.
    offsetDecls :: [(a, Type)] -> [m Dec]
offsetDecls [(a, Type)]
fields =
        -- Skip the last one, to avoid unused variable warnings.
        forall a. [a] -> [a]
init forall a b. (a -> b) -> a -> b
$
        forall a b. (a -> b) -> [a] -> [b]
map (\(Int
ix, m Exp
expr) -> forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Dec
valD (forall (m :: * -> *). Quote m => Name -> m Pat
varP (Int -> Name
offset Int
ix)) (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB m Exp
expr) []) forall a b. (a -> b) -> a -> b
$
        -- Initial offset is the tag size.
        ((Int
0, [| tagSize |]) forall a. a -> [a] -> [a]
:) forall a b. (a -> b) -> a -> b
$
        forall a b. (a -> b) -> [a] -> [b]
map (\(Int
ix, (a
_, Type
ty)) -> (Int
ix, forall {m :: * -> *}. Quote m => Int -> Type -> m Exp
offsetExpr Int
ix Type
ty)) forall a b. (a -> b) -> a -> b
$
        forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1..] [(a, Type)]
fields
      where
        offsetExpr :: Int -> Type -> m Exp
offsetExpr Int
ix Type
ty = [| $(sizeOfExpr ty) + $(varE (offset (ix - 1))) |]
    sizeOfExpr :: Type -> m Exp
sizeOfExpr Type
ty = [| $(varE 'sizeOf) (error "sizeOf evaluated its argument" :: $(return ty)) |]
    offset :: Int -> Name
offset Int
ix = String -> Name
mkName (String
"offset" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (Int
ix :: Int))