{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}

-- | Extensible ADT
module Data.Variant.EADT
   ( EADT (..)
   , (:<:)
   , (:<<:)
   , pattern VF
   , appendEADT
   , liftEADT
   , popEADT
   , contToEADT
   , contToEADTM
   , EADTShow (..)
   , eadtShow
   -- * Reexport
   , module Data.Variant.Functor
   , module Data.Variant.VariantF
   )
where

import Data.Variant
import Data.Variant.VariantF
import Data.Variant.Types
import Data.Variant.ContFlow
import Data.Variant.Functor

import GHC.TypeLits

-- $setup
-- >>> :seti -XDataKinds
-- >>> :seti -XTypeApplications
-- >>> :seti -XTypeOperators
-- >>> :seti -XFlexibleContexts
-- >>> :seti -XTypeFamilies
-- >>> :seti -XPatternSynonyms
-- >>> :seti -XDeriveFunctor
-- >>>
-- >>> import Data.Functor.Classes
-- >>>
-- >>> data ConsF a e = ConsF a e deriving (Eq,Ord,Show,Functor)
-- >>> data NilF    e = NilF      deriving (Eq,Ord,Show,Functor)
-- >>>
-- >>> instance Eq a => Eq1 (ConsF a) where liftEq cmp (ConsF a e1) (ConsF b e2) = a == b && cmp e1 e2
-- >>> instance Eq1 NilF where liftEq _ _ _ = True
-- >>>
-- >>> :{
-- >>> pattern Cons :: ConsF a :<: xs => a -> EADT xs -> EADT xs
-- >>> pattern Cons a l = VF (ConsF a l)
-- >>> pattern Nil :: NilF :<: xs => EADT xs
-- >>> pattern Nil = VF NilF
-- >>> type ListF a = VariantF '[NilF, ConsF a]
-- >>> type List a = EADT '[NilF, ConsF a]
-- >>> :}
--
-- >>>
-- >>> let a = Cons "Hello" (Cons "World" Nil) :: List String
-- >>> let b = Cons "Bonjour" (Cons "Monde" Nil) :: List String
-- >>> a == b
-- False
-- >>> a == a
-- True


-- | An extensible ADT
newtype EADT fs
   = EADT (VariantF fs (EADT fs))

type instance Base (EADT fs) = VariantF fs

instance Functor (VariantF fs) => Recursive (EADT fs) where
   project :: EADT fs -> Base (EADT fs) (EADT fs)
project (EADT VariantF fs (EADT fs)
a) = Base (EADT fs) (EADT fs)
VariantF fs (EADT fs)
a

instance Functor (VariantF fs) => Corecursive (EADT fs) where
   embed :: Base (EADT fs) (EADT fs) -> EADT fs
embed = Base (EADT fs) (EADT fs) -> EADT fs
VariantF fs (EADT fs) -> EADT fs
forall (fs :: [* -> *]). VariantF fs (EADT fs) -> EADT fs
EADT

instance Eq1 (VariantF fs) => Eq (EADT fs) where
  EADT VariantF fs (EADT fs)
a == :: EADT fs -> EADT fs -> Bool
== EADT VariantF fs (EADT fs)
b = VariantF fs (EADT fs) -> VariantF fs (EADT fs) -> Bool
forall (f :: * -> *) a. (Eq1 f, Eq a) => f a -> f a -> Bool
eq1 VariantF fs (EADT fs)
a VariantF fs (EADT fs)
b

instance Ord1 (VariantF fs) => Ord (EADT fs) where
  compare :: EADT fs -> EADT fs -> Ordering
compare (EADT VariantF fs (EADT fs)
a) (EADT VariantF fs (EADT fs)
b) = VariantF fs (EADT fs) -> VariantF fs (EADT fs) -> Ordering
forall (f :: * -> *) a. (Ord1 f, Ord a) => f a -> f a -> Ordering
compare1 VariantF fs (EADT fs)
a VariantF fs (EADT fs)
b

instance Show1 (VariantF fs) => Show (EADT fs) where
  showsPrec :: Int -> EADT fs -> ShowS
showsPrec Int
d (EADT VariantF fs (EADT fs)
a) =
    Bool -> ShowS -> ShowS
showParen (Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
11)
      (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String -> ShowS
showString String
"EADT "
      ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> VariantF fs (EADT fs) -> ShowS
forall (f :: * -> *) a. (Show1 f, Show a) => Int -> f a -> ShowS
showsPrec1 Int
11 VariantF fs (EADT fs)
a

-- | Constructor `f` is in `xs`
type family f :<: xs where
   f :<: xs = EADTF' f (EADT xs) xs

-- | Forall `x` in `xs`, `x :<: ys`
type family (:<<:) xs ys :: Constraint where
   '[] :<<: ys       = ()
   (x ': xs) :<<: ys = (x :<: ys, xs :<<: ys)

type EADTF' f e cs =
   ( Member f cs
   , Index (IndexOf (f e) (ApplyAll e cs)) (ApplyAll e cs) ~ f e
   , PopVariant (f e) (ApplyAll e cs)
   , KnownNat (IndexOf (f e) (ApplyAll e cs))
   , Remove (f e) (ApplyAll e cs) ~ ApplyAll e (Remove f cs)
   )

-- | Pattern-match in an extensible ADT
pattern VF :: forall e f cs.
   ( e ~ EADT cs  -- allow easy use of TypeApplication to set the EADT type
   , f :<: cs     -- constraint synonym ensuring `f` is in `cs`
   ) => f (EADT cs) -> EADT cs
pattern $mVF :: forall {r} {e} {f :: * -> *} {cs :: [* -> *]}.
(e ~ EADT cs, f :<: cs) =>
EADT cs -> (f (EADT cs) -> r) -> ((# #) -> r) -> r
$bVF :: forall e (f :: * -> *) (cs :: [* -> *]).
(e ~ EADT cs, f :<: cs) =>
f (EADT cs) -> EADT cs
VF x = EADT (VariantF (VSilent x))
   -- `VSilent` matches a variant value without checking the membership: we
   -- already do it with :<:

-- | Append new "constructors" to the EADT
appendEADT :: forall ys xs zs.
   ( zs ~ Concat xs ys
   , ApplyAll (EADT zs) zs ~ Concat (ApplyAll (EADT zs) xs) (ApplyAll (EADT zs) ys)
   , Functor (VariantF xs)
   ) => EADT xs -> EADT zs
appendEADT :: forall (ys :: [* -> *]) (xs :: [* -> *]) (zs :: [* -> *]).
(zs ~ Concat xs ys,
 ApplyAll (EADT zs) zs
 ~ Concat (ApplyAll (EADT zs) xs) (ApplyAll (EADT zs) ys),
 Functor (VariantF xs)) =>
EADT xs -> EADT zs
appendEADT (EADT VariantF xs (EADT xs)
v) = VariantF zs (EADT zs) -> EADT zs
forall (fs :: [* -> *]). VariantF fs (EADT fs) -> EADT fs
EADT (forall (ys :: [* -> *]) (xs :: [* -> *]) e.
(ApplyAll e (Concat xs ys)
 ~ Concat (ApplyAll e xs) (ApplyAll e ys)) =>
VariantF xs e -> VariantF (Concat xs ys) e
appendVariantF @ys ((EADT xs -> EADT zs)
-> VariantF xs (EADT xs) -> VariantF xs (EADT zs)
forall a b. (a -> b) -> VariantF xs a -> VariantF xs b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (ys :: [* -> *]) (xs :: [* -> *]) (zs :: [* -> *]).
(zs ~ Concat xs ys,
 ApplyAll (EADT zs) zs
 ~ Concat (ApplyAll (EADT zs) xs) (ApplyAll (EADT zs) ys),
 Functor (VariantF xs)) =>
EADT xs -> EADT zs
appendEADT @ys) VariantF xs (EADT xs)
v))

-- | Lift an EADT into another
liftEADT :: forall e as bs.
   ( e ~ EADT bs
   , LiftVariantF as bs e
   , Functor (VariantF as)
   ) => EADT as -> EADT bs
liftEADT :: forall e (as :: [* -> *]) (bs :: [* -> *]).
(e ~ EADT bs, LiftVariantF as bs e, Functor (VariantF as)) =>
EADT as -> EADT bs
liftEADT = (Base (EADT as) (EADT bs) -> EADT bs) -> EADT as -> EADT bs
forall t a. Recursive t => (Base t a -> a) -> t -> a
forall a. (Base (EADT as) a -> a) -> EADT as -> a
cata (VariantF bs (EADT bs) -> EADT bs
forall (fs :: [* -> *]). VariantF fs (EADT fs) -> EADT fs
EADT (VariantF bs (EADT bs) -> EADT bs)
-> (VariantF as (EADT bs) -> VariantF bs (EADT bs))
-> VariantF as (EADT bs)
-> EADT bs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VariantF as (EADT bs) -> VariantF bs (EADT bs)
forall {t} (as :: [t -> *]) (bs :: [t -> *]) (e :: t).
LiftVariantF as bs e =>
VariantF as e -> VariantF bs e
liftVariantF)

-- | Pop an EADT value
popEADT :: forall f xs e.
   ( f :<: xs
   , e ~ EADT xs
   , f e :< ApplyAll e xs
   ) => EADT xs -> Either (VariantF (Remove f xs) (EADT xs)) (f (EADT xs))
popEADT :: forall (f :: * -> *) (xs :: [* -> *]) e.
(f :<: xs, e ~ EADT xs, f e :< ApplyAll e xs) =>
EADT xs -> Either (VariantF (Remove f xs) (EADT xs)) (f (EADT xs))
popEADT (EADT VariantF xs (EADT xs)
v) = VariantF xs (EADT xs)
-> Either (VariantF (Remove f xs) (EADT xs)) (f (EADT xs))
forall {t} (x :: t -> *) (xs :: [t -> *]) (e :: t).
PopVariantF x xs e =>
VariantF xs e -> Either (VariantF (Remove x xs) e) (x e)
popVariantF VariantF xs (EADT xs)
v

-- | MultiCont instance
--
-- >>> let f x = toCont x >::> (const "[]", \(ConsF u us) -> u ++ ":" ++ f us)
-- >>> f a
-- "Hello:World:[]"
instance (Functor (VariantF xs), ContVariant (ApplyAll (EADT xs) xs)) => MultiCont (EADT xs) where
   type MultiContTypes (EADT xs) = ApplyAll (EADT xs) xs
   toCont :: forall r. EADT xs -> ContFlow (MultiContTypes (EADT xs)) r
toCont  (EADT VariantF xs (EADT xs)
v) = VariantF xs (EADT xs) -> ContFlow (ApplyAll (EADT xs) xs) r
forall {t} (e :: t) (xs :: [t -> *]) r.
ContVariant (ApplyAll e xs) =>
VariantF xs e -> ContFlow (ApplyAll e xs) r
variantFToCont VariantF xs (EADT xs)
v
   toContM :: forall (m :: * -> *) r.
Monad m =>
m (EADT xs) -> ContFlow (MultiContTypes (EADT xs)) (m r)
toContM m (EADT xs)
f        = m (VariantF xs (EADT xs)) -> ContFlow (ApplyAll (EADT xs) xs) (m r)
forall {t} (e :: t) (xs :: [t -> *]) (m :: * -> *) r.
(ContVariant (ApplyAll e xs), Monad m) =>
m (VariantF xs e) -> ContFlow (ApplyAll e xs) (m r)
variantFToContM (EADT xs -> Base (EADT xs) (EADT xs)
EADT xs -> VariantF xs (EADT xs)
forall t. Recursive t => t -> Base t t
project (EADT xs -> VariantF xs (EADT xs))
-> m (EADT xs) -> m (VariantF xs (EADT xs))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (EADT xs)
f)

-- | Convert a multi-continuation into an EADT
contToEADT ::
   ( ContVariant (ApplyAll (EADT xs) xs)
   ) => ContFlow (ApplyAll (EADT xs) xs)
                 (V (ApplyAll (EADT xs) xs))
     -> EADT xs
contToEADT :: forall (xs :: [* -> *]).
ContVariant (ApplyAll (EADT xs) xs) =>
ContFlow (ApplyAll (EADT xs) xs) (V (ApplyAll (EADT xs) xs))
-> EADT xs
contToEADT ContFlow (ApplyAll (EADT xs) xs) (V (ApplyAll (EADT xs) xs))
c = VariantF xs (EADT xs) -> EADT xs
forall (fs :: [* -> *]). VariantF fs (EADT fs) -> EADT fs
EADT (ContFlow (ApplyAll (EADT xs) xs) (V (ApplyAll (EADT xs) xs))
-> VariantF xs (EADT xs)
forall {t} (xs :: [t -> *]) (e :: t).
ContVariant (ApplyAll e xs) =>
ContFlow (ApplyAll e xs) (V (ApplyAll e xs)) -> VariantF xs e
contToVariantF ContFlow (ApplyAll (EADT xs) xs) (V (ApplyAll (EADT xs) xs))
c)

-- | Convert a multi-continuation into an EADT
contToEADTM ::
   ( ContVariant (ApplyAll (EADT xs) xs)
   , Monad f
   ) => ContFlow (ApplyAll (EADT xs) xs)
                 (f (V (ApplyAll (EADT xs) xs)))
     -> f (EADT xs)
contToEADTM :: forall (xs :: [* -> *]) (f :: * -> *).
(ContVariant (ApplyAll (EADT xs) xs), Monad f) =>
ContFlow (ApplyAll (EADT xs) xs) (f (V (ApplyAll (EADT xs) xs)))
-> f (EADT xs)
contToEADTM ContFlow (ApplyAll (EADT xs) xs) (f (V (ApplyAll (EADT xs) xs)))
f = VariantF xs (EADT xs) -> EADT xs
forall (fs :: [* -> *]). VariantF fs (EADT fs) -> EADT fs
EADT (VariantF xs (EADT xs) -> EADT xs)
-> f (VariantF xs (EADT xs)) -> f (EADT xs)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ContFlow (ApplyAll (EADT xs) xs) (f (V (ApplyAll (EADT xs) xs)))
-> f (VariantF xs (EADT xs))
forall {t} (xs :: [t -> *]) (e :: t) (m :: * -> *).
(ContVariant (ApplyAll e xs), Monad m) =>
ContFlow (ApplyAll e xs) (m (V (ApplyAll e xs)))
-> m (VariantF xs e)
contToVariantFM ContFlow (ApplyAll (EADT xs) xs) (f (V (ApplyAll (EADT xs) xs)))
f


class EADTShow f where
   eadtShow' :: f String -> String

-- | Show an EADT
eadtShow :: forall xs. BottomUpF EADTShow xs => EADT xs -> String
eadtShow :: forall (xs :: [* -> *]). BottomUpF EADTShow xs => EADT xs -> String
eadtShow = (Base (EADT xs) String -> String) -> EADT xs -> String
forall t a. Recursive t => (Base t a -> a) -> t -> a
bottomUp (forall {t} (c :: (t -> *) -> Constraint) (fs :: [t -> *]) (a :: t)
       b.
BottomUp c fs =>
(forall (f :: t -> *). c f => f a -> b) -> VariantF fs a -> b
forall (c :: (* -> *) -> Constraint) (fs :: [* -> *]) a b.
BottomUp c fs =>
(forall (f :: * -> *). c f => f a -> b) -> VariantF fs a -> b
toBottomUp @EADTShow f String -> String
forall (f :: * -> *). EADTShow f => f String -> String
eadtShow')