{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Binding.Hobbits.NuMatching (
NuMatching(..), mkNuMatching,
MbTypeRepr(), isoMbTypeRepr, unsafeMbTypeRepr,
NuMatchingAny1(..)
) where
import Data.Vector (Vector)
import qualified Data.Vector as Vector
import Language.Haskell.TH hiding (Name, Type(..))
import qualified Language.Haskell.TH as TH
import Control.Monad.State
import Numeric.Natural
import Data.Functor.Constant
import Data.Kind as DK
import Data.Word
import Data.Proxy
import Data.Type.Equality
import Data.Type.RList hiding (map)
import Data.Binding.Hobbits.Internal.Name
import Data.Binding.Hobbits.Internal.Mb
import Data.Binding.Hobbits.Internal.Closed
mapNames :: NuMatching a => NameRefresher -> a -> a
mapNames = mapNamesPf nuMatchingProof
matchDataDecl :: Dec -> Maybe (Cxt, TH.Name, [TyVarBndr], [Con])
matchDataDecl (DataD cxt name tyvars _ constrs _) =
Just (cxt, name, tyvars, constrs)
matchDataDecl (NewtypeD cxt name tyvars _ constr _) =
Just (cxt, name, tyvars, [constr])
matchDataDecl _ = Nothing
mkInstanceD :: Cxt -> TH.Type -> [Dec] -> Dec
mkInstanceD = InstanceD Nothing
class NuMatching a where
nuMatchingProof :: MbTypeRepr a
isoMbTypeRepr :: NuMatching b => (a -> b) -> (b -> a) -> MbTypeRepr a
isoMbTypeRepr f_to f_from =
MbTypeReprData $ MkMbTypeReprData $ \refresher a ->
f_from $ mapNames refresher (f_to a)
unsafeMbTypeRepr :: MbTypeRepr a
unsafeMbTypeRepr = MbTypeReprData (MkMbTypeReprData $ (\_ -> id))
instance NuMatching (Name a) where
nuMatchingProof = MbTypeReprName
instance NuMatching (Closed a) where
nuMatchingProof = MbTypeReprData (MkMbTypeReprData $ (\refresher -> id))
instance (NuMatching a, NuMatching b) => NuMatching (a -> b) where
nuMatchingProof = MbTypeReprFun nuMatchingProof nuMatchingProof
instance NuMatching a => NuMatching (Mb ctx a) where
nuMatchingProof = MbTypeReprMb nuMatchingProof
instance NuMatching Bool where
nuMatchingProof = MbTypeReprData (MkMbTypeReprData $ (\_ -> id))
instance NuMatching Int where
nuMatchingProof = MbTypeReprData (MkMbTypeReprData $ (\_ -> id))
instance NuMatching Integer where
nuMatchingProof = MbTypeReprData (MkMbTypeReprData $ (\_ -> id))
instance NuMatching Char where
nuMatchingProof = MbTypeReprData (MkMbTypeReprData $ (\_ -> id))
instance NuMatching Natural where
nuMatchingProof = MbTypeReprData (MkMbTypeReprData $ (\_ -> id))
instance NuMatching Float where
nuMatchingProof = MbTypeReprData (MkMbTypeReprData $ (\_ -> id))
instance NuMatching Double where
nuMatchingProof = MbTypeReprData (MkMbTypeReprData $ (\_ -> id))
instance NuMatching Word where
nuMatchingProof = MbTypeReprData (MkMbTypeReprData $ (\_ -> id))
instance NuMatching Word8 where
nuMatchingProof = MbTypeReprData (MkMbTypeReprData $ (\_ -> id))
instance NuMatching Word16 where
nuMatchingProof = MbTypeReprData (MkMbTypeReprData $ (\_ -> id))
instance NuMatching Word32 where
nuMatchingProof = MbTypeReprData (MkMbTypeReprData $ (\_ -> id))
instance NuMatching Word64 where
nuMatchingProof = MbTypeReprData (MkMbTypeReprData $ (\_ -> id))
instance NuMatching () where
nuMatchingProof = MbTypeReprData (MkMbTypeReprData $ (\_ -> id))
instance (NuMatching a, NuMatching b) => NuMatching (a,b) where
nuMatchingProof = MbTypeReprData (MkMbTypeReprData $ \r (a,b) ->
(mapNames r a, mapNames r b))
instance (NuMatching a, NuMatching b, NuMatching c) => NuMatching (a,b,c) where
nuMatchingProof = MbTypeReprData (MkMbTypeReprData $ \r (a,b,c) ->
(mapNames r a, mapNames r b, mapNames r c))
instance (NuMatching a, NuMatching b,
NuMatching c, NuMatching d) => NuMatching (a,b,c,d) where
nuMatchingProof = MbTypeReprData (MkMbTypeReprData $ \r (a,b,c,d) ->
(mapNames r a, mapNames r b,
mapNames r c, mapNames r d))
instance NuMatching a => NuMatching [a] where
nuMatchingProof = MbTypeReprData (MkMbTypeReprData $ (\r -> map (mapNames r)))
instance NuMatching a => NuMatching (Vector a) where
nuMatchingProof =
MbTypeReprData (MkMbTypeReprData $ (\r -> Vector.map (mapNames r)))
class NuMatchingAny1 (f :: k -> Type) where
nuMatchingAny1Proof :: MbTypeRepr (f a)
instance {-# INCOHERENT #-} NuMatchingAny1 f => NuMatching (f a) where
nuMatchingProof = nuMatchingAny1Proof
instance NuMatchingAny1 Name where
nuMatchingAny1Proof = nuMatchingProof
instance NuMatchingAny1 ((:~:) a) where
nuMatchingAny1Proof = nuMatchingProof
instance NuMatching a => NuMatchingAny1 (Constant a) where
nuMatchingAny1Proof = nuMatchingProof
instance {-# OVERLAPPABLE #-} NuMatchingAny1 f => NuMatching (RAssign f ctx) where
nuMatchingProof = MbTypeReprData $ MkMbTypeReprData helper where
helper :: NuMatchingAny1 f => NameRefresher -> RAssign f args ->
RAssign f args
helper r MNil = MNil
helper r (elems :>: elem) = helper r elems :>: mapNames r elem
natsFrom i = i : natsFrom (i+1)
fst3 :: (a,b,c) -> a
fst3 (x,_,_) = x
snd3 :: (a,b,c) -> b
snd3 (_,y,_) = y
thd3 :: (a,b,c) -> c
thd3 (_,_,z) = z
type Names = (TH.Name, TH.Name, TH.Name)
mapNamesType a = [t| forall ctx. NameRefresher -> $a -> $a |]
mkNuMatching :: Q TH.Type -> Q [Dec]
mkNuMatching tQ =
do t <- tQ
(cxt, cType, tName, constrs, tyvars) <- getMbTypeReprInfoTop t
fName <- newName "f"
refrName <- newName "refresher"
clauses <- getClauses (tName, fName, refrName) constrs
mapNamesT <- mapNamesType (return cType)
return [mkInstanceD
cxt (TH.AppT (TH.ConT ''NuMatching) cType)
[ValD (VarP 'nuMatchingProof)
(NormalB
$ AppE (ConE 'MbTypeReprData)
$ AppE (ConE 'MkMbTypeReprData)
$ LetE [SigD fName
$ TH.ForallT (map PlainTV tyvars) cxt mapNamesT,
FunD fName clauses]
(VarE fName)) []]]
where
tyBndrToName (PlainTV n) = n
tyBndrToName (KindedTV n _) = n
getMbTypeReprInfoFail t extraMsg =
fail ("mkMbTypeRepr: " ++ show t
++ " is not a type constructor for a (G)ADT applied to zero or more distinct type variables" ++ extraMsg)
getMbTypeReprInfoTop t = getMbTypeReprInfo [] [] t t
getMbTypeReprInfo ctx tyvars topT (TH.ConT tName) =
do info <- reify tName
case info of
TyConI (matchDataDecl -> Just (_, _, tyvarsReq, constrs)) ->
success tyvarsReq constrs
_ -> getMbTypeReprInfoFail topT (": info for " ++ (show tName) ++ " = " ++ (show info))
where
success tyvarsReq constrs =
let tyvarsRet = if tyvars == [] && ctx == []
then map tyBndrToName tyvarsReq
else tyvars in
return (ctx,
foldl TH.AppT (TH.ConT tName) (map TH.VarT tyvars),
tName, constrs, tyvarsRet)
getMbTypeReprInfo ctx tyvars topT (TH.AppT f (TH.VarT argName)) =
if elem argName tyvars then
getMbTypeReprInfoFail topT ""
else
getMbTypeReprInfo ctx (argName:tyvars) topT f
getMbTypeReprInfo ctx tyvars topT (TH.ForallT _ ctx' t) =
getMbTypeReprInfo (ctx ++ ctx') tyvars topT t
getMbTypeReprInfo ctx tyvars topT t = getMbTypeReprInfoFail topT ""
getTCtor t = getTCtorHelper t t []
getTCtorHelper (TH.ConT tName) topT tyvars = Just (topT, tName, tyvars)
getTCtorHelper (TH.AppT t1 (TH.VarT var)) topT tyvars =
getTCtorHelper t1 topT (tyvars ++ [var])
getTCtorHelper (TH.SigT t1 _) topT tyvars = getTCtorHelper t1 topT tyvars
getTCtorHelper _ _ _ = Nothing
getClauses :: Names -> [Con] -> Q [Clause]
getClauses _ [] = return []
getClauses names (NormalC cName cTypes : constrs) =
do clause <-
getClauseHelper names (map snd cTypes) (natsFrom 0)
(\l -> ConP cName (map (VarP . fst3) l))
(\l -> foldl AppE (ConE cName) (map fst3 l))
clauses <- getClauses names constrs
return $ clause : clauses
getClauses names (RecC cName cVarTypes : constrs) =
do clause <-
getClauseHelper names (map thd3 cVarTypes) (map fst3 cVarTypes)
(\l -> RecP cName (map (\(var,_,field) -> (field, VarP var)) l))
(\l -> RecConE cName (map (\(exp,_,field) -> (field, exp)) l))
clauses <- getClauses names constrs
return $ clause : clauses
getClauses names (InfixC cType1 cName cType2 : constrs) =
do clause <-
getClauseHelper names (map snd [cType1, cType2]) (natsFrom 0)
(\l -> ConP cName (map (VarP . fst3) l))
(\l -> foldl AppE (ConE cName) (map fst3 l))
clauses <- getClauses names constrs
return $ clause : clauses
getClauses names (GadtC cNames cTypes _ : constrs) =
do clauses1 <-
forM cNames $ \cName ->
getClauseHelper names (map snd cTypes) (natsFrom 0)
(\l -> ConP cName (map (VarP . fst3) l))
(\l -> foldl AppE (ConE cName) (map fst3 l))
clauses2 <- getClauses names constrs
return (clauses1 ++ clauses2)
getClauses names (RecGadtC cNames cVarTypes _ : constrs) =
do clauses1 <-
forM cNames $ \cName ->
getClauseHelper names (map thd3 cVarTypes) (map fst3 cVarTypes)
(\l -> RecP cName (map (\(var,_,field) -> (field, VarP var)) l))
(\l -> RecConE cName (map (\(exp,_,field) -> (field, exp)) l))
clauses2 <- getClauses names constrs
return (clauses1 ++ clauses2)
getClauses names (ForallC _ _ constr : constrs) =
getClauses names (constr : constrs)
getClauseHelper :: Names -> [TH.Type] -> [a] ->
([(TH.Name,TH.Type,a)] -> Pat) ->
([(Exp,TH.Type,a)] -> Exp) ->
Q Clause
getClauseHelper names@(tName, fName, refrName) cTypes cData pFun eFun =
do varNames <- mapM (newName . ("x" ++) . show . fst)
$ zip (natsFrom 0) cTypes
let varsTypesData = zip3 varNames cTypes cData
let expsTypesData = map (mkExpTypeData names) varsTypesData
return $ Clause [(VarP refrName), (pFun varsTypesData)]
(NormalB $ eFun expsTypesData) []
mkExpTypeData :: Names -> (TH.Name,TH.Type,a) -> (Exp,TH.Type,a)
mkExpTypeData (tName, fName, refrName)
(varName, getTCtor -> Just (t, tName', _), cData)
| tName == tName' =
(foldl AppE (VarE fName)
[(VarE refrName), (VarE varName)],
t, cData)
mkExpTypeData (tName, fName, refrName) (varName, t, cData) =
(foldl AppE (VarE 'mapNames)
[(VarE refrName), (VarE varName)],
t, cData)
type CxtStateQ a = StateT Cxt Q a
mkMkMbTypeReprDataOld :: Q TH.Name -> Q Exp
mkMkMbTypeReprDataOld conNameQ =
do conName <- conNameQ
(cxt, name, tyvars, constrs) <- getMbTypeReprInfo conName
(clauses, reqCxt) <- runStateT (getClauses cxt name tyvars [] constrs) []
fname <- newName "f"
return (LetE
[SigD fname
(TH.ForallT tyvars reqCxt
$ foldl TH.AppT TH.ArrowT
[foldl TH.AppT (TH.ConT conName)
(map tyVarToType tyvars)]),
FunD fname clauses]
(VarE fname))
where
tyVarToType (PlainTV n) = TH.VarT n
tyVarToType (KindedTV n _) = TH.VarT n
getMbTypeReprInfo conName =
reify conName >>= \info ->
case info of
TyConI (matchDataDecl -> Just (cxt, name, tyvars, constrs)) ->
return (cxt, name, tyvars, constrs)
_ -> fail ("mkMkMbTypeReprData: " ++ show conName
++ " is not a (G)ADT")
getClauses :: Cxt -> TH.Name -> [TyVarBndr] -> [TyVarBndr] -> [Con] ->
CxtStateQ [Clause]
getClauses cxt name tyvars locTyvars [] = return []
getClauses cxt name tyvars locTyvars (NormalC cName cTypes : constrs) =
do clause <-
getClauseHelper cxt name tyvars locTyvars (map snd cTypes)
(natsFrom 0)
(\l -> ConP cName (map (VarP . fst3) l))
(\l -> foldl AppE (ConE cName) (map (VarE . fst3) l))
clauses <- getClauses cxt name tyvars locTyvars constrs
return (clause : clauses)
getClauses cxt name tyvars locTyvars (RecC cName cVarTypes : constrs) =
do clause <-
getClauseHelper cxt name tyvars locTyvars (map thd3 cVarTypes)
(map fst3 cVarTypes)
(\l -> RecP cName (map (\(var,_,field) -> (field, VarP var)) l))
(\l -> RecConE cName (map (\(var,_,field) -> (field, VarE var)) l))
clauses <- getClauses cxt name tyvars locTyvars constrs
return (clause : clauses)
getClauses cxt name tyvars locTyvars (InfixC cType1 cName cType2 : _) =
undefined
getClauses cxt name tyvars locTyvars (ForallC tyvars2 cxt2 constr
: constrs) =
do clauses1 <-
getClauses (cxt ++ cxt2) name tyvars (locTyvars ++ tyvars2) [constr]
clauses2 <- getClauses cxt name tyvars locTyvars constrs
return (clauses1 ++ clauses2)
getClauses cxt name tyvars locTyvars (GadtC cNames cTypes _ : constrs) =
do clauses1 <-
forM cNames $ \cName ->
getClauseHelper cxt name tyvars locTyvars (map snd cTypes)
(natsFrom 0) (\l -> ConP cName (map (VarP . fst3) l))
(\l -> foldl AppE (ConE cName) (map (VarE . fst3) l))
clauses2 <- getClauses cxt name tyvars locTyvars constrs
return (clauses1 ++ clauses2)
getClauses cxt name tyvars locTyvars (RecGadtC cNames cVarTypes _
: constrs) =
do clauses1 <-
forM cNames $ \cName ->
getClauseHelper cxt name tyvars locTyvars
(map thd3 cVarTypes) (map fst3 cVarTypes)
(\l -> RecP cName (map (\(var,_,field) -> (field, VarP var)) l))
(\l -> RecConE cName (map (\(var,_,field) -> (field, VarE var)) l))
clauses2 <- getClauses cxt name tyvars locTyvars constrs
return (clauses1 ++ clauses2)
getClauseHelper :: Cxt -> TH.Name -> [TyVarBndr] -> [TyVarBndr] ->
[TH.Type] -> [a] ->
([(TH.Name,TH.Type,a)] -> Pat) ->
([(TH.Name,TH.Type,a)] -> Exp) ->
CxtStateQ Clause
getClauseHelper cxt name tyvars locTyvars cTypes cData pFun eFun =
do varNames <- mapM (lift . newName . ("x" ++) . show . fst)
$ zip (natsFrom 0) cTypes
() <- ensureCxt cxt locTyvars cTypes
let varsTypesData = zip3 varNames cTypes cData
return $ Clause [(pFun varsTypesData)]
(NormalB $ eFun varsTypesData) []
ensureCxt :: Cxt -> [TyVarBndr] -> [TH.Type] -> CxtStateQ ()
ensureCxt cxt locTyvars cTypes =
foldM (const (ensureCxt1 cxt locTyvars)) () cTypes
ensureCxt1 :: Cxt -> [TyVarBndr] -> TH.Type -> CxtStateQ ()
ensureCxt1 cxt locTyvars t = undefined