-- |
-- Module      :  Cryptol.Transform.Specialize
-- Copyright   :  (c) 2013-2016 Galois, Inc.
-- License     :  BSD3
-- Maintainer  :  cryptol@galois.com
-- Stability   :  provisional
-- Portability :  portable

module Cryptol.Transform.Specialize
where

import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.TypeMap
import Cryptol.TypeCheck.Subst
import qualified Cryptol.ModuleSystem as M
import qualified Cryptol.ModuleSystem.Env as M
import qualified Cryptol.ModuleSystem.Monad as M
import Cryptol.ModuleSystem.Name

import           Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (catMaybes)

import MonadLib hiding (mapM)

-- Specializer Monad -----------------------------------------------------------

-- | A 'Name' should have an entry in the 'SpecCache' iff it is
-- specializable. Each 'Name' starts out with an empty 'TypesMap'.
type SpecCache = Map Name (Decl, TypesMap (Name, Maybe Decl))

-- | The specializer monad.
type SpecT m a = StateT SpecCache (M.ModuleT m) a

type SpecM a = SpecT IO a

runSpecT :: SpecCache -> SpecT m a -> M.ModuleT m (a, SpecCache)
runSpecT :: SpecCache -> SpecT m a -> ModuleT m (a, SpecCache)
runSpecT SpecCache
s SpecT m a
m = SpecCache -> SpecT m a -> ModuleT m (a, SpecCache)
forall i (m :: * -> *) a. i -> StateT i m a -> m (a, i)
runStateT SpecCache
s SpecT m a
m

liftSpecT :: Monad m => M.ModuleT m a -> SpecT m a
liftSpecT :: ModuleT m a -> SpecT m a
liftSpecT ModuleT m a
m = ModuleT m a -> SpecT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadT t, Monad m) =>
m a -> t m a
lift ModuleT m a
m

getSpecCache :: Monad m => SpecT m SpecCache
getSpecCache :: SpecT m SpecCache
getSpecCache = SpecT m SpecCache
forall (m :: * -> *) i. StateM m i => m i
get

setSpecCache :: Monad m => SpecCache -> SpecT m ()
setSpecCache :: SpecCache -> SpecT m ()
setSpecCache = SpecCache -> SpecT m ()
forall (m :: * -> *) i. StateM m i => i -> m ()
set

modifySpecCache :: Monad m => (SpecCache -> SpecCache) -> SpecT m ()
modifySpecCache :: (SpecCache -> SpecCache) -> SpecT m ()
modifySpecCache = (SpecCache -> SpecCache) -> SpecT m ()
forall (m :: * -> *) s. StateM m s => (s -> s) -> m ()
modify

modify :: StateM m s => (s -> s) -> m ()
modify :: (s -> s) -> m ()
modify s -> s
f = m s
forall (m :: * -> *) i. StateM m i => m i
get m s -> (s -> m ()) -> m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (s -> m ()
forall (m :: * -> *) i. StateM m i => i -> m ()
set (s -> m ()) -> (s -> s) -> s -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> s
f)


-- Specializer -----------------------------------------------------------------

-- | Add a @where@ clause to the given expression containing
-- type-specialized versions of all functions called (transitively) by
-- the body of the expression.
specialize :: Expr -> M.ModuleCmd Expr
specialize :: Expr -> ModuleCmd Expr
specialize Expr
expr ModuleInput IO
minp = SpecT IO Expr
-> IO (Either ModuleError (Expr, ModuleEnv), [ModuleWarning])
forall a.
SpecT IO a
-> IO (Either ModuleError (a, ModuleEnv), [ModuleWarning])
run (SpecT IO Expr
 -> IO (Either ModuleError (Expr, ModuleEnv), [ModuleWarning]))
-> SpecT IO Expr
-> IO (Either ModuleError (Expr, ModuleEnv), [ModuleWarning])
forall a b. (a -> b) -> a -> b
$ do
  let extDgs :: [DeclGroup]
extDgs = ModuleEnv -> [DeclGroup]
allDeclGroups (ModuleInput IO -> ModuleEnv
forall (m :: * -> *). ModuleInput m -> ModuleEnv
M.minpModuleEnv ModuleInput IO
minp)
  let ([TParam]
tparams, Expr
expr') = Expr -> ([TParam], Expr)
destETAbs Expr
expr
  Expr
spec' <- Expr -> [DeclGroup] -> SpecT IO Expr
specializeEWhere Expr
expr' [DeclGroup]
extDgs
  Expr -> SpecT IO Expr
forall (m :: * -> *) a. Monad m => a -> m a
return ((TParam -> Expr -> Expr) -> Expr -> [TParam] -> Expr
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr TParam -> Expr -> Expr
ETAbs Expr
spec' [TParam]
tparams)
  where
  run :: SpecT IO a
-> IO (Either ModuleError (a, ModuleEnv), [ModuleWarning])
run = ModuleInput IO
-> ModuleT IO a
-> IO (Either ModuleError (a, ModuleEnv), [ModuleWarning])
forall (m :: * -> *) a.
Monad m =>
ModuleInput m
-> ModuleT m a
-> m (Either ModuleError (a, ModuleEnv), [ModuleWarning])
M.runModuleT ModuleInput IO
minp (ModuleT IO a
 -> IO (Either ModuleError (a, ModuleEnv), [ModuleWarning]))
-> (SpecT IO a -> ModuleT IO a)
-> SpecT IO a
-> IO (Either ModuleError (a, ModuleEnv), [ModuleWarning])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, SpecCache) -> a) -> ModuleT IO (a, SpecCache) -> ModuleT IO a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, SpecCache) -> a
forall a b. (a, b) -> a
fst (ModuleT IO (a, SpecCache) -> ModuleT IO a)
-> (SpecT IO a -> ModuleT IO (a, SpecCache))
-> SpecT IO a
-> ModuleT IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SpecCache -> SpecT IO a -> ModuleT IO (a, SpecCache)
forall (m :: * -> *) a.
SpecCache -> SpecT m a -> ModuleT m (a, SpecCache)
runSpecT SpecCache
forall k a. Map k a
Map.empty

specializeExpr :: Expr -> SpecM Expr
specializeExpr :: Expr -> SpecT IO Expr
specializeExpr Expr
expr =
  case Expr
expr of
    ELocated Range
r Expr
e  -> Range -> Expr -> Expr
ELocated Range
r (Expr -> Expr) -> SpecT IO Expr -> SpecT IO Expr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> SpecT IO Expr
specializeExpr Expr
e
    EList [Expr]
es Type
t    -> [Expr] -> Type -> Expr
EList ([Expr] -> Type -> Expr)
-> StateT SpecCache (ModuleT IO) [Expr]
-> StateT SpecCache (ModuleT IO) (Type -> Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Expr -> SpecT IO Expr)
-> [Expr] -> StateT SpecCache (ModuleT IO) [Expr]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Expr -> SpecT IO Expr
specializeExpr [Expr]
es StateT SpecCache (ModuleT IO) (Type -> Expr)
-> StateT SpecCache (ModuleT IO) Type -> SpecT IO Expr
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> StateT SpecCache (ModuleT IO) Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t
    ETuple [Expr]
es     -> [Expr] -> Expr
ETuple ([Expr] -> Expr)
-> StateT SpecCache (ModuleT IO) [Expr] -> SpecT IO Expr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Expr -> SpecT IO Expr)
-> [Expr] -> StateT SpecCache (ModuleT IO) [Expr]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Expr -> SpecT IO Expr
specializeExpr [Expr]
es
    ERec RecordMap Ident Expr
fs       -> RecordMap Ident Expr -> Expr
ERec (RecordMap Ident Expr -> Expr)
-> StateT SpecCache (ModuleT IO) (RecordMap Ident Expr)
-> SpecT IO Expr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Expr -> SpecT IO Expr)
-> RecordMap Ident Expr
-> StateT SpecCache (ModuleT IO) (RecordMap Ident Expr)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Expr -> SpecT IO Expr
specializeExpr RecordMap Ident Expr
fs
    ESel Expr
e Selector
s      -> Expr -> Selector -> Expr
ESel (Expr -> Selector -> Expr)
-> SpecT IO Expr
-> StateT SpecCache (ModuleT IO) (Selector -> Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> SpecT IO Expr
specializeExpr Expr
e StateT SpecCache (ModuleT IO) (Selector -> Expr)
-> StateT SpecCache (ModuleT IO) Selector -> SpecT IO Expr
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Selector -> StateT SpecCache (ModuleT IO) Selector
forall (f :: * -> *) a. Applicative f => a -> f a
pure Selector
s
    ESet Type
ty Expr
e Selector
s Expr
v -> Type -> Expr -> Selector -> Expr -> Expr
ESet Type
ty (Expr -> Selector -> Expr -> Expr)
-> SpecT IO Expr
-> StateT SpecCache (ModuleT IO) (Selector -> Expr -> Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> SpecT IO Expr
specializeExpr Expr
e StateT SpecCache (ModuleT IO) (Selector -> Expr -> Expr)
-> StateT SpecCache (ModuleT IO) Selector
-> StateT SpecCache (ModuleT IO) (Expr -> Expr)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Selector -> StateT SpecCache (ModuleT IO) Selector
forall (f :: * -> *) a. Applicative f => a -> f a
pure Selector
s StateT SpecCache (ModuleT IO) (Expr -> Expr)
-> SpecT IO Expr -> SpecT IO Expr
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Expr -> SpecT IO Expr
specializeExpr Expr
v
    EIf Expr
e1 Expr
e2 Expr
e3  -> Expr -> Expr -> Expr -> Expr
EIf (Expr -> Expr -> Expr -> Expr)
-> SpecT IO Expr
-> StateT SpecCache (ModuleT IO) (Expr -> Expr -> Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> SpecT IO Expr
specializeExpr Expr
e1 StateT SpecCache (ModuleT IO) (Expr -> Expr -> Expr)
-> SpecT IO Expr -> StateT SpecCache (ModuleT IO) (Expr -> Expr)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Expr -> SpecT IO Expr
specializeExpr Expr
e2 StateT SpecCache (ModuleT IO) (Expr -> Expr)
-> SpecT IO Expr -> SpecT IO Expr
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Expr -> SpecT IO Expr
specializeExpr Expr
e3
    EComp Type
len Type
t Expr
e [[Match]]
mss -> Type -> Type -> Expr -> [[Match]] -> Expr
EComp Type
len Type
t (Expr -> [[Match]] -> Expr)
-> SpecT IO Expr
-> StateT SpecCache (ModuleT IO) ([[Match]] -> Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> SpecT IO Expr
specializeExpr Expr
e StateT SpecCache (ModuleT IO) ([[Match]] -> Expr)
-> StateT SpecCache (ModuleT IO) [[Match]] -> SpecT IO Expr
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([Match] -> StateT SpecCache (ModuleT IO) [Match])
-> [[Match]] -> StateT SpecCache (ModuleT IO) [[Match]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Match -> StateT SpecCache (ModuleT IO) Match)
-> [Match] -> StateT SpecCache (ModuleT IO) [Match]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Match -> StateT SpecCache (ModuleT IO) Match
specializeMatch) [[Match]]
mss
    -- Bindings within list comprehensions always have monomorphic types.
    EVar {}       -> Expr -> SpecT IO Expr
specializeConst Expr
expr
    ETAbs TParam
t Expr
e     -> do
      SpecCache
cache <- SpecT IO SpecCache
forall (m :: * -> *). Monad m => SpecT m SpecCache
getSpecCache
      SpecCache -> SpecT IO ()
forall (m :: * -> *). Monad m => SpecCache -> SpecT m ()
setSpecCache SpecCache
forall k a. Map k a
Map.empty
      Expr
e' <- Expr -> SpecT IO Expr
specializeExpr Expr
e
      SpecCache -> SpecT IO ()
forall (m :: * -> *). Monad m => SpecCache -> SpecT m ()
setSpecCache SpecCache
cache
      Expr -> SpecT IO Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (TParam -> Expr -> Expr
ETAbs TParam
t Expr
e')
    -- We need to make sure that after processing @e@, no specialized
    -- decls mentioning type variable @t@ escape outside the
    -- 'ETAbs'. To avoid this, we reset to an empty 'SpecCache' while we
    -- run @'specializeExpr' e@, and restore it afterward: this
    -- effectively prevents the specializer from registering any type
    -- instantiations involving @t@ for any decls bound outside the
    -- scope of @t@.
    ETApp {}      -> Expr -> SpecT IO Expr
specializeConst Expr
expr
    EApp Expr
e1 Expr
e2    -> Expr -> Expr -> Expr
EApp (Expr -> Expr -> Expr)
-> SpecT IO Expr -> StateT SpecCache (ModuleT IO) (Expr -> Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> SpecT IO Expr
specializeExpr Expr
e1 StateT SpecCache (ModuleT IO) (Expr -> Expr)
-> SpecT IO Expr -> SpecT IO Expr
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Expr -> SpecT IO Expr
specializeExpr Expr
e2
    EAbs Name
qn Type
t Expr
e   -> Name -> Type -> Expr -> Expr
EAbs Name
qn Type
t (Expr -> Expr) -> SpecT IO Expr -> SpecT IO Expr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> SpecT IO Expr
specializeExpr Expr
e
    EProofAbs Type
p Expr
e -> Type -> Expr -> Expr
EProofAbs Type
p (Expr -> Expr) -> SpecT IO Expr -> SpecT IO Expr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> SpecT IO Expr
specializeExpr Expr
e
    EProofApp {}  -> Expr -> SpecT IO Expr
specializeConst Expr
expr
    EWhere Expr
e [DeclGroup]
dgs  -> Expr -> [DeclGroup] -> SpecT IO Expr
specializeEWhere Expr
e [DeclGroup]
dgs

specializeMatch :: Match -> SpecM Match
specializeMatch :: Match -> StateT SpecCache (ModuleT IO) Match
specializeMatch (From Name
qn Type
l Type
t Expr
e) = Name -> Type -> Type -> Expr -> Match
From Name
qn Type
l Type
t (Expr -> Match)
-> SpecT IO Expr -> StateT SpecCache (ModuleT IO) Match
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> SpecT IO Expr
specializeExpr Expr
e
specializeMatch (Let Decl
decl)
  | [TParam] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Schema -> [TParam]
sVars (Decl -> Schema
dSignature Decl
decl)) = Match -> StateT SpecCache (ModuleT IO) Match
forall (m :: * -> *) a. Monad m => a -> m a
return (Decl -> Match
Let Decl
decl)
  | Bool
otherwise = String -> StateT SpecCache (ModuleT IO) Match
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"unimplemented: specializeMatch Let unimplemented"
  -- TODO: should treat this case like EWhere.


-- | Add the declarations to the SpecCache, run the given monadic
-- action, and then pull the specialized declarations back out of the
-- SpecCache state. Return the result along with the declarations and
-- a table of names of specialized bindings.
withDeclGroups :: [DeclGroup] -> SpecM a
                  -> SpecM (a, [DeclGroup], Map Name (TypesMap Name))
withDeclGroups :: [DeclGroup]
-> SpecM a -> SpecM (a, [DeclGroup], Map Name (TypesMap Name))
withDeclGroups [DeclGroup]
dgs SpecM a
action = do
  SpecCache
origCache <- SpecT IO SpecCache
forall (m :: * -> *). Monad m => SpecT m SpecCache
getSpecCache
  let decls :: [Decl]
decls = (DeclGroup -> [Decl]) -> [DeclGroup] -> [Decl]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap DeclGroup -> [Decl]
groupDecls [DeclGroup]
dgs
  let newCache :: Map Name (Decl, List TypeMap a)
newCache = [(Name, (Decl, List TypeMap a))] -> Map Name (Decl, List TypeMap a)
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [ (Decl -> Name
dName Decl
d, (Decl
d, List TypeMap a
forall (m :: * -> *) k a. TrieMap m k => m a
emptyTM)) | Decl
d <- [Decl]
decls ]
  let savedCache :: SpecCache
savedCache = SpecCache -> Map Name (Decl, List TypeMap Any) -> SpecCache
forall k a b. Ord k => Map k a -> Map k b -> Map k a
Map.intersection SpecCache
origCache Map Name (Decl, List TypeMap Any)
forall a. Map Name (Decl, List TypeMap a)
newCache
  -- We assume that the names bound in dgs are disjoint from the other names in scope.
  SpecCache -> SpecT IO ()
forall (m :: * -> *). Monad m => SpecCache -> SpecT m ()
setSpecCache (SpecCache -> SpecCache -> SpecCache
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union SpecCache
forall a. Map Name (Decl, List TypeMap a)
newCache SpecCache
origCache)
  a
result <- SpecM a
action
  -- Then reassemble the DeclGroups.
  let splitDecl :: Decl -> SpecM [Decl]
      splitDecl :: Decl -> SpecM [Decl]
splitDecl Decl
d = do
        ~(Just (Decl
_, TypesMap (Name, Maybe Decl)
tm)) <- Name -> SpecCache -> Maybe (Decl, TypesMap (Name, Maybe Decl))
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup (Decl -> Name
dName Decl
d) (SpecCache -> Maybe (Decl, TypesMap (Name, Maybe Decl)))
-> SpecT IO SpecCache
-> StateT
     SpecCache (ModuleT IO) (Maybe (Decl, TypesMap (Name, Maybe Decl)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SpecT IO SpecCache
forall (m :: * -> *). Monad m => SpecT m SpecCache
getSpecCache
        [Decl] -> SpecM [Decl]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Maybe Decl] -> [Decl]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Decl] -> [Decl]) -> [Maybe Decl] -> [Decl]
forall a b. (a -> b) -> a -> b
$ (([Type], (Name, Maybe Decl)) -> Maybe Decl)
-> [([Type], (Name, Maybe Decl))] -> [Maybe Decl]
forall a b. (a -> b) -> [a] -> [b]
map ((Name, Maybe Decl) -> Maybe Decl
forall a b. (a, b) -> b
snd ((Name, Maybe Decl) -> Maybe Decl)
-> (([Type], (Name, Maybe Decl)) -> (Name, Maybe Decl))
-> ([Type], (Name, Maybe Decl))
-> Maybe Decl
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Type], (Name, Maybe Decl)) -> (Name, Maybe Decl)
forall a b. (a, b) -> b
snd) ([([Type], (Name, Maybe Decl))] -> [Maybe Decl])
-> [([Type], (Name, Maybe Decl))] -> [Maybe Decl]
forall a b. (a -> b) -> a -> b
$ TypesMap (Name, Maybe Decl) -> [([Type], (Name, Maybe Decl))]
forall (m :: * -> *) k a. TrieMap m k => m a -> [(k, a)]
toListTM TypesMap (Name, Maybe Decl)
tm)
  let splitDeclGroup :: DeclGroup -> SpecM [DeclGroup]
      splitDeclGroup :: DeclGroup -> SpecM [DeclGroup]
splitDeclGroup (Recursive [Decl]
ds) = do
        [Decl]
ds' <- [[Decl]] -> [Decl]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Decl]] -> [Decl])
-> StateT SpecCache (ModuleT IO) [[Decl]] -> SpecM [Decl]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Decl -> SpecM [Decl])
-> [Decl] -> StateT SpecCache (ModuleT IO) [[Decl]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Decl -> SpecM [Decl]
splitDecl [Decl]
ds
        if [Decl] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Decl]
ds'
          then [DeclGroup] -> SpecM [DeclGroup]
forall (m :: * -> *) a. Monad m => a -> m a
return []
          else [DeclGroup] -> SpecM [DeclGroup]
forall (m :: * -> *) a. Monad m => a -> m a
return [[Decl] -> DeclGroup
Recursive [Decl]
ds']
      splitDeclGroup (NonRecursive Decl
d) = (Decl -> DeclGroup) -> [Decl] -> [DeclGroup]
forall a b. (a -> b) -> [a] -> [b]
map Decl -> DeclGroup
NonRecursive ([Decl] -> [DeclGroup]) -> SpecM [Decl] -> SpecM [DeclGroup]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Decl -> SpecM [Decl]
splitDecl Decl
d
  [DeclGroup]
dgs' <- [[DeclGroup]] -> [DeclGroup]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[DeclGroup]] -> [DeclGroup])
-> StateT SpecCache (ModuleT IO) [[DeclGroup]] -> SpecM [DeclGroup]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (DeclGroup -> SpecM [DeclGroup])
-> [DeclGroup] -> StateT SpecCache (ModuleT IO) [[DeclGroup]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse DeclGroup -> SpecM [DeclGroup]
splitDeclGroup [DeclGroup]
dgs
  -- Get updated map of only the local entries we added.
  SpecCache
newCache' <- (SpecCache -> Map Name (Decl, List TypeMap Any) -> SpecCache)
-> Map Name (Decl, List TypeMap Any) -> SpecCache -> SpecCache
forall a b c. (a -> b -> c) -> b -> a -> c
flip SpecCache -> Map Name (Decl, List TypeMap Any) -> SpecCache
forall k a b. Ord k => Map k a -> Map k b -> Map k a
Map.intersection Map Name (Decl, List TypeMap Any)
forall a. Map Name (Decl, List TypeMap a)
newCache (SpecCache -> SpecCache)
-> SpecT IO SpecCache -> SpecT IO SpecCache
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SpecT IO SpecCache
forall (m :: * -> *). Monad m => SpecT m SpecCache
getSpecCache
  let nameTable :: Map Name (TypesMap Name)
nameTable = ((Decl, TypesMap (Name, Maybe Decl)) -> TypesMap Name)
-> SpecCache -> Map Name (TypesMap Name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (((Name, Maybe Decl) -> Name)
-> TypesMap (Name, Maybe Decl) -> TypesMap Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Name, Maybe Decl) -> Name
forall a b. (a, b) -> a
fst (TypesMap (Name, Maybe Decl) -> TypesMap Name)
-> ((Decl, TypesMap (Name, Maybe Decl))
    -> TypesMap (Name, Maybe Decl))
-> (Decl, TypesMap (Name, Maybe Decl))
-> TypesMap Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Decl, TypesMap (Name, Maybe Decl)) -> TypesMap (Name, Maybe Decl)
forall a b. (a, b) -> b
snd) SpecCache
newCache'
  -- Remove local definitions from the cache.
  (SpecCache -> SpecCache) -> SpecT IO ()
forall (m :: * -> *).
Monad m =>
(SpecCache -> SpecCache) -> SpecT m ()
modifySpecCache (SpecCache -> SpecCache -> SpecCache
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union SpecCache
savedCache (SpecCache -> SpecCache)
-> (SpecCache -> SpecCache) -> SpecCache -> SpecCache
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SpecCache -> Map Name (Decl, List TypeMap Any) -> SpecCache)
-> Map Name (Decl, List TypeMap Any) -> SpecCache -> SpecCache
forall a b c. (a -> b -> c) -> b -> a -> c
flip SpecCache -> Map Name (Decl, List TypeMap Any) -> SpecCache
forall k a b. Ord k => Map k a -> Map k b -> Map k a
Map.difference Map Name (Decl, List TypeMap Any)
forall a. Map Name (Decl, List TypeMap a)
newCache)
  (a, [DeclGroup], Map Name (TypesMap Name))
-> SpecM (a, [DeclGroup], Map Name (TypesMap Name))
forall (m :: * -> *) a. Monad m => a -> m a
return (a
result, [DeclGroup]
dgs', Map Name (TypesMap Name)
nameTable)

-- | Compute the specialization of @'EWhere' e dgs@. A decl within @dgs@
-- is replicated once for each monomorphic type instance at which it
-- is used; decls not mentioned in @e@ (even monomorphic ones) are
-- simply dropped.
specializeEWhere :: Expr -> [DeclGroup] -> SpecM Expr
specializeEWhere :: Expr -> [DeclGroup] -> SpecT IO Expr
specializeEWhere Expr
e [DeclGroup]
dgs = do
  (Expr
e', [DeclGroup]
dgs', Map Name (TypesMap Name)
_) <- [DeclGroup]
-> SpecT IO Expr
-> SpecM (Expr, [DeclGroup], Map Name (TypesMap Name))
forall a.
[DeclGroup]
-> SpecM a -> SpecM (a, [DeclGroup], Map Name (TypesMap Name))
withDeclGroups [DeclGroup]
dgs (Expr -> SpecT IO Expr
specializeExpr Expr
e)
  Expr -> SpecT IO Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> SpecT IO Expr) -> Expr -> SpecT IO Expr
forall a b. (a -> b) -> a -> b
$ if [DeclGroup] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DeclGroup]
dgs'
    then Expr
e'
    else Expr -> [DeclGroup] -> Expr
EWhere Expr
e' [DeclGroup]
dgs'

-- | Transform the given declaration groups into a set of monomorphic
-- declarations. All of the original declarations with monomorphic
-- types are kept; additionally the result set includes instantiated
-- versions of polymorphic decls that are referenced by the
-- monomorphic bindings. We also return a map relating generated names
-- to the names from the original declarations.
specializeDeclGroups :: [DeclGroup] -> SpecM ([DeclGroup], Map Name (TypesMap Name))
specializeDeclGroups :: [DeclGroup] -> SpecM ([DeclGroup], Map Name (TypesMap Name))
specializeDeclGroups [DeclGroup]
dgs = do
  let decls :: [Decl]
decls = (DeclGroup -> [Decl]) -> [DeclGroup] -> [Decl]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap DeclGroup -> [Decl]
groupDecls [DeclGroup]
dgs
  let isMonoType :: Schema -> Bool
isMonoType Schema
s = [TParam] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Schema -> [TParam]
sVars Schema
s) Bool -> Bool -> Bool
&& [Type] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (Schema -> [Type]
sProps Schema
s)
  let monos :: [Expr]
monos = [ Name -> Expr
EVar (Decl -> Name
dName Decl
d) | Decl
d <- [Decl]
decls, Schema -> Bool
isMonoType (Decl -> Schema
dSignature Decl
d) ]
  ([Expr]
_, [DeclGroup]
dgs', Map Name (TypesMap Name)
names) <- [DeclGroup]
-> StateT SpecCache (ModuleT IO) [Expr]
-> SpecM ([Expr], [DeclGroup], Map Name (TypesMap Name))
forall a.
[DeclGroup]
-> SpecM a -> SpecM (a, [DeclGroup], Map Name (TypesMap Name))
withDeclGroups [DeclGroup]
dgs (StateT SpecCache (ModuleT IO) [Expr]
 -> SpecM ([Expr], [DeclGroup], Map Name (TypesMap Name)))
-> StateT SpecCache (ModuleT IO) [Expr]
-> SpecM ([Expr], [DeclGroup], Map Name (TypesMap Name))
forall a b. (a -> b) -> a -> b
$ (Expr -> SpecT IO Expr)
-> [Expr] -> StateT SpecCache (ModuleT IO) [Expr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Expr -> SpecT IO Expr
specializeExpr [Expr]
monos
  ([DeclGroup], Map Name (TypesMap Name))
-> SpecM ([DeclGroup], Map Name (TypesMap Name))
forall (m :: * -> *) a. Monad m => a -> m a
return ([DeclGroup]
dgs', Map Name (TypesMap Name)
names)

specializeConst :: Expr -> SpecM Expr
specializeConst :: Expr -> SpecT IO Expr
specializeConst Expr
e0 = do
  let (Expr
e1, Int
n) = Expr -> (Expr, Int)
destEProofApps Expr
e0
  let (Expr
e2, [Type]
ts) = Expr -> (Expr, [Type])
destETApps Expr
e1
  case Expr
e2 of
    EVar Name
qname ->
      do SpecCache
cache <- SpecT IO SpecCache
forall (m :: * -> *). Monad m => SpecT m SpecCache
getSpecCache
         case Name -> SpecCache -> Maybe (Decl, TypesMap (Name, Maybe Decl))
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Name
qname SpecCache
cache of
           Maybe (Decl, TypesMap (Name, Maybe Decl))
Nothing -> Expr -> SpecT IO Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e0 -- Primitive/unspecializable variable; leave it alone
           Just (Decl
decl, TypesMap (Name, Maybe Decl)
tm) ->
             case [Type] -> TypesMap (Name, Maybe Decl) -> Maybe (Name, Maybe Decl)
forall (m :: * -> *) k a. TrieMap m k => k -> m a -> Maybe a
lookupTM [Type]
ts TypesMap (Name, Maybe Decl)
tm of
               Just (Name
qname', Maybe Decl
_) -> Expr -> SpecT IO Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> Expr
EVar Name
qname') -- Already specialized
               Maybe (Name, Maybe Decl)
Nothing -> do  -- A new type instance of this function
                 Name
qname' <- Name -> [Type] -> SpecM Name
freshName Name
qname [Type]
ts -- New type instance, record new name
                 Schema
sig' <- [Type] -> Int -> Schema -> SpecM Schema
instantiateSchema [Type]
ts Int
n (Decl -> Schema
dSignature Decl
decl)
                 (SpecCache -> SpecCache) -> SpecT IO ()
forall (m :: * -> *).
Monad m =>
(SpecCache -> SpecCache) -> SpecT m ()
modifySpecCache (((Decl, TypesMap (Name, Maybe Decl))
 -> (Decl, TypesMap (Name, Maybe Decl)))
-> Name -> SpecCache -> SpecCache
forall k a. Ord k => (a -> a) -> k -> Map k a -> Map k a
Map.adjust ((TypesMap (Name, Maybe Decl) -> TypesMap (Name, Maybe Decl))
-> (Decl, TypesMap (Name, Maybe Decl))
-> (Decl, TypesMap (Name, Maybe Decl))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Type]
-> (Name, Maybe Decl)
-> TypesMap (Name, Maybe Decl)
-> TypesMap (Name, Maybe Decl)
forall (m :: * -> *) k a. TrieMap m k => k -> a -> m a -> m a
insertTM [Type]
ts (Name
qname', Maybe Decl
forall a. Maybe a
Nothing))) Name
qname)
                 DeclDef
rhs' <- case Decl -> DeclDef
dDefinition Decl
decl of
                           DExpr Expr
e -> do Expr
e' <- Expr -> SpecT IO Expr
specializeExpr (Expr -> SpecT IO Expr) -> SpecT IO Expr -> SpecT IO Expr
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Type] -> Int -> Expr -> SpecT IO Expr
instantiateExpr [Type]
ts Int
n Expr
e
                                         DeclDef -> StateT SpecCache (ModuleT IO) DeclDef
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> DeclDef
DExpr Expr
e')
                           DeclDef
DPrim   -> DeclDef -> StateT SpecCache (ModuleT IO) DeclDef
forall (m :: * -> *) a. Monad m => a -> m a
return DeclDef
DPrim
                 let decl' :: Decl
decl' = Decl
decl { dName :: Name
dName = Name
qname', dSignature :: Schema
dSignature = Schema
sig', dDefinition :: DeclDef
dDefinition = DeclDef
rhs' }
                 (SpecCache -> SpecCache) -> SpecT IO ()
forall (m :: * -> *).
Monad m =>
(SpecCache -> SpecCache) -> SpecT m ()
modifySpecCache (((Decl, TypesMap (Name, Maybe Decl))
 -> (Decl, TypesMap (Name, Maybe Decl)))
-> Name -> SpecCache -> SpecCache
forall k a. Ord k => (a -> a) -> k -> Map k a -> Map k a
Map.adjust ((TypesMap (Name, Maybe Decl) -> TypesMap (Name, Maybe Decl))
-> (Decl, TypesMap (Name, Maybe Decl))
-> (Decl, TypesMap (Name, Maybe Decl))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Type]
-> (Name, Maybe Decl)
-> TypesMap (Name, Maybe Decl)
-> TypesMap (Name, Maybe Decl)
forall (m :: * -> *) k a. TrieMap m k => k -> a -> m a -> m a
insertTM [Type]
ts (Name
qname', Decl -> Maybe Decl
forall a. a -> Maybe a
Just Decl
decl'))) Name
qname)
                 Expr -> SpecT IO Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> Expr
EVar Name
qname')
    Expr
_ -> Expr -> SpecT IO Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e0 -- type/proof application to non-variable; not specializable


-- Utility Functions -----------------------------------------------------------

destEProofApps :: Expr -> (Expr, Int)
destEProofApps :: Expr -> (Expr, Int)
destEProofApps = Int -> Expr -> (Expr, Int)
forall b. Num b => b -> Expr -> (Expr, b)
go Int
0
  where
    go :: b -> Expr -> (Expr, b)
go b
n (EProofApp Expr
e) = b -> Expr -> (Expr, b)
go (b
n b -> b -> b
forall a. Num a => a -> a -> a
+ b
1) Expr
e
    go b
n Expr
e             = (Expr
e, b
n)

destETApps :: Expr -> (Expr, [Type])
destETApps :: Expr -> (Expr, [Type])
destETApps = [Type] -> Expr -> (Expr, [Type])
go []
  where
    go :: [Type] -> Expr -> (Expr, [Type])
go [Type]
ts (ETApp Expr
e Type
t) = [Type] -> Expr -> (Expr, [Type])
go (Type
t Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
ts) Expr
e
    go [Type]
ts Expr
e           = (Expr
e, [Type]
ts)

destEProofAbs :: Expr -> ([Prop], Expr)
destEProofAbs :: Expr -> ([Type], Expr)
destEProofAbs = [Type] -> Expr -> ([Type], Expr)
go []
  where
    go :: [Type] -> Expr -> ([Type], Expr)
go [Type]
ps (EProofAbs Type
p Expr
e) = [Type] -> Expr -> ([Type], Expr)
go (Type
p Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
ps) Expr
e
    go [Type]
ps Expr
e               = ([Type]
ps, Expr
e)

destETAbs :: Expr -> ([TParam], Expr)
destETAbs :: Expr -> ([TParam], Expr)
destETAbs = [TParam] -> Expr -> ([TParam], Expr)
go []
  where
    go :: [TParam] -> Expr -> ([TParam], Expr)
go [TParam]
ts (ETAbs TParam
t Expr
e) = [TParam] -> Expr -> ([TParam], Expr)
go (TParam
t TParam -> [TParam] -> [TParam]
forall a. a -> [a] -> [a]
: [TParam]
ts) Expr
e
    go [TParam]
ts Expr
e           = ([TParam]
ts, Expr
e)

-- Any top-level declarations in the current module can be found in the
-- ModuleEnv's LoadedModules, and so we can count of freshName to avoid
-- collisions with them.  Any generated name for a
-- specialized function will be qualified with the current 'ModName', so genned
-- names will not collide with local decls either.
-- freshName :: Name -> [Type] -> SpecM Name
-- freshName n [] = return n
-- freshName (QName m name) tys = do
--   let name' = reifyName name tys
--   bNames <- matchingBoundNames m
--   let loop i = let nm = name' ++ "_" ++ show i
--                  in if nm `elem` bNames
--                       then loop $ i + 1
--                       else nm
--   let go = if name' `elem` bNames
--                then loop (1 :: Integer)
--                else name'
--   return $ QName m (mkName go)

-- | Freshen a name by giving it a new unique.
freshName :: Name -> [Type] -> SpecM Name
freshName :: Name -> [Type] -> SpecM Name
freshName Name
n [Type]
_ =
  case Name -> NameInfo
nameInfo Name
n of
    Declared ModName
ns NameSource
s -> (Supply -> (Name, Supply)) -> SpecM Name
forall (m :: * -> *) a. FreshM m => (Supply -> (a, Supply)) -> m a
liftSupply (ModName
-> NameSource
-> Ident
-> Maybe Fixity
-> Range
-> Supply
-> (Name, Supply)
mkDeclared ModName
ns NameSource
s Ident
ident Maybe Fixity
fx Range
loc)
    NameInfo
Parameter     -> (Supply -> (Name, Supply)) -> SpecM Name
forall (m :: * -> *) a. FreshM m => (Supply -> (a, Supply)) -> m a
liftSupply (Ident -> Range -> Supply -> (Name, Supply)
mkParameter Ident
ident Range
loc)
  where
  fx :: Maybe Fixity
fx    = Name -> Maybe Fixity
nameFixity Name
n
  ident :: Ident
ident = Name -> Ident
nameIdent Name
n
  loc :: Range
loc   = Name -> Range
nameLoc Name
n

-- matchingBoundNames :: (Maybe ModName) -> SpecM [String]
-- matchingBoundNames m = do
--   qns <- allPublicNames <$> liftSpecT M.getModuleEnv
--   return [ unpack n | QName m' (Name n) <- qns , m == m' ]

-- reifyName :: Name -> [Type] -> String
-- reifyName name tys = intercalate "_" (showName name : concatMap showT tys)
--   where
--     tvInt (TVFree i _ _ _) = i
--     tvInt (TVBound i _) = i
--     showT typ =
--       case typ of
--         TCon tc ts  -> showTCon tc : concatMap showT ts
--         TUser _ _ t -> showT t
--         TVar tv     -> [ "a" ++ show (tvInt tv) ]
--         TRec tr     -> "rec" : concatMap showRecFld tr
--     showTCon tCon =
--       case tCon of
--         TC tc -> showTC tc
--         PC pc -> showPC pc
--         TF tf -> showTF tf
--     showPC pc =
--       case pc of
--         PEqual   -> "eq"
--         PNeq     -> "neq"
--         PGeq     -> "geq"
--         PFin     -> "fin"
--         PHas sel -> "sel_" ++ showSel sel
--         PArith   -> "arith"
--         PCmp     -> "cmp"
--     showTC tc =
--       case tc of
--         TCNum n     -> show n
--         TCInf       -> "inf"
--         TCBit       -> "bit"
--         TCSeq       -> "seq"
--         TCFun       -> "fun"
--         TCTuple n   -> "t" ++ show n
--         TCNewtype _ -> "user"
--     showSel sel = intercalate "_" $
--       case sel of
--         TupleSel  _ sig -> "tup"  : maybe [] ((:[]) . show) sig
--         RecordSel x sig -> "rec"  : showName x : map showName (maybe [] id sig)
--         ListSel   _ sig -> "list" : maybe [] ((:[]) . show) sig
--     showName nm =
--       case nm of
--         Name s       -> unpack s
--         NewName _ n -> "x" ++ show n
--     showTF tf =
--       case tf of
--         TCAdd           -> "add"
--         TCSub           -> "sub"
--         TCMul           -> "mul"
--         TCDiv           -> "div"
--         TCMod           -> "mod"
--         TCExp           -> "exp"
--         TCWidth         -> "width"
--         TCMin           -> "min"
--         TCMax           -> "max"
--         TCLenFromThenTo -> "len_from_then_to"
--     showRecFld (nm,t) = showName nm : showT t



instantiateSchema :: [Type] -> Int -> Schema -> SpecM Schema
instantiateSchema :: [Type] -> Int -> Schema -> SpecM Schema
instantiateSchema [Type]
ts Int
n (Forall [TParam]
params [Type]
props Type
ty)
  | [TParam] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TParam]
params Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts = String -> SpecM Schema
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"instantiateSchema: wrong number of type arguments"
  | [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
props Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n          = String -> SpecM Schema
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"instantiateSchema: wrong number of prop arguments"
  | Bool
otherwise                  = Schema -> SpecM Schema
forall (m :: * -> *) a. Monad m => a -> m a
return (Schema -> SpecM Schema) -> Schema -> SpecM Schema
forall a b. (a -> b) -> a -> b
$ [TParam] -> [Type] -> Type -> Schema
Forall [] [] (Subst -> Type -> Type
forall t. TVars t => Subst -> t -> t
apSubst Subst
sub Type
ty)
  where sub :: Subst
sub = [(TParam, Type)] -> Subst
listParamSubst ([TParam] -> [Type] -> [(TParam, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TParam]
params [Type]
ts)

-- | Reduce @length ts@ outermost type abstractions and @n@ proof abstractions.
instantiateExpr :: [Type] -> Int -> Expr -> SpecM Expr
instantiateExpr :: [Type] -> Int -> Expr -> SpecT IO Expr
instantiateExpr [] Int
0 Expr
e = Expr -> SpecT IO Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e
instantiateExpr [] Int
n (EProofAbs Type
_ Expr
e) = [Type] -> Int -> Expr -> SpecT IO Expr
instantiateExpr [] (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Expr
e
instantiateExpr (Type
t : [Type]
ts) Int
n (ETAbs TParam
param Expr
e) =
  [Type] -> Int -> Expr -> SpecT IO Expr
instantiateExpr [Type]
ts Int
n (Subst -> Expr -> Expr
forall t. TVars t => Subst -> t -> t
apSubst (TParam -> Type -> Subst
singleTParamSubst TParam
param Type
t) Expr
e)
instantiateExpr [Type]
_ Int
_ Expr
_ = String -> SpecT IO Expr
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"instantiateExpr: wrong number of type/proof arguments"



allDeclGroups :: M.ModuleEnv -> [DeclGroup]
allDeclGroups :: ModuleEnv -> [DeclGroup]
allDeclGroups =
    (Module -> [DeclGroup]) -> [Module] -> [DeclGroup]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Module -> [DeclGroup]
mDecls
  ([Module] -> [DeclGroup])
-> (ModuleEnv -> [Module]) -> ModuleEnv -> [DeclGroup]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModuleEnv -> [Module]
M.loadedModules


traverseSnd :: Functor f => (b -> f c) -> (a, b) -> f (a, c)
traverseSnd :: (b -> f c) -> (a, b) -> f (a, c)
traverseSnd b -> f c
f (a
x, b
y) = (,) a
x (c -> (a, c)) -> f c -> f (a, c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> b -> f c
f b
y