{-# OPTIONS_GHC -Wno-missing-export-lists #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE ViewPatterns #-}
#if __GLASGOW_HASKELL__ >= 806
{-# LANGUAGE QuantifiedConstraints #-}
#endif
#include "MachDeps.h"
module Data.TypeRepMap.Internal where
import Prelude hiding (lookup)
import Control.DeepSeq
import Control.Monad.ST (ST, runST)
import Control.Monad.Zip (mzip)
import Data.Function (on)
import Data.Kind (Type)
import Data.Type.Equality ((:~:) (..), TestEquality (..))
import Data.List (intercalate, nubBy)
import Data.Maybe (fromMaybe)
import Data.Primitive.Array (Array, MutableArray, indexArray, mapArray', readArray, sizeofArray,
thawArray, unsafeFreezeArray, writeArray)
import Data.Primitive.PrimArray (PrimArray, indexPrimArray, sizeofPrimArray)
import Data.Semigroup (Semigroup (..), All(..))
import GHC.Base (Any, Int (..), Int#, (*#), (+#), (<#))
import GHC.Exts (IsList (..), inline, sortWith)
import GHC.Fingerprint (Fingerprint (..))
#if WORD_SIZE_IN_BITS >= 64
import GHC.Prim (eqWord#, ltWord#)
#else
import GHC.IntWord64 (eqWord64#, ltWord64#)
#define eqWord eqWord64
#define ltWord ltWord64
#endif
import GHC.Word (Word64 (..))
import Type.Reflection (SomeTypeRep (..), TypeRep, Typeable, typeRep, withTypeable)
import Type.Reflection.Unsafe (typeRepFingerprint)
import Unsafe.Coerce (unsafeCoerce)
import qualified Data.Map.Strict as Map
import qualified GHC.Exts as GHC (fromList, toList)
data TypeRepMap (f :: k -> Type) =
TypeRepMap
{ fingerprintAs :: {-# UNPACK #-} !(PrimArray Word64)
, fingerprintBs :: {-# UNPACK #-} !(PrimArray Word64)
, trAnys :: {-# UNPACK #-} !(Array Any)
, trKeys :: {-# UNPACK #-} !(Array Any)
}
instance NFData (TypeRepMap f) where
rnf x = rnf (keys x) `seq` ()
instance Show (TypeRepMap f) where
show TypeRepMap{..} = "TypeRepMap [" ++ showKeys ++ "]"
where
showKeys :: String
showKeys = intercalate ", " $ toList $ mapArray' (show . anyToTypeRep) trKeys
instance Semigroup (TypeRepMap f) where
(<>) :: TypeRepMap f -> TypeRepMap f -> TypeRepMap f
(<>) = union
{-# INLINE (<>) #-}
instance Monoid (TypeRepMap f) where
mempty = TypeRepMap mempty mempty mempty mempty
mappend = (<>)
{-# INLINE mempty #-}
{-# INLINE mappend #-}
#if __GLASGOW_HASKELL__ >= 806
instance (forall a. Typeable a => Eq (f a)) => Eq (TypeRepMap f) where
tm1 == tm2 = size tm1 == size tm2 && go 0
where
go :: Int -> Bool
go i
| i == size tm1 = True
| otherwise = case testEquality tr1i tr2i of
Nothing -> False
Just Refl -> repEq tr1i (fromAny tv1i) (fromAny tv2i) && go (i + 1)
where
tr1i :: TypeRep x
tr1i = anyToTypeRep $ indexArray (trKeys tm1) i
tr2i :: TypeRep y
tr2i = anyToTypeRep $ indexArray (trKeys tm2) i
tv1i, tv2i :: Any
tv1i = indexArray (trAnys tm1) i
tv2i = indexArray (trAnys tm2) i
repEq :: TypeRep x -> f x -> f x -> Bool
repEq tr = withTypeable tr (==)
#endif
toFingerprints :: TypeRepMap f -> [Fingerprint]
toFingerprints TypeRepMap{..} =
zipWith Fingerprint (GHC.toList fingerprintAs) (GHC.toList fingerprintBs)
empty :: TypeRepMap f
empty = mempty
{-# INLINE empty #-}
one :: forall a f . Typeable a => f a -> TypeRepMap f
one x = insert x empty
{-# INLINE one #-}
insert :: forall a f . Typeable a => f a -> TypeRepMap f -> TypeRepMap f
insert x = fromTriples . addX . toTriples
where
tripleX :: (Fingerprint, Any, Any)
tripleX@(fpX, _, _) = (calcFp @a, toAny x, unsafeCoerce $ typeRep @a)
addX :: [(Fingerprint, Any, Any)] -> [(Fingerprint, Any, Any)]
addX l = tripleX : deleteByFst fpX l
{-# INLINE insert #-}
type KindOf (a :: k) = k
delete :: forall a (f :: KindOf a -> Type) . Typeable a => TypeRepMap f -> TypeRepMap f
delete = fromTriples . deleteByFst (typeFp @a) . toTriples
{-# INLINE delete #-}
adjust :: forall a f . Typeable a => (f a -> f a) -> TypeRepMap f -> TypeRepMap f
adjust fun tr = case cachedBinarySearch (typeFp @a) (fingerprintAs tr) (fingerprintBs tr) of
Nothing -> tr
Just i -> tr {trAnys = changeAnyArr i (trAnys tr)}
where
changeAnyArr :: Int -> Array Any -> Array Any
changeAnyArr i trAs = runST $ do
let n = sizeofArray trAs
mutArr <- thawArray trAs 0 n
a <- toAny . fun . fromAny <$> readArray mutArr i
writeArray mutArr i a
unsafeFreezeArray mutArr
{-# INLINE adjust #-}
hoist :: (forall x. f x -> g x) -> TypeRepMap f -> TypeRepMap g
hoist f (TypeRepMap as bs ans ks) = TypeRepMap as bs (mapArray' (toAny . f . fromAny) ans) ks
{-# INLINE hoist #-}
hoistA :: (Applicative t) => (forall x. f x -> t (g x)) -> TypeRepMap f -> t (TypeRepMap g)
hoistA f (TypeRepMap as bs (toList -> ans) ks) = (\l -> TypeRepMap as bs (fromList $ map toAny l) ks)
<$> traverse (f . fromAny) ans
{-# INLINE hoistA #-}
hoistWithKey :: forall f g. (forall x. Typeable x => f x -> g x) -> TypeRepMap f -> TypeRepMap g
hoistWithKey f (TypeRepMap as bs ans ks) = TypeRepMap as bs newAns ks
where
newAns = mapArray' mapAns (mzip ans ks)
mapAns (a, k) = toAny $ withTr (unsafeCoerce k) $ fromAny a
withTr :: forall x. TypeRep x -> f x -> g x
withTr t = withTypeable t f
{-# INLINE hoistWithKey #-}
unionWith :: forall f. (forall x. Typeable x => f x -> f x -> f x) -> TypeRepMap f -> TypeRepMap f -> TypeRepMap f
unionWith f m1 m2 = fromTriples
$ toTripleList
$ Map.unionWith combine
(fromTripleList $ toTriples m1)
(fromTripleList $ toTriples m2)
where
f' :: forall x. TypeRep x -> f x -> f x -> f x
f' tr = withTypeable tr f
combine :: (Any, Any) -> (Any, Any) -> (Any, Any)
combine (av, ak) (bv, _) = (toAny $ f' (fromAny ak) (fromAny av) (fromAny bv), ak)
fromTripleList :: Ord a => [(a, b, c)] -> Map.Map a (b, c)
fromTripleList = Map.fromList . map (\(a, b, c) -> (a, (b, c)))
toTripleList :: Map.Map a (b, c) -> [(a, b, c)]
toTripleList = map (\(a, (b, c)) -> (a, b, c)) . Map.toList
{-# INLINE unionWith #-}
union :: TypeRepMap f -> TypeRepMap f -> TypeRepMap f
union = unionWith const
{-# INLINE union #-}
member :: forall a (f :: KindOf a -> Type) . Typeable a => TypeRepMap f -> Bool
member tm = case lookup @a tm of
Nothing -> False
Just _ -> True
{-# INLINE member #-}
lookup :: forall a f . Typeable a => TypeRepMap f -> Maybe (f a)
lookup tVect = fromAny . (trAnys tVect `indexArray`)
<$> cachedBinarySearch (typeFp @a)
(fingerprintAs tVect)
(fingerprintBs tVect)
{-# INLINE lookup #-}
size :: TypeRepMap f -> Int
size = sizeofPrimArray . fingerprintAs
{-# INLINE size #-}
keys :: TypeRepMap f -> [SomeTypeRep]
keys TypeRepMap{..} = SomeTypeRep . anyToTypeRep <$> toList trKeys
{-# INLINE keys #-}
cachedBinarySearch :: Fingerprint -> PrimArray Word64 -> PrimArray Word64 -> Maybe Int
cachedBinarySearch (Fingerprint (W64# a) (W64# b)) fpAs fpBs = inline (go 0#)
where
go :: Int# -> Maybe Int
go i = case i <# len of
0# -> Nothing
_ -> let !(W64# valA) = indexPrimArray fpAs (I# i) in case a `ltWord#` valA of
0# -> case a `eqWord#` valA of
0# -> go (2# *# i +# 2#)
_ -> let !(W64# valB) = indexPrimArray fpBs (I# i) in case b `eqWord#` valB of
0# -> case b `ltWord#` valB of
0# -> go (2# *# i +# 2#)
_ -> go (2# *# i +# 1#)
_ -> Just (I# i)
_ -> go (2# *# i +# 1#)
len :: Int#
len = let !(I# l) = sizeofPrimArray fpAs in l
{-# INLINE cachedBinarySearch #-}
toAny :: f a -> Any
toAny = unsafeCoerce
fromAny :: Any -> f a
fromAny = unsafeCoerce
anyToTypeRep :: Any -> TypeRep f
anyToTypeRep = unsafeCoerce
typeFp :: forall a . Typeable a => Fingerprint
typeFp = typeRepFingerprint $ typeRep @a
{-# INLINE typeFp #-}
toTriples :: TypeRepMap f -> [(Fingerprint, Any, Any)]
toTriples tm = zip3 (toFingerprints tm) (GHC.toList $ trAnys tm) (GHC.toList $ trKeys tm)
deleteByFst :: Eq a => a -> [(a, b, c)] -> [(a, b, c)]
deleteByFst x = filter ((/= x) . fst3)
nubByFst :: (Eq a) => [(a, b, c)] -> [(a, b, c)]
nubByFst = nubBy ((==) `on` fst3)
fst3 :: (a, b, c) -> a
fst3 (a, _, _) = a
data WrapTypeable f where
WrapTypeable :: Typeable a => f a -> WrapTypeable f
instance Show (WrapTypeable f) where
show (WrapTypeable (_ :: f a)) = show $ calcFp @a
wrapTypeable :: TypeRep a -> f a -> WrapTypeable f
wrapTypeable tr = withTypeable tr WrapTypeable
instance IsList (TypeRepMap f) where
type Item (TypeRepMap f) = WrapTypeable f
fromList :: [WrapTypeable f] -> TypeRepMap f
fromList = fromTriples . map (\x -> (fp x, an x, k x))
where
fp :: WrapTypeable f -> Fingerprint
fp (WrapTypeable (_ :: f a)) = calcFp @a
an :: WrapTypeable f -> Any
an (WrapTypeable x) = toAny x
k :: WrapTypeable f -> Any
k (WrapTypeable (_ :: f a)) = unsafeCoerce $ typeRep @a
toList :: TypeRepMap f -> [WrapTypeable f]
toList = map toWrapTypeable . toTriples
where
toWrapTypeable :: (Fingerprint, Any, Any) -> WrapTypeable f
toWrapTypeable (_, an, k) = wrapTypeable (unsafeCoerce k) (fromAny an)
calcFp :: forall a . Typeable a => Fingerprint
calcFp = typeRepFingerprint $ typeRep @a
fromTriples :: [(Fingerprint, Any, Any)] -> TypeRepMap f
fromTriples kvs = TypeRepMap (GHC.fromList fpAs) (GHC.fromList fpBs) (GHC.fromList ans) (GHC.fromList ks)
where
(fpAs, fpBs) = unzip $ map (\(Fingerprint a b) -> (a, b)) fps
(fps, ans, ks) = unzip3 $ fromSortedList $ sortWith fst3 $ nubByFst kvs
fromSortedList :: forall a . [a] -> [a]
fromSortedList l = runST $ do
let n = length l
let arrOrigin = fromListN n l
arrResult <- thawArray arrOrigin 0 n
go n arrResult arrOrigin
toList <$> unsafeFreezeArray arrResult
where
go :: forall s . Int -> MutableArray s a -> Array a -> ST s ()
go len result origin = () <$ loop 0 0
where
loop :: Int -> Int -> ST s Int
loop i first =
if i >= len
then pure first
else do
newFirst <- loop (2 * i + 1) first
writeArray result i (indexArray origin newFirst)
loop (2 * i + 2) (newFirst + 1)
invariantCheck :: TypeRepMap f -> Bool
invariantCheck TypeRepMap{..} = getAll (check 0)
where
lastMay [] = Nothing
lastMay [x] = Just x
lastMay (_:xs) = lastMay xs
sz = sizeofPrimArray fingerprintAs
check i | i >= sz = All True
| otherwise =
let left = i*2+1
right = i*2+2
leftMax =
fmap (\j -> (indexPrimArray fingerprintAs j, indexPrimArray fingerprintBs j))
$ lastMay
$ takeWhile (<sz)
$ iterate (\j -> j*2+2) left
rightMin =
fmap (\j -> (indexPrimArray fingerprintAs j, indexPrimArray fingerprintBs j))
$ lastMay
$ takeWhile (<sz)
$ iterate (\j -> j*2+1) right
in mconcat
[ All $
if left < sz
then
case indexPrimArray fingerprintAs i `compare` indexPrimArray fingerprintAs left of
LT -> False
EQ -> indexPrimArray fingerprintBs i >= indexPrimArray fingerprintBs left
GT -> True
else True
, All $
if right < sz
then
case indexPrimArray fingerprintAs i `compare` indexPrimArray fingerprintAs right of
LT -> True
EQ -> indexPrimArray fingerprintBs i <= indexPrimArray fingerprintBs right
GT -> False
else True
, All $ fromMaybe True $ (<=) <$> leftMax <*> rightMin
, check (i+1)
]