module Top.Types.Substitution where
import Top.Types.Primitive
import Data.List (union, (\\), nub)
import qualified Data.Map as M
import qualified Data.Set as S
import Utils (internalError)
infix 4 |->
class Substitution s where
lookupInt :: Int -> s -> Tp
removeDom :: [Int] -> s -> s
restrictDom :: [Int] -> s -> s
dom :: s -> [Int]
cod :: s -> Tps
class Substitutable a where
(|->) :: Substitution s => s -> a -> a
ftv :: a -> [Int]
nextFTV :: Substitutable a => a -> Int
nextFTV a = case ftv a of
[] -> 0
is -> maximum is + 1
type MapSubstitution = M.Map Int Tp
instance Substitution MapSubstitution where
lookupInt i = M.findWithDefault (TVar i) i
removeDom = flip (foldr M.delete)
restrictDom is = let set = S.fromList is
in M.filterWithKey (\i _ -> S.member i set)
dom = M.keys
cod = M.elems
emptySubst :: MapSubstitution
emptySubst = M.empty
(@@) :: MapSubstitution -> MapSubstitution -> MapSubstitution
fm1 @@ fm2 = fm1 `M.union` M.map (\t -> fm1 |-> t) fm2
(@@@) :: MapSubstitution -> MapSubstitution -> MapSubstitution
(@@@) = M.union
singleSubstitution :: Int -> Tp -> MapSubstitution
singleSubstitution = M.singleton
listToSubstitution :: [(Int,Tp)] -> MapSubstitution
listToSubstitution = M.fromList
newtype FixpointSubstitution = FixpointSubstitution (M.Map Int Tp)
instance Substitution FixpointSubstitution where
lookupInt i original@(FixpointSubstitution fm) =
case M.lookup i fm of
Just tp | tp == TVar i -> TVar i
| otherwise -> original |-> tp
Nothing -> TVar i
removeDom is (FixpointSubstitution fm) = FixpointSubstitution (M.filterWithKey (\i _ -> i `notElem` is) fm)
restrictDom is (FixpointSubstitution fm) = let js = M.keys fm \\ is
in FixpointSubstitution (M.filterWithKey (\i _ -> i `notElem` js) fm)
dom (FixpointSubstitution fm) = M.keys fm
cod (FixpointSubstitution fm) = M.elems fm
emptyFPS :: FixpointSubstitution
emptyFPS = FixpointSubstitution M.empty
disjointFPS :: FixpointSubstitution -> FixpointSubstitution -> FixpointSubstitution
disjointFPS (FixpointSubstitution fm1) (FixpointSubstitution fm2) =
let notDisjoint = internalError "Substitution" "disjointFPS" "the two fixpoint substitutions are not disjoint"
in FixpointSubstitution (M.unionWith notDisjoint fm1 fm2)
wrapSubstitution :: Substitution substitution => substitution -> WrappedSubstitution
wrapSubstitution substitution =
WrappedSubstitution substitution
( lookupInt
, removeDom
, restrictDom
, dom
, cod
)
data WrappedSubstitution =
forall a . Substitution a =>
WrappedSubstitution a
( Int -> a -> Tp
, [Int] -> a -> a
, [Int] -> a -> a
, a -> [Int]
, a -> Tps
)
instance Substitution WrappedSubstitution where
lookupInt i (WrappedSubstitution x (f,_,_,_,_)) = f i x
removeDom is (WrappedSubstitution x (_,f,_,_,_)) = wrapSubstitution (f is x)
restrictDom is (WrappedSubstitution x (_,_,f,_,_)) = wrapSubstitution (f is x)
dom (WrappedSubstitution x (_,_,_,f,_)) = f x
cod (WrappedSubstitution x (_,_,_,_,f)) = f x
instance Substitutable Tp where
sub |-> tp =
case tp of
TVar i -> lookupInt i sub
TCon _ -> tp
TApp t1 t2 -> TApp (sub |-> t1) (sub |-> t2)
ftv tp =
case tp of
TVar i -> [i]
TCon _ -> []
TApp t1 t2 -> ftv t1 `union` ftv t2
instance Substitutable a => Substitutable [a] where
sub |-> as = map (sub |->) as
ftv = foldr (union . ftv) []
instance (Substitutable a, Substitutable b) => Substitutable (a, b) where
sub |-> (a, b) = (sub |-> a, sub |-> b)
ftv (a, b) = ftv a `union` ftv b
instance Substitutable a => Substitutable (Maybe a) where
sub |-> ma = fmap (sub |->) ma
ftv = maybe [] ftv
instance (Substitutable a, Substitutable b) => Substitutable (Either a b) where
sub |-> x = either (Left . (sub |->)) (Right . (sub |->)) x
ftv = either ftv ftv
freezeFTV :: Substitutable a => a -> a
freezeFTV a =
let sub = listToSubstitution [ (i, TCon ('_':show i)) | i <- ftv a ]
in sub |-> a
allTypeVariables :: HasTypes a => a -> [Int]
allTypeVariables = ftv . getTypes
allTypeConstants :: HasTypes a => a -> [String]
allTypeConstants =
let f (TVar _) = []
f (TCon s) = [s]
f (TApp l r) = f l ++ f r
in nub . concatMap f . getTypes