{-# LANGUAGE FlexibleContexts #-}

-- |
-- Module      :   Grisette.Lib.Data.List
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Lib.Data.List
  ( -- * Symbolic versions of 'Data.List' operations
    (!!~),
    symFilter,
    symTake,
    symDrop,
  )
where

import Control.Exception
import Control.Monad.Except
import Grisette.Core.Control.Monad.Union
import Grisette.Core.Data.Class.Bool
import Grisette.Core.Data.Class.Error
import Grisette.Core.Data.Class.Integer
import Grisette.Core.Data.Class.Mergeable
import Grisette.Core.Data.Class.SOrd
import Grisette.Core.Data.Class.SimpleMergeable
import Grisette.IR.SymPrim.Data.SymPrim
import Grisette.Lib.Control.Monad

-- | Symbolic version of 'Data.List.!!', the result would be merged and
-- propagate the mergeable knowledge.
(!!~) ::
  ( MonadUnion uf,
    MonadError e uf,
    TransformError ArrayException e,
    Mergeable a
  ) =>
  [a] ->
  SymInteger ->
  uf a
[a]
l !!~ :: forall (uf :: * -> *) e a.
(MonadUnion uf, MonadError e uf, TransformError ArrayException e,
 Mergeable a) =>
[a] -> SymInteger -> uf a
!!~ SymInteger
p = [a] -> SymInteger -> SymInteger -> uf a
forall {e} {m :: * -> *} {a} {t}.
(MonadError e m, TransformError ArrayException e, UnionLike m,
 Mergeable a, SEq t, Num t) =>
[a] -> t -> t -> m a
go [a]
l SymInteger
p SymInteger
0
  where
    go :: [a] -> t -> t -> m a
go [] t
_ t
_ = e -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (e -> m a) -> e -> m a
forall a b. (a -> b) -> a -> b
$ ArrayException -> e
forall from to. TransformError from to => from -> to
transformError (String -> ArrayException
IndexOutOfBounds String
"!!~")
    go (a
x : [a]
xs) t
p1 t
i = SymBool -> m a -> m a -> m a
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (t
p1 t -> t -> SymBool
forall a. SEq a => a -> a -> SymBool
==~ t
i) (a -> m a
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn a
x) ([a] -> t -> t -> m a
go [a]
xs t
p1 (t -> m a) -> t -> m a
forall a b. (a -> b) -> a -> b
$ t
i t -> t -> t
forall a. Num a => a -> a -> a
+ t
1)

-- | Symbolic version of 'Data.List.filter', the result would be merged and
-- propagate the mergeable knowledge.
symFilter :: (MonadUnion u, Mergeable a) => (a -> SymBool) -> [a] -> u [a]
symFilter :: forall (u :: * -> *) a.
(MonadUnion u, Mergeable a) =>
(a -> SymBool) -> [a] -> u [a]
symFilter a -> SymBool
f = [a] -> u [a]
forall {u :: * -> *}. (UnionLike u, Monad u) => [a] -> u [a]
go
  where
    go :: [a] -> u [a]
go [] = [a] -> u [a]
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn []
    go (a
x : [a]
xs) = do
      [a]
r <- [a] -> u [a]
go [a]
xs
      SymBool -> u [a] -> u [a] -> u [a]
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (a -> SymBool
f a
x) ([a] -> u [a]
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn (a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
r)) ([a] -> u [a]
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn [a]
r)

-- | Symbolic version of 'Data.List.take', the result would be merged and
-- propagate the mergeable knowledge.
symTake :: (MonadUnion u, Mergeable a) => SymInteger -> [a] -> u [a]
symTake :: forall (u :: * -> *) a.
(MonadUnion u, Mergeable a) =>
SymInteger -> [a] -> u [a]
symTake SymInteger
_ [] = [a] -> u [a]
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn []
symTake SymInteger
x (a
v : [a]
vs) = SymBool -> u [a] -> u [a] -> u [a]
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (SymInteger
x SymInteger -> SymInteger -> SymBool
forall a. SOrd a => a -> a -> SymBool
<=~ SymInteger
0) ([a] -> u [a]
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn []) (([a] -> [a]) -> u [a] -> u [a]
forall (f :: * -> *) b a.
(MonadUnion f, Mergeable b, Functor f) =>
(a -> b) -> f a -> f b
mrgFmap (a
v a -> [a] -> [a]
forall a. a -> [a] -> [a]
:) (u [a] -> u [a]) -> u [a] -> u [a]
forall a b. (a -> b) -> a -> b
$ SymInteger -> [a] -> u [a]
forall (u :: * -> *) a.
(MonadUnion u, Mergeable a) =>
SymInteger -> [a] -> u [a]
symTake (SymInteger
x SymInteger -> SymInteger -> SymInteger
forall a. Num a => a -> a -> a
- SymInteger
1) [a]
vs)

-- | Symbolic version of 'Data.List.drop', the result would be merged and
-- propagate the mergeable knowledge.
symDrop :: (MonadUnion u, Mergeable a) => SymInteger -> [a] -> u [a]
symDrop :: forall (u :: * -> *) a.
(MonadUnion u, Mergeable a) =>
SymInteger -> [a] -> u [a]
symDrop SymInteger
_ [] = [a] -> u [a]
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn []
symDrop SymInteger
x r :: [a]
r@(a
_ : [a]
vs) = SymBool -> u [a] -> u [a] -> u [a]
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (SymInteger
x SymInteger -> SymInteger -> SymBool
forall a. SOrd a => a -> a -> SymBool
<=~ SymInteger
0) ([a] -> u [a]
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn [a]
r) (SymInteger -> [a] -> u [a]
forall (u :: * -> *) a.
(MonadUnion u, Mergeable a) =>
SymInteger -> [a] -> u [a]
symDrop (SymInteger
x SymInteger -> SymInteger -> SymInteger
forall a. Num a => a -> a -> a
- SymInteger
1) [a]
vs)