{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE StandaloneDeriving #-}
module Overloaded.Plugin.Categories where
import Data.Bifunctor (Bifunctor (..))
import Data.Bifunctor.Assoc (Assoc (..))
import Data.Kind (Type)
import Data.Map.Strict (Map)
import Data.Void (Void, absurd)
import qualified Data.Generics as SYB
import qualified Data.Map.Strict as Map
import qualified GHC.Compat.All as GHC
import GHC.Compat.Expr
import qualified GhcPlugins as Plugins
import Overloaded.Plugin.Diagnostics
import Overloaded.Plugin.Names
import Overloaded.Plugin.Rewrite
transformCategories
:: Names
-> LHsExpr GhcRn
-> Rewrite (LHsExpr GhcRn)
transformCategories names (L _l (HsProc _ pat (L _ (HsCmdTop _ cmd)))) = do
SomePattern pat' <- parsePat pat
kont <- parseCmd names (patternMap pat') cmd
let proc :: Proc (LHsExpr GhcRn) Void
proc = Proc (nameToString <$> pat') kont
morp :: Morphism (LHsExpr GhcRn)
morp = desugar absurd proc
expr :: LHsExpr GhcRn
expr = generate names morp
return expr
transformCategories _ _ = NoRewrite
parsePat :: LPat GhcRn -> Rewrite (SomePattern GHC.Name)
#if MIN_VERSION_ghc(8,8,0) && !MIN_VERSION_ghc(8,10,1)
parsePat (XPat (L l pat)) = parsePat' l pat
parsePat pat = parsePat' noSrcSpan pat
#else
parsePat (L l pat) = parsePat' l pat
#endif
parsePat' :: SrcSpan -> Pat GhcRn -> Rewrite (SomePattern GHC.Name)
parsePat' _ WildPat {} =
return $ SomePattern PatternWild
parsePat' _ (VarPat _ (L _ name)) =
return $ SomePattern $ PatternVar name
parsePat' _ (TuplePat _ [x, y] Plugins.Boxed) = do
SomePattern x' <- parsePat x
SomePattern y' <- parsePat y
return $ SomePattern $ PatternTuple x' y'
parsePat' l TuplePat {} = Error $ \dflags ->
putError dflags l $ GHC.text "Overloaded:Categories: only boxed tuples of arity 2 are supported"
parsePat' l pat = Error $ \dflags ->
putError dflags l $ GHC.text "Cannot parse pattern for Overloaded:Categories"
GHC.$$ GHC.ppr pat
GHC.$$ GHC.text (SYB.gshow pat)
parseExpr
:: Names
-> Map GHC.Name b
-> LHsExpr GhcRn
-> Rewrite (Expression (Var b a))
parseExpr names ctx (L _ (HsPar _ expr)) =
parseExpr names ctx expr
parseExpr _ ctx (L _ (HsVar _ (L l name)))
| name == GHC.getName (GHC.tupleDataCon GHC.Boxed 0)
= return ExpressionUnit
| otherwise
= case Map.lookup name ctx of
Nothing -> Error $ \dflags ->
putError dflags l $ GHC.text "Overloaded:Categories: Unbound variable" GHC.<+> GHC.ppr name
Just b -> return $ ExpressionVar (B b)
parseExpr names ctx (L _ (ExplicitTuple _ [L _ (Present _ x), L _ (Present _ y)] Plugins.Boxed)) = do
x' <- parseExpr names ctx x
y' <- parseExpr names ctx y
return (ExpressionTuple x' y')
parseExpr _ _ (L l ExplicitTuple {}) = Error $ \dflags ->
putError dflags l $ GHC.text "Overloaded:Categories: only boxed tuples of arity 2 are supported"
parseExpr names ctx (L _ (HsApp _ (L _ (HsVar _ (L l fName))) x))
| fName == conLeftName names = do
x' <- parseExpr names ctx x
return (ExpressionLeft x')
| fName == conRightName names = do
x' <- parseExpr names ctx x
return (ExpressionRight x')
| otherwise = Error $ \dflags ->
putError dflags l $ GHC.text "Overloaded:Categories: only applications of Left and Right are supported"
parseExpr _ _ (L l expr) = Error $ \dflags ->
putError dflags l $ GHC.text "Cannot parse -< right-hand-side for Overloaded:Categories"
GHC.$$ GHC.ppr expr
GHC.$$ GHC.text (SYB.gshow expr)
parseCmd
:: Names
-> Map GHC.Name b
-> LHsCmd GhcRn
-> Rewrite (Continuation (LHsExpr GhcRn) (Var b a))
parseCmd names ctx (L _ (HsCmdDo _ (L l stmts))) =
parseStmts names ctx l stmts
parseCmd names ctx (L _ (HsCmdArrApp _ morp expr HsFirstOrderApp _)) = do
morp' <- parseTerm names morp
expr' <- parseExpr names ctx expr
return $ Last (Right morp') expr'
parseCmd names ctx (L _ (HsCmdArrApp _ morp expr HsHigherOrderApp _)) = do
morp' <- parseExpr names ctx morp
expr' <- parseExpr names ctx expr
return $ Last (Left morp') expr'
parseCmd names ctx (L _ (HsCmdCase _ expr matchGroup)) =
case mg_alts matchGroup of
#if MIN_VERSION_ghc(8,8,0) && !MIN_VERSION_ghc(8,10,1)
L _ [ L _ Match { m_pats = [XPat (L _ (ConPatIn (L _ acon) aargs))], m_grhss = abody' }
, L _ Match { m_pats = [XPat (L _ (ConPatIn (L _ bcon) bargs))], m_grhss = bbody' }
]
#else
L _ [ L _ Match { m_pats = [L _ (ConPatIn (L _ acon) aargs)], m_grhss = abody' }
, L _ Match { m_pats = [L _ (ConPatIn (L _ bcon) bargs)], m_grhss = bbody' }
]
#endif
| [acon,bcon] == [conLeftName names,conRightName names]
|| [acon,bcon] == [conRightName names,conLeftName names]
, [aarg] <- hsConPatArgs aargs
, [barg] <- hsConPatArgs bargs
, Just abody <- simpleGRHSs abody'
, Just bbody <- simpleGRHSs bbody'
-> do
expr' <- parseExpr names ctx expr
SomePattern apat <- parsePat aarg
SomePattern bpat <- parsePat barg
acont <- parseCmd names (combineMaps ctx apat) abody
bcont <- parseCmd names (combineMaps ctx bpat) bbody
return $ caseCont expr' apat bpat (second assoc acont) (second assoc bcont)
L l _ -> Error $ \dflags ->
putError dflags l $ GHC.text "Overloaded:Categories only case of Left and Right are supported"
GHC.$$ GHC.text (SYB.gshow (mg_alts matchGroup))
parseCmd _ _ (L l cmd) =
Error $ \dflags ->
putError dflags l $ GHC.text "Unsupported command in proc for Overloaded:Categories"
GHC.$$ GHC.ppr cmd
GHC.$$ GHC.text (SYB.gshow cmd)
simpleGRHSs :: GRHSs GhcRn body -> Maybe body
simpleGRHSs (GRHSs _ [L _ (GRHS _ [] body)] (L _ (EmptyLocalBinds _))) = Just body
simpleGRHSs _ = Nothing
parseTerm
:: Names
-> LHsExpr GhcRn
-> Rewrite (Morphism (LHsExpr GhcRn))
parseTerm Names {catNames = CatNames {..}} (L _ (HsVar _ (L _ name)))
| name == catIdentityName = return MId
parseTerm _ term = return (MTerm term)
parseStmts
:: Names
-> Map GHC.Name b
-> SrcSpan
-> [CmdLStmt GhcRn]
-> Rewrite (Continuation (LHsExpr GhcRn) (Var b a))
parseStmts names ctx _ (L l (BindStmt _ pat body _ _) : next) = do
SomePattern pat' <- parsePat pat
cont1 <- parseCmd names ctx body
cont2 <- parseStmts names (combineMaps ctx pat') l next
return $ compCont (nameToString <$> pat') cont1 (second assoc cont2)
parseStmts names ctx _ [L _ (LastStmt _ body _ _)] =
parseCmd names ctx body
parseStmts _ _ _ (L l stmt : _) =
Error $ \dflags ->
putError dflags l $ GHC.text "Unsupported statement in proc-do for Overloaded:Categories"
GHC.$$ GHC.ppr stmt
GHC.$$ GHC.text (SYB.gshow stmt)
parseStmts _ _ l [] =
Error $ \dflags ->
putError dflags l $ GHC.text "Empty do block in proc"
data Var b a
= B b
| F a
deriving (Show, Functor)
instance Bifunctor Var where
bimap f _ (B b) = B (f b)
bimap _ g (F a) = F (g a)
instance Assoc Var where
assoc (B (B x)) = B x
assoc (B (F y)) = F (B y)
assoc (F z) = F (F z)
unassoc (B x) = B (B x)
unassoc (F (B y)) = B (F y)
unassoc (F (F z)) = F z
unvar :: (b -> c) -> (a -> c) -> Var b a -> c
unvar f _ (B b) = f b
unvar _ g (F a) = g a
data Proc term a where
Proc :: Pattern sh String -> Continuation term (Var (Index sh) a) -> Proc term a
deriving instance (Show a, Show term) => Show (Proc term a)
instance Bifunctor Proc where
bimap f g (Proc p c) = Proc p (bimap f (fmap g) c)
data Continuation term a where
Last :: Either (Expression a) (Morphism term) -> Expression a -> Continuation term a
Edge
:: Pattern sh String
-> Either (Expression a) (Morphism term)
-> Expression a
-> Continuation term (Var (Index sh) a)
-> Continuation term a
Split
:: Expression a
-> Pattern shA String
-> Pattern shB String
-> Continuation term (Var (Index shA) a)
-> Continuation term (Var (Index shB) a)
-> Continuation term a
deriving instance (Show a, Show term) => Show (Continuation term a)
instance Bifunctor Continuation where
bimap f g (Last term e) = Last (bimap (fmap g) (fmap f) term) (fmap g e)
bimap f g (Edge p term e c) = Edge p (bimap (fmap g) (fmap f) term) (fmap g e) (bimap f (fmap g) c)
bimap f g (Split e pa pb ca cb) = Split (fmap g e) pa pb
(bimap f (fmap g) ca)
(bimap f (fmap g) cb)
instance Functor (Continuation term) where
fmap = second
compCont
:: Pattern sh String
-> Continuation term a
-> Continuation term (Var (Index sh) a)
-> Continuation term a
compCont pat (Last term expr) c
= Edge pat term expr c
compCont pat (Edge pat' term expr c') c
= Edge pat' term expr
$ compCont pat c' (weaken1 c)
compCont pat (Split expr patA patB contA contB) c
= Split expr patA patB
(compCont pat contA (weaken1 c))
(compCont pat contB (weaken1 c))
weaken1 :: Functor f => f (Var a b) -> f (Var a (Var c b))
weaken1 = fmap (unvar B (F . F))
caseCont
:: Expression a
-> Pattern shA Plugins.Name
-> Pattern shB Plugins.Name
-> Continuation (LHsExpr GhcRn) (Var (Index shA) a)
-> Continuation (LHsExpr GhcRn) (Var (Index shB) a)
-> Continuation (LHsExpr GhcRn) a
caseCont e patA patB =
Split e (fmap nameToString patA) (fmap nameToString patB)
data Shape = One | Two Shape Shape
data Pattern :: Shape -> Type -> Type where
PatternVar :: a -> Pattern 'One a
PatternWild :: Pattern 'One a
PatternTuple :: Pattern l a -> Pattern r a -> Pattern ('Two l r) a
deriving instance Show a => Show (Pattern sh a)
deriving instance Functor (Pattern sh)
data SomePattern :: Type -> Type where
SomePattern :: Pattern sh a -> SomePattern a
data Index :: Shape -> Type where
Here :: Index 'One
InL :: Index x -> Index ('Two x y)
InR :: Index y -> Index ('Two x y)
deriving instance Show (Index sh)
patternMap :: Ord a => Pattern sh a -> Map a (Index sh)
patternMap (PatternVar x) = Map.singleton x Here
patternMap PatternWild = Map.empty
patternMap (PatternTuple l r) = Map.union
(Map.map InL (patternMap l))
(Map.map InR (patternMap r))
combineMaps
:: Map Plugins.Name b
-> Pattern sh Plugins.Name
-> Map Plugins.Name (Var (Index sh) b)
combineMaps m pat = Map.union (Map.map F m) (Map.map B (patternMap pat))
data Expression a
= ExpressionVar a
| ExpressionUnit
| ExpressionTuple (Expression a) (Expression a)
| ExpressionLeft (Expression a)
| ExpressionRight (Expression a)
deriving (Show, Functor)
data Morphism term
= MId
| MCompose (Morphism term) (Morphism term)
| MProduct (Morphism term) (Morphism term)
| MTerminal
| MProj1
| MProj2
| MInL
| MInR
| MCase (Morphism term) (Morphism term)
| MDistr
| MEval
| MTerm term
deriving (Show, Functor)
instance Semigroup (Morphism term) where
MTerminal <> _ = MTerminal
MId <> m = m
m <> MId = m
MProj1 <> MProduct f _ = f
MProj2 <> MProduct _ g = g
MCase f _ <> MInL = f
MCase _ g <> MInR = g
f <> g = MCompose f g
instance Monoid (Morphism term) where
mempty = MId
mappend = (<>)
desugar :: (a -> Morphism term) -> Proc term a -> Morphism term
desugar ctx (Proc p k) = desugarC (unvar (desugarP p) ctx) k
desugarC :: (a -> Morphism term) -> Continuation term a -> Morphism term
desugarC ctx (Last (Right term) e) = mconcat
[ term
, desugarE ctx e
]
desugarC ctx (Last (Left f) e) = mconcat
[ MEval
, MProduct (desugarE ctx f) (desugarE ctx e)
]
desugarC ctx (Edge p (Right term) e k) = mconcat
[ desugarC (unvar (\x -> desugarP p x <> MProj1) (\y -> ctx y <> MProj2)) k
, MProduct
(term <> desugarE ctx e)
MId
]
desugarC ctx (Edge p (Left f) e k) = mconcat
[ desugarC (unvar (\x -> desugarP p x <> MEval <> MProj1) (\y -> ctx y <> MProj2)) k
, MProduct
(MProduct (desugarE ctx f) (desugarE ctx e))
MId
]
desugarC ctx (Split e pa pb ka kb) = mconcat
[ MCase
(desugarC (unvar (\x -> desugarP pa x <> MProj1) (\y -> ctx y <> MProj2)) ka)
(desugarC (unvar (\x -> desugarP pb x <> MProj1) (\y -> ctx y <> MProj2)) kb)
, MDistr
, MProduct
(desugarE ctx e)
MId
]
desugarP :: Pattern sh name -> Index sh -> Morphism term
desugarP (PatternVar _) Here = MId
desugarP PatternWild Here = MId
desugarP (PatternTuple l _) (InL i) = desugarP l i <> MProj1
desugarP (PatternTuple _ r) (InR i) = desugarP r i <> MProj2
desugarE :: (a -> Morphism term) -> Expression a -> Morphism term
desugarE ctx = go where
go ExpressionUnit = MTerminal
go (ExpressionVar a) = ctx a
go (ExpressionTuple x y) = MProduct (go x) (go y)
go (ExpressionLeft x) = MInL <> go x
go (ExpressionRight y) = MInR <> go y
generate :: Names -> Morphism (LHsExpr GhcRn) -> LHsExpr GhcRn
generate Names {catNames = CatNames {..}} = go where
go MId = hsVar noSrcSpan catIdentityName
go (MCompose f g) = hsPar noSrcSpan $ hsOpApp noSrcSpan (go f) (hsVar noSrcSpan catComposeName) (go g)
go (MTerm term) = term
go MTerminal = hsVar noSrcSpan catTerminalName
go MProj1 = hsVar noSrcSpan catProj1Name
go MProj2 = hsVar noSrcSpan catProj2Name
go (MProduct f g) = hsPar noSrcSpan $ hsApps noSrcSpan (hsVar noSrcSpan catFanoutName) [go f, go g]
go MInL = hsVar noSrcSpan catInlName
go MInR = hsVar noSrcSpan catInrName
go MDistr = hsVar noSrcSpan catDistrName
go MEval = hsVar noSrcSpan catEvalName
go (MCase f g) = hsPar noSrcSpan $ hsApps noSrcSpan (hsVar noSrcSpan catFaninName) [go f, go g]