-- | Row types {-# LANGUAGE FlexibleInstances, UndecidableInstances, FlexibleContexts, TemplateHaskell #-} module AST.Term.Row ( RowConstraints(..), RowKey , RowExtend(..), eKey, eVal, eRest, KWitness(..) , FlatRowExtends(..), freExtends, freRest , flattenRow, flattenRowExtend, unflattenRow , verifyRowExtendConstraints, rowExtendStructureMismatch , rowElementInfer ) where import AST import AST.TH.Internal.Instances (makeCommonInstances) import AST.Unify import AST.Unify.Lookup (semiPruneLookup) import AST.Unify.New (newTerm, newUnbound) import AST.Unify.Term (UTerm(..), _UTerm, uBody) import Control.DeepSeq (NFData) import Control.Lens (Prism', Lens', makeLenses, contains) import qualified Control.Lens as Lens import Control.Lens.Operators import Control.Monad (foldM) import Data.Binary (Binary) import Data.Foldable (sequenceA_) import Data.Map (Map) import qualified Data.Map as Map import Data.Set (Set) import Generics.Constraints (Constraints, makeDerivings, makeInstances) import GHC.Generics (Generic) import Text.Show.Combinators ((@|), showCon) import Prelude.Compat class (Ord (RowConstraintsKey constraints), TypeConstraints constraints) => RowConstraints constraints where type RowConstraintsKey constraints forbidden :: Lens' constraints (Set (RowConstraintsKey constraints)) type RowKey typ = RowConstraintsKey (TypeConstraintsOf typ) -- | Row-extend primitive for use in both value-level and type-level data RowExtend key val rest k = RowExtend { _eKey :: key , _eVal :: k # val , _eRest :: k # rest } deriving Generic data FlatRowExtends key val rest k = FlatRowExtends { _freExtends :: Map key (k # val) , _freRest :: k # rest } deriving Generic makeLenses ''RowExtend makeLenses ''FlatRowExtends makeCommonInstances [''FlatRowExtends] makeZipMatch ''RowExtend makeKTraversableApplyAndBases ''RowExtend makeKTraversableApplyAndBases ''FlatRowExtends makeDerivings [''Eq, ''Ord] [''RowExtend] makeInstances [''Binary, ''NFData] [''RowExtend] instance Constraints (RowExtend key val rest k) Show => Show (RowExtend key val rest k) where showsPrec p (RowExtend k v r) = (showCon "RowExtend" @| k @| v @| r) p {-# INLINE flattenRowExtend #-} flattenRowExtend :: (Ord key, Monad m) => (Tree v rest -> m (Maybe (Tree (RowExtend key val rest) v))) -> Tree (RowExtend key val rest) v -> m (Tree (FlatRowExtends key val rest) v) flattenRowExtend nextExtend (RowExtend k v rest) = flattenRow nextExtend rest <&> freExtends %~ Map.unionWith (error "Colliding keys") (Map.singleton k v) {-# INLINE flattenRow #-} flattenRow :: (Ord key, Monad m) => (Tree v rest -> m (Maybe (Tree (RowExtend key val rest) v))) -> Tree v rest -> m (Tree (FlatRowExtends key val rest) v) flattenRow nextExtend x = nextExtend x >>= maybe (pure (FlatRowExtends mempty x)) (flattenRowExtend nextExtend) {-# INLINE unflattenRow #-} unflattenRow :: Monad m => (Tree (RowExtend key val rest) v -> m (Tree v rest)) -> Tree (FlatRowExtends key val rest) v -> m (Tree v rest) unflattenRow mkExtend (FlatRowExtends fields rest) = Map.toList fields & foldM f rest where f acc (key, val) = RowExtend key val acc & mkExtend -- Helpers for Unify instances of type-level RowExtends: {-# INLINE verifyRowExtendConstraints #-} verifyRowExtendConstraints :: RowConstraints (TypeConstraintsOf rowTyp) => (TypeConstraintsOf rowTyp -> TypeConstraintsOf valTyp) -> TypeConstraintsOf rowTyp -> Tree (RowExtend (RowKey rowTyp) valTyp rowTyp) k -> Maybe (Tree (RowExtend (RowKey rowTyp) valTyp rowTyp) (WithConstraint k)) verifyRowExtendConstraints toChildC c (RowExtend k v rest) | c ^. forbidden . contains k = Nothing | otherwise = RowExtend k (WithConstraint (c & forbidden .~ mempty & toChildC) v) (WithConstraint (c & forbidden . contains k .~ True) rest) & Just {-# INLINE rowExtendStructureMismatch #-} rowExtendStructureMismatch :: Ord key => ( Unify m rowTyp , Unify m valTyp ) => (forall c. Unify m c => Tree (UVarOf m) c -> Tree (UVarOf m) c -> m (Tree (UVarOf m) c)) -> Prism' (Tree rowTyp (UVarOf m)) (Tree (RowExtend key valTyp rowTyp) (UVarOf m)) -> (TypeConstraintsOf rowTyp, Tree (RowExtend key valTyp rowTyp) (UVarOf m)) -> (TypeConstraintsOf rowTyp, Tree (RowExtend key valTyp rowTyp) (UVarOf m)) -> m () rowExtendStructureMismatch match extend (c0, r0) (c1, r1) = do flat0 <- flattenRowExtend nextExtend r0 flat1 <- flattenRowExtend nextExtend r1 Map.intersectionWith match (flat0 ^. freExtends) (flat1 ^. freExtends) & sequenceA_ restVar <- c0 <> c1 & UUnbound & newVar binding let side x y = unflattenRow mkExtend FlatRowExtends { _freExtends = (x ^. freExtends) `Map.difference` (y ^. freExtends) , _freRest = restVar } >>= match (y ^. freRest) _ <- side flat0 flat1 _ <- side flat1 flat0 pure () where mkExtend ext = extend # ext & newTerm nextExtend v = semiPruneLookup v <&> (^? Lens._2 . _UTerm . uBody . extend) -- Helper for infering row usages of a row element, -- such as getting a field from a record or injecting into a sum type. -- Returns a unification variable for the element and for the whole row. {-# INLINE rowElementInfer #-} rowElementInfer :: ( Unify m valTyp , Unify m rowTyp , RowConstraints (TypeConstraintsOf rowTyp) ) => (Tree (RowExtend (RowKey rowTyp) valTyp rowTyp) (UVarOf m) -> Tree rowTyp (UVarOf m)) -> RowKey rowTyp -> m (Tree (UVarOf m) valTyp, Tree (UVarOf m) rowTyp) rowElementInfer extendToRow k = do restVar <- scopeConstraints >>= newVar binding . UUnbound . (forbidden . contains k .~ True) part <- newUnbound whole <- RowExtend k part restVar & extendToRow & newTerm pure (part, whole)