{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE CPP #-} module Sym where import GHC.Generics import Test.Tasty import Test.Tasty.HUnit import qualified Data.IntMap.Strict as IM import qualified Data.Set as S import Data.String import Data.Maybe (isJust) import qualified Data.Foldable as F #if MIN_VERSION_base(4,18,0) #else import Control.Applicative (liftA2) #endif import Control.Monad (unless) import Data.Function ((&)) import Data.Equality.Graph.Lens import Data.Equality.Graph import Data.Equality.Extraction import Data.Equality.Analysis import Data.Equality.Matching import Data.Equality.Matching.Database import Data.Equality.Saturation data Expr a = Sym !String | Const !Double | UnOp !UOp !a | BinOp !BOp !a !a deriving ( Eq, Ord, Show , Functor, Foldable, Traversable , Generic ) data BOp = Add | Sub | Mul | Div | Pow | Diff | Integral deriving (Eq, Ord, Show, Generic) data UOp = Sin | Cos | Sqrt | Ln deriving (Eq, Ord, Show, Generic) instance IsString (Fix Expr) where fromString = Fix . Sym instance Num (Fix Expr) where (+) a b = Fix (BinOp Add a b) (-) a b = Fix (BinOp Sub a b) (*) a b = Fix (BinOp Mul a b) fromInteger = Fix . Const . fromInteger negate = error "DONT USE" abs = error "abs" signum = error "signum" instance Fractional (Fix Expr) where (/) a b = Fix (BinOp Div a b) fromRational = Fix . Const . fromRational symCost :: CostFunction Expr Int symCost = \case BinOp Pow e1 e2 -> e1 + e2 + 6 BinOp Div e1 e2 -> e1 + e2 + 5 BinOp Sub e1 e2 -> e1 + e2 + 4 BinOp Mul e1 e2 -> e1 + e2 + 4 BinOp Add e1 e2 -> e1 + e2 + 2 BinOp Diff e1 e2 -> e1 + e2 + 500 BinOp Integral e1 e2 -> e1 + e2 + 20000 UnOp Sin e1 -> e1 + 20 UnOp Cos e1 -> e1 + 20 UnOp Sqrt e1 -> e1 + 30 UnOp Ln e1 -> e1 + 30 Sym _ -> 1 Const _ -> 1 instance Num (Pattern Expr) where (+) a b = NonVariablePattern $ BinOp Add a b (-) a b = NonVariablePattern $ BinOp Sub a b (*) a b = NonVariablePattern $ BinOp Mul a b fromInteger = NonVariablePattern . Const . fromInteger negate = error "DONT USE" -- NonVariablePattern. BinOp Mul (fromInteger $ -1) abs = error "abs" signum = error "signum" instance Fractional (Pattern Expr) where (/) a b = NonVariablePattern $ BinOp Div a b fromRational = NonVariablePattern . Const . fromRational -- | Define analysis for the @Expr@ language over domain @Maybe Double@ for -- constant folding instance Analysis (Maybe Double) Expr where makeA = evalConstant -- joinA = (<|>) joinA ma mb = do a <- ma b <- mb -- this assertion only seemed to be triggering when using bogus -- constant assignments for "Fold all classes with x:=c" -- 0 bug found by property checking !_ <- unless (a == b || (a == 0 && b == (-0)) || (a == (-0) && b == 0)) (error "Merged non-equal constants!") return a modifyA cl eg0 = case eg0^._class cl._data of Nothing -> eg0 Just d -> -- Add constant as e-node let (new_c,eg1) = represent (Fix (Const d)) eg0 (rep, eg2) = merge cl new_c eg1 -- Prune all except leaf e-nodes in eg2 & _class rep._nodes %~ S.filter (F.null .unNode) evalConstant :: Expr (Maybe Double) -> Maybe Double evalConstant = \case -- Exception: Negative exponent: BinOp Pow e1 e2 -> liftA2 (^) e1 (round <$> e2 :: Maybe Integer) BinOp Div e1 e2 -> liftA2 (/) e1 e2 BinOp Sub e1 e2 -> liftA2 (-) e1 e2 BinOp Mul e1 e2 -> liftA2 (*) e1 e2 BinOp Add e1 e2 -> liftA2 (+) e1 e2 BinOp Pow _ _ -> Nothing BinOp Diff _ _ -> Nothing BinOp Integral _ _ -> Nothing UnOp Sin e1 -> sin <$> e1 UnOp Cos e1 -> cos <$> e1 UnOp Sqrt e1 -> sqrt <$> e1 UnOp Ln _ -> Nothing Sym _ -> Nothing Const x -> Just x unsafeGetSubst :: Pattern Expr -> Subst -> ClassId unsafeGetSubst (NonVariablePattern _) _ = error "unsafeGetSubst: NonVariablePattern; expecting VariablePattern" unsafeGetSubst (VariablePattern v) subst = case IM.lookup v subst of Nothing -> error "Searching for non existent bound var in conditional" Just class_id -> class_id is_not_zero :: Pattern Expr -> RewriteCondition (Maybe Double) Expr is_not_zero v subst egr = egr^._class (unsafeGetSubst v subst)._data /= Just 0 is_sym :: Pattern Expr -> RewriteCondition (Maybe Double) Expr is_sym v subst egr = any ((\case (Sym _) -> True; _ -> False) . unNode) (egr^._class (unsafeGetSubst v subst)._nodes) is_const :: Pattern Expr -> RewriteCondition (Maybe Double) Expr is_const v subst egr = isJust (egr^._class (unsafeGetSubst v subst)._data) is_const_or_distinct_var :: Pattern Expr -> Pattern Expr -> RewriteCondition (Maybe Double) Expr is_const_or_distinct_var v w subst egr = let v' = unsafeGetSubst v subst w' = unsafeGetSubst w subst in (eClassId (egr^._class v') /= eClassId (egr^._class w')) && (isJust (egr^._class v'._data) || any ((\case (Sym _) -> True; _ -> False) . unNode) (egr^._class v'._nodes)) rewrites :: [Rewrite (Maybe Double) Expr] rewrites = [ "a"+"b" := "b"+"a" -- comm add , "a"*"b" := "b"*"a" -- comm mul , "a"+("b"+"c") := ("a"+"b")+"c" -- assoc add , "a"*("b"*"c") := ("a"*"b")*"c" -- assoc mul , "a"-"b" := "a"+(fromInteger (-1) * "b") -- sub cannon , "a"/"b" := "a"*powP "b" (fromInteger $ -1) :| is_not_zero "b" -- div cannon -- identities , "a"+0 := "a" , "a"*0 := 0 , "a"*1 := "a" -- TODO This causes many problems -- , "a" := "a"+0 -- This already works , "a" := "a"*1 , "a"-"a" := 0 -- cancel sub , "a"/"a" := 1 :| is_not_zero "a" -- cancel div , "a"*("b"+"c") := ("a"*"b")+("a"*"c") -- distribute , ("a"*"b")+("a"*"c") := "a"*("b"+"c") -- factor , powP "a" "b"*powP "a" "c" := powP "a" ("b" + "c") -- pow mul , powP "a" 0 := 1 :| is_not_zero "a" , powP "a" 1 := "a" , powP "a" 2 := "a"*"a" , powP "a" (fromInteger $ -1) := 1/"a" :| is_not_zero "a" , "x"*(1/"x") := 1 :| is_not_zero "x" , diffP "x" "x" := 1 :| is_sym "x" , diffP "x" "c" := 0 :| is_sym "x" :| is_const_or_distinct_var "c" "x" , diffP "x" ("a" + "b") := diffP "x" "a" + diffP "x" "b" , diffP "x" ("a" * "b") := ("a"*diffP "x" "b") + ("b"*diffP "x" "a") , diffP "x" (sinP "x") := cosP "x" , diffP "x" (cosP "x") := fromInteger (-1) * sinP "x" , diffP "x" (lnP "x") := 1/"x" :| is_not_zero "x" -- diff-power , diffP "x" (powP "f" "g") := powP "f" "g" * ((diffP "x" "f" * ("g" / "f")) + (diffP "x" "g" * lnP "f")) :| is_not_zero "f" :| is_not_zero "g" -- i-one , intP 1 "x" := "x" -- i power const , intP (powP "x" "c") "x" := (/) (powP "x" ((+) "c" 1)) ((+) "c" 1) :| is_const "c" , intP (cosP "x") "x" := sinP "x" , intP (sinP "x") "x" := fromInteger (-1)*cosP "x" , intP ("f" + "g") "x" := intP "f" "x" + intP "g" "x" , intP ("f" - "g") "x" := intP "f" "x" - intP "g" "x" , intP ("a" * "b") "x" := (-) ((*) "a" (intP "b" "x")) (intP ((*) (diffP "x" "a") (intP "b" "x")) "x") -- Additional ad-hoc: because of negate representations? , "a"-(fromInteger (-1)*"b") := "a"+"b" ] rewrite :: Fix Expr -> Fix Expr rewrite e = fst $ equalitySaturation e rewrites symCost symTests :: TestTree symTests = testGroup "Symbolic" [ testCase "(a*2)/2 = a (custom rules)" $ fst (equalitySaturation @(Maybe Double) (("a"*2)/2) [ ("x"*"y")/"z" := "x"*("y"/"z") , "y"/"y" := 1 , "x"*1 := "x"] symCost) @?= "a" , testCase "(a/2)*2 = a (all rules)" $ rewrite (("a"/2)*2) @?= "a" , testCase "(a+a)/2 = a (extra rules)" $ rewrite (("a"+"a")/2) @?= "a" , testCase "x/y (custom rules)" $ -- without backoff scheduler this will loop forever fst (equalitySaturation @(Maybe Double) ("x"/"y") [ "x"/"y" := "x"*(1/"y") , "x"*("y"*"z") := ("x"*"y")*"z" ] symCost) @?= ("x"/"y") , testCase "0+1 = 1 (all rules)" $ fst (equalitySaturation (0+1) rewrites symCost) @?= 1 , testCase "b*(1/b) = 1 (custom rules)" $ fst (equalitySaturation @(Maybe Double) ("b"*(1/"b")) [ "a"*(1/"a") := 1 ] symCost) @?= 1 , testCase "1+1=2 (constant folding)" $ fst (equalitySaturation @(Maybe Double) (1+1) [] symCost) @?= 2 , testCase "a*(2-1) (1 rule + constant folding)" $ fst (equalitySaturation @(Maybe Double) ("a" * (2-1)) ["x"*1:="x"] symCost) @?= "a" , testCase "1+a*(2-1) = 1+a (all + constant folding)" $ rewrite (1+("a"*(2-1))) @?= (1+"a") , testCase "1+a*(2-1) = 1+a (all + constant f.)" $ rewrite (fromInteger(-3)+fromInteger(-3)-6) @?= Fix (Const $ -12) , testCase "1+a-a*(2-1) = 1 (all + constant f.)" $ rewrite (1 + "a" - "a"*(2-1)) @?= 1 , testCase "1+(a-a*(2-1)) = 1 (all + constant f.)" $ rewrite ("a" - "a"*(4-1)) @?= "a"*(Fix . Const $ -2) , testCase "x + x + x + x = 4*x" $ rewrite ("a"+"a"+"a"+"a") @?= "a"*4 , testCase "math powers" $ rewrite (Fix (BinOp Pow 2 "x")*Fix (BinOp Pow 2 "y")) @?= Fix (BinOp Pow 2 ("x" + "y")) , testCase "d1" $ rewrite (Fix $ BinOp Diff "a" "a") @?= 1 , testCase "d2" $ rewrite (Fix $ BinOp Diff "a" "b") @?= 0 , testCase "d3" $ rewrite (Fix $ BinOp Diff "x" (1 + 2*"x")) @?= 2 , testCase "d4" $ rewrite (Fix $ BinOp Diff "x" (1 + "y"*"x")) @?= "y" , testCase "d5" $ rewrite (Fix $ BinOp Diff "x" (Fix $ UnOp Ln "x")) @?= 1/"x" , testCase "i1" $ rewrite (Fix $ BinOp Integral 1 "x") @?= "x" , testCase "i2" $ rewrite (Fix $ BinOp Integral (Fix $ UnOp Cos "x") "x") @?= Fix (UnOp Sin "x") , testCase "i3" $ rewrite (Fix $ BinOp Integral (Fix $ BinOp Pow "x" 1) "x") @?= "x"*("x"*0.5) , testCase "i4" $ rewrite (_i ((*) "x" (_cos "x")) "x") @?= (+) (_cos "x") ((*) "x" (_sin "x")) , testCase "i5" $ rewrite (_i ((*) (_cos "x") "x") "x") @?= (+) (_cos "x") ((*) "x" (_sin "x")) -- TODO: How does this even work ? , testCase "i6" $ rewrite (_i (_ln "x") "x") @?= "x"*(_ln "x" + fromInteger(-1)) -- TODO: Require ability to fine tune parameters -- , testCase "diff_power_harder" $ -- rewrite (_d "x" ((_pow "x" 3) - 7*(_pow "x" 2))) @?= "x"*(3*"x"-14) ] _i, _d, _pow :: Fix Expr -> Fix Expr -> Fix Expr _i a b = Fix (BinOp Integral a b) _d a b = Fix (BinOp Diff a b) _pow a b = Fix (BinOp Pow a b) _ln, _cos, _sin :: Fix Expr -> Fix Expr _ln a = Fix (UnOp Ln a) _cos a = Fix (UnOp Cos a) _sin a = Fix (UnOp Sin a) powP :: Pattern Expr -> Pattern Expr -> Pattern Expr powP a b = NonVariablePattern (BinOp Pow a b) diffP :: Pattern Expr -> Pattern Expr -> Pattern Expr diffP a b = NonVariablePattern (BinOp Diff a b) intP :: Pattern Expr -> Pattern Expr -> Pattern Expr intP a b = NonVariablePattern (BinOp Integral a b) cosP :: Pattern Expr -> Pattern Expr cosP a = NonVariablePattern (UnOp Cos a) sinP :: Pattern Expr -> Pattern Expr sinP a = NonVariablePattern (UnOp Sin a) lnP :: Pattern Expr -> Pattern Expr lnP a = NonVariablePattern (UnOp Ln a)