{-# LANGUAGE FlexibleContexts #-}

-- |
-- Module      :   Grisette.Lib.Data.List
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Lib.Data.List
  ( -- * Symbolic versions of 'Data.List' operations
    (!!~),
    symFilter,
    symTake,
    symDrop,
  )
where

import Control.Exception
import Control.Monad.Except
import Grisette.Core.Control.Monad.Union
import Grisette.Core.Data.Class.Bool
import Grisette.Core.Data.Class.Error
import Grisette.Core.Data.Class.Mergeable
import Grisette.Core.Data.Class.SOrd
import Grisette.Core.Data.Class.SafeArith
import Grisette.Core.Data.Class.SimpleMergeable
import Grisette.IR.SymPrim.Data.SymPrim
import Grisette.Lib.Control.Monad

-- | Symbolic version of 'Data.List.!!', the result would be merged and
-- propagate the mergeable knowledge.
(!!~) ::
  ( MonadUnion uf,
    MonadError e uf,
    TransformError ArrayException e,
    Mergeable a
  ) =>
  [a] ->
  SymInteger ->
  uf a
[a]
l !!~ :: forall (uf :: * -> *) e a.
(MonadUnion uf, MonadError e uf, TransformError ArrayException e,
 Mergeable a) =>
[a] -> SymInteger -> uf a
!!~ SymInteger
p = forall {e} {m :: * -> *} {a} {a}.
(MonadError e m, TransformError ArrayException e, UnionLike m,
 Mergeable a, SEq a, Num a) =>
[a] -> a -> a -> m a
go [a]
l SymInteger
p SymInteger
0
  where
    go :: [a] -> a -> a -> m a
go [] a
_ a
_ = forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ forall from to. TransformError from to => from -> to
transformError (String -> ArrayException
IndexOutOfBounds String
"!!~")
    go (a
x : [a]
xs) a
p1 a
i = forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (a
p1 forall a. SEq a => a -> a -> SymBool
==~ a
i) (forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn a
x) ([a] -> a -> a -> m a
go [a]
xs a
p1 forall a b. (a -> b) -> a -> b
$ a
i forall a. Num a => a -> a -> a
+ a
1)

-- | Symbolic version of 'Data.List.filter', the result would be merged and
-- propagate the mergeable knowledge.
symFilter :: (MonadUnion u, Mergeable a) => (a -> SymBool) -> [a] -> u [a]
symFilter :: forall (u :: * -> *) a.
(MonadUnion u, Mergeable a) =>
(a -> SymBool) -> [a] -> u [a]
symFilter a -> SymBool
f = forall {u :: * -> *}. (UnionLike u, Monad u) => [a] -> u [a]
go
  where
    go :: [a] -> u [a]
go [] = forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn []
    go (a
x : [a]
xs) = do
      [a]
r <- [a] -> u [a]
go [a]
xs
      forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (a -> SymBool
f a
x) (forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn (a
x forall a. a -> [a] -> [a]
: [a]
r)) (forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn [a]
r)

-- | Symbolic version of 'Data.List.take', the result would be merged and
-- propagate the mergeable knowledge.
symTake :: (MonadUnion u, Mergeable a) => SymInteger -> [a] -> u [a]
symTake :: forall (u :: * -> *) a.
(MonadUnion u, Mergeable a) =>
SymInteger -> [a] -> u [a]
symTake SymInteger
_ [] = forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn []
symTake SymInteger
x (a
v : [a]
vs) = forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (SymInteger
x forall a. SOrd a => a -> a -> SymBool
<=~ SymInteger
0) (forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn []) (forall (f :: * -> *) b a.
(MonadUnion f, Mergeable b, Functor f) =>
(a -> b) -> f a -> f b
mrgFmap (a
v forall a. a -> [a] -> [a]
:) forall a b. (a -> b) -> a -> b
$ forall (u :: * -> *) a.
(MonadUnion u, Mergeable a) =>
SymInteger -> [a] -> u [a]
symTake (SymInteger
x forall a. Num a => a -> a -> a
- SymInteger
1) [a]
vs)

-- | Symbolic version of 'Data.List.drop', the result would be merged and
-- propagate the mergeable knowledge.
symDrop :: (MonadUnion u, Mergeable a) => SymInteger -> [a] -> u [a]
symDrop :: forall (u :: * -> *) a.
(MonadUnion u, Mergeable a) =>
SymInteger -> [a] -> u [a]
symDrop SymInteger
_ [] = forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn []
symDrop SymInteger
x r :: [a]
r@(a
_ : [a]
vs) = forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (SymInteger
x forall a. SOrd a => a -> a -> SymBool
<=~ SymInteger
0) (forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn [a]
r) (forall (u :: * -> *) a.
(MonadUnion u, Mergeable a) =>
SymInteger -> [a] -> u [a]
symDrop (SymInteger
x forall a. Num a => a -> a -> a
- SymInteger
1) [a]
vs)