{-# LANGUAGE TemplateHaskell      #-}

{-# LANGUAGE ScopedTypeVariables  #-}

{-# LANGUAGE ViewPatterns         #-}

{-# OPTIONS_GHC -fno-warn-orphans #-}



-----------------------------------------------------------------------------

-- |

-- Module      :  Data.Tensor.Static.TH

-- Copyright   :  (C) 2017 Alexey Vagarenko

-- License     :  BSD-style (see LICENSE)

-- Maintainer  :  Alexey Vagarenko (vagarenko@gmail.com)

-- Stability   :  experimental

-- Portability :  non-portable

--

----------------------------------------------------------------------------



module Data.Tensor.Static.TH (

      genTensorInstance

) where



import Data.List                  (foldl')

import Data.Tensor.Static         (tensor, Tensor, unsafeFromList, toList, IsTensor)

import Language.Haskell.TH

import qualified Data.List.NonEmpty as N





-- | Generate instance for tensor and lenses for its elements.

genTensorInstance :: N.NonEmpty Int     -- ^ Dimensions of the tensor.

                  -> Name               -- ^ Type of elements.

                  -> Q [Dec]

genTensorInstance (N.toList -> dimensions) elemTypeName = do

    conName <- newName ("Tensor'" ++ concatMap (\x -> show x ++ "'") dimensions ++ nameBase elemTypeName)



    let fieldCount = product dimensions

    fieldNames <- mapM (newName . ('x' :) . show) [0 .. fieldCount - 1]



    let fields   = replicate (fromIntegral fieldCount) (Bang SourceUnpack SourceStrict, elemType)

        dims     = natListT dimensions

        elemType = ConT elemTypeName



    let dataInstDec = DataInstD

            []                       -- context

            ''Tensor                 -- family name

            [dims, elemType]         -- family params

            (Just StarT)             -- kind

            [NormalC conName fields] -- data constructor with `fieldCount` unpacked fields of type `elemType`

            []

        

    let fromListPat     = foldr (\name pat -> InfixP (VarP name) '(:) pat) WildP fieldNames

        constructTensor = foldl' (\acc name -> acc `AppE` VarE name) (ConE conName) fieldNames

        tensorPat       = ConP conName (map VarP fieldNames)

        toListBody      = ListE (map VarE fieldNames)

        failBody        = VarE 'error `AppE` LitE (StringL ("Not enough elements to build a Tensor of shape "

                                                            ++ show dimensions))



    let tensorDec          = ValD (VarP 'tensor) (NormalB $ ConE conName ) []

        unsafeFromListDec  = FunD 'unsafeFromList [ Clause [fromListPat] (NormalB constructTensor ) []

                                                  , Clause [WildP      ] (NormalB failBody        ) []]  

        toListDec          = FunD 'toList         [ Clause [tensorPat  ] (NormalB toListBody      ) []]

        tensorCInstPragmas =

            [ PragmaD (InlineP 'tensor         Inline FunLike AllPhases)

            , PragmaD (InlineP 'unsafeFromList Inline FunLike AllPhases)

            , PragmaD (InlineP 'toList         Inline FunLike AllPhases) ]



    let tensorCInstDec = InstanceD

            Nothing

            []

            (ConT ''IsTensor `AppT` dims `AppT` elemType)

            ([dataInstDec, tensorDec, unsafeFromListDec, toListDec] ++ tensorCInstPragmas)        



    pure [tensorCInstDec]



-- | Create type-level list of Nat.

natListT :: [Int] -> Type

natListT = foldr (\d acc -> PromotedConsT `AppT` LitT (NumTyLit $ fromIntegral d) `AppT` acc) PromotedNilT