{-# LANGUAGE QuasiQuotes #-}

{- |
Module      : Language.Egison.Math.Expr
Licence     : MIT

This module implements the normalization of polynomials. Normalization rules
for particular mathematical functions (such as sqrt and sin/cos) are defined
in Rewrite.hs.
-}

module Language.Egison.Math.Normalize
  ( mathNormalize'
  , termsGcd
  , mathDivideTerm
  ) where

import           Control.Egison

import           Language.Egison.Math.Expr


mathNormalize' :: ScalarData -> ScalarData
mathNormalize' :: ScalarData -> ScalarData
mathNormalize' = ScalarData -> ScalarData
mathDivide (ScalarData -> ScalarData)
-> (ScalarData -> ScalarData) -> ScalarData -> ScalarData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScalarData -> ScalarData
mathRemoveZero (ScalarData -> ScalarData)
-> (ScalarData -> ScalarData) -> ScalarData -> ScalarData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScalarData -> ScalarData
mathFold (ScalarData -> ScalarData)
-> (ScalarData -> ScalarData) -> ScalarData -> ScalarData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScalarData -> ScalarData
mathRemoveZeroSymbol

termsGcd :: [TermExpr] -> TermExpr
termsGcd :: [TermExpr] -> TermExpr
termsGcd ts :: [TermExpr]
ts@(TermExpr
_:[TermExpr]
_) =
  (TermExpr -> TermExpr -> TermExpr) -> [TermExpr] -> TermExpr
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 (\(Term Integer
a Monomial
xs) (Term Integer
b Monomial
ys) -> Integer -> Monomial -> TermExpr
Term (Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
gcd Integer
a Integer
b) (Monomial -> Monomial -> Monomial
monoGcd Monomial
xs Monomial
ys)) [TermExpr]
ts
 where
  monoGcd :: Monomial -> Monomial -> Monomial
  monoGcd :: Monomial -> Monomial -> Monomial
monoGcd [] Monomial
_ = []
  monoGcd ((SymbolExpr
x, Integer
n):Monomial
xs) Monomial
ys =
    case (SymbolExpr, Integer) -> Monomial -> (SymbolExpr, Integer)
f (SymbolExpr
x, Integer
n) Monomial
ys of
      (SymbolExpr
_, Integer
0) -> Monomial -> Monomial -> Monomial
monoGcd Monomial
xs Monomial
ys
      (SymbolExpr
z, Integer
m) -> (SymbolExpr
z, Integer
m) (SymbolExpr, Integer) -> Monomial -> Monomial
forall a. a -> [a] -> [a]
: Monomial -> Monomial -> Monomial
monoGcd Monomial
xs Monomial
ys

  f :: (SymbolExpr, Integer) -> Monomial -> (SymbolExpr, Integer)
  f :: (SymbolExpr, Integer) -> Monomial -> (SymbolExpr, Integer)
f (SymbolExpr
x, Integer
_) [] = (SymbolExpr
x, Integer
0)
  f (Quote ScalarData
x, Integer
n) ((Quote ScalarData
y, Integer
m):Monomial
ys)
    | ScalarData
x ScalarData -> ScalarData -> Bool
forall a. Eq a => a -> a -> Bool
== ScalarData
y            = (ScalarData -> SymbolExpr
Quote ScalarData
x, Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
min Integer
n Integer
m)
    | ScalarData
x ScalarData -> ScalarData -> Bool
forall a. Eq a => a -> a -> Bool
== ScalarData -> ScalarData
mathNegate ScalarData
y = (ScalarData -> SymbolExpr
Quote ScalarData
x, Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
min Integer
n Integer
m)
    | Bool
otherwise         = (SymbolExpr, Integer) -> Monomial -> (SymbolExpr, Integer)
f (ScalarData -> SymbolExpr
Quote ScalarData
x, Integer
n) Monomial
ys
  f (SymbolExpr
x, Integer
n) ((SymbolExpr
y, Integer
m):Monomial
ys)
    | SymbolExpr
x SymbolExpr -> SymbolExpr -> Bool
forall a. Eq a => a -> a -> Bool
== SymbolExpr
y    = (SymbolExpr
x, Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
min Integer
n Integer
m)
    | Bool
otherwise = (SymbolExpr, Integer) -> Monomial -> (SymbolExpr, Integer)
f (SymbolExpr
x, Integer
n) Monomial
ys

mathDivide :: ScalarData -> ScalarData
mathDivide :: ScalarData -> ScalarData
mathDivide mExpr :: ScalarData
mExpr@(Div (Plus [TermExpr]
_) (Plus [])) = ScalarData
mExpr
mathDivide mExpr :: ScalarData
mExpr@(Div (Plus []) (Plus [TermExpr]
_)) = ScalarData
mExpr
mathDivide (Div (Plus [TermExpr]
ts1) (Plus [TermExpr]
ts2)) =
  let z :: TermExpr
z@(Term Integer
c Monomial
zs) = [TermExpr] -> TermExpr
termsGcd ([TermExpr]
ts1 [TermExpr] -> [TermExpr] -> [TermExpr]
forall a. [a] -> [a] -> [a]
++ [TermExpr]
ts2) in
  case [TermExpr]
ts2 of
    [Term Integer
a Monomial
_] | Integer
a Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0 -> PolyExpr -> PolyExpr -> ScalarData
Div ([TermExpr] -> PolyExpr
Plus ((TermExpr -> TermExpr) -> [TermExpr] -> [TermExpr]
forall a b. (a -> b) -> [a] -> [b]
map (TermExpr -> TermExpr -> TermExpr
`mathDivideTerm` Integer -> Monomial -> TermExpr
Term (-Integer
c) Monomial
zs) [TermExpr]
ts1))
                              ([TermExpr] -> PolyExpr
Plus ((TermExpr -> TermExpr) -> [TermExpr] -> [TermExpr]
forall a b. (a -> b) -> [a] -> [b]
map (TermExpr -> TermExpr -> TermExpr
`mathDivideTerm` Integer -> Monomial -> TermExpr
Term (-Integer
c) Monomial
zs) [TermExpr]
ts2))
    [TermExpr]
_                  -> PolyExpr -> PolyExpr -> ScalarData
Div ([TermExpr] -> PolyExpr
Plus ((TermExpr -> TermExpr) -> [TermExpr] -> [TermExpr]
forall a b. (a -> b) -> [a] -> [b]
map (TermExpr -> TermExpr -> TermExpr
`mathDivideTerm` TermExpr
z) [TermExpr]
ts1))
                              ([TermExpr] -> PolyExpr
Plus ((TermExpr -> TermExpr) -> [TermExpr] -> [TermExpr]
forall a b. (a -> b) -> [a] -> [b]
map (TermExpr -> TermExpr -> TermExpr
`mathDivideTerm` TermExpr
z) [TermExpr]
ts2))

mathDivideTerm :: TermExpr -> TermExpr -> TermExpr
mathDivideTerm :: TermExpr -> TermExpr -> TermExpr
mathDivideTerm (Term Integer
a Monomial
xs) (Term Integer
b Monomial
ys) =
  let (Integer
sgn, Monomial
zs) = Monomial -> Monomial -> (Integer, Monomial)
divMonomial Monomial
xs Monomial
ys in
  Integer -> Monomial -> TermExpr
Term (Integer
sgn Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
div Integer
a Integer
b) Monomial
zs
 where
  divMonomial :: Monomial -> Monomial -> (Integer, Monomial)
  divMonomial :: Monomial -> Monomial -> (Integer, Monomial)
divMonomial Monomial
xs [] = (Integer
1, Monomial
xs)
  divMonomial Monomial
xs ((SymbolExpr
y, Integer
m):Monomial
ys) =
    ((Pair SymbolM (Multiset (Pair SymbolM Eql)),
  (SymbolExpr, Monomial))
 -> DFS
      (Pair SymbolM (Multiset (Pair SymbolM Eql)),
       (SymbolExpr, Monomial)))
-> (SymbolExpr, Monomial)
-> Pair SymbolM (Multiset (Pair SymbolM Eql))
-> [(Pair SymbolM (Multiset (Pair SymbolM Eql)),
     (SymbolExpr, Monomial))
    -> DFS (Integer, Monomial)]
-> (Integer, Monomial)
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Pair SymbolM (Multiset (Pair SymbolM Eql)),
 (SymbolExpr, Monomial))
-> DFS
     (Pair SymbolM (Multiset (Pair SymbolM Eql)),
      (SymbolExpr, Monomial))
forall a. a -> DFS a
dfs (SymbolExpr
y, Monomial
xs) (SymbolM
-> Multiset (Pair SymbolM Eql)
-> Pair SymbolM (Multiset (Pair SymbolM Eql))
forall m1 m2. m1 -> m2 -> Pair m1 m2
Pair SymbolM
SymbolM (Pair SymbolM Eql -> Multiset (Pair SymbolM Eql)
forall m. m -> Multiset m
Multiset (SymbolM -> Eql -> Pair SymbolM Eql
forall m1 m2. m1 -> m2 -> Pair m1 m2
Pair SymbolM
SymbolM Eql
Eql)))
      -- Because we've applied |mathFold|, we can only divide the first matching monomial
      [ [mc| (quote $s, ($x & negQuote #s, $n) : $xss) ->
               let (sgn, xs') = divMonomial xss ys in
               let sgn' = if even m then 1 else -1 in
               if n == m then (sgn * sgn', xs')
                         else (sgn * sgn', (x, n - m) : xs') |]
      , [mc| (_, (#y, $n) : $xss) ->
               let (sgn, xs') = divMonomial xss ys in
               if n == m then (sgn, xs') else (sgn, (y, n - m) : xs') |]
      , [mc| _ -> divMonomial xs ys |]
      ]

mathRemoveZeroSymbol :: ScalarData -> ScalarData
mathRemoveZeroSymbol :: ScalarData -> ScalarData
mathRemoveZeroSymbol (Div (Plus [TermExpr]
ts1) (Plus [TermExpr]
ts2)) =
  let ts1' :: [TermExpr]
ts1' = (TermExpr -> TermExpr) -> [TermExpr] -> [TermExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Term Integer
a Monomial
xs) -> Integer -> Monomial -> TermExpr
Term Integer
a (((SymbolExpr, Integer) -> Bool) -> Monomial -> Monomial
forall a. (a -> Bool) -> [a] -> [a]
filter (SymbolExpr, Integer) -> Bool
forall a a. (Eq a, Num a) => (a, a) -> Bool
p Monomial
xs)) [TermExpr]
ts1
      ts2' :: [TermExpr]
ts2' = (TermExpr -> TermExpr) -> [TermExpr] -> [TermExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Term Integer
a Monomial
xs) -> Integer -> Monomial -> TermExpr
Term Integer
a (((SymbolExpr, Integer) -> Bool) -> Monomial -> Monomial
forall a. (a -> Bool) -> [a] -> [a]
filter (SymbolExpr, Integer) -> Bool
forall a a. (Eq a, Num a) => (a, a) -> Bool
p Monomial
xs)) [TermExpr]
ts2
   in PolyExpr -> PolyExpr -> ScalarData
Div ([TermExpr] -> PolyExpr
Plus [TermExpr]
ts1') ([TermExpr] -> PolyExpr
Plus [TermExpr]
ts2')
  where
    p :: (a, a) -> Bool
p (a
_, a
0) = Bool
False
    p (a, a)
_      = Bool
True

mathRemoveZero :: ScalarData -> ScalarData
mathRemoveZero :: ScalarData -> ScalarData
mathRemoveZero (Div (Plus [TermExpr]
ts1) (Plus [TermExpr]
ts2)) =
  let ts1' :: [TermExpr]
ts1' = (TermExpr -> Bool) -> [TermExpr] -> [TermExpr]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(Term Integer
a Monomial
_) -> Integer
a Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer
0) [TermExpr]
ts1 in
  let ts2' :: [TermExpr]
ts2' = (TermExpr -> Bool) -> [TermExpr] -> [TermExpr]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(Term Integer
a Monomial
_) -> Integer
a Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer
0) [TermExpr]
ts2 in
    case [TermExpr]
ts1' of
      [] -> PolyExpr -> PolyExpr -> ScalarData
Div ([TermExpr] -> PolyExpr
Plus []) ([TermExpr] -> PolyExpr
Plus [Integer -> Monomial -> TermExpr
Term Integer
1 []])
      [TermExpr]
_  -> PolyExpr -> PolyExpr -> ScalarData
Div ([TermExpr] -> PolyExpr
Plus [TermExpr]
ts1') ([TermExpr] -> PolyExpr
Plus [TermExpr]
ts2')

mathFold :: ScalarData -> ScalarData
mathFold :: ScalarData -> ScalarData
mathFold = ScalarData -> ScalarData
mathTermFold (ScalarData -> ScalarData)
-> (ScalarData -> ScalarData) -> ScalarData -> ScalarData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScalarData -> ScalarData
mathSymbolFold

-- x^2 y x -> x^3 y
mathSymbolFold :: ScalarData -> ScalarData
mathSymbolFold :: ScalarData -> ScalarData
mathSymbolFold (Div (Plus [TermExpr]
ts1) (Plus [TermExpr]
ts2)) = PolyExpr -> PolyExpr -> ScalarData
Div ([TermExpr] -> PolyExpr
Plus ((TermExpr -> TermExpr) -> [TermExpr] -> [TermExpr]
forall a b. (a -> b) -> [a] -> [b]
map TermExpr -> TermExpr
f [TermExpr]
ts1)) ([TermExpr] -> PolyExpr
Plus ((TermExpr -> TermExpr) -> [TermExpr] -> [TermExpr]
forall a b. (a -> b) -> [a] -> [b]
map TermExpr -> TermExpr
f [TermExpr]
ts2))
 where
  f :: TermExpr -> TermExpr
  f :: TermExpr -> TermExpr
f (Term Integer
a Monomial
xs) =
    let (Integer
sgn, Monomial
ys) = Monomial -> (Integer, Monomial)
g Monomial
xs in Integer -> Monomial -> TermExpr
Term (Integer
sgn Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
a) Monomial
ys
  g :: Monomial -> (Integer, Monomial)
  g :: Monomial -> (Integer, Monomial)
g [] = (Integer
1, [])
  g ((SymbolExpr
x, Integer
m):Monomial
xs) =
    ((Pair SymbolM (Multiset (Pair SymbolM Eql)),
  (SymbolExpr, Monomial))
 -> DFS
      (Pair SymbolM (Multiset (Pair SymbolM Eql)),
       (SymbolExpr, Monomial)))
-> (SymbolExpr, Monomial)
-> Pair SymbolM (Multiset (Pair SymbolM Eql))
-> [(Pair SymbolM (Multiset (Pair SymbolM Eql)),
     (SymbolExpr, Monomial))
    -> DFS (Integer, Monomial)]
-> (Integer, Monomial)
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Pair SymbolM (Multiset (Pair SymbolM Eql)),
 (SymbolExpr, Monomial))
-> DFS
     (Pair SymbolM (Multiset (Pair SymbolM Eql)),
      (SymbolExpr, Monomial))
forall a. a -> DFS a
dfs (SymbolExpr
x, Monomial
xs) (SymbolM
-> Multiset (Pair SymbolM Eql)
-> Pair SymbolM (Multiset (Pair SymbolM Eql))
forall m1 m2. m1 -> m2 -> Pair m1 m2
Pair SymbolM
SymbolM (Pair SymbolM Eql -> Multiset (Pair SymbolM Eql)
forall m. m -> Multiset m
Multiset (SymbolM -> Eql -> Pair SymbolM Eql
forall m1 m2. m1 -> m2 -> Pair m1 m2
Pair SymbolM
SymbolM Eql
Eql)))
      [ [mc| (quote $s, (negQuote #s, $n) : $xs) ->
               let (sgn, ys) = g ((x, m + n) : xs) in
               if even n then (sgn, ys) else (- sgn, ys) |]
      , [mc| (_, (#x, $n) : $xs) -> g ((x, m + n) : xs) |]
      , [mc| _ -> let (sgn', ys) = g xs in (sgn', (x, m):ys) |]
      ]

-- x^2 y + x^2 y -> 2 x^2 y
mathTermFold :: ScalarData -> ScalarData
mathTermFold :: ScalarData -> ScalarData
mathTermFold (Div (Plus [TermExpr]
ts1) (Plus [TermExpr]
ts2)) = PolyExpr -> PolyExpr -> ScalarData
Div ([TermExpr] -> PolyExpr
Plus ([TermExpr] -> [TermExpr]
f [TermExpr]
ts1)) ([TermExpr] -> PolyExpr
Plus ([TermExpr] -> [TermExpr]
f [TermExpr]
ts2))
 where
  f :: [TermExpr] -> [TermExpr]
  f :: [TermExpr] -> [TermExpr]
f [] = []
  f (TermExpr
t:[TermExpr]
ts) =
    ((Pair TermM (Multiset TermM), (TermExpr, [TermExpr]))
 -> DFS (Pair TermM (Multiset TermM), (TermExpr, [TermExpr])))
-> (TermExpr, [TermExpr])
-> Pair TermM (Multiset TermM)
-> [(Pair TermM (Multiset TermM), (TermExpr, [TermExpr]))
    -> DFS [TermExpr]]
-> [TermExpr]
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (Pair TermM (Multiset TermM), (TermExpr, [TermExpr]))
-> DFS (Pair TermM (Multiset TermM), (TermExpr, [TermExpr]))
forall a. a -> DFS a
dfs (TermExpr
t, [TermExpr]
ts) (TermM -> Multiset TermM -> Pair TermM (Multiset TermM)
forall m1 m2. m1 -> m2 -> Pair m1 m2
Pair TermM
TermM (TermM -> Multiset TermM
forall m. m -> Multiset m
Multiset TermM
TermM))
      [ [mc| (term $a $xs, term $b (equalMonomial $sgn #xs) : $tss) ->
               f (Term (sgn * a + b) xs : tss) |]
      , [mc| _ -> t : f ts |]
      ]