{-# LANGUAGE GADTs, TypeFamilies, TypeOperators, ScopedTypeVariables, CPP #-}
{-# LANGUAGE StandaloneDeriving, FlexibleInstances #-}
{-# LANGUAGE DefaultSignatures, FlexibleContexts, LambdaCase #-}
{-# OPTIONS_GHC -Wall -fenable-rewrite-rules #-}
module Data.MemoTrie
( HasTrie(..), (:->:)(..)
, domain, idTrie, (@.@)
, memo, memo2, memo3, mup
, inTrie, inTrie2, inTrie3
, trieGeneric, untrieGeneric, enumerateGeneric, Reg
, memoFix
) where
import Data.Function (fix)
import Data.Bits
import Data.Word
import Data.Int
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative
#endif
import Control.Arrow (first,(&&&))
#if !MIN_VERSION_base(4,8,0)
import Data.Monoid
#endif
import Data.Function (on)
import GHC.Generics
import Control.Newtype.Generics
import Data.Void (Void)
infixr 0 :->:
class HasTrie a where
data (:->:) a :: * -> *
trie :: (a -> b) -> (a :->: b)
untrie :: (a :->: b) -> (a -> b)
enumerate :: (a :->: b) -> [(a,b)]
domain :: HasTrie a => [a]
domain = map fst (enumerate (trie (const oops)))
where
oops = error "Data.MemoTrie.domain: range element evaluated."
instance (HasTrie a, Eq b) => Eq (a :->: b) where
(==) = (==) `on` (map snd . enumerate)
instance (HasTrie a, Show a, Show b) => Show (a :->: b) where
show t = "Trie: " ++ show (enumerate t)
memo :: HasTrie t => (t -> a) -> (t -> a)
memo = untrie . trie
memo2 :: (HasTrie s,HasTrie t) => (s -> t -> a) -> (s -> t -> a)
memo3 :: (HasTrie r,HasTrie s,HasTrie t) => (r -> s -> t -> a) -> (r -> s -> t -> a)
mup :: HasTrie t => (b -> c) -> (t -> b) -> (t -> c)
mup mem f = memo (mem . f)
memo2 = mup memo
memo3 = mup memo2
memoFix :: HasTrie a => ((a -> b) -> (a -> b)) -> (a -> b)
memoFix h = fix (memo . h)
#if 0
memoFix h = fix (\ f' -> memo (h f'))
memoFix h = f'
where f' = memo (h f')
memoFix h = f'
where
f' = memo f
f = h f'
#endif
#if 0
fibF :: (Integer -> Integer) -> (Integer -> Integer)
fibF _ 0 = 1
fibF _ 1 = 1
fibF f n = f (n-1) + f (n-2)
fib :: Integer -> Integer
fib = fix fibF
fib' :: Integer -> Integer
fib' = memoFix fibF
#endif
inTrie :: (HasTrie a, HasTrie c) =>
((a -> b) -> (c -> d))
-> ((a :->: b) -> (c :->: d))
inTrie = untrie ~> trie
inTrie2 :: (HasTrie a, HasTrie c, HasTrie e) =>
((a -> b) -> (c -> d) -> (e -> f))
-> ((a :->: b) -> (c :->: d) -> (e :->: f))
inTrie2 = untrie ~> inTrie
inTrie3 :: (HasTrie a, HasTrie c, HasTrie e, HasTrie g) =>
((a -> b) -> (c -> d) -> (e -> f) -> (g -> h))
-> ((a :->: b) -> (c :->: d) -> (e :->: f) -> (g :->: h))
inTrie3 = untrie ~> inTrie2
instance HasTrie Void where
data Void :->: a = VoidTrie
trie _ = VoidTrie
untrie VoidTrie = \ _ -> error "untrie VoidTrie"
enumerate VoidTrie = []
instance Newtype (Void :->: a) where
type O (Void :->: a) = ()
pack () = VoidTrie
unpack VoidTrie = ()
instance HasTrie () where
newtype () :->: a = UnitTrie a
trie f = UnitTrie (f ())
untrie (UnitTrie a) = \ () -> a
enumerate (UnitTrie a) = [((),a)]
instance Newtype (() :->: a) where
type O (() :->: a) = a
pack a = UnitTrie a
unpack (UnitTrie a) = a
instance HasTrie Bool where
data Bool :->: x = BoolTrie x x
trie f = BoolTrie (f False) (f True)
untrie (BoolTrie f t) = if' f t
enumerate (BoolTrie f t) = [(False,f),(True,t)]
instance Newtype (Bool :->: a) where
type O (Bool :->: a) = (a,a)
pack (a,a') = BoolTrie a a'
unpack (BoolTrie a a') = (a,a')
if' :: x -> x -> Bool -> x
if' t _ False = t
if' _ e True = e
instance HasTrie a => HasTrie (Maybe a) where
data (:->:) (Maybe a) b = MaybeTrie b (a :->: b)
trie f = MaybeTrie (f Nothing) (trie (f . Just))
untrie (MaybeTrie nothing_val a_trie) = maybe nothing_val (untrie a_trie)
enumerate (MaybeTrie nothing_val a_trie) = (Nothing, nothing_val) : enum' Just a_trie
instance Newtype (Maybe a :->: x) where
type O (Maybe a :->: x) = (x, a :->: x)
pack (a,f) = MaybeTrie a f
unpack (MaybeTrie a f) = (a,f)
instance (HasTrie a, HasTrie b) => HasTrie (Either a b) where
data (Either a b) :->: x = EitherTrie (a :->: x) (b :->: x)
trie f = EitherTrie (trie (f . Left)) (trie (f . Right))
untrie (EitherTrie s t) = either (untrie s) (untrie t)
enumerate (EitherTrie s t) = enum' Left s `weave` enum' Right t
instance Newtype (Either a b :->: x) where
type O (Either a b :->: x) = (a :->: x, b :->: x)
pack (f,g) = EitherTrie f g
unpack (EitherTrie f g) = (f,g)
enum' :: (HasTrie a) => (a -> a') -> (a :->: b) -> [(a', b)]
enum' f = (fmap.first) f . enumerate
weave :: [a] -> [a] -> [a]
[] `weave` as = as
as `weave` [] = as
(a:as) `weave` bs = a : (bs `weave` as)
instance (HasTrie a, HasTrie b) => HasTrie (a,b) where
newtype (a,b) :->: x = PairTrie (a :->: (b :->: x))
trie f = PairTrie (trie (trie . curry f))
untrie (PairTrie t) = uncurry (untrie . untrie t)
enumerate (PairTrie tt) =
[ ((a,b),x) | (a,t) <- enumerate tt , (b,x) <- enumerate t ]
instance Newtype ((a,b) :->: x) where
type O ((a,b) :->: x) = a :->: b :->: x
pack abx = PairTrie abx
unpack (PairTrie abx) = abx
instance (HasTrie a, HasTrie b, HasTrie c) => HasTrie (a,b,c) where
newtype (a,b,c) :->: x = TripleTrie (((a,b),c) :->: x)
trie f = TripleTrie (trie (f . trip))
untrie (TripleTrie t) = untrie t . detrip
enumerate (TripleTrie t) = enum' trip t
trip :: ((a,b),c) -> (a,b,c)
trip ((a,b),c) = (a,b,c)
detrip :: (a,b,c) -> ((a,b),c)
detrip (a,b,c) = ((a,b),c)
instance HasTrie x => HasTrie [x] where
newtype [x] :->: a = ListTrie (Either () (x,[x]) :->: a)
trie f = ListTrie (trie (f . list))
untrie (ListTrie t) = untrie t . delist
enumerate (ListTrie t) = enum' list t
list :: Either () (x,[x]) -> [x]
list = either (const []) (uncurry (:))
delist :: [x] -> Either () (x,[x])
delist [] = Left ()
delist (x:xs) = Right (x,xs)
#define WordInstance(Type,TrieType)\
instance HasTrie Type where \
newtype Type :->: a = TrieType ([Bool] :->: a);\
trie f = TrieType (trie (f . unbits));\
untrie (TrieType t) = untrie t . bits;\
enumerate (TrieType t) = enum' unbits t
WordInstance(Word,WordTrie)
WordInstance(Word8,Word8Trie)
WordInstance(Word16,Word16Trie)
WordInstance(Word32,Word32Trie)
WordInstance(Word64,Word64Trie)
bits :: (Num t, Bits t) => t -> [Bool]
bits 0 = []
bits x = testBit x 0 : bits (shiftR x 1)
unbit :: Num t => Bool -> t
unbit False = 0
unbit True = 1
unbits :: (Num t, Bits t) => [Bool] -> t
unbits [] = 0
unbits (x:xs) = unbit x .|. shiftL (unbits xs) 1
instance HasTrie Char where
newtype Char :->: a = CharTrie (Int :->: a)
untrie (CharTrie t) n = untrie t (fromEnum n)
trie f = CharTrie (trie (f . toEnum))
enumerate (CharTrie t) = enum' toEnum t
#define IntInstance(IntType,WordType,TrieType) \
instance HasTrie IntType where \
newtype IntType :->: a = TrieType (WordType :->: a); \
untrie (TrieType t) n = untrie t (fromIntegral n); \
trie f = TrieType (trie (f . fromIntegral)); \
enumerate (TrieType t) = enum' fromIntegral t
IntInstance(Int,Word,IntTrie)
IntInstance(Int8,Word8,Int8Trie)
IntInstance(Int16,Word16,Int16Trie)
IntInstance(Int32,Word32,Int32Trie)
IntInstance(Int64,Word64,Int64Trie)
instance HasTrie Integer where
newtype Integer :->: a = IntegerTrie ((Bool,[Bool]) :->: a)
trie f = IntegerTrie (trie (f . unbitsZ))
untrie (IntegerTrie t) = untrie t . bitsZ
enumerate (IntegerTrie t) = enum' unbitsZ t
unbitsZ :: (Num n, Bits n) => (Bool,[Bool]) -> n
unbitsZ (positive,bs) = sig (unbits bs)
where
sig | positive = id
| otherwise = negate
bitsZ :: (Num n, Ord n, Bits n) => n -> (Bool,[Bool])
bitsZ = (>= 0) &&& (bits . abs)
instance (HasTrie a, Monoid b) => Monoid (a :->: b) where
mempty = trie mempty
#if !MIN_VERSION_base(4,11,0)
mappend = inTrie2 mappend
#else
instance (HasTrie a, Semigroup b) => Semigroup (a :->: b) where
(<>) = inTrie2 (<>)
#endif
instance HasTrie a => Functor ((:->:) a) where
fmap f = inTrie (fmap f)
instance HasTrie a => Applicative ((:->:) a) where
pure b = trie (pure b)
(<*>) = inTrie2 (<*>)
instance HasTrie a => Monad ((:->:) a) where
return a = trie (return a)
u >>= k = trie (untrie u >>= untrie . k)
idTrie :: HasTrie a => a :->: a
idTrie = trie id
infixr 9 @.@
(@.@) :: (HasTrie a, HasTrie b) =>
(b :->: c) -> (a :->: b) -> (a :->: c)
(@.@) = inTrie2 (.)
(~>) :: (a' -> a) -> (b -> b') -> ((a -> b) -> (a' -> b'))
g ~> f = (f .) . (. g)
instance HasTrie (V1 x) where
data (V1 x :->: b) = V1Trie
trie _ = V1Trie
untrie V1Trie = \ _ -> error "untrie V1Trie"
enumerate V1Trie = []
instance HasTrie (U1 x) where
data (U1 x :->: b) = U1Trie b
trie f = U1Trie (f U1)
untrie (U1Trie b) = \U1 -> b
enumerate (U1Trie b) = [(U1, b)]
instance (HasTrie (f x), HasTrie (g x)) => HasTrie ((f :+: g) x) where
newtype ((f :+: g) x :->: b) = EitherTrie1 (Either (f x) (g x) :->: b)
trie f = EitherTrie1 (trie (f . liftSum))
untrie (EitherTrie1 t) = (untrie t) . dropSum
enumerate (EitherTrie1 t) = enum' liftSum t
instance (HasTrie (f x), HasTrie (g x)) => HasTrie ((f :*: g) x) where
newtype ((f :*: g) x :->: b) = PairTrie1 ((f x, g x) :->: b)
trie f = PairTrie1 (trie (f . liftProduct))
untrie (PairTrie1 t) = (untrie t) . dropProduct
enumerate (PairTrie1 t) = enum' liftProduct t
instance (HasTrie a) => HasTrie (K1 i a x) where
data (K1 i a x :->: b) = K1Trie (a :->: b)
trie f = K1Trie (trie (f . K1))
untrie (K1Trie t) = \(K1 a) -> (untrie t) a
enumerate (K1Trie t) = enum' K1 t
instance (HasTrie (f x)) => HasTrie (M1 i t f x) where
data (M1 i t f x :->: b) = M1Trie (f x :->: b)
trie f = M1Trie (trie (f . M1))
untrie (M1Trie t) = \(M1 a) -> (untrie t) a
enumerate (M1Trie t) = enum' M1 t
type Reg a = Rep a ()
trieGeneric :: (Generic a, HasTrie (Reg a))
=> ((Reg a :->: b) -> (a :->: b))
-> (a -> b)
-> (a :->: b)
trieGeneric theConstructor f = theConstructor (trie (f . to))
{-# INLINEABLE trieGeneric #-}
untrieGeneric :: (Generic a, HasTrie (Reg a))
=> ((a :->: b) -> (Reg a :->: b))
-> (a :->: b)
-> (a -> b)
untrieGeneric theDestructor t = \a -> (untrie (theDestructor t)) (from a)
{-# INLINEABLE untrieGeneric #-}
enumerateGeneric :: (Generic a, HasTrie (Reg a))
=> ((a :->: b) -> (Reg a :->: b))
-> (a :->: b)
-> [(a, b)]
enumerateGeneric theDestructor t = enum' to (theDestructor t)
{-# INLINEABLE enumerateGeneric #-}
dropProduct :: (f :*: g) a -> (f a, g a)
dropProduct (a :*: b) = (a, b)
{-# INLINEABLE dropProduct #-}
liftProduct :: (f a, g a) -> (f :*: g) a
liftProduct (a, b) = (a :*: b)
{-# INLINEABLE liftProduct #-}
dropSum :: (f :+: g) a -> Either (f a) (g a)
dropSum s = case s of
L1 x -> Left x
R1 x -> Right x
{-# INLINEABLE dropSum #-}
liftSum :: Either (f a) (g a) -> (f :+: g) a
liftSum = either L1 R1
{-# INLINEABLE liftSum #-}