{-# LANGUAGE GeneralizedNewtypeDeriving, ExistentialQuantification, RankNTypes #-}
module Data.SNMap (
SNMap,
SNMapReaderT,
runSNMapReaderT,
newSNMap,
memoize,
memoizeM,
scopedM
)where
import System.Mem.StableName
import qualified Data.HashTable.IO as HT
import Data.Functor
import Control.Monad.IO.Class (liftIO, MonadIO)
import Control.Monad.Trans.Class
import System.Mem.Weak (addFinalizer)
import Control.Applicative (Applicative)
import Control.Monad.Exception (MonadException, MonadAsyncException)
import Control.Monad.Trans.State.Strict
newtype SNMap m a = SNMap (HT.BasicHashTable (StableName (m a)) a)
newSNMap :: IO (SNMap m a)
newSNMap = SNMap <$> HT.new
memoize :: MonadIO m => m (SNMap m a) -> m a -> m a
memoize getter m = do s <- liftIO $ makeStableName $! m
(SNMap h) <- getter
x <- liftIO $ HT.lookup h s
case x of
Just a -> return a
Nothing -> do a <- m
(SNMap h') <- getter
liftIO $ HT.insert h' s a
return a
newtype SNMapReaderT a m b = SNMapReaderT (StateT (SNMap (SNMapReaderT a m) a) m b) deriving (Functor, Applicative, Monad, MonadIO, MonadException, MonadAsyncException)
runSNMapReaderT :: MonadIO m => SNMapReaderT a m b -> m b
runSNMapReaderT (SNMapReaderT m) = do h <- liftIO newSNMap
evalStateT m h
instance MonadTrans (SNMapReaderT a) where
lift = SNMapReaderT . lift
memoizeM :: MonadIO m => SNMapReaderT a m a -> SNMapReaderT a m a
memoizeM = memoize (SNMapReaderT get)
scopedM :: MonadIO m => SNMapReaderT a m x -> SNMapReaderT a m x
scopedM m= do SNMap h <- SNMapReaderT get
save <- liftIO $ HT.toList h
x <- m
h' <- liftIO $ HT.fromList save
SNMapReaderT $ put (SNMap h')
return x