{-# LANGUAGE QuasiQuotes #-}

{- |
Module      : Language.Egison.Math.Rewrite
Licence     : MIT

This module implements rewrite rules for common mathematical functions.
-}

module Language.Egison.Math.Rewrite
  ( rewriteSymbol
  ) where

import           Control.Egison

import           Language.Egison.Math.Arith
import           Language.Egison.Math.Expr
import           Language.Egison.Math.Normalize


rewriteSymbol :: ScalarData -> ScalarData
rewriteSymbol :: ScalarData -> ScalarData
rewriteSymbol =
  ((ScalarData -> ScalarData)
 -> (ScalarData -> ScalarData) -> ScalarData -> ScalarData)
-> [ScalarData -> ScalarData] -> ScalarData -> ScalarData
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 (\ScalarData -> ScalarData
acc ScalarData -> ScalarData
f -> ScalarData -> ScalarData
f (ScalarData -> ScalarData)
-> (ScalarData -> ScalarData) -> ScalarData -> ScalarData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScalarData -> ScalarData
acc)
    [ ScalarData -> ScalarData
rewriteI
    , ScalarData -> ScalarData
rewriteW
    , ScalarData -> ScalarData
rewriteLog
    , ScalarData -> ScalarData
rewriteSinCos
    , ScalarData -> ScalarData
rewriteExp
    , ScalarData -> ScalarData
rewritePower
    , ScalarData -> ScalarData
rewriteSqrt
    , ScalarData -> ScalarData
rewriteRt
    , ScalarData -> ScalarData
rewriteRtu
    , ScalarData -> ScalarData
rewriteDd
    ]

mapTerms :: (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms :: (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f (Div (Plus [TermExpr]
ts1) (Plus [TermExpr]
ts2)) =
  PolyExpr -> PolyExpr -> ScalarData
Div ([TermExpr] -> PolyExpr
Plus ((TermExpr -> TermExpr) -> [TermExpr] -> [TermExpr]
forall a b. (a -> b) -> [a] -> [b]
map TermExpr -> TermExpr
f [TermExpr]
ts1)) ([TermExpr] -> PolyExpr
Plus ((TermExpr -> TermExpr) -> [TermExpr] -> [TermExpr]
forall a b. (a -> b) -> [a] -> [b]
map TermExpr -> TermExpr
f [TermExpr]
ts2))

mapTerms' :: (TermExpr -> ScalarData) -> ScalarData -> ScalarData
mapTerms' :: (TermExpr -> ScalarData) -> ScalarData -> ScalarData
mapTerms' TermExpr -> ScalarData
f (Div (Plus [TermExpr]
ts1) (Plus [TermExpr]
ts2)) =
  ScalarData -> ScalarData -> ScalarData
mathDiv ((ScalarData -> ScalarData -> ScalarData)
-> [ScalarData] -> ScalarData
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 ScalarData -> ScalarData -> ScalarData
mathPlus ((TermExpr -> ScalarData) -> [TermExpr] -> [ScalarData]
forall a b. (a -> b) -> [a] -> [b]
map TermExpr -> ScalarData
f [TermExpr]
ts1)) ((ScalarData -> ScalarData -> ScalarData)
-> [ScalarData] -> ScalarData
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 ScalarData -> ScalarData -> ScalarData
mathPlus ((TermExpr -> ScalarData) -> [TermExpr] -> [ScalarData]
forall a b. (a -> b) -> [a] -> [b]
map TermExpr -> ScalarData
f [TermExpr]
ts2))

mapPolys :: (PolyExpr -> PolyExpr) -> ScalarData -> ScalarData
mapPolys :: (PolyExpr -> PolyExpr) -> ScalarData -> ScalarData
mapPolys PolyExpr -> PolyExpr
f (Div PolyExpr
p1 PolyExpr
p2) = PolyExpr -> PolyExpr -> ScalarData
Div (PolyExpr -> PolyExpr
f PolyExpr
p1) (PolyExpr -> PolyExpr
f PolyExpr
p2)

rewriteI :: ScalarData -> ScalarData
rewriteI :: ScalarData -> ScalarData
rewriteI = (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f
 where
  f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
    ((Multiset (Pair SymbolM Eql), Monomial)
 -> DFS (Multiset (Pair SymbolM Eql), Monomial))
-> Monomial
-> Multiset (Pair SymbolM Eql)
-> [(Multiset (Pair SymbolM Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (Pair SymbolM Eql), Monomial)
-> DFS (Multiset (Pair SymbolM Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs (Pair SymbolM Eql -> Multiset (Pair SymbolM Eql)
forall m. m -> Multiset m
Multiset (SymbolM -> Eql -> Pair SymbolM Eql
forall m1 m2. m1 -> m2 -> Pair m1 m2
Pair SymbolM
SymbolM Eql
Eql))
      [ [mc| (symbol #"i", $k) : $xss ->
              if even k
                then Term (a * (-1) ^ (quot k 2)) xss
                else Term (a * (-1) ^ (quot k 2)) ((Symbol "" "i" [], 1) : xss) |]
      , [mc| _ -> term |]
      ]

rewriteW :: ScalarData -> ScalarData
rewriteW :: ScalarData -> ScalarData
rewriteW = (PolyExpr -> PolyExpr) -> ScalarData -> ScalarData
mapPolys PolyExpr -> PolyExpr
g (ScalarData -> ScalarData)
-> (ScalarData -> ScalarData) -> ScalarData -> ScalarData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f
 where
  f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
    ((Multiset (Pair SymbolM Eql), Monomial)
 -> DFS (Multiset (Pair SymbolM Eql), Monomial))
-> Monomial
-> Multiset (Pair SymbolM Eql)
-> [(Multiset (Pair SymbolM Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (Pair SymbolM Eql), Monomial)
-> DFS (Multiset (Pair SymbolM Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs (Pair SymbolM Eql -> Multiset (Pair SymbolM Eql)
forall m. m -> Multiset m
Multiset (SymbolM -> Eql -> Pair SymbolM Eql
forall m1 m2. m1 -> m2 -> Pair m1 m2
Pair SymbolM
SymbolM Eql
Eql))
      [ [mc| (symbol #"w", $k & ?(>= 3)) : $xss ->
               Term a ((Symbol "" "w" [], k `mod` 3) : xss) |]
      , [mc| _ -> term |]
      ]
  g :: PolyExpr -> PolyExpr
g poly :: PolyExpr
poly@(Plus [TermExpr]
ts) =
    ((Multiset TermM, [TermExpr]) -> DFS (Multiset TermM, [TermExpr]))
-> [TermExpr]
-> Multiset TermM
-> [(Multiset TermM, [TermExpr]) -> DFS PolyExpr]
-> PolyExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset TermM, [TermExpr]) -> DFS (Multiset TermM, [TermExpr])
forall a. a -> DFS a
dfs [TermExpr]
ts (TermM -> Multiset TermM
forall m. m -> Multiset m
Multiset TermM
TermM)
      [ [mc| term $a ((symbol #"w", #2) : $mr) :
             term $b ((symbol #"w", #1) : #mr) : $pr ->
               g (Plus (Term (-a) mr :
                        Term (b - a) ((Symbol "" "w" [], 1) : mr) : pr)) |]
      , [mc| _ -> poly |]
      ]

rewriteLog :: ScalarData -> ScalarData
rewriteLog :: ScalarData -> ScalarData
rewriteLog = (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f
 where
  f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
    ((Multiset (Pair SymbolM Eql), Monomial)
 -> DFS (Multiset (Pair SymbolM Eql), Monomial))
-> Monomial
-> Multiset (Pair SymbolM Eql)
-> [(Multiset (Pair SymbolM Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (Pair SymbolM Eql), Monomial)
-> DFS (Multiset (Pair SymbolM Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs (Pair SymbolM Eql -> Multiset (Pair SymbolM Eql)
forall m. m -> Multiset m
Multiset (SymbolM -> Eql -> Pair SymbolM Eql
forall m1 m2. m1 -> m2 -> Pair m1 m2
Pair SymbolM
SymbolM Eql
Eql))
      [ [mc| (apply #"log" [zero], _) : _ -> Term 0 [] |]
      , [mc| (apply #"log" [singleTerm _ #1 [(symbol #"e", $n)]], _) : $xss ->
              Term (n * a) xss |]
      , [mc| _ -> term |]
      ]

makeApply :: String -> [ScalarData] -> SymbolExpr
makeApply :: String -> [ScalarData] -> SymbolExpr
makeApply String
f [ScalarData]
args =
  ScalarData -> [ScalarData] -> SymbolExpr
Apply (SymbolExpr -> ScalarData
SingleSymbol (String -> String -> [Index ScalarData] -> SymbolExpr
Symbol String
"" String
f [])) [ScalarData]
args

rewriteExp :: ScalarData -> ScalarData
rewriteExp :: ScalarData -> ScalarData
rewriteExp = (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f
 where
  f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
    ((Multiset (Pair SymbolM Eql), Monomial)
 -> DFS (Multiset (Pair SymbolM Eql), Monomial))
-> Monomial
-> Multiset (Pair SymbolM Eql)
-> [(Multiset (Pair SymbolM Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (Pair SymbolM Eql), Monomial)
-> DFS (Multiset (Pair SymbolM Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs (Pair SymbolM Eql -> Multiset (Pair SymbolM Eql)
forall m. m -> Multiset m
Multiset (SymbolM -> Eql -> Pair SymbolM Eql
forall m1 m2. m1 -> m2 -> Pair m1 m2
Pair SymbolM
SymbolM Eql
Eql))
      [ [mc| (apply #"exp" [zero], _) : $xss ->
               f (Term a xss) |]
      , [mc| (apply #"exp" [singleTerm #1 #1 []], _) : $xss ->
               f (Term a ((Symbol "" "e" [], 1) : xss)) |]
      , [mc| (apply #"exp" [singleTerm $n #1 [(symbol #"i", #1), (symbol #"π", #1)]], _) : $xss ->
               f (Term ((-1) ^ n * a) xss) |]
      , [mc| (apply #"exp" [$x], $n & ?(>= 2)) : $xss ->
               f (Term a ((makeApply "exp" [mathScalarMult n x], 1) : xss)) |]
      , [mc| (apply #"exp" [$x], #1) : (apply #"exp" [$y], #1) : $xss ->
               f (Term a ((makeApply "exp" [mathPlus x y], 1) : xss)) |]
      , [mc| _ -> term |]
      ]

rewritePower :: ScalarData -> ScalarData
rewritePower :: ScalarData -> ScalarData
rewritePower = (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f
 where
  f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
    ((Multiset (Pair SymbolM Eql), Monomial)
 -> DFS (Multiset (Pair SymbolM Eql), Monomial))
-> Monomial
-> Multiset (Pair SymbolM Eql)
-> [(Multiset (Pair SymbolM Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (Pair SymbolM Eql), Monomial)
-> DFS (Multiset (Pair SymbolM Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs (Pair SymbolM Eql -> Multiset (Pair SymbolM Eql)
forall m. m -> Multiset m
Multiset (SymbolM -> Eql -> Pair SymbolM Eql
forall m1 m2. m1 -> m2 -> Pair m1 m2
Pair SymbolM
SymbolM Eql
Eql))
      [ [mc| (apply #"^" [singleTerm #1 #1 [], _], _) : $xss -> f (Term a xss) |]
      , [mc| (apply #"^" [$x, $y], $n & ?(>= 2)) : $xss ->
               f (Term a ((makeApply "^" [x, mathScalarMult n y], 1) : xss)) |]
      , [mc| (apply #"^" [$x, $y], #1) : (apply #"^" [#x, $z], #1) : $xss ->
               f (Term a ((makeApply "^" [x, mathPlus y z], 1) : xss)) |]
      , [mc| _ -> term |]
      ]

rewriteSinCos :: ScalarData -> ScalarData
rewriteSinCos :: ScalarData -> ScalarData
rewriteSinCos = (TermExpr -> ScalarData) -> ScalarData -> ScalarData
mapTerms' TermExpr -> ScalarData
h (ScalarData -> ScalarData)
-> (ScalarData -> ScalarData) -> ScalarData -> ScalarData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms (TermExpr -> TermExpr
g (TermExpr -> TermExpr)
-> (TermExpr -> TermExpr) -> TermExpr -> TermExpr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermExpr -> TermExpr
f)
 where
  f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
    ((Multiset (Pair SymbolM Eql), Monomial)
 -> DFS (Multiset (Pair SymbolM Eql), Monomial))
-> Monomial
-> Multiset (Pair SymbolM Eql)
-> [(Multiset (Pair SymbolM Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (Pair SymbolM Eql), Monomial)
-> DFS (Multiset (Pair SymbolM Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs (Pair SymbolM Eql -> Multiset (Pair SymbolM Eql)
forall m. m -> Multiset m
Multiset (SymbolM -> Eql -> Pair SymbolM Eql
forall m1 m2. m1 -> m2 -> Pair m1 m2
Pair SymbolM
SymbolM Eql
Eql))
      [ [mc| (apply #"sin" [zero], _) : _ -> Term 0 [] |]
      , [mc| (apply #"sin" [singleTerm _ #1 [(symbol #"π", #1)]], _) : _ ->
               Term 0 [] |]
      , [mc| (apply #"sin" [singleTerm $n #2 [(symbol #"π", #1)]], $m) : $xss ->
              Term (a * (-1) ^ (div (abs n - 1) 2) * m) xss |]
      , [mc| _ -> term |]
      ]
  g :: TermExpr -> TermExpr
g term :: TermExpr
term@(Term Integer
a Monomial
xs) =
    ((Multiset (Pair SymbolM Eql), Monomial)
 -> DFS (Multiset (Pair SymbolM Eql), Monomial))
-> Monomial
-> Multiset (Pair SymbolM Eql)
-> [(Multiset (Pair SymbolM Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (Pair SymbolM Eql), Monomial)
-> DFS (Multiset (Pair SymbolM Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs (Pair SymbolM Eql -> Multiset (Pair SymbolM Eql)
forall m. m -> Multiset m
Multiset (SymbolM -> Eql -> Pair SymbolM Eql
forall m1 m2. m1 -> m2 -> Pair m1 m2
Pair SymbolM
SymbolM Eql
Eql))
      [ [mc| (apply #"cos" [zero], _) : $xss -> Term a xss |]
      , [mc| (apply #"cos" [singleTerm _ #2 [(symbol #"π", #1)]], _) : _ ->
              Term 0 [] |]
      , [mc| (apply #"cos" [singleTerm $n #1 [(symbol #"π", #1)]], $m) : $xss ->
               Term (a * (-1) ^ (abs n * m)) xss |]
      , [mc| _ -> term |]
      ]
  h :: TermExpr -> ScalarData
h (Term Integer
a Monomial
xs) =
    ((Multiset (Pair SymbolM Eql), Monomial)
 -> DFS (Multiset (Pair SymbolM Eql), Monomial))
-> Monomial
-> Multiset (Pair SymbolM Eql)
-> [(Multiset (Pair SymbolM Eql), Monomial) -> DFS ScalarData]
-> ScalarData
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (Pair SymbolM Eql), Monomial)
-> DFS (Multiset (Pair SymbolM Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs (Pair SymbolM Eql -> Multiset (Pair SymbolM Eql)
forall m. m -> Multiset m
Multiset (SymbolM -> Eql -> Pair SymbolM Eql
forall m1 m2. m1 -> m2 -> Pair m1 m2
Pair SymbolM
SymbolM Eql
Eql))
      [ [mc| (apply #"cos" [$x], #2) : $mr ->
               mathMult
                 (mathMinus (SingleTerm 1 []) (SingleTerm 1 [(makeApply "sin" [x], 2)]))
                 (h (Term a mr)) |]
      , [mc| _ -> SingleTerm a xs |]
      ]

rewriteSqrt :: ScalarData -> ScalarData
rewriteSqrt :: ScalarData -> ScalarData
rewriteSqrt = (TermExpr -> ScalarData) -> ScalarData -> ScalarData
mapTerms' TermExpr -> ScalarData
f
 where
  f :: TermExpr -> ScalarData
f (Term Integer
a Monomial
xs) =
    ((Multiset (Pair SymbolM Eql), Monomial)
 -> DFS (Multiset (Pair SymbolM Eql), Monomial))
-> Monomial
-> Multiset (Pair SymbolM Eql)
-> [(Multiset (Pair SymbolM Eql), Monomial) -> DFS ScalarData]
-> ScalarData
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (Pair SymbolM Eql), Monomial)
-> DFS (Multiset (Pair SymbolM Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs (Pair SymbolM Eql -> Multiset (Pair SymbolM Eql)
forall m. m -> Multiset m
Multiset (SymbolM -> Eql -> Pair SymbolM Eql
forall m1 m2. m1 -> m2 -> Pair m1 m2
Pair SymbolM
SymbolM Eql
Eql))
      [ [mc| (apply #"sqrt" [$x], ?(> 1) & $k) : $xss ->
               rewriteSqrt
                 (mathMult (SingleTerm a ((makeApply "sqrt" [x], k `mod` 2) : xss))
                           (mathPower x (div k 2))) |]
      , [mc| (apply #"sqrt" [singleTerm $n #1 $x], #1) :
               (apply #"sqrt" [singleTerm $m #1 $y], #1) : $xss ->
             let d@(Term c z) = termsGcd [Term n x, Term m y]
                 Term n' x' = mathDivideTerm (Term n x) d
                 Term m' y' = mathDivideTerm (Term m y) d
                 in case (n' * m', Term n' x', Term m' y') of
                      (1, Term _ [], Term _ []) -> mathMult (SingleTerm c z) (SingleTerm a xss)
                      (_, _, _) -> mathMult (SingleTerm c z) (SingleTerm a ((makeApply "sqrt" [SingleTerm (n' * m') (x' ++ y')], 1) : xss)) |]
      , [mc| _ -> SingleTerm a xs |]
      ]

rewriteRt :: ScalarData -> ScalarData
rewriteRt :: ScalarData -> ScalarData
rewriteRt = (TermExpr -> ScalarData) -> ScalarData -> ScalarData
mapTerms' TermExpr -> ScalarData
f
 where
  f :: TermExpr -> ScalarData
f (Term Integer
a Monomial
xs) =
    ((Multiset (Pair SymbolM Eql), Monomial)
 -> DFS (Multiset (Pair SymbolM Eql), Monomial))
-> Monomial
-> Multiset (Pair SymbolM Eql)
-> [(Multiset (Pair SymbolM Eql), Monomial) -> DFS ScalarData]
-> ScalarData
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (Pair SymbolM Eql), Monomial)
-> DFS (Multiset (Pair SymbolM Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs (Pair SymbolM Eql -> Multiset (Pair SymbolM Eql)
forall m. m -> Multiset m
Multiset (SymbolM -> Eql -> Pair SymbolM Eql
forall m1 m2. m1 -> m2 -> Pair m1 m2
Pair SymbolM
SymbolM Eql
Eql))
      [ [mc| (apply #"rt" [singleTerm $n #1 [], $x] & $rtnx, ?(>= n) & $k) : $xss ->
               mathMult (SingleTerm a ((rtnx, k `mod` n) : xss))
                        (mathPower x (div k n)) |]
      , [mc| _ -> SingleTerm a xs |]
      ]

rewriteRtu :: ScalarData -> ScalarData
rewriteRtu :: ScalarData -> ScalarData
rewriteRtu = (TermExpr -> ScalarData) -> ScalarData -> ScalarData
mapTerms' TermExpr -> ScalarData
g (ScalarData -> ScalarData)
-> (ScalarData -> ScalarData) -> ScalarData -> ScalarData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TermExpr -> TermExpr) -> ScalarData -> ScalarData
mapTerms TermExpr -> TermExpr
f
 where
  f :: TermExpr -> TermExpr
f term :: TermExpr
term@(Term Integer
a Monomial
xs) =
    ((Multiset (Pair SymbolM Eql), Monomial)
 -> DFS (Multiset (Pair SymbolM Eql), Monomial))
-> Monomial
-> Multiset (Pair SymbolM Eql)
-> [(Multiset (Pair SymbolM Eql), Monomial) -> DFS TermExpr]
-> TermExpr
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (Pair SymbolM Eql), Monomial)
-> DFS (Multiset (Pair SymbolM Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs (Pair SymbolM Eql -> Multiset (Pair SymbolM Eql)
forall m. m -> Multiset m
Multiset (SymbolM -> Eql -> Pair SymbolM Eql
forall m1 m2. m1 -> m2 -> Pair m1 m2
Pair SymbolM
SymbolM Eql
Eql))
      [ [mc| (apply #"rtu" [singleTerm $n #1 []] & $rtun, ?(>= n) & $k) : $r ->
               Term a ((rtun, k `mod` n) : r) |]
      , [mc| _ -> term |]
      ]
  g :: TermExpr -> ScalarData
g (Term Integer
a Monomial
xs) =
    ((Multiset (Pair SymbolM Eql), Monomial)
 -> DFS (Multiset (Pair SymbolM Eql), Monomial))
-> Monomial
-> Multiset (Pair SymbolM Eql)
-> [(Multiset (Pair SymbolM Eql), Monomial) -> DFS ScalarData]
-> ScalarData
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset (Pair SymbolM Eql), Monomial)
-> DFS (Multiset (Pair SymbolM Eql), Monomial)
forall a. a -> DFS a
dfs Monomial
xs (Pair SymbolM Eql -> Multiset (Pair SymbolM Eql)
forall m. m -> Multiset m
Multiset (SymbolM -> Eql -> Pair SymbolM Eql
forall m1 m2. m1 -> m2 -> Pair m1 m2
Pair SymbolM
SymbolM Eql
Eql))
      [ [mc| (apply #"rtu" [singleTerm $n #1 []] & $rtun, ?(== n - 1)) : $mr ->
               mathMult
                 (foldl mathMinus (SingleTerm (-1) []) (map (\k -> SingleTerm 1 [(rtun, k)]) [1..(n-2)]))
                 (g (Term a mr)) |]
      , [mc| _ -> SingleTerm a xs |]
      ]

rewriteDd :: ScalarData -> ScalarData
rewriteDd :: ScalarData -> ScalarData
rewriteDd (Div (Plus [TermExpr]
p1) (Plus [TermExpr]
p2)) =
  PolyExpr -> PolyExpr -> ScalarData
Div ([TermExpr] -> PolyExpr
Plus ([TermExpr] -> [TermExpr]
rewriteDdPoly [TermExpr]
p1)) ([TermExpr] -> PolyExpr
Plus ([TermExpr] -> [TermExpr]
rewriteDdPoly [TermExpr]
p2))
 where
  rewriteDdPoly :: [TermExpr] -> [TermExpr]
rewriteDdPoly [TermExpr]
poly =
    ((Multiset TermM, [TermExpr]) -> DFS (Multiset TermM, [TermExpr]))
-> [TermExpr]
-> Multiset TermM
-> [(Multiset TermM, [TermExpr]) -> DFS [TermExpr]]
-> [TermExpr]
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Multiset TermM, [TermExpr]) -> DFS (Multiset TermM, [TermExpr])
forall a. a -> DFS a
dfs [TermExpr]
poly (TermM -> Multiset TermM
forall m. m -> Multiset m
Multiset TermM
TermM)
      [ [mc| term $a (($f & func $g $arg, $n) : $mr) :
               term $b ((func #g #arg, #n) : #mr) : $pr ->
                 rewriteDdPoly (Term (a + b) ((f, n) : mr) : pr) |]
      , [mc| _ -> poly |]
      ]