{-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE RankNTypes #-} -- | -- Module : Data.Diagram.Simple -- Copyright : (c) Eddie Jones 2020 -- License : BSD-3 -- Maintainer : eddiejones2108@gmail.com -- Stability : experimental -- -- Zero-suppressed Binary Decision Diagrams module Data.Diagram.ZeroSup ( -- * Diagram Diagram, runDiagram, -- * Families of Sets Family, mkFamily, empty, base, change, subset, bindElem, -- * Combinations intersect, union, difference, -- * Summary fold, anySat, ) where import Control.Monad import Control.Monad.State import qualified Data.Diagram as D import Data.Functor.Classes import Data.Functor.Identity import Data.Hashable import Data.Hashable.Lifted import qualified Data.Map as M data Memo l s = Intersect (Family l s) (Family l s) | Union (Family l s) (Family l s) | Difference (Family l s) (Family l s) | Subset l Bool (Family l s) | Change l (Family l s) deriving (Eq, Ord) -- Run an operation if not in the cache memo :: Ord l => Memo l s -> Diagram l s (Family l s) -> Diagram l s (Family l s) memo m d = do cache <- Diagram $ lift get case M.lookup m cache of Just r -> return r Nothing -> do r <- d Diagram $ lift $ put (M.insert m r cache) return r -- | A binary decision diagram newtype Diagram l s a = Diagram { unDiag :: D.Diagram (Node l) Bool s (State (M.Map (Memo l s) (Family l s))) a } deriving (Functor, Applicative, Monad) -- | Extract non-diagrammatic information runDiagram :: (forall s. Diagram l s a) -> a runDiagram d = runIdentity $ D.runDiagram ((`evalState` M.empty) <$> D.compress (unDiag d)) -- | A diagramatic family of sets build on atomic elements of type @l@ newtype Family l s = Family { unFamily :: D.Free (Node l) Bool s } deriving (Eq, Ord) -- | An internal node of the diagram data Node l k = Node { label :: l, lo :: k, hi :: k } deriving (Functor, Foldable, Traversable) instance Eq l => Eq1 (Node l) where liftEq eq n m = label n == label m && eq (lo n) (lo m) && eq (hi n) (hi m) instance Hashable l => Hashable1 (Node l) where liftHashWithSalt l s n = hashWithSalt s (label n, l s (lo n), l s (hi n)) -- | Make a family (if not already present) from it's hi and lo cases mkFamily :: (Eq l, Hashable l) => l -> Family l s -> Family l s -> Diagram l s (Family l s) mkFamily l lo hi | hi == base = return lo | otherwise = Diagram (Family <$> D.free Node {label = l, lo = unFamily lo, hi = unFamily hi}) -- | Map elements in a family mapAtom :: (Eq l, Hashable l) => (l -> l) -> Family l s -> Diagram l s (Family l s) mapAtom f (Family p) = Diagram $ D.fold ( \n -> unDiag $ do lo' <- mapAtom f (lo n) hi' <- mapAtom f (hi n) mkFamily (f $ label n) lo' hi' ) (return . Family . D.Pure) p -- | Replace an element with a family of sets bindElem :: (Ord l, Hashable l) => Family l s -> (l -> Diagram l s (Family l s)) -> Diagram l s (Family l s) bindElem p f = Diagram $ D.fold ( \n -> unDiag $ do a <- f (label n) b <- a `intersect` hi n b `union` lo n ) (return . Family . D.Pure) (unFamily p) -- | Simple families empty, base :: Family l s empty = Family $ D.Pure True base = Family $ D.Pure False -- | Subsets that do or do not contain a particular element subset :: (Ord l, Hashable l) => l -> Bool -> Family l s -> Diagram l s (Family l s) subset var b p = memo (Subset var b p) $ Diagram $ D.fromFree (unFamily p) ( \n -> unDiag $ case compare (label n) var of LT -> return empty EQ -> return $ if b then Family $ hi n else Family $ lo n GT -> do lo' <- subset var b (Family $ lo n) hi' <- subset var b (Family $ hi n) mkFamily (label n) lo' hi' ) (return . Family . D.Pure) -- | Flip an element in a family change :: (Ord l, Hashable l) => l -> Family l s -> Diagram l s (Family l s) change var p = memo (Change var p) $ Diagram $ D.fromFree (unFamily p) ( \n -> unDiag $ case compare (label n) var of LT -> mkFamily var empty p EQ -> mkFamily var (Family $ hi n) (Family $ lo n) GT -> do lo' <- change var (Family $ lo n) hi' <- change var (Family $ hi n) mkFamily (label n) lo' hi' ) (return . Family . D.Pure) setLeftMost :: (Eq l, Hashable l) => Bool -> Family l s -> Diagram l s (Family l s) setLeftMost b p = Diagram $ D.fromFree (unFamily p) ( \n -> unDiag $ do lo' <- setLeftMost b (Family $ lo n) mkFamily (label n) lo' (Family $ hi n) ) (\_ -> return $ Family $ D.Pure b) flipLeftMost :: (Eq l, Hashable l) => Family l s -> Diagram l s (Family l s) flipLeftMost p = Diagram $ D.fromFree (unFamily p) ( \n -> unDiag $ do lo' <- flipLeftMost (Family $ lo n) mkFamily (label n) lo' (Family $ hi n) ) (return . Family . D.Pure . not) getLeftMost :: Family l s -> Diagram l s (Family l s) getLeftMost p = Diagram $ D.fromFree (unFamily p) (unDiag . getLeftMost . lo . fmap Family) (return . Family . D.Pure) -- | The union of families union :: (Ord l, Hashable l) => Family l s -> Family l s -> Diagram l s (Family l s) union p q | p == q = return p union p q | p == empty = return q union p q | q == empty = return p union p q | p == base = setLeftMost True q union p q | p == base = setLeftMost True q union p q = memo (Union p q) $ Diagram $ D.fromFree (unFamily p) ( \n -> D.fromFree (unFamily q) ( \m -> unDiag $ case compare (label n) (label m) of LT -> do lo' <- p `union` Family (lo m) mkFamily (label m) lo' (Family $ hi m) EQ -> do lo' <- Family (lo n) `union` Family (lo m) hi' <- Family (hi n) `union` Family (hi m) mkFamily (label n) lo' hi' GT -> do lo' <- Family (lo n) `union` p mkFamily (label n) lo' (Family $ hi n) ) (\b -> error "Unreachable!") ) (\b -> error "Unreachable!") -- | The intersection of families intersect :: (Ord l, Hashable l) => Family l s -> Family l s -> Diagram l s (Family l s) intersect p q | p == q = return p intersect p q | p == empty = return empty intersect p q | q == empty = return empty intersect p q | p == base = getLeftMost q intersect p q | q == base = getLeftMost p intersect p q = memo (Intersect p q) $ Diagram $ D.fromFree (unFamily p) ( \n -> D.fromFree (unFamily q) ( \m -> unDiag $ case compare (label n) (label m) of LT -> p `intersect` Family (lo m) EQ -> do lo' <- Family (lo n) `intersect` Family (lo m) hi' <- Family (hi n) `intersect` Family (hi m) mkFamily (label n) lo' hi' GT -> Family (lo n) `intersect` q ) (\b -> error "Unreachable!") ) (\b -> error "Unreachable!") -- | The difference between families difference :: (Ord l, Hashable l) => Family l s -> Family l s -> Diagram l s (Family l s) difference p q | p == q = return empty difference p q | p == empty = return empty difference p q | q == empty = return p difference p q | p == base = flipLeftMost q difference p q | q == base = setLeftMost False p difference p q = memo (Difference p q) $ Diagram $ D.fromFree (unFamily p) ( \n -> D.fromFree (unFamily q) ( \m -> unDiag $ case compare (label n) (label m) of LT -> p `difference` Family (lo m) EQ -> do lo' <- Family (lo n) `difference` Family (lo m) hi' <- Family (hi n) `difference` Family (hi m) mkFamily (label n) lo' hi' GT -> do lo' <- Family (lo n) `difference` q mkFamily (label n) lo' (Family $ hi n) ) (\b -> error "Unreachable!") ) (\b -> error "Unreachable!") -- | Determine if the family is empty -- -- > anySat = fold (\_ p q -> p || q) id anySat :: (Hashable l, Eq l) => Family l s -> Diagram l s Bool anySat = fold (\_ p q -> return (p || q)) return -- | Create a summary value of a family fold :: (Hashable l, Eq l) => (l -> b -> b -> Diagram l s b) -> (Bool -> Diagram l s b) -> Family l s -> Diagram l s b fold f g (Family p) = Diagram $ D.fold (\n -> unDiag $ f (label n) (lo n) (hi n)) (unDiag . g) p