{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds          #-}
{-# LANGUAGE DeriveFunctor      #-}
{-# LANGUAGE GADTs              #-}
{-# LANGUAGE KindSignatures     #-}
{-# LANGUAGE RecordWildCards    #-}
{-# LANGUAGE StandaloneDeriving #-}
module Overloaded.Plugin.Categories where

import Data.Bifunctor       (Bifunctor (..))
import Data.Bifunctor.Assoc (Assoc (..))
import Data.Kind            (Type)
import Data.Map.Strict      (Map)
import Data.Void            (Void, absurd)

import qualified Data.Generics   as SYB
import qualified Data.Map.Strict as Map
import qualified GHC.Compat.All  as GHC
import           GHC.Compat.Expr
import qualified GhcPlugins      as Plugins

import Overloaded.Plugin.Diagnostics
import Overloaded.Plugin.Names
import Overloaded.Plugin.Rewrite

-------------------------------------------------------------------------------
-- Rewriter
-------------------------------------------------------------------------------

transformCategories
    :: Names
    -> LHsExpr GhcRn
    -> Rewrite (LHsExpr GhcRn)
transformCategories names (L _l (HsProc _ pat (L _ (HsCmdTop _ cmd)))) = do
    SomePattern pat' <- parsePat pat
    kont <- parseCmd names (patternMap pat') cmd
    let proc :: Proc (LHsExpr GhcRn) Void
        proc = Proc (nameToString <$> pat') kont

        morp :: Morphism (LHsExpr GhcRn)
        morp = desugar absurd proc

        expr :: LHsExpr GhcRn
        expr = generate names morp

    -- _ <- Error $ \dflags -> putError dflags _l $ GHC.text "DEBUG"
    --     GHC.$$ GHC.text (show $ first (GHC.showPpr dflags) proc)
    --     GHC.$$ GHC.text (show $ fmap  (GHC.showPpr dflags) morp)
    --     GHC.$$ GHC.ppr expr

    return expr

transformCategories _ _ = NoRewrite

-------------------------------------------------------------------------------
-- Parsing
-------------------------------------------------------------------------------

parsePat :: LPat GhcRn -> Rewrite (SomePattern GHC.Name)
#if MIN_VERSION_ghc(8,8,0) && !MIN_VERSION_ghc(8,10,1)
parsePat (XPat (L l pat)) = parsePat' l pat
parsePat pat              = parsePat' noSrcSpan pat
#else
parsePat (L l pat) = parsePat' l pat
#endif

parsePat' :: SrcSpan -> Pat GhcRn -> Rewrite (SomePattern GHC.Name)
parsePat' _ WildPat {} =
    return $ SomePattern PatternWild
parsePat' _ (VarPat _ (L _ name)) =
    return $ SomePattern $ PatternVar name
parsePat' _ (TuplePat _ [x, y] Plugins.Boxed) = do
    SomePattern x' <- parsePat x
    SomePattern y' <- parsePat y
    return $ SomePattern $ PatternTuple x' y'
parsePat' l TuplePat {} = Error $ \dflags ->
    putError dflags l $ GHC.text "Overloaded:Categories: only boxed tuples of arity 2 are supported"
parsePat' l pat = Error $ \dflags ->
    putError dflags l $ GHC.text "Cannot parse pattern for Overloaded:Categories"
        GHC.$$ GHC.ppr pat
        GHC.$$ GHC.text (SYB.gshow pat)

parseExpr
    :: Names
    -> Map GHC.Name b
    -> LHsExpr GhcRn
    -> Rewrite (Expression (Var b a))
parseExpr names ctx (L _ (HsPar _ expr)) =
    parseExpr names ctx expr
parseExpr _     ctx (L _ (HsVar _ (L l name)))
    | name == GHC.getName (GHC.tupleDataCon GHC.Boxed 0)
    = return ExpressionUnit
    | otherwise
    = case Map.lookup name ctx of
        Nothing -> Error $ \dflags ->
            putError dflags l $ GHC.text "Overloaded:Categories: Unbound variable" GHC.<+> GHC.ppr name
        Just b -> return $ ExpressionVar (B b)
parseExpr names ctx (L _ (ExplicitTuple _ [L _ (Present _ x), L _ (Present _ y)] Plugins.Boxed)) = do
    x' <- parseExpr names ctx x
    y' <- parseExpr names ctx y
    return (ExpressionTuple x' y')
parseExpr _     _ (L l ExplicitTuple {}) = Error $ \dflags ->
    putError dflags l $ GHC.text "Overloaded:Categories: only boxed tuples of arity 2 are supported"
parseExpr names ctx (L _ (HsApp _ (L _ (HsVar _ (L l fName))) x))
    | fName == conLeftName names = do
        x' <- parseExpr names ctx x
        return (ExpressionLeft x')
    | fName == conRightName names = do
        x' <- parseExpr names ctx x
        return (ExpressionRight x')
    | otherwise = Error $ \dflags ->
        putError dflags l $ GHC.text "Overloaded:Categories: only applications of Left and Right are supported"
parseExpr _     _   (L l expr) = Error $ \dflags ->
    putError dflags l $ GHC.text "Cannot parse -< right-hand-side for Overloaded:Categories"
        GHC.$$ GHC.ppr expr
        GHC.$$ GHC.text (SYB.gshow expr)

parseCmd
    :: Names
    -> Map GHC.Name b
    -> LHsCmd GhcRn
    -> Rewrite (Continuation (LHsExpr GhcRn) (Var b a))
parseCmd names ctx (L _ (HsCmdDo _ (L l stmts))) =
    parseStmts names ctx l stmts
parseCmd names ctx (L _ (HsCmdArrApp _ morp expr HsFirstOrderApp _)) = do
    morp' <- parseTerm names morp
    expr' <- parseExpr names ctx expr
    return $ Last (Right morp') expr'
parseCmd names ctx (L _ (HsCmdArrApp _ morp expr HsHigherOrderApp _)) = do
    morp' <- parseExpr names ctx morp
    expr' <- parseExpr names ctx expr
    return $ Last (Left morp') expr'
parseCmd names ctx (L _ (HsCmdCase _ expr matchGroup)) =
    case mg_alts matchGroup of
#if MIN_VERSION_ghc(8,8,0) && !MIN_VERSION_ghc(8,10,1)
        L _ [ L _ Match { m_pats = [XPat (L _ (ConPatIn (L _ acon) aargs))], m_grhss = abody' }
            , L _ Match { m_pats = [XPat (L _ (ConPatIn (L _ bcon) bargs))], m_grhss = bbody' }
            ]
#else
        L _ [ L _ Match { m_pats = [L _ (ConPatIn (L _ acon) aargs)], m_grhss = abody' }
            , L _ Match { m_pats = [L _ (ConPatIn (L _ bcon) bargs)], m_grhss = bbody' }
            ]
#endif
            -- Left and Right, or Right and Left
            |  [acon,bcon] == [conLeftName names,conRightName names]
            || [acon,bcon] == [conRightName names,conLeftName names]
            -- only one argument
            , [aarg] <- hsConPatArgs aargs
            , [barg] <- hsConPatArgs bargs
            -- and simple bodies
            , Just abody <- simpleGRHSs abody'
            , Just bbody <- simpleGRHSs bbody'

            -> do
                expr' <- parseExpr names ctx expr

                SomePattern apat <- parsePat aarg
                SomePattern bpat <- parsePat barg

                acont <- parseCmd names (combineMaps ctx apat) abody
                bcont <- parseCmd names (combineMaps ctx bpat) bbody

                -- Error $ \dflags -> putError dflags noSrcSpan $ GHC.text "TODO"
                --     GHC.$$ GHC.ppr acon
                --     GHC.$$ GHC.ppr bcon
                --     GHC.$$ GHC.ppr aarg
                --     GHC.$$ GHC.ppr barg
                --     GHC.$$ GHC.ppr abody
                --     GHC.$$ GHC.ppr bbody

                return $ caseCont expr' apat bpat (second assoc acont) (second assoc bcont)

        L l _ -> Error $ \dflags ->
            putError dflags l $ GHC.text "Overloaded:Categories only case of Left and Right are supported"
                GHC.$$ GHC.text (SYB.gshow (mg_alts matchGroup))
parseCmd _     _   (L l cmd) =
    Error $ \dflags ->
        putError dflags l $ GHC.text "Unsupported command in proc for Overloaded:Categories"
            GHC.$$ GHC.ppr cmd
            GHC.$$ GHC.text (SYB.gshow cmd)

simpleGRHSs :: GRHSs GhcRn body -> Maybe body
simpleGRHSs (GRHSs _ [L _ (GRHS _ [] body)] (L _ (EmptyLocalBinds _))) = Just body
simpleGRHSs _ = Nothing

parseTerm
    :: Names
    -> LHsExpr GhcRn
    -> Rewrite (Morphism (LHsExpr GhcRn))
parseTerm Names {catNames = CatNames {..}} (L _ (HsVar _ (L _ name)))
    | name == catIdentityName = return MId
parseTerm _ term = return (MTerm term)

parseStmts
    :: Names
    -> Map GHC.Name b
    -> SrcSpan
    -> [CmdLStmt GhcRn]
    -> Rewrite (Continuation (LHsExpr GhcRn) (Var b a))
parseStmts names ctx _ (L l (BindStmt _ pat body _ _) : next) = do
    SomePattern pat' <- parsePat pat
    cont1 <- parseCmd names ctx body
    cont2 <- parseStmts names (combineMaps ctx pat') l next
    return $ compCont (nameToString <$> pat') cont1 (second assoc cont2)
parseStmts names ctx _ [L _ (LastStmt _ body _ _)] =
    parseCmd names ctx body
parseStmts _     _   _ (L l stmt : _) =
    Error $ \dflags ->
        putError dflags l $ GHC.text "Unsupported statement in proc-do for Overloaded:Categories"
            GHC.$$ GHC.ppr stmt
            GHC.$$ GHC.text (SYB.gshow stmt)
parseStmts _     _   l [] =
    Error $ \dflags ->
        putError dflags l $ GHC.text "Empty do block in proc"

-------------------------------------------------------------------------------
-- Variables
-------------------------------------------------------------------------------

data Var b a
    = B b
    | F a
  deriving (Show, Functor)

instance Bifunctor Var where
    bimap f _ (B b) = B (f b)
    bimap _ g (F a) = F (g a)

instance Assoc Var where
    assoc (B (B x)) = B x
    assoc (B (F y)) = F (B y)
    assoc (F z)     = F (F z)

    unassoc (B x)     = B (B x)
    unassoc (F (B y)) = B (F y)
    unassoc (F (F z)) = F z

unvar :: (b -> c) -> (a -> c) -> Var b a -> c
unvar f _ (B b) = f b
unvar _ g (F a) = g a

-------------------------------------------------------------------------------
-- A subset of Arrow notation syntax we support.
-------------------------------------------------------------------------------

-- | Proc syntax
data Proc term a where
    Proc :: Pattern sh String -> Continuation term (Var (Index sh) a) -> Proc term a

deriving instance (Show a, Show term) => Show (Proc term a)

instance Bifunctor Proc where
    bimap f g (Proc p c) = Proc p (bimap f (fmap g) c)

data Continuation term a where
    Last :: Either (Expression a) (Morphism term) -> Expression a -> Continuation term a
      -- ^ term -< y
    Edge
        :: Pattern sh String
        -> Either (Expression a) (Morphism term)
        -> Expression a
        -> Continuation term (Var (Index sh) a)
        -> Continuation term a
      -- ^ x <- term -< y

    Split
        :: Expression a
        -> Pattern shA String
        -> Pattern shB String
        -> Continuation term (Var (Index shA) a)
        -> Continuation term (Var (Index shB) a)
        -> Continuation term a

deriving instance (Show a, Show term) => Show (Continuation term a)

instance Bifunctor Continuation where
    bimap f g (Last term e)         = Last (bimap (fmap g) (fmap f) term) (fmap g e)
    bimap f g (Edge p term e c)     = Edge p (bimap (fmap g) (fmap f) term) (fmap g e) (bimap f (fmap g) c)
    bimap f g (Split e pa pb ca cb) = Split (fmap g e) pa pb
        (bimap f (fmap g) ca)
        (bimap f (fmap g) cb)

instance Functor (Continuation term) where
    fmap = second

compCont
    :: Pattern sh String
    -> Continuation term a
    -> Continuation term (Var (Index sh) a)
    -> Continuation term a
compCont pat (Last term expr) c
    = Edge pat term expr c
compCont pat (Edge pat' term expr c') c
    = Edge pat' term expr
    $ compCont pat c' (weaken1 c)
compCont pat (Split expr patA patB contA contB) c
    = Split expr patA patB
        (compCont pat contA (weaken1 c))
        (compCont pat contB (weaken1 c))

weaken1 :: Functor f => f (Var a b) -> f (Var a (Var c b))
weaken1 = fmap (unvar B (F . F))

caseCont
    :: Expression a
    -> Pattern shA Plugins.Name
    -> Pattern shB Plugins.Name
    -> Continuation (LHsExpr GhcRn) (Var (Index shA) a)
    -> Continuation (LHsExpr GhcRn) (Var (Index shB) a)
    -> Continuation (LHsExpr GhcRn) a
caseCont e patA patB =
    Split e (fmap nameToString patA) (fmap nameToString patB)

-------------------------------------------------------------------------------
-- Patterns
-------------------------------------------------------------------------------

data Shape = One | Two Shape Shape

data Pattern :: Shape -> Type -> Type where
    PatternVar   :: a -> Pattern 'One a
    PatternWild  :: Pattern 'One a
    PatternTuple :: Pattern l a -> Pattern r a -> Pattern ('Two l r) a

deriving instance Show a => Show (Pattern sh a)
deriving instance Functor (Pattern sh)

data SomePattern :: Type -> Type where
    SomePattern :: Pattern sh a -> SomePattern a

data Index :: Shape -> Type where
    Here :: Index 'One
    InL  :: Index x -> Index ('Two x y)
    InR  :: Index y -> Index ('Two x y)

deriving instance Show (Index sh)

patternMap :: Ord a => Pattern sh a -> Map a (Index sh)
patternMap (PatternVar x)     = Map.singleton x Here
patternMap PatternWild        = Map.empty
patternMap (PatternTuple l r) = Map.union
    (Map.map InL (patternMap l))
    (Map.map InR (patternMap r))

combineMaps
    :: Map Plugins.Name b
    -> Pattern sh Plugins.Name
    -> Map Plugins.Name (Var (Index sh) b)
combineMaps m pat = Map.union (Map.map F m) (Map.map B (patternMap pat))

-------------------------------------------------------------------------------
-- Expressions
-------------------------------------------------------------------------------

data Expression a
    = ExpressionVar a
    | ExpressionUnit
    | ExpressionTuple (Expression a) (Expression a)
    | ExpressionLeft (Expression a)
    | ExpressionRight (Expression a)
  deriving (Show, Functor)

-------------------------------------------------------------------------------
-- Skeleton of syntax we desugar arrow notation to
-------------------------------------------------------------------------------

-- | Note: morpisms don't have variables!
data Morphism term
    = MId
    | MCompose (Morphism term) (Morphism term)
    | MProduct (Morphism term) (Morphism term)
    | MTerminal
    | MProj1
    | MProj2
    | MInL
    | MInR
    | MCase (Morphism term) (Morphism term)
    | MDistr
    | MEval
    | MTerm term
  deriving (Show, Functor)

instance Semigroup (Morphism term) where
    MTerminal <> _            = MTerminal
    MId       <> m            = m
    m         <> MId          = m
    MProj1    <> MProduct f _ = f
    MProj2    <> MProduct _ g = g
    MCase f _ <> MInL         = f
    MCase _ g <> MInR         = g
    f         <> g            = MCompose f g

instance Monoid (Morphism term) where
    mempty  = MId
    mappend = (<>)

-------------------------------------------------------------------------------
-- Desugaring
-------------------------------------------------------------------------------

desugar :: (a -> Morphism term) -> Proc term a -> Morphism term
desugar ctx (Proc p k) = desugarC (unvar (desugarP p) ctx) k

desugarC :: (a -> Morphism term) -> Continuation term a -> Morphism term
desugarC ctx (Last (Right term) e) = mconcat
    [ term
    , desugarE ctx e
    ]
desugarC ctx (Last (Left f) e) = mconcat
    [ MEval
    , MProduct (desugarE ctx f) (desugarE ctx e)
    ]
desugarC ctx (Edge p (Right term) e k) = mconcat
    [ desugarC (unvar (\x -> desugarP p x <> MProj1) (\y -> ctx y <> MProj2)) k
    , MProduct
        (term <> desugarE ctx e)
        MId
    ]
desugarC ctx (Edge p (Left f) e k) = mconcat
    [ desugarC (unvar (\x -> desugarP p x <> MEval <> MProj1) (\y -> ctx y <> MProj2)) k
    , MProduct
        (MProduct (desugarE ctx f) (desugarE ctx e))
        MId
    ]
desugarC ctx (Split e pa pb ka kb) = mconcat
    [ MCase
        (desugarC (unvar (\x -> desugarP pa x <> MProj1) (\y -> ctx y <> MProj2)) ka)
        (desugarC (unvar (\x -> desugarP pb x <> MProj1) (\y -> ctx y <> MProj2)) kb)
    , MDistr
    , MProduct
        (desugarE ctx e)
        MId
    ]

desugarP :: Pattern sh name -> Index sh -> Morphism term
desugarP (PatternVar _)     Here    = MId
desugarP PatternWild        Here    = MId
desugarP (PatternTuple l _) (InL i) = desugarP l i <> MProj1
desugarP (PatternTuple _ r) (InR i) = desugarP r i <> MProj2

desugarE :: (a -> Morphism term) -> Expression a -> Morphism term
desugarE ctx = go where
    go ExpressionUnit        = MTerminal
    go (ExpressionVar a)     = ctx a
    go (ExpressionTuple x y) = MProduct (go x) (go y)
    go (ExpressionLeft x)    = MInL <> go x
    go (ExpressionRight y)   = MInR <> go y

-------------------------------------------------------------------------------
-- Generating
-------------------------------------------------------------------------------

generate :: Names -> Morphism (LHsExpr GhcRn) -> LHsExpr GhcRn
generate Names {catNames = CatNames {..}} = go where
    go MId            = hsVar noSrcSpan catIdentityName
    go (MCompose f g) = hsPar noSrcSpan $ hsOpApp noSrcSpan (go f) (hsVar noSrcSpan catComposeName) (go g)
    go (MTerm term)   = term
    go MTerminal      = hsVar noSrcSpan catTerminalName
    go MProj1         = hsVar noSrcSpan catProj1Name
    go MProj2         = hsVar noSrcSpan catProj2Name
    go (MProduct f g) = hsPar noSrcSpan $ hsApps noSrcSpan (hsVar noSrcSpan catFanoutName) [go f, go g]
    go MInL           = hsVar noSrcSpan catInlName
    go MInR           = hsVar noSrcSpan catInrName
    go MDistr         = hsVar noSrcSpan catDistrName
    go MEval          = hsVar noSrcSpan catEvalName
    go (MCase f g)    = hsPar noSrcSpan $ hsApps noSrcSpan (hsVar noSrcSpan catFaninName) [go f, go g]