{-# OPTIONS_HADDOCK not-home #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DoAndIfThenElse #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Hedgehog.Internal.Tree (
Tree
, pattern Tree
, TreeT(..)
, runTree
, mapTreeT
, treeValue
, treeChildren
, Node
, pattern Node
, NodeT(..)
, fromNodeT
, unfold
, unfoldForest
, expand
, prune
, catMaybes
, filter
, filterMaybeT
, filterT
, depth
, interleave
, render
, renderT
) where
import Control.Applicative (Alternative(..), liftA2)
import Control.Monad (MonadPlus(..), join)
import Control.Monad.Base (MonadBase(..))
import Control.Monad.Trans.Control ()
import Control.Monad.Catch (MonadThrow(..), MonadCatch(..), Exception)
import Control.Monad.Error.Class (MonadError(..))
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Morph (MFunctor(..), MMonad(..), generalize)
import Control.Monad.Primitive (PrimMonad(..))
import Control.Monad.Reader.Class (MonadReader(..))
import Control.Monad.State.Class (MonadState(..))
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.Maybe (MaybeT(..))
import Control.Monad.Trans.Resource (MonadResource(..))
import Control.Monad.Writer.Class (MonadWriter(..))
import Control.Monad.Zip (MonadZip(..))
import Data.Functor.Identity (Identity(..))
import Data.Functor.Classes (Eq1(..))
import Data.Functor.Classes (Show1(..), showsPrec1)
import Data.Functor.Classes (showsUnaryWith, showsBinaryWith)
import qualified Data.Maybe as Maybe
import Hedgehog.Internal.Distributive
import Prelude hiding (filter)
type Tree =
TreeT Identity
pattern Tree :: NodeT Identity a -> Tree a
pattern Tree node =
TreeT (Identity node)
{-# COMPLETE Tree #-}
newtype TreeT m a =
TreeT {
runTreeT :: m (NodeT m a)
}
type Node =
NodeT Identity
{-# COMPLETE Node #-}
pattern Node :: a -> [Tree a] -> Node a
pattern Node x xs =
NodeT x xs
data NodeT m a =
NodeT {
nodeValue :: a
, nodeChildren :: [TreeT m a]
} deriving (Eq)
runTree :: Tree a -> Node a
runTree =
runIdentity . runTreeT
mapTreeT :: (m (NodeT m a) -> m (NodeT m a)) -> TreeT m a -> TreeT m a
mapTreeT f =
TreeT . f . runTreeT
fromNodeT :: Applicative m => NodeT m a -> TreeT m a
fromNodeT =
TreeT . pure
treeValue :: Tree a -> a
treeValue =
nodeValue . runTree
treeChildren :: Tree a -> [Tree a]
treeChildren =
nodeChildren . runTree
unfold :: Monad m => (a -> [a]) -> a -> TreeT m a
unfold f x =
TreeT . pure $
NodeT x (unfoldForest f x)
unfoldForest :: Monad m => (a -> [a]) -> a -> [TreeT m a]
unfoldForest f =
fmap (unfold f) . f
expand :: Monad m => (a -> [a]) -> TreeT m a -> TreeT m a
expand f m =
TreeT $ do
NodeT x xs <- runTreeT m
pure . NodeT x $
fmap (expand f) xs ++ unfoldForest f x
prune :: Monad m => Int -> TreeT m a -> TreeT m a
prune n m =
if n <= 0 then
TreeT $ do
NodeT x _ <- runTreeT m
pure $ NodeT x []
else
TreeT $ do
NodeT x xs0 <- runTreeT m
pure . NodeT x $
fmap (prune (n - 1)) xs0
depth :: Tree a -> Int
depth m =
let
NodeT _ xs =
runTree m
n =
if null xs then
0
else
maximum (fmap depth xs)
in
1 + n
catMaybes :: Tree (Maybe a) -> Maybe (Tree a)
catMaybes m =
let
NodeT mx mxs =
runTree m
in
case mx of
Nothing -> do
case Maybe.mapMaybe catMaybes mxs of
[] ->
Nothing
Tree (NodeT x xs0) : xs1 ->
Just . Tree $
Node x (xs0 ++ xs1)
Just x ->
Just . Tree $
Node x (Maybe.mapMaybe catMaybes mxs)
filter :: (a -> Bool) -> Tree a -> Maybe (Tree a)
filter p =
catMaybes .
runTreeMaybeT .
filterMaybeT p .
hoist lift
runTreeMaybeT :: Monad m => TreeT (MaybeT m) a -> TreeT m (Maybe a)
runTreeMaybeT =
runMaybeT .
distributeT
filterMaybeT :: (a -> Bool) -> TreeT (MaybeT Identity) a -> TreeT (MaybeT Identity) a
filterMaybeT p t =
case runTreeMaybeT t of
Tree (Node Nothing _) ->
TreeT . MaybeT . Identity $ Nothing
Tree (Node (Just x) xs) ->
hoist generalize $
Tree . Node x $
concatMap (flattenTree (maybe False p)) xs
flattenTree :: (Maybe a -> Bool) -> Tree (Maybe a) -> [Tree a]
flattenTree p (Tree (Node mx mxs0)) =
let
mxs =
concatMap (flattenTree p) mxs0
in
if p mx then
case mx of
Nothing ->
[]
Just x ->
[Tree (Node x mxs)]
else
mxs
filterT :: (Monad m, Alternative m) => (a -> Bool) -> TreeT m a -> TreeT m a
filterT p m =
TreeT $ do
NodeT x xs <- runTreeT m
if p x then
pure $
NodeT x (fmap (filterT p) xs)
else
empty
splits :: [a] -> [([a], a, [a])]
splits = \case
[] ->
[]
x : xs ->
([], x, xs) :
fmap (\(as, b, cs) -> (x : as, b, cs)) (splits xs)
dropOne :: Monad m => [NodeT m a] -> [TreeT m [a]]
dropOne ts = do
(xs, _y, zs) <- splits ts
pure . TreeT . pure $
interleave (xs ++ zs)
shrinkOne :: Monad m => [NodeT m a] -> [TreeT m [a]]
shrinkOne ts = do
(xs, y0, zs) <- splits ts
y1 <- nodeChildren y0
pure . TreeT $ do
y2 <- runTreeT y1
pure $
interleave (xs ++ [y2] ++ zs)
interleave :: forall m a. Monad m => [NodeT m a] -> NodeT m [a]
interleave ts =
NodeT (fmap nodeValue ts) $
concat [
dropOne ts
, shrinkOne ts
]
instance Foldable Tree where
foldMap f (TreeT mx) =
foldMap f (runIdentity mx)
instance Foldable Node where
foldMap f (NodeT x xs) =
f x `mappend` mconcat (fmap (foldMap f) xs)
instance Traversable Tree where
traverse f (TreeT mx) =
TreeT <$> traverse (traverse f) mx
instance Traversable Node where
traverse f (NodeT x xs) =
NodeT <$> f x <*> traverse (traverse f) xs
instance (Eq1 m, Eq a) => Eq (TreeT m a) where
TreeT m0 == TreeT m1 =
liftEq (==) m0 m1
instance Functor m => Functor (NodeT m) where
fmap f (NodeT x xs) =
NodeT (f x) (fmap (fmap f) xs)
instance Functor m => Functor (TreeT m) where
fmap f =
TreeT . fmap (fmap f) . runTreeT
instance Applicative m => Applicative (NodeT m) where
pure x =
NodeT x []
(<*>) (NodeT ab tabs) na@(NodeT a tas) =
NodeT (ab a) $
map (<*> (fromNodeT na)) tabs ++ map (fmap ab) tas
instance Applicative m => Applicative (TreeT m) where
pure =
TreeT . pure . pure
(<*>) (TreeT mab) (TreeT ma) =
TreeT $
liftA2 (<*>) mab ma
instance Monad m => Monad (NodeT m) where
return =
pure
(>>=) (NodeT x xs) k =
case k x of
NodeT y ys ->
NodeT y $
fmap (TreeT . fmap (>>= k) . runTreeT) xs ++ ys
instance Monad m => Monad (TreeT m) where
return =
pure
(>>=) m k =
TreeT $ do
NodeT x xs <- runTreeT m
NodeT y ys <- runTreeT (k x)
pure . NodeT y $
fmap (>>= k) xs ++ ys
instance Alternative m => Alternative (TreeT m) where
empty =
TreeT empty
(<|>) x y =
TreeT (runTreeT x <|> runTreeT y)
instance MonadPlus m => MonadPlus (TreeT m) where
mzero =
TreeT mzero
mplus x y =
TreeT (runTreeT x `mplus` runTreeT y)
zipTreeT :: forall f a b. Applicative f => TreeT f a -> TreeT f b -> TreeT f (a, b)
zipTreeT l0@(TreeT left) r0@(TreeT right) =
TreeT $
let
zipNodeT :: NodeT f a -> NodeT f b -> NodeT f (a, b)
zipNodeT (NodeT a ls) (NodeT b rs) =
NodeT (a, b) $
concat [
[zipTreeT l1 r0 | l1 <- ls]
, [zipTreeT l0 r1 | r1 <- rs]
]
in
zipNodeT <$> left <*> right
instance Monad m => MonadZip (TreeT m) where
mzip =
zipTreeT
instance MonadTrans TreeT where
lift f =
TreeT $
fmap (\x -> NodeT x []) f
instance MFunctor NodeT where
hoist f (NodeT x xs) =
NodeT x (fmap (hoist f) xs)
instance MFunctor TreeT where
hoist f (TreeT m) =
TreeT . f $ fmap (hoist f) m
embedNodeT :: Monad m => (t (NodeT t b) -> TreeT m (NodeT t b)) -> NodeT t b -> NodeT m b
embedNodeT f (NodeT x xs) =
NodeT x (fmap (embedTreeT f) xs)
embedTreeT :: Monad m => (t (NodeT t b) -> TreeT m (NodeT t b)) -> TreeT t b -> TreeT m b
embedTreeT f (TreeT m) =
TreeT . pure . embedNodeT f =<< f m
instance MMonad TreeT where
embed f m =
embedTreeT f m
distributeNodeT :: Transformer t TreeT m => NodeT (t m) a -> t (TreeT m) a
distributeNodeT (NodeT x xs) =
join . lift . fromNodeT . NodeT (pure x) $
fmap (pure . distributeTreeT) xs
distributeTreeT :: Transformer t TreeT m => TreeT (t m) a -> t (TreeT m) a
distributeTreeT x =
distributeNodeT =<< hoist lift (runTreeT x)
instance MonadTransDistributive TreeT where
distributeT =
distributeTreeT
instance PrimMonad m => PrimMonad (TreeT m) where
type PrimState (TreeT m) =
PrimState m
primitive =
lift . primitive
instance MonadIO m => MonadIO (TreeT m) where
liftIO =
lift . liftIO
instance MonadBase b m => MonadBase b (TreeT m) where
liftBase =
lift . liftBase
instance MonadThrow m => MonadThrow (TreeT m) where
throwM =
lift . throwM
handleNodeT :: (Exception e, MonadCatch m) => (e -> TreeT m a) -> NodeT m a -> NodeT m a
handleNodeT onErr (NodeT x xs) =
NodeT x $
fmap (handleTreeT onErr) xs
handleTreeT :: (Exception e, MonadCatch m) => (e -> TreeT m a) -> TreeT m a -> TreeT m a
handleTreeT onErr m =
TreeT . fmap (handleNodeT onErr) $
catch (runTreeT m) (runTreeT . onErr)
instance MonadCatch m => MonadCatch (TreeT m) where
catch =
flip handleTreeT
localNodeT :: MonadReader r m => (r -> r) -> NodeT m a -> NodeT m a
localNodeT f (NodeT x xs) =
NodeT x $
fmap (localTreeT f) xs
localTreeT :: MonadReader r m => (r -> r) -> TreeT m a -> TreeT m a
localTreeT f (TreeT m) =
TreeT $
pure . localNodeT f =<< local f m
instance MonadReader r m => MonadReader r (TreeT m) where
ask =
lift ask
local =
localTreeT
instance MonadState s m => MonadState s (TreeT m) where
get =
lift get
put =
lift . put
state =
lift . state
listenNodeT :: MonadWriter w m => w -> NodeT m a -> NodeT m (a, w)
listenNodeT w (NodeT x xs) =
NodeT (x, w) $
fmap (listenTreeT w) xs
listenTreeT :: MonadWriter w m => w -> TreeT m a -> TreeT m (a, w)
listenTreeT w0 (TreeT m) =
TreeT $ do
(x, w) <- listen m
pure $ listenNodeT (mappend w0 w) x
passNodeT :: MonadWriter w m => NodeT m (a, w -> w) -> NodeT m a
passNodeT (NodeT (x, _) xs) =
NodeT x $
fmap passTreeT xs
passTreeT :: MonadWriter w m => TreeT m (a, w -> w) -> TreeT m a
passTreeT (TreeT m) =
TreeT $
pure . passNodeT =<< m
instance MonadWriter w m => MonadWriter w (TreeT m) where
writer =
lift . writer
tell =
lift . tell
listen =
listenTreeT mempty
pass =
passTreeT
handleErrorNodeT :: MonadError e m => (e -> TreeT m a) -> NodeT m a -> NodeT m a
handleErrorNodeT onErr (NodeT x xs) =
NodeT x $
fmap (handleErrorTreeT onErr) xs
handleErrorTreeT :: MonadError e m => (e -> TreeT m a) -> TreeT m a -> TreeT m a
handleErrorTreeT onErr m =
TreeT . fmap (handleErrorNodeT onErr) $
catchError (runTreeT m) (runTreeT . onErr)
instance MonadError e m => MonadError e (TreeT m) where
throwError =
lift . throwError
catchError =
flip handleErrorTreeT
instance MonadResource m => MonadResource (TreeT m) where
liftResourceT =
lift . liftResourceT
instance (Show1 m, Show a) => Show (NodeT m a) where
showsPrec =
showsPrec1
instance (Show1 m, Show a) => Show (TreeT m a) where
showsPrec =
showsPrec1
instance Show1 m => Show1 (NodeT m) where
liftShowsPrec sp sl d (NodeT x xs) =
let
sp1 =
liftShowsPrec sp sl
sl1 =
liftShowList sp sl
sp2 =
liftShowsPrec sp1 sl1
in
showsBinaryWith sp sp2 "NodeT" d x xs
instance Show1 m => Show1 (TreeT m) where
liftShowsPrec sp sl d (TreeT m) =
let
sp1 =
liftShowsPrec sp sl
sl1 =
liftShowList sp sl
sp2 =
liftShowsPrec sp1 sl1
in
showsUnaryWith sp2 "TreeT" d m
renderTreeTLines :: Monad m => TreeT m String -> m [String]
renderTreeTLines (TreeT m) = do
NodeT x xs0 <- m
xs <- renderForestLines xs0
pure $
lines (renderNodeT x) ++ xs
renderNodeT :: String -> String
renderNodeT xs =
case xs of
[_] ->
' ' : xs
_ ->
xs
renderForestLines :: Monad m => [TreeT m String] -> m [String]
renderForestLines xs0 =
let
shift hd other =
zipWith (++) (hd : repeat other)
in
case xs0 of
[] ->
pure []
[x] -> do
s <- renderTreeTLines x
pure $
shift " └╼" " " s
x : xs -> do
s <- renderTreeTLines x
ss <- renderForestLines xs
pure $
shift " ├╼" " │ " s ++ ss
render :: Tree String -> String
render =
runIdentity . renderT
renderT :: Monad m => TreeT m String -> m String
renderT =
fmap unlines . renderTreeTLines