{-# LANGUAGE TemplateHaskell, GADTs #-}
{-# OPTIONS_GHC -fno-warn-type-defaults -fno-warn-missing-signatures #-}
module Data.Random.Source.Internal.TH (monadRandom, randomSource) where

import Data.Bits
import Data.Generics
import Data.List
import Data.Maybe
import Data.Monoid
import Data.Random.Internal.Source (Prim(..), MonadRandom(..), RandomSource(..))
import Data.Random.Internal.Words
import Language.Haskell.TH
import Language.Haskell.TH.Extras
import qualified Language.Haskell.TH.FlexibleDefaults as FD

import Control.Monad.Reader

data Method
    = GetPrim
    | GetWord8
    | GetWord16
    | GetWord32
    | GetWord64
    | GetDouble
    | GetNByteInteger
    deriving (Eq, Ord, Enum, Bounded, Read, Show)

allMethods :: [Method]
allMethods = [minBound .. maxBound]

data Context
    = Generic
    | RandomSource
    | MonadRandom
    deriving (Eq, Ord, Enum, Bounded, Read, Show)

methodNameBase :: Context -> Method -> String
methodNameBase c n = nameBase (methodName c n)

methodName :: Context -> Method -> Name
methodName Generic      GetPrim          = mkName "getPrim"
methodName Generic      GetWord8         = mkName "getWord8"
methodName Generic      GetWord16        = mkName "getWord16"
methodName Generic      GetWord32        = mkName "getWord32"
methodName Generic      GetWord64        = mkName "getWord64"
methodName Generic      GetDouble        = mkName "getDouble"
methodName Generic      GetNByteInteger  = mkName "getNByteInteger"
methodName RandomSource GetPrim          = 'getRandomPrimFrom
methodName RandomSource GetWord8         = 'getRandomWord8From
methodName RandomSource GetWord16        = 'getRandomWord16From
methodName RandomSource GetWord32        = 'getRandomWord32From
methodName RandomSource GetWord64        = 'getRandomWord64From
methodName RandomSource GetDouble        = 'getRandomDoubleFrom
methodName RandomSource GetNByteInteger  = 'getRandomNByteIntegerFrom
methodName MonadRandom  GetPrim          = 'getRandomPrim
methodName MonadRandom  GetWord8         = 'getRandomWord8
methodName MonadRandom  GetWord16        = 'getRandomWord16
methodName MonadRandom  GetWord32        = 'getRandomWord32
methodName MonadRandom  GetWord64        = 'getRandomWord64
methodName MonadRandom  GetDouble        = 'getRandomDouble
methodName MonadRandom  GetNByteInteger  = 'getRandomNByteInteger

isMethodName :: Context -> Name -> Bool
isMethodName c n = isJust (nameToMethod c n)

nameToMethod :: Context -> Name -> Maybe Method
nameToMethod c name
    = lookup name
        [ (n, m)
        | m <- allMethods
        , let n = methodName c m
        ]


-- 'Context'-sensitive version of the FlexibleDefaults DSL
scoreBy :: (a -> b) -> ReaderT Context (FD.Defaults a) t -> ReaderT Context (FD.Defaults b) t
scoreBy f = mapReaderT (FD.scoreBy f)

method :: Method -> ReaderT Context (FD.Function s) t -> ReaderT Context (FD.Defaults s) t
method m f = do
    c <- ask
    mapReaderT (FD.function (methodNameBase c m)) f

implementation :: ReaderT Context (FD.Implementation s) (Q [Dec]) -> ReaderT Context (FD.Function s) ()
implementation = mapReaderT FD.implementation

cost :: Num s => s -> ReaderT Context (FD.Implementation s) ()
cost = lift . FD.cost

dependsOn :: Method -> ReaderT Context (FD.Implementation s) ()
dependsOn m = do
    c <- ask
    lift (FD.dependsOn (methodNameBase c m))

changeContext :: Context -> Context -> Name -> Name
changeContext c1 c2 = replace (fmap (methodName c2) . nameToMethod c1)

-- map all occurrences of generic method names to the proper local ones
-- and introduce a 'src' parameter where needed if the Context is RandomSource
specialize :: Monad m => Q [Dec] -> ReaderT Context m (Q [Dec])
specialize futzedDecsQ = do
    let decQ = fmap genericalizeDecs futzedDecsQ
    c <- ask
    let specializeDec = everywhere (mkT (changeContext Generic c))
    if c == RandomSource
        then return $ do
                src <- newName "_src"
                decs <- decQ
                return (map (addSrcParam src) . specializeDec $ decs)
        else return (fmap specializeDec decQ)

stripTypeSigs :: Q [Dec] -> Q [Dec]
stripTypeSigs = fmap (filter (not . isSig))
    where isSig SigD{} = True; isSig _ = False

addSrcParam :: Name -> Dec -> Dec
addSrcParam src
    = everywhere (mkT expandDecs)
    . everywhere (mkT expandExps)
    where
        srcP = VarP src
        srcE = VarE src

        expandDecs (ValD (VarP n) body decs)
            | isMethodName RandomSource n
            = FunD n [Clause [srcP] body decs]
        expandDecs (FunD n clauses)
            | isMethodName RandomSource n
            = FunD n [Clause (srcP : ps) body decs | Clause ps body decs <- clauses]

        expandDecs other = other

        expandExps e@(VarE n)
            | isMethodName RandomSource n   = AppE e srcE
        expandExps other = other

-- dummy expressions which will be remapped by 'specialize'
dummy :: Method -> ExpQ
dummy = return . VarE . methodName Generic

getPrim, getWord8, getWord16,
    getWord32, getWord64, getDouble,
    getNByteInteger :: ExpQ
getPrim             = dummy GetPrim
getWord8            = dummy GetWord8
getWord16           = dummy GetWord16
getWord32           = dummy GetWord32
getWord64           = dummy GetWord64
getDouble           = dummy GetDouble
getNByteInteger     = dummy GetNByteInteger

-- The defaulting rules for RandomSource and MonadRandom.  Costs are rates of
-- entropy waste (bits discarded per bit requested) plus the occasional ad-hoc
-- penalty where it seems appropriate.

-- TODO: figure out a clean way to break these up for individual testing.
-- Also analyze to see which of these can never be selected (I suspect that set is non-empty)
defaults :: Context -> FD.Defaults (Sum Double) ()
defaults = runReaderT $
    scoreBy Sum $ do
        method GetPrim $ do
            implementation $ do
                mapM_ dependsOn (allMethods \\ [GetPrim])

                -- GHC 6 requires type signatures for GADT matches, even
                -- inside [d||].  This code is evaluated at more than one type, though,
                -- and at its eventual splice site the signature actually isn't even allowed.
                -- So, there's a dummy signature here which is immediately stripped out.
                specialize . stripTypeSigs $
                    [d| getPrim :: Prim a -> m a
                        getPrim PrimWord8               = $getWord8
                        getPrim PrimWord16              = $getWord16
                        getPrim PrimWord32              = $getWord32
                        getPrim PrimWord64              = $getWord64
                        getPrim PrimDouble              = $getDouble
                        getPrim (PrimNByteInteger n)    = $getNByteInteger n
                     |]

        scoreBy (/8) $
            method GetWord8 $ do
                implementation $ do
                    dependsOn GetPrim
                    specialize [d| getWord8 = $getPrim PrimWord8 |]

                implementation $ do
                    cost 1
                    dependsOn GetNByteInteger
                    specialize [d| getWord8 = liftM fromInteger ($getNByteInteger 1) |]

                implementation $ do
                    cost 8
                    dependsOn GetWord16
                    specialize [d| getWord8 = liftM fromIntegral $getWord16 |]

                implementation $ do
                    cost 24
                    dependsOn GetWord32
                    specialize [d| getWord8 = liftM fromIntegral $getWord32 |]

                implementation $ do
                    cost 56
                    dependsOn GetWord64
                    specialize [d| getWord8 = liftM fromIntegral $getWord64 |]

                implementation $ do
                    cost 64
                    dependsOn GetDouble
                    specialize [d| getWord8 = liftM (truncate . (256*)) $getDouble |]

        scoreBy (/16) $
            method GetWord16 $ do
                implementation $ do
                    dependsOn GetPrim
                    specialize [d| getWord16 = $getPrim PrimWord16 |]

                implementation $ do
                    cost 1
                    dependsOn GetNByteInteger
                    specialize [d| getWord16 = liftM fromInteger ($getNByteInteger 2) |]

                implementation $ do
                    dependsOn GetWord8
                    specialize
                        [d|
                            getWord16 = do
                                a <- $getWord8
                                b <- $getWord8
                                return (buildWord16 a b)
                         |]

                implementation $ do
                    cost 16
                    dependsOn GetWord32
                    specialize [d| getWord16 = liftM fromIntegral $getWord32 |]

                implementation $ do
                    cost 48
                    dependsOn GetWord64
                    specialize [d| getWord16 = liftM fromIntegral $getWord64 |]

                implementation $ do
                    cost 64
                    dependsOn GetDouble
                    specialize [d| getWord16 = liftM (truncate . (65536*)) $getDouble |]

        scoreBy (/32) $
            method GetWord32 $ do
                implementation $ do
                    dependsOn GetPrim
                    specialize [d| getWord32 = $getPrim PrimWord32 |]

                implementation $ do
                    cost 1
                    dependsOn GetNByteInteger
                    specialize [d| getWord32 = liftM fromInteger ($getNByteInteger 4) |]

                implementation $ do
                    cost 0.1
                    dependsOn GetWord8
                    specialize
                        [d|
                            getWord32 = do
                                a <- $getWord8
                                b <- $getWord8
                                c <- $getWord8
                                d <- $getWord8
                                return (buildWord32 a b c d)
                         |]

                implementation $ do
                    dependsOn GetWord16
                    specialize
                        [d|
                            getWord32 = do
                                a <- $getWord16
                                b <- $getWord16
                                return (buildWord32' a b)
                         |]

                implementation $ do
                    cost 32
                    dependsOn GetWord64
                    specialize [d| getWord32 = liftM fromIntegral $getWord64 |]

                implementation $ do
                    cost 64
                    dependsOn GetDouble
                    specialize [d| getWord32 = liftM (truncate . (4294967296*)) $getDouble |]

        scoreBy (/64) $
            method GetWord64 $ do
                implementation $ do
                    dependsOn GetPrim
                    specialize [d| getWord64 = $getPrim PrimWord64 |]

                implementation $ do
                    cost 1
                    dependsOn GetNByteInteger
                    specialize [d| getWord64 = liftM fromInteger ($getNByteInteger 8) |]

                implementation $ do
                    cost 0.2
                    dependsOn GetWord8
                    specialize
                        [d|
                            getWord64 = do
                                a <- $getWord8
                                b <- $getWord8
                                c <- $getWord8
                                d <- $getWord8
                                e <- $getWord8
                                f <- $getWord8
                                g <- $getWord8
                                h <- $getWord8
                                return (buildWord64 a b c d e f g h)
                         |]

                implementation $ do
                    cost 0.1
                    dependsOn GetWord16
                    specialize
                        [d|
                            getWord64 = do
                                a <- $getWord16
                                b <- $getWord16
                                c <- $getWord16
                                d <- $getWord16
                                return (buildWord64' a b c d)
                         |]

                implementation $ do
                    dependsOn GetWord32
                    specialize
                        [d|
                            getWord64 = do
                                a <- $getWord32
                                b <- $getWord32
                                return (buildWord64'' a b)
                         |]

        scoreBy (/52) $
            method GetDouble $ do
                implementation $ do
                    dependsOn GetPrim
                    specialize [d| getDouble = $getPrim PrimDouble |]

                implementation $ do
                    cost 12
                    dependsOn GetWord64
                    specialize
                        [d|
                            getDouble = do
                                w <- $getWord64
                                return (wordToDouble w)
                         |]

        method GetNByteInteger $ do
            implementation $ do
                dependsOn GetPrim
                specialize [d| getNByteInteger n = $getPrim (PrimNByteInteger n) |]

            implementation $ do
                when intIs64 (cost 1e-2)
                dependsOn GetWord8
                dependsOn GetWord16
                dependsOn GetWord32
                specialize
                    [d|
                        getNByteInteger 1 = do
                            x <- $getWord8
                            return $! toInteger x
                        getNByteInteger 2 = do
                            x <- $getWord16
                            return $! toInteger x
                        getNByteInteger 4 = do
                            x <- $getWord32
                            return $! toInteger x
                        getNByteInteger np4
                            | np4 > 4 = do
                                let n = np4 - 4
                                x <- $getWord32
                                y <- $(dummy GetNByteInteger) n
                                return $! (toInteger x `shiftL` (n `shiftL` 3)) .|. y
                        getNByteInteger np2
                            | np2 > 2 = do
                                let n = np2 - 2
                                x <- $getWord16
                                y <- $(dummy GetNByteInteger) n
                                return $! (toInteger x `shiftL` (n `shiftL` 3)) .|. y
                        getNByteInteger _ = return 0
                      |]

            implementation $ do
                when (not intIs64) (cost 1e-2)
                dependsOn GetWord8
                dependsOn GetWord16
                dependsOn GetWord32
                dependsOn GetWord64
                specialize
                    [d|
                        getNByteInteger 1 = do
                            x <- $getWord8
                            return $! toInteger x
                        getNByteInteger 2 = do
                            x <- $getWord16
                            return $! toInteger x
                        getNByteInteger 4 = do
                            x <- $getWord32
                            return $! toInteger x
                        getNByteInteger 8 = do
                            x <- $getWord64
                            return $! toInteger x
                        getNByteInteger np8
                            | np8 > 8 = do
                                let n = np8 - 8
                                x <- $getWord64
                                y <- $(dummy GetNByteInteger) n
                                return $! (toInteger x `shiftL` (n `shiftL` 3)) .|. y
                        getNByteInteger np4
                            | np4 > 4 = do
                                let n = np4 - 4
                                x <- $getWord32
                                y <- $(dummy GetNByteInteger) n
                                return $! (toInteger x `shiftL` (n `shiftL` 3)) .|. y
                        getNByteInteger np2
                            | np2 > 2 = do
                                let n = np2 - 2
                                x <- $getWord16
                                y <- $(dummy GetNByteInteger) n
                                return $! (toInteger x `shiftL` (n `shiftL` 3)) .|. y
                        getNByteInteger _ = return 0
                      |]


-- |Complete a possibly-incomplete 'RandomSource' implementation.  It is 
-- recommended that this macro be used even if the implementation is currently
-- complete, as the 'RandomSource' class may be extended at any time.
--
-- To use 'randomSource', just wrap your instance declaration as follows (and
-- enable the TemplateHaskell, MultiParamTypeClasses and GADTs language
-- extensions, as well as any others required by your instances, such as
-- FlexibleInstances):
--
-- > $(randomSource [d|
-- >         instance RandomSource FooM Bar where
-- >             {- at least one RandomSource function... -}
-- >     |])
randomSource :: Q [Dec] -> Q [Dec]
randomSource = FD.withDefaults (defaults RandomSource)

-- |Complete a possibly-incomplete 'MonadRandom' implementation.  It is 
-- recommended that this macro be used even if the implementation is currently
-- complete, as the 'MonadRandom' class may be extended at any time.
--
-- To use 'monadRandom', just wrap your instance declaration as follows (and
-- enable the TemplateHaskell and GADTs language extensions):
--
-- > $(monadRandom [d|
-- >         instance MonadRandom FooM where
-- >             getRandomDouble = return pi
-- >             getRandomWord16 = return 4
-- >             {- etc... -}
-- >     |])
monadRandom :: Q [Dec] -> Q [Dec]
monadRandom = FD.withDefaults (defaults MonadRandom)

-- -- This is nice in theory, but under GHC 7 it never typechecks; without generalizing the let-bound
-- -- functions, it gets absurd errors like "cannot match 'm Int' with 'IO t'".  Probably need
-- -- to mechanically specialize the supplied signature to create a signature for every other
-- -- let-bound function.
-- primFunction :: Q Type -> Q [Dec] -> ExpQ
-- primFunction getPrimType decsQ = do
--     getPrimSig <- sigD (mkName (methodName Generic GetPrim)) getPrimType
--     decs <- decsQ >>= FD.implementDefaults (defaults Generic)
--     f <- getPrim
--     return (LetE (getPrimSig : decs) f)