{-|
Copyright  :  (C) 2019-2024, QBayLogic B.V.
License    :  BSD2 (see the file LICENSE)
Maintainer :  QBayLogic B.V. <devops@qbaylogic.com>
-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE TemplateHaskell #-}

module Clash.Class.BitPack.Internal.TH where

import           Clash.CPP             (maxTupleSize)
import           Language.Haskell.TH.Compat (mkTySynInstD,mkTupE)
import           Control.Monad         (replicateM)
#if !MIN_VERSION_base(4,20,0)
import           Data.List             (foldl')
#endif
import           GHC.TypeLits          (KnownNat)
import           Language.Haskell.TH

-- | Contruct all the tuple (starting at size 3) instances for BitPack.
deriveBitPackTuples
  :: Name
  -- ^ BitPack
  -> Name
  -- ^ BitSize
  -> Name
  -- ^ pack
  -> Name
  -- ^ unpack
  -> DecsQ
deriveBitPackTuples :: Name -> Name -> Name -> Name -> DecsQ
deriveBitPackTuples Name
bitPackName Name
bitSizeName Name
packName Name
unpackName = do
  let bitPack :: Type
bitPack  = Name -> Type
ConT Name
bitPackName
      bitSize :: Type
bitSize  = Name -> Type
ConT Name
bitSizeName
      knownNat :: Type
knownNat = Name -> Type
ConT ''KnownNat
      plus :: Type
plus     = Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName String
"+"

  [Name]
allNames <- Int -> Q Name -> Q [Name]
forall (m :: Type -> Type) a. Applicative m => Int -> m a -> m [a]
replicateM Int
forall a. Num a => a
maxTupleSize (String -> Q Name
newName String
"a")
  Name
retupName <- String -> Q Name
newName String
"retup"
  Name
x <- String -> Q Name
newName String
"x"
  Name
y <- String -> Q Name
newName String
"y"
  Name
tup <- String -> Q Name
newName String
"tup"

  [Dec] -> DecsQ
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ([Dec] -> DecsQ) -> [Dec] -> DecsQ
forall a b. (a -> b) -> a -> b
$ ((Int -> Dec) -> [Int] -> [Dec]) -> [Int] -> (Int -> Dec) -> [Dec]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> Dec) -> [Int] -> [Dec]
forall a b. (a -> b) -> [a] -> [b]
map [Int
3..Int
forall a. Num a => a
maxTupleSize] ((Int -> Dec) -> [Dec]) -> (Int -> Dec) -> [Dec]
forall a b. (a -> b) -> a -> b
$ \Int
tupleNum ->
    let names :: [Name]
names  = Int -> [Name] -> [Name]
forall a. Int -> [a] -> [a]
take Int
tupleNum [Name]
allNames
        (Type
v,[Type]
vs) = case (Name -> Type) -> [Name] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Type
VarT [Name]
names of
                    (Type
z:[Type]
zs) -> (Type
z,[Type]
zs)
                    [Type]
_ -> String -> (Type, [Type])
forall a. HasCallStack => String -> a
error String
"maxTupleSize <= 3"
        tuple :: t Type -> Type
tuple t Type
xs = (Type -> Type -> Type) -> Type -> t Type -> Type
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Type -> Type -> Type
AppT (Int -> Type
TupleT (Int -> Type) -> Int -> Type
forall a b. (a -> b) -> a -> b
$ t Type -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length t Type
xs) t Type
xs

        -- Instance declaration
        context :: [Type]
context =
          [ Type
bitPack Type -> Type -> Type
`AppT` Type
v
          , Type
knownNat Type -> Type -> Type
`AppT` (Type
bitSize Type -> Type -> Type
`AppT` Type
v)
          , Type
bitPack Type -> Type -> Type
`AppT` [Type] -> Type
forall (t :: Type -> Type). Foldable t => t Type -> Type
tuple [Type]
vs
          , Type
knownNat Type -> Type -> Type
`AppT` (Type
bitSize Type -> Type -> Type
`AppT` [Type] -> Type
forall (t :: Type -> Type). Foldable t => t Type -> Type
tuple [Type]
vs)
          ]
        instTy :: Type
instTy = Type -> Type -> Type
AppT Type
bitPack (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ [Type] -> Type
forall (t :: Type -> Type). Foldable t => t Type -> Type
tuple (Type
vType -> [Type] -> [Type]
forall a. a -> [a] -> [a]
:[Type]
vs)

        -- Associated type BitSize
        bitSizeType :: Dec
bitSizeType =
          Name -> [Type] -> Type -> Dec
mkTySynInstD Name
bitSizeName [[Type] -> Type
forall (t :: Type -> Type). Foldable t => t Type -> Type
tuple (Type
vType -> [Type] -> [Type]
forall a. a -> [a] -> [a]
:[Type]
vs)]
            (Type -> Dec) -> Type -> Dec
forall a b. (a -> b) -> a -> b
$ Type
plus Type -> Type -> Type
`AppT` (Type
bitSize Type -> Type -> Type
`AppT` Type
v) Type -> Type -> Type
`AppT`
              (Type
bitSize Type -> Type -> Type
`AppT` (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Int -> Type
TupleT (Int -> Type) -> Int -> Type
forall a b. (a -> b) -> a -> b
$ Int
tupleNum Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [Type]
vs)

        pack :: Dec
pack =
          Name -> [Clause] -> Dec
FunD
            Name
packName
            [ [Pat] -> Body -> [Dec] -> Clause
Clause
                [Name -> Pat
VarP Name
tup]
                (Exp -> Body
NormalB (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE Name
packName) (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE Name
retupName) (Name -> Exp
VarE Name
tup))))
                [Name -> [Clause] -> Dec
FunD
                    Name
retupName
                    [ [Pat] -> Body -> [Dec] -> Clause
Clause
                        [ [Pat] -> Pat
TupP ([Pat] -> Pat) -> [Pat] -> Pat
forall a b. (a -> b) -> a -> b
$ (Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
names ]
                        ( let (Exp
e,[Exp]
es) = case (Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
names of
                                          (Exp
z:[Exp]
zs) -> (Exp
z,[Exp]
zs)
                                          [Exp]
_ -> String -> (Exp, [Exp])
forall a. HasCallStack => String -> a
error String
"maxTupleSize <= 3"
                          in Exp -> Body
NormalB ([Exp] -> Exp
mkTupE [Exp
e,[Exp] -> Exp
mkTupE [Exp]
es])
                        )
                        []
                    ]
                ]
            ]

        unpack :: Dec
unpack =
          Name -> [Clause] -> Dec
FunD
            Name
unpackName
            [ [Pat] -> Body -> [Dec] -> Clause
Clause
                [ Name -> Pat
VarP Name
x ]
                ( Exp -> Body
NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$
                    let (Pat
p,[Pat]
ps) = case (Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
names of
                                   (Pat
z:[Pat]
zs) -> (Pat
z,[Pat]
zs)
                                   [Pat]
_ -> String -> (Pat, [Pat])
forall a. HasCallStack => String -> a
error String
"maxTupleSize <= 3"
                    in
                    [Dec] -> Exp -> Exp
LetE
                      [ Pat -> Body -> [Dec] -> Dec
ValD
                          ( [Pat] -> Pat
TupP [ Pat
p, Name -> Pat
VarP Name
y ] )
                          ( Exp -> Body
NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$ Name -> Exp
VarE Name
unpackName Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
x )
                          []
                      , Pat -> Body -> [Dec] -> Dec
ValD
                          ( [Pat] -> Pat
TupP [Pat]
ps )
                          ( Exp -> Body
NormalB (Exp -> Body) -> Exp -> Body
forall a b. (a -> b) -> a -> b
$ Name -> Exp
VarE Name
unpackName Exp -> Exp -> Exp
`AppE` Name -> Exp
VarE Name
y )
                          []
                      ]
                      ( [Exp] -> Exp
mkTupE ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
names )
                )
                []
            ]

    in Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
forall a. Maybe a
Nothing [Type]
context Type
instTy [Dec
bitSizeType, Dec
pack, Dec
unpack]