{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

-- |
-- Module      :   Grisette.IR.SymPrim.Data.Prim.PartialEval.PartialEval
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.IR.SymPrim.Data.Prim.PartialEval.PartialEval
  ( PartialFun,
    PartialRuleUnary,
    TotalRuleUnary,
    PartialRuleBinary,
    TotalRuleBinary,
    totalize,
    totalize2,
    UnaryPartialStrategy (..),
    unaryPartial,
    BinaryCommPartialStrategy (..),
    BinaryPartialStrategy (..),
    binaryPartial,
  )
where

import Control.Monad.Except
import Grisette.IR.SymPrim.Data.Prim.InternedTerm.Term

type PartialFun a b = a -> Maybe b

type PartialRuleUnary a b = PartialFun (Term a) (Term b)

type TotalRuleUnary a b = Term a -> Term b

type PartialRuleBinary a b c = Term a -> PartialFun (Term b) (Term c)

type TotalRuleBinary a b c = Term a -> Term b -> Term c

totalize :: PartialFun a b -> (a -> b) -> a -> b
totalize :: forall a b. PartialFun a b -> (a -> b) -> a -> b
totalize PartialFun a b
partial a -> b
fallback a
a =
  case PartialFun a b
partial a
a of
    Just b
b -> b
b
    Maybe b
Nothing -> a -> b
fallback a
a

totalize2 :: (a -> PartialFun b c) -> (a -> b -> c) -> a -> b -> c
totalize2 :: forall a b c. (a -> PartialFun b c) -> (a -> b -> c) -> a -> b -> c
totalize2 a -> PartialFun b c
partial a -> b -> c
fallback a
a b
b =
  case a -> PartialFun b c
partial a
a b
b of
    Just c
c -> c
c
    Maybe c
Nothing -> a -> b -> c
fallback a
a b
b

class UnaryPartialStrategy tag a b | tag a -> b where
  extractor :: tag -> Term a -> Maybe a
  constantHandler :: tag -> a -> Maybe (Term b)
  nonConstantHandler :: tag -> Term a -> Maybe (Term b)

unaryPartial :: forall tag a b. (UnaryPartialStrategy tag a b) => tag -> PartialRuleUnary a b
unaryPartial :: forall tag a b.
UnaryPartialStrategy tag a b =>
tag -> PartialRuleUnary a b
unaryPartial tag
tag Term a
a = case tag -> Term a -> Maybe a
forall tag a b.
UnaryPartialStrategy tag a b =>
tag -> Term a -> Maybe a
extractor tag
tag Term a
a of
  Maybe a
Nothing -> tag -> Term a -> Maybe (Term b)
forall tag a b.
UnaryPartialStrategy tag a b =>
tag -> PartialRuleUnary a b
nonConstantHandler tag
tag Term a
a
  Just a
a' -> tag -> a -> Maybe (Term b)
forall tag a b.
UnaryPartialStrategy tag a b =>
tag -> a -> Maybe (Term b)
constantHandler tag
tag a
a'

class BinaryCommPartialStrategy tag a c | tag a -> c where
  singleConstantHandler :: tag -> a -> Term a -> Maybe (Term c)

class BinaryPartialStrategy tag a b c | tag a b -> c where
  extractora :: tag -> Term a -> Maybe a
  extractorb :: tag -> Term b -> Maybe b
  allConstantHandler :: tag -> a -> b -> Maybe (Term c)
  leftConstantHandler :: tag -> a -> Term b -> Maybe (Term c)
  default leftConstantHandler :: (a ~ b, BinaryCommPartialStrategy tag a c) => tag -> a -> Term b -> Maybe (Term c)
  leftConstantHandler = forall tag a c.
BinaryCommPartialStrategy tag a c =>
tag -> a -> Term a -> Maybe (Term c)
singleConstantHandler @tag @a
  rightConstantHandler :: tag -> Term a -> b -> Maybe (Term c)
  default rightConstantHandler :: (a ~ b, BinaryCommPartialStrategy tag a c) => tag -> Term a -> b -> Maybe (Term c)
  rightConstantHandler tag
tag = (a -> Term a -> Maybe (Term c)) -> Term a -> a -> Maybe (Term c)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((a -> Term a -> Maybe (Term c)) -> Term a -> a -> Maybe (Term c))
-> (a -> Term a -> Maybe (Term c)) -> Term a -> a -> Maybe (Term c)
forall a b. (a -> b) -> a -> b
$ forall tag a c.
BinaryCommPartialStrategy tag a c =>
tag -> a -> Term a -> Maybe (Term c)
singleConstantHandler @tag @a tag
tag
  nonBinaryConstantHandler :: tag -> Term a -> Term b -> Maybe (Term c)

binaryPartial :: forall tag a b c. (BinaryPartialStrategy tag a b c) => tag -> PartialRuleBinary a b c
binaryPartial :: forall tag a b c.
BinaryPartialStrategy tag a b c =>
tag -> PartialRuleBinary a b c
binaryPartial tag
tag Term a
a Term b
b = case (forall tag a b c.
BinaryPartialStrategy tag a b c =>
tag -> Term a -> Maybe a
extractora @tag @a @b @c tag
tag Term a
a, forall tag a b c.
BinaryPartialStrategy tag a b c =>
tag -> Term b -> Maybe b
extractorb @tag @a @b @c tag
tag Term b
b) of
  (Maybe a
Nothing, Maybe b
Nothing) -> forall tag a b c.
BinaryPartialStrategy tag a b c =>
tag -> PartialRuleBinary a b c
nonBinaryConstantHandler @tag @a @b @c tag
tag Term a
a Term b
b
  (Just a
a', Maybe b
Nothing) ->
    forall tag a b c.
BinaryPartialStrategy tag a b c =>
tag -> a -> Term b -> Maybe (Term c)
leftConstantHandler @tag @a @b @c tag
tag a
a' Term b
b
      Maybe (Term c) -> (() -> Maybe (Term c)) -> Maybe (Term c)
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` \()
_ -> forall tag a b c.
BinaryPartialStrategy tag a b c =>
tag -> PartialRuleBinary a b c
nonBinaryConstantHandler @tag @a @b @c tag
tag Term a
a Term b
b
  (Maybe a
Nothing, Just b
b') ->
    forall tag a b c.
BinaryPartialStrategy tag a b c =>
tag -> Term a -> b -> Maybe (Term c)
rightConstantHandler @tag @a @b @c tag
tag Term a
a b
b'
      Maybe (Term c) -> (() -> Maybe (Term c)) -> Maybe (Term c)
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` \()
_ -> forall tag a b c.
BinaryPartialStrategy tag a b c =>
tag -> PartialRuleBinary a b c
nonBinaryConstantHandler @tag @a @b @c tag
tag Term a
a Term b
b
  (Just a
a', Just b
b') ->
    forall tag a b c.
BinaryPartialStrategy tag a b c =>
tag -> a -> b -> Maybe (Term c)
allConstantHandler @tag @a @b @c tag
tag a
a' b
b'