{-# language DeriveTraversable #-}
{-# language StandaloneDeriving #-}
{-# language LambdaCase #-}
{-# language TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use camelCase" #-}
module Data.SRTree.EqSat ( simplifyEqSat ) where
import Control.Applicative (liftA2)
import Control.Monad (unless)
import Data.AEq ( AEq((~==)) )
import Data.Eq.Deriving ( deriveEq1 )
import Data.Equality.Analysis ( Analysis(..) )
import Data.Equality.Graph ( ClassId, Language, ENode(unNode) )
import Data.Equality.Graph.Lens hiding ((^.))
import Data.Equality.Graph.Lens qualified as L
import Data.Equality.Matching
import Data.Equality.Matching.Database ( Subst )
import Data.Equality.Saturation
import Data.Equality.Saturation.Scheduler ( BackoffScheduler(BackoffScheduler) )
import Data.Foldable qualified as F
import Data.IntMap.Strict qualified as IM
import Data.Maybe (isJust, isNothing)
import Data.Ord.Deriving ( deriveOrd1 )
import Data.SRTree hiding (Fix(..))
import Data.SRTree.Recursion qualified as R
import Data.Set qualified as S
import Text.Show.Deriving ( deriveShow1 )
deriving instance Foldable SRTree
deriving instance Traversable SRTree
deriveEq1 ''SRTree
deriveOrd1 ''SRTree
deriveShow1 ''SRTree
instance Num (Pattern SRTree) where
Pattern SRTree
l + :: Pattern SRTree -> Pattern SRTree -> Pattern SRTree
+ Pattern SRTree
r = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
Add Pattern SRTree
l Pattern SRTree
r
Pattern SRTree
l - :: Pattern SRTree -> Pattern SRTree -> Pattern SRTree
- Pattern SRTree
r = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Pattern SRTree
l Pattern SRTree
r
Pattern SRTree
l * :: Pattern SRTree -> Pattern SRTree -> Pattern SRTree
* Pattern SRTree
r = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Pattern SRTree
l Pattern SRTree
r
abs :: Pattern SRTree -> Pattern SRTree
abs = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Abs
negate :: Pattern SRTree -> Pattern SRTree
negate Pattern SRTree
t = forall a. Num a => Integer -> a
fromInteger (-Integer
1) forall a. Num a => a -> a -> a
* Pattern SRTree
t
signum :: Pattern SRTree -> Pattern SRTree
signum Pattern SRTree
_ = forall a. HasCallStack => a
undefined
fromInteger :: Integer -> Pattern SRTree
fromInteger = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Double -> SRTree val
Const forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => Integer -> a
fromInteger
instance Fractional (Pattern SRTree) where
/ :: Pattern SRTree -> Pattern SRTree -> Pattern SRTree
(/) Pattern SRTree
a Pattern SRTree
b = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall a b. (a -> b) -> a -> b
$ forall val. Op -> val -> val -> SRTree val
Bin Op
Div Pattern SRTree
a Pattern SRTree
b
fromRational :: Rational -> Pattern SRTree
fromRational = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Double -> SRTree val
Const forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Fractional a => Rational -> a
fromRational
instance Floating (Pattern SRTree) where
pi :: Pattern SRTree
pi = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall a b. (a -> b) -> a -> b
$ forall val. Double -> SRTree val
Const forall a. Floating a => a
pi
exp :: Pattern SRTree -> Pattern SRTree
exp = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Exp
log :: Pattern SRTree -> Pattern SRTree
log = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Log
sqrt :: Pattern SRTree -> Pattern SRTree
sqrt = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Sqrt
sin :: Pattern SRTree -> Pattern SRTree
sin = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Sin
cos :: Pattern SRTree -> Pattern SRTree
cos = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Cos
tan :: Pattern SRTree -> Pattern SRTree
tan = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Tan
asin :: Pattern SRTree -> Pattern SRTree
asin = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ASin
acos :: Pattern SRTree -> Pattern SRTree
acos = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ACos
atan :: Pattern SRTree -> Pattern SRTree
atan = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ATan
sinh :: Pattern SRTree -> Pattern SRTree
sinh = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Sinh
cosh :: Pattern SRTree -> Pattern SRTree
cosh = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Cosh
tanh :: Pattern SRTree -> Pattern SRTree
tanh = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
Tanh
asinh :: Pattern SRTree -> Pattern SRTree
asinh = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ASinh
acosh :: Pattern SRTree -> Pattern SRTree
acosh = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ACosh
atanh :: Pattern SRTree -> Pattern SRTree
atanh = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall val. Function -> val -> SRTree val
Uni Function
ATanh
Pattern SRTree
l ** :: Pattern SRTree -> Pattern SRTree -> Pattern SRTree
** Pattern SRTree
r = forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern (forall val. Op -> val -> val -> SRTree val
Bin Op
Power Pattern SRTree
l Pattern SRTree
r)
logBase :: Pattern SRTree -> Pattern SRTree -> Pattern SRTree
logBase Pattern SRTree
l Pattern SRTree
r = forall a. HasCallStack => a
undefined
instance Analysis (Maybe Double) SRTree where
makeA :: SRTree (Maybe Double) -> Maybe Double
makeA = SRTree (Maybe Double) -> Maybe Double
evalConstant
joinA :: Maybe Double -> Maybe Double -> Maybe Double
joinA Maybe Double
ma Maybe Double
mb = do
Double
a <- Maybe Double
ma
Double
b <- Maybe Double
mb
!()
_ <- forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall a. Num a => a -> a
abs (Double
aforall a. Num a => a -> a -> a
-Double
b) forall a. Ord a => a -> a -> Bool
<= Double
1e-6 Bool -> Bool -> Bool
|| Double
a forall a. AEq a => a -> a -> Bool
~== Double
b Bool -> Bool -> Bool
|| (Double
a forall a. Eq a => a -> a -> Bool
== Double
0 Bool -> Bool -> Bool
&& Double
b forall a. Eq a => a -> a -> Bool
== (-Double
0)) Bool -> Bool -> Bool
|| (Double
a forall a. Eq a => a -> a -> Bool
== (-Double
0) Bool -> Bool -> Bool
&& Double
b forall a. Eq a => a -> a -> Bool
== Double
0)) (forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Merged non-equal constants!" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Double
a forall a. Semigroup a => a -> a -> a
<> String
" " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Double
b forall a. Semigroup a => a -> a -> a
<> String
" " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (Double
aforall a. Eq a => a -> a -> Bool
==Double
b))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
a
modifyA :: EClass (Maybe Double) SRTree
-> (EClass (Maybe Double) SRTree, [Fix SRTree])
modifyA EClass (Maybe Double) SRTree
cl = case EClass (Maybe Double) SRTree
cl forall s a. s -> Lens' s a -> a
L.^.forall domain (l :: * -> *). Lens' (EClass domain l) domain
_data of
Maybe Double
Nothing -> (EClass (Maybe Double) SRTree
cl, [])
Just Double
d -> ((forall a (l :: * -> *). Lens' (EClass a l) (Set (ENode l))
_nodes forall s a. Lens' s a -> (a -> a) -> s -> s
%~ forall a. (a -> Bool) -> Set a -> Set a
S.filter (forall (t :: * -> *) a. Foldable t => t a -> Bool
F.null forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall (l :: * -> *). ENode l -> l Int
unNode)) EClass (Maybe Double) SRTree
cl, [forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Double -> SRTree val
Const Double
d)])
evalConstant :: SRTree (Maybe Double) -> Maybe Double
evalConstant :: SRTree (Maybe Double) -> Maybe Double
evalConstant = \case
Bin Op
Div Maybe Double
e1 Maybe Double
e2 -> forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall a. Fractional a => a -> a -> a
(/) Maybe Double
e1 Maybe Double
e2
Bin Op
Sub Maybe Double
e1 Maybe Double
e2 -> forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (-) Maybe Double
e1 Maybe Double
e2
Bin Op
Mul Maybe Double
e1 Maybe Double
e2 -> forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall a. Num a => a -> a -> a
(*) Maybe Double
e1 Maybe Double
e2
Bin Op
Add Maybe Double
e1 Maybe Double
e2 -> forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall a. Num a => a -> a -> a
(+) Maybe Double
e1 Maybe Double
e2
Bin Op
Power Maybe Double
e1 Maybe Double
e2 -> forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall a. Floating a => a -> a -> a
(**) Maybe Double
e1 Maybe Double
e2
Uni Function
f Maybe Double
e1 -> forall a. Floating a => Function -> a -> a
evalFun Function
f forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Double
e1
Var Int
_ -> forall a. Maybe a
Nothing
Const Double
x -> forall a. a -> Maybe a
Just Double
x
Param Int
_ -> forall a. Maybe a
Nothing
instance Language SRTree
cost :: CostFunction SRTree Int
cost :: CostFunction SRTree Int
cost = \case
Const Double
_ -> Int
5
Var Int
_ -> Int
1
Bin Op
_ Int
c1 Int
c2 -> Int
c1 forall a. Num a => a -> a -> a
+ Int
c2 forall a. Num a => a -> a -> a
+ Int
1
Uni Function
_ Int
c -> Int
c forall a. Num a => a -> a -> a
+ Int
1
Param Int
_ -> Int
5
unsafeGetSubst :: Pattern SRTree -> Subst -> ClassId
unsafeGetSubst :: Pattern SRTree -> Subst -> Int
unsafeGetSubst (NonVariablePattern SRTree (Pattern SRTree)
_) Subst
_ = forall a. HasCallStack => String -> a
error String
"unsafeGetSubst: NonVariablePattern; expecting VariablePattern"
unsafeGetSubst (VariablePattern Int
v) Subst
subst = case forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
v Subst
subst of
Maybe Int
Nothing -> forall a. HasCallStack => String -> a
error String
"Searching for non existent bound var in conditional"
Just Int
class_id -> Int
class_id
is_not_zero :: Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_zero :: Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_zero Pattern SRTree
v Subst
subst EGraph (Maybe Double) SRTree
egr =
EGraph (Maybe Double) SRTree
egr forall s a. s -> Lens' s a -> a
L.^.forall a (l :: * -> *). Int -> Lens' (EGraph a l) (EClass a l)
_class (Pattern SRTree -> Subst -> Int
unsafeGetSubst Pattern SRTree
v Subst
subst)forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall domain (l :: * -> *). Lens' (EClass domain l) domain
_data forall a. Eq a => a -> a -> Bool
/= forall a. a -> Maybe a
Just Double
0
is_not_neg_consts :: Pattern SRTree -> Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_neg_consts :: Pattern SRTree
-> Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_neg_consts Pattern SRTree
v1 Pattern SRTree
v2 Subst
subst EGraph (Maybe Double) SRTree
egr =
(forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Ord a => a -> a -> Bool
>=Double
0) (EGraph (Maybe Double) SRTree
egr forall s a. s -> Lens' s a -> a
L.^.forall a (l :: * -> *). Int -> Lens' (EGraph a l) (EClass a l)
_class (Pattern SRTree -> Subst -> Int
unsafeGetSubst Pattern SRTree
v1 Subst
subst)forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall domain (l :: * -> *). Lens' (EClass domain l) domain
_data) forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just Bool
True) Bool -> Bool -> Bool
||
(forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Ord a => a -> a -> Bool
>=Double
0) (EGraph (Maybe Double) SRTree
egr forall s a. s -> Lens' s a -> a
L.^.forall a (l :: * -> *). Int -> Lens' (EGraph a l) (EClass a l)
_class (Pattern SRTree -> Subst -> Int
unsafeGetSubst Pattern SRTree
v2 Subst
subst)forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall domain (l :: * -> *). Lens' (EClass domain l) domain
_data) forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just Bool
True)
is_negative :: Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_negative :: Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_negative Pattern SRTree
v Subst
subst EGraph (Maybe Double) SRTree
egr =
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Ord a => a -> a -> Bool
<Double
0) (EGraph (Maybe Double) SRTree
egr forall s a. s -> Lens' s a -> a
L.^.forall a (l :: * -> *). Int -> Lens' (EGraph a l) (EClass a l)
_class (Pattern SRTree -> Subst -> Int
unsafeGetSubst Pattern SRTree
v Subst
subst)forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall domain (l :: * -> *). Lens' (EClass domain l) domain
_data) forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just Bool
True
is_const :: Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const :: Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
v Subst
subst EGraph (Maybe Double) SRTree
egr =
forall a. Maybe a -> Bool
isJust (EGraph (Maybe Double) SRTree
egr forall s a. s -> Lens' s a -> a
L.^.forall a (l :: * -> *). Int -> Lens' (EGraph a l) (EClass a l)
_class (Pattern SRTree -> Subst -> Int
unsafeGetSubst Pattern SRTree
v Subst
subst)forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall domain (l :: * -> *). Lens' (EClass domain l) domain
_data)
is_not_const :: Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const :: Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
v Subst
subst EGraph (Maybe Double) SRTree
egr =
forall a. Maybe a -> Bool
isNothing (EGraph (Maybe Double) SRTree
egr forall s a. s -> Lens' s a -> a
L.^.forall a (l :: * -> *). Int -> Lens' (EGraph a l) (EClass a l)
_class (Pattern SRTree -> Subst -> Int
unsafeGetSubst Pattern SRTree
v Subst
subst)forall b c a. (b -> c) -> (a -> b) -> a -> c
.forall domain (l :: * -> *). Lens' (EClass domain l) domain
_data)
rewritesBasic :: [Rewrite (Maybe Double) SRTree]
rewritesBasic :: [Rewrite (Maybe Double) SRTree]
rewritesBasic =
[
Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ Pattern SRTree
"y" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"y" forall a. Num a => a -> a -> a
+ Pattern SRTree
"x"
, Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"y" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"y" forall a. Num a => a -> a -> a
* Pattern SRTree
"x"
, Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** Pattern SRTree
2
, (Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** Pattern SRTree
"a") forall a. Num a => a -> a -> a
* Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** (Pattern SRTree
"a" forall a. Num a => a -> a -> a
+ Pattern SRTree
1)
, (Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** Pattern SRTree
"a") forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** Pattern SRTree
"b") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** (Pattern SRTree
"a" forall a. Num a => a -> a -> a
+ Pattern SRTree
"b")
, (Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ Pattern SRTree
"y") forall a. Num a => a -> a -> a
+ Pattern SRTree
"z" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ (Pattern SRTree
"y" forall a. Num a => a -> a -> a
+ Pattern SRTree
"z")
, (Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall a. Num a => a -> a -> a
* Pattern SRTree
"z" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Num a => a -> a -> a
* (Pattern SRTree
"y" forall a. Num a => a -> a -> a
* Pattern SRTree
"z")
, (Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"z" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Num a => a -> a -> a
* (Pattern SRTree
"y" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"z")
, Pattern SRTree
"x" forall a. Num a => a -> a -> a
- (Pattern SRTree
"y" forall a. Num a => a -> a -> a
+ Pattern SRTree
"z") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
"y") forall a. Num a => a -> a -> a
- Pattern SRTree
"z"
, Pattern SRTree
"x" forall a. Num a => a -> a -> a
- (Pattern SRTree
"y" forall a. Num a => a -> a -> a
- Pattern SRTree
"z") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
"y") forall a. Num a => a -> a -> a
+ Pattern SRTree
"z"
, forall a. Num a => a -> a
negate (Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ Pattern SRTree
"y") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Num a => a -> a
negate Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
"y"
, (Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
"a") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ forall a. Num a => a -> a
negate Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
, (Pattern SRTree
"x" forall a. Num a => a -> a -> a
- (Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"y")) forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ (forall a. Num a => a -> a
negate Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
, (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"x") forall a. Num a => a -> a -> a
* (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"y") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ (Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"y")
, (Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"x") forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt (Pattern SRTree
1 forall a. Num a => a -> a -> a
+ (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall a. Floating a => a -> a -> a
** Pattern SRTree
2) forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
"a" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"x") forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt (Pattern SRTree
1 forall a. Num a => a -> a -> a
+ Pattern SRTree
"y" forall a. Floating a => a -> a -> a
** Pattern SRTree
2) forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b"
]
rewritesFun :: [Rewrite (Maybe Double) SRTree]
rewritesFun :: [Rewrite (Maybe Double) SRTree]
rewritesFun = [
forall a. Floating a => a -> a
log (Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Floating a => a -> a
log Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ forall a. Floating a => a -> a
log Pattern SRTree
"y" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree
-> Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_neg_consts Pattern SRTree
"x" Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_zero Pattern SRTree
"x"
, Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** Pattern SRTree
"b" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** (Pattern SRTree
"a" forall a. Num a => a -> a -> a
+ Pattern SRTree
"b")
, forall a. Floating a => a -> a
log (Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"y") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Floating a => a -> a
log Pattern SRTree
"x" forall a. Num a => a -> a -> a
- forall a. Floating a => a -> a
log Pattern SRTree
"y" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree
-> Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_neg_consts Pattern SRTree
"x" Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_zero Pattern SRTree
"x"
, forall a. Floating a => a -> a
log (Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** Pattern SRTree
"y") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"y" forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree
-> Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_neg_consts Pattern SRTree
"y" Pattern SRTree
"y" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_zero Pattern SRTree
"y"
, forall a. Floating a => a -> a
log (forall a. Floating a => a -> a
sqrt Pattern SRTree
"x") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
0.5 forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
, forall a. Floating a => a -> a
log (forall a. Floating a => a -> a
exp Pattern SRTree
"x") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
, forall a. Floating a => a -> a
exp (forall a. Floating a => a -> a
log Pattern SRTree
"x") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
, Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** (Pattern SRTree
1forall a. Fractional a => a -> a -> a
/Pattern SRTree
2) forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Floating a => a -> a
sqrt Pattern SRTree
"x"
, forall a. Floating a => a -> a
sqrt (Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"x") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Floating a => a -> a
sqrt Pattern SRTree
"a" forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree
-> Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_neg_consts Pattern SRTree
"a" Pattern SRTree
"x"
, forall a. Floating a => a -> a
sqrt (Pattern SRTree
"a" forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
"y")) forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Floating a => a -> a
sqrt (forall a. Num a => a -> a
negate Pattern SRTree
"a") forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt (Pattern SRTree
"y" forall a. Num a => a -> a -> a
- Pattern SRTree
"x") forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_negative Pattern SRTree
"a"
, forall a. Floating a => a -> a
sqrt (Pattern SRTree
"a" forall a. Num a => a -> a -> a
* (Pattern SRTree
"b" forall a. Num a => a -> a -> a
+ Pattern SRTree
"y")) forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Floating a => a -> a
sqrt (forall a. Num a => a -> a
negate Pattern SRTree
"a") forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt (Pattern SRTree
"b" forall a. Num a => a -> a -> a
- Pattern SRTree
"y") forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_negative Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_negative Pattern SRTree
"b"
, forall a. Floating a => a -> a
sqrt (Pattern SRTree
"a" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"x") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Floating a => a -> a
sqrt Pattern SRTree
"a" forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree
-> Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_neg_consts Pattern SRTree
"a" Pattern SRTree
"x"
, forall a. Num a => a -> a
abs (Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Num a => a -> a
abs Pattern SRTree
"x" forall a. Num a => a -> a -> a
* forall a. Num a => a -> a
abs Pattern SRTree
"y"
]
constReduction :: [Rewrite (Maybe Double) SRTree]
constReduction :: [Rewrite (Maybe Double) SRTree]
constReduction = [
Pattern SRTree
0 forall a. Num a => a -> a -> a
+ Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x"
, Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ Pattern SRTree
0 forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x"
, Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
0 forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x"
, Pattern SRTree
1 forall a. Num a => a -> a -> a
* Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x"
, Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
1 forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x"
, Pattern SRTree
0 forall a. Num a => a -> a -> a
* Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
0
, Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
0 forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
0
, Pattern SRTree
0 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
0
, Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
0
, Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
1 forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_zero Pattern SRTree
"x"
, Pattern SRTree
"x" forall a. Floating a => a -> a -> a
** Pattern SRTree
1 forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x"
, Pattern SRTree
0 forall a. Floating a => a -> a -> a
** Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
0
, Pattern SRTree
1 forall a. Floating a => a -> a -> a
** Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
1
, Pattern SRTree
"x" forall a. Num a => a -> a -> a
* (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"x") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
1 forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_zero Pattern SRTree
"x"
, (Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall a. Num a => a -> a -> a
+ (Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"z") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Num a => a -> a -> a
* (Pattern SRTree
"y" forall a. Num a => a -> a -> a
+ Pattern SRTree
"z")
, Pattern SRTree
"x" forall a. Num a => a -> a -> a
- ( (-Pattern SRTree
1) forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ Pattern SRTree
"y" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
, Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ forall a. Num a => a -> a
negate Pattern SRTree
"y" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
"y" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
, Pattern SRTree
0 forall a. Num a => a -> a -> a
- Pattern SRTree
"x" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= forall a. Num a => a -> a
negate Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
, (Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"x") forall a. Num a => a -> a -> a
* (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"b") forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
, Pattern SRTree
"a" forall a. Fractional a => a -> a -> a
/ (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"x") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
"a" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"b") forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
]
constFusion :: [Rewrite (Maybe Double) SRTree]
constFusion :: [Rewrite (Maybe Double) SRTree]
constFusion = [
Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ Pattern SRTree
"b" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"a" forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ (Pattern SRTree
"b" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a")) forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
, Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ Pattern SRTree
"b" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"y" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"a" forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ (Pattern SRTree
"b" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a") forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"y") forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
, Pattern SRTree
"a" forall a. Num a => a -> a -> a
* Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
"b" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"y" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= Pattern SRTree
"a" forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
- (Pattern SRTree
"b" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a") forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"y") forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
, Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"b") forall a. Num a => a -> a -> a
* Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"y" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
, Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a" forall a. Num a => a -> a -> a
+ Pattern SRTree
"b" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a") forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"a")) forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
, Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a" forall a. Num a => a -> a -> a
- Pattern SRTree
"b" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a") forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
- (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"a")) forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
, Pattern SRTree
"b" forall a. Num a => a -> a -> a
- Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a") forall a. Num a => a -> a -> a
* ((Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"a") forall a. Num a => a -> a -> a
- Pattern SRTree
"x") forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x"
, Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a" forall a. Num a => a -> a -> a
+ Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"y" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a") forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"a") forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
, Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a" forall a. Num a => a -> a -> a
+ Pattern SRTree
"y" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"b" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a") forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
+ Pattern SRTree
"y" forall a. Fractional a => a -> a -> a
/ (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"a")) forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
, Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a" forall a. Num a => a -> a -> a
- Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"y" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a") forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
- (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"a") forall a. Num a => a -> a -> a
* Pattern SRTree
"y") forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
, Pattern SRTree
"x" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a" forall a. Num a => a -> a -> a
- Pattern SRTree
"b" forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"y" forall anl (lang :: * -> *).
Pattern lang -> Pattern lang -> Rewrite anl lang
:= (Pattern SRTree
1 forall a. Fractional a => a -> a -> a
/ Pattern SRTree
"a") forall a. Num a => a -> a -> a
* (Pattern SRTree
"x" forall a. Num a => a -> a -> a
- Pattern SRTree
"y" forall a. Fractional a => a -> a -> a
/ (Pattern SRTree
"b" forall a. Num a => a -> a -> a
* Pattern SRTree
"a")) forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"a" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_const Pattern SRTree
"b" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"x" forall anl (lang :: * -> *).
Rewrite anl lang -> RewriteCondition anl lang -> Rewrite anl lang
:| Pattern SRTree -> RewriteCondition (Maybe Double) SRTree
is_not_const Pattern SRTree
"y"
]
rewriteTree :: (Analysis a l, Language l, Ord cost) => [Rewrite a l] -> Int -> Int -> CostFunction l cost -> Fix l -> Fix l
rewriteTree :: forall a (l :: * -> *) cost.
(Analysis a l, Language l, Ord cost) =>
[Rewrite a l]
-> Int -> Int -> CostFunction l cost -> Fix l -> Fix l
rewriteTree [Rewrite a l]
rules Int
n Int
coolOff CostFunction l cost
c Fix l
t = forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a (l :: * -> *) schd cost.
(Analysis a l, Language l, Scheduler schd, Ord cost) =>
schd
-> Fix l
-> [Rewrite a l]
-> CostFunction l cost
-> (Fix l, EGraph a l)
equalitySaturation' (Int -> Int -> BackoffScheduler
BackoffScheduler Int
n Int
coolOff) Fix l
t [Rewrite a l]
rules CostFunction l cost
c
rewriteAll, rewriteConst :: Fix SRTree -> Fix SRTree
rewriteAll :: Fix SRTree -> Fix SRTree
rewriteAll = forall a (l :: * -> *) cost.
(Analysis a l, Language l, Ord cost) =>
[Rewrite a l]
-> Int -> Int -> CostFunction l cost -> Fix l -> Fix l
rewriteTree ([Rewrite (Maybe Double) SRTree]
rewritesBasic forall a. Semigroup a => a -> a -> a
<> [Rewrite (Maybe Double) SRTree]
constReduction forall a. Semigroup a => a -> a -> a
<> [Rewrite (Maybe Double) SRTree]
constFusion forall a. Semigroup a => a -> a -> a
<> [Rewrite (Maybe Double) SRTree]
rewritesFun) Int
2500 Int
30 CostFunction SRTree Int
cost
rewriteConst :: Fix SRTree -> Fix SRTree
rewriteConst = forall a (l :: * -> *) cost.
(Analysis a l, Language l, Ord cost) =>
[Rewrite a l]
-> Int -> Int -> CostFunction l cost -> Fix l -> Fix l
rewriteTree [Rewrite (Maybe Double) SRTree]
constReduction Int
100 Int
10 CostFunction SRTree Int
cost
rewriteUntilNoChange :: [Fix SRTree -> Fix SRTree] -> Int -> Fix SRTree -> Fix SRTree
rewriteUntilNoChange :: [Fix SRTree -> Fix SRTree] -> Int -> Fix SRTree -> Fix SRTree
rewriteUntilNoChange [Fix SRTree -> Fix SRTree]
_ Int
0 Fix SRTree
t = Fix SRTree
t
rewriteUntilNoChange [Fix SRTree -> Fix SRTree]
rs Int
n Fix SRTree
t
| Fix SRTree
t forall a. Eq a => a -> a -> Bool
== Fix SRTree
t' = Fix SRTree
t'
| Bool
otherwise = [Fix SRTree -> Fix SRTree] -> Int -> Fix SRTree -> Fix SRTree
rewriteUntilNoChange (forall a. [a] -> [a]
tail [Fix SRTree -> Fix SRTree]
rs forall a. Semigroup a => a -> a -> a
<> [forall a. [a] -> a
head [Fix SRTree -> Fix SRTree]
rs]) (Int
nforall a. Num a => a -> a -> a
-Int
1) Fix SRTree
t'
where t' :: Fix SRTree
t' = forall a. [a] -> a
head [Fix SRTree -> Fix SRTree]
rs Fix SRTree
t
simplifyEqSat :: R.Fix SRTree -> R.Fix SRTree
simplifyEqSat :: Fix SRTree -> Fix SRTree
simplifyEqSat = Fix SRTree -> Fix SRTree
relabelParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fix SRTree -> Fix SRTree
fromEqFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Fix SRTree -> Fix SRTree] -> Int -> Fix SRTree -> Fix SRTree
rewriteUntilNoChange [Fix SRTree -> Fix SRTree
rewriteAll] Int
2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fix SRTree -> Fix SRTree
rewriteConst forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fix SRTree -> Fix SRTree
toEqFix
fromEqFix :: Fix SRTree -> R.Fix SRTree
fromEqFix :: Fix SRTree -> Fix SRTree
fromEqFix = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (Fix SRTree) -> Fix SRTree
alg
where
alg :: SRTree (Fix SRTree) -> Fix SRTree
alg (Const Double
x) = forall (f :: * -> *). f (Fix f) -> Fix f
R.Fix (forall val. Double -> SRTree val
Const Double
x)
alg (Var Int
ix) = forall (f :: * -> *). f (Fix f) -> Fix f
R.Fix (forall val. Int -> SRTree val
Var Int
ix)
alg (Param Int
ix) = forall (f :: * -> *). f (Fix f) -> Fix f
R.Fix (forall val. Int -> SRTree val
Param Int
ix)
alg (Bin Op
op Fix SRTree
l Fix SRTree
r) = forall (f :: * -> *). f (Fix f) -> Fix f
R.Fix (forall val. Op -> val -> val -> SRTree val
Bin Op
op Fix SRTree
l Fix SRTree
r)
alg (Uni Function
f Fix SRTree
t) = forall (f :: * -> *). f (Fix f) -> Fix f
R.Fix (forall val. Function -> val -> SRTree val
Uni Function
f Fix SRTree
t)
toEqFix :: R.Fix SRTree -> Fix SRTree
toEqFix :: Fix SRTree -> Fix SRTree
toEqFix = forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
R.cata SRTree (Fix SRTree) -> Fix SRTree
alg
where
alg :: SRTree (Fix SRTree) -> Fix SRTree
alg (Const Double
x) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Double -> SRTree val
Const Double
x)
alg (Var Int
ix) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Int -> SRTree val
Var Int
ix)
alg (Param Int
ix) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Int -> SRTree val
Param Int
ix)
alg (Bin Op
op Fix SRTree
l Fix SRTree
r) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Op -> val -> val -> SRTree val
Bin Op
op Fix SRTree
l Fix SRTree
r)
alg (Uni Function
f Fix SRTree
t) = forall (f :: * -> *). f (Fix f) -> Fix f
Fix (forall val. Function -> val -> SRTree val
Uni Function
f Fix SRTree
t)