{-|
Copyright  :  (C) 2015-2016, University of Twente,
                  2017     , QBayLogic B.V.
License    :  BSD2 (see the file LICENSE)
Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>

= SOP: Sum-of-Products, sorta

The arithmetic operation for 'GHC.TypeLits.Nat' are, addition
(@'GHC.TypeLits.+'@), subtraction (@'GHC.TypeLits.-'@), multiplication
(@'GHC.TypeLits.*'@), and exponentiation (@'GHC.TypeLits.^'@). This means we
cannot write expressions in a canonical SOP normal form. We can get rid of
subtraction by working with integers, and translating @a - b@ to @a + (-1)*b@.
Exponentation cannot be getten rid of that way. So we define the following
grammar for our canonical SOP-like normal form of arithmetic expressions:

@
SOP      ::= Product \'+\' SOP | Product
Product  ::= Symbol \'*\' Product | Symbol
Symbol   ::= Integer
          |  Var
          |  Var \'^\' Product
          |  SOP \'^\' ProductE

ProductE ::= SymbolE \'*\' ProductE | SymbolE
SymbolE  ::= Var
          |  Var \'^\' Product
          |  SOP \'^\' ProductE
@

So a valid SOP terms are:

@
x*y + y^2
(x+y)^(k*z)
@

, but,

@
(x*y)^2
@

is not, and should be:

@
x^2 * y^2
@

Exponents are thus not allowed to have products, so for example, the expression:

@
(x + 2)^(y + 2)
@

in valid SOP form is:

@
4*x*(2 + x)^y + 4*(2 + x)^y + (2 + x)^y*x^2
@

Also, exponents can only be integer values when the base is a variable. Although
not enforced by the grammar, the exponentials are flatted as far as possible in
SOP form. So:

@
(x^y)^z
@

is flattened to:

@
x^(y*z)
@
-}

{-# LANGUAGE CPP #-}

module GHC.TypeLits.Normalise.SOP
  ( -- * SOP types
    Symbol (..)
  , Product (..)
  , SOP (..)
    -- * Simplification
  , reduceExp
  , mergeS
  , mergeP
  , mergeSOPAdd
  , mergeSOPMul
  , normaliseExp
  , simplifySOP
  )
where

-- External
import Data.Either (partitionEithers)
import Data.List   (sort)

-- GHC API
#if MIN_VERSION_ghc(9,0,0)
import GHC.Utils.Outputable (Outputable (..), (<+>), text, hcat, integer, punctuate)
#else
import Outputable (Outputable (..), (<+>), text, hcat, integer, punctuate)
#endif

data Symbol v c
  = I Integer                 -- ^ Integer constant
  | C c                       -- ^ Non-integer constant
  | E (SOP v c) (Product v c) -- ^ Exponentiation
  | V v                       -- ^ Variable
  deriving (Symbol v c -> Symbol v c -> Bool
(Symbol v c -> Symbol v c -> Bool)
-> (Symbol v c -> Symbol v c -> Bool) -> Eq (Symbol v c)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall v c. (Eq c, Eq v) => Symbol v c -> Symbol v c -> Bool
/= :: Symbol v c -> Symbol v c -> Bool
$c/= :: forall v c. (Eq c, Eq v) => Symbol v c -> Symbol v c -> Bool
== :: Symbol v c -> Symbol v c -> Bool
$c== :: forall v c. (Eq c, Eq v) => Symbol v c -> Symbol v c -> Bool
Eq,Eq (Symbol v c)
Eq (Symbol v c)
-> (Symbol v c -> Symbol v c -> Ordering)
-> (Symbol v c -> Symbol v c -> Bool)
-> (Symbol v c -> Symbol v c -> Bool)
-> (Symbol v c -> Symbol v c -> Bool)
-> (Symbol v c -> Symbol v c -> Bool)
-> (Symbol v c -> Symbol v c -> Symbol v c)
-> (Symbol v c -> Symbol v c -> Symbol v c)
-> Ord (Symbol v c)
Symbol v c -> Symbol v c -> Bool
Symbol v c -> Symbol v c -> Ordering
Symbol v c -> Symbol v c -> Symbol v c
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall v c. (Ord c, Ord v) => Eq (Symbol v c)
forall v c. (Ord c, Ord v) => Symbol v c -> Symbol v c -> Bool
forall v c. (Ord c, Ord v) => Symbol v c -> Symbol v c -> Ordering
forall v c.
(Ord c, Ord v) =>
Symbol v c -> Symbol v c -> Symbol v c
min :: Symbol v c -> Symbol v c -> Symbol v c
$cmin :: forall v c.
(Ord c, Ord v) =>
Symbol v c -> Symbol v c -> Symbol v c
max :: Symbol v c -> Symbol v c -> Symbol v c
$cmax :: forall v c.
(Ord c, Ord v) =>
Symbol v c -> Symbol v c -> Symbol v c
>= :: Symbol v c -> Symbol v c -> Bool
$c>= :: forall v c. (Ord c, Ord v) => Symbol v c -> Symbol v c -> Bool
> :: Symbol v c -> Symbol v c -> Bool
$c> :: forall v c. (Ord c, Ord v) => Symbol v c -> Symbol v c -> Bool
<= :: Symbol v c -> Symbol v c -> Bool
$c<= :: forall v c. (Ord c, Ord v) => Symbol v c -> Symbol v c -> Bool
< :: Symbol v c -> Symbol v c -> Bool
$c< :: forall v c. (Ord c, Ord v) => Symbol v c -> Symbol v c -> Bool
compare :: Symbol v c -> Symbol v c -> Ordering
$ccompare :: forall v c. (Ord c, Ord v) => Symbol v c -> Symbol v c -> Ordering
$cp1Ord :: forall v c. (Ord c, Ord v) => Eq (Symbol v c)
Ord)

newtype Product v c = P { Product v c -> [Symbol v c]
unP :: [Symbol v c] }
  deriving (Product v c -> Product v c -> Bool
(Product v c -> Product v c -> Bool)
-> (Product v c -> Product v c -> Bool) -> Eq (Product v c)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall v c. (Eq c, Eq v) => Product v c -> Product v c -> Bool
/= :: Product v c -> Product v c -> Bool
$c/= :: forall v c. (Eq c, Eq v) => Product v c -> Product v c -> Bool
== :: Product v c -> Product v c -> Bool
$c== :: forall v c. (Eq c, Eq v) => Product v c -> Product v c -> Bool
Eq)

instance (Ord v, Ord c) => Ord (Product v c) where
  compare :: Product v c -> Product v c -> Ordering
compare (P [Symbol v c
x])   (P [Symbol v c
y])   = Symbol v c -> Symbol v c -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Symbol v c
x Symbol v c
y
  compare (P [Symbol v c
_])   (P (Symbol v c
_:[Symbol v c]
_)) = Ordering
LT
  compare (P (Symbol v c
_:[Symbol v c]
_)) (P [Symbol v c
_])   = Ordering
GT
  compare (P [Symbol v c]
xs)    (P [Symbol v c]
ys)    = [Symbol v c] -> [Symbol v c] -> Ordering
forall a. Ord a => a -> a -> Ordering
compare [Symbol v c]
xs [Symbol v c]
ys

newtype SOP v c = S { SOP v c -> [Product v c]
unS :: [Product v c] }
  deriving (Eq (SOP v c)
Eq (SOP v c)
-> (SOP v c -> SOP v c -> Ordering)
-> (SOP v c -> SOP v c -> Bool)
-> (SOP v c -> SOP v c -> Bool)
-> (SOP v c -> SOP v c -> Bool)
-> (SOP v c -> SOP v c -> Bool)
-> (SOP v c -> SOP v c -> SOP v c)
-> (SOP v c -> SOP v c -> SOP v c)
-> Ord (SOP v c)
SOP v c -> SOP v c -> Bool
SOP v c -> SOP v c -> Ordering
SOP v c -> SOP v c -> SOP v c
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall v c. (Ord v, Ord c) => Eq (SOP v c)
forall v c. (Ord v, Ord c) => SOP v c -> SOP v c -> Bool
forall v c. (Ord v, Ord c) => SOP v c -> SOP v c -> Ordering
forall v c. (Ord v, Ord c) => SOP v c -> SOP v c -> SOP v c
min :: SOP v c -> SOP v c -> SOP v c
$cmin :: forall v c. (Ord v, Ord c) => SOP v c -> SOP v c -> SOP v c
max :: SOP v c -> SOP v c -> SOP v c
$cmax :: forall v c. (Ord v, Ord c) => SOP v c -> SOP v c -> SOP v c
>= :: SOP v c -> SOP v c -> Bool
$c>= :: forall v c. (Ord v, Ord c) => SOP v c -> SOP v c -> Bool
> :: SOP v c -> SOP v c -> Bool
$c> :: forall v c. (Ord v, Ord c) => SOP v c -> SOP v c -> Bool
<= :: SOP v c -> SOP v c -> Bool
$c<= :: forall v c. (Ord v, Ord c) => SOP v c -> SOP v c -> Bool
< :: SOP v c -> SOP v c -> Bool
$c< :: forall v c. (Ord v, Ord c) => SOP v c -> SOP v c -> Bool
compare :: SOP v c -> SOP v c -> Ordering
$ccompare :: forall v c. (Ord v, Ord c) => SOP v c -> SOP v c -> Ordering
$cp1Ord :: forall v c. (Ord v, Ord c) => Eq (SOP v c)
Ord)

instance (Eq v, Eq c) => Eq (SOP v c) where
  (S []) == :: SOP v c -> SOP v c -> Bool
== (S [P [I Integer
0]]) = Bool
True
  (S [P [I Integer
0]]) == (S []) = Bool
True
  (S [Product v c]
ps1) == (S [Product v c]
ps2)      = [Product v c]
ps1 [Product v c] -> [Product v c] -> Bool
forall a. Eq a => a -> a -> Bool
== [Product v c]
ps2

instance (Outputable v, Outputable c) => Outputable (SOP v c) where
  ppr :: SOP v c -> SDoc
ppr = [SDoc] -> SDoc
hcat ([SDoc] -> SDoc) -> (SOP v c -> [SDoc]) -> SOP v c -> SDoc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SDoc -> [SDoc] -> [SDoc]
punctuate (String -> SDoc
text String
" + ") ([SDoc] -> [SDoc]) -> (SOP v c -> [SDoc]) -> SOP v c -> [SDoc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Product v c -> SDoc) -> [Product v c] -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map Product v c -> SDoc
forall a. Outputable a => a -> SDoc
ppr ([Product v c] -> [SDoc])
-> (SOP v c -> [Product v c]) -> SOP v c -> [SDoc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOP v c -> [Product v c]
forall v c. SOP v c -> [Product v c]
unS

instance (Outputable v, Outputable c) => Outputable (Product v c) where
  ppr :: Product v c -> SDoc
ppr = [SDoc] -> SDoc
hcat ([SDoc] -> SDoc) -> (Product v c -> [SDoc]) -> Product v c -> SDoc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SDoc -> [SDoc] -> [SDoc]
punctuate (String -> SDoc
text String
" * ") ([SDoc] -> [SDoc])
-> (Product v c -> [SDoc]) -> Product v c -> [SDoc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Symbol v c -> SDoc) -> [Symbol v c] -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map Symbol v c -> SDoc
forall a. Outputable a => a -> SDoc
ppr ([Symbol v c] -> [SDoc])
-> (Product v c -> [Symbol v c]) -> Product v c -> [SDoc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Product v c -> [Symbol v c]
forall v c. Product v c -> [Symbol v c]
unP

instance (Outputable v, Outputable c) => Outputable (Symbol v c) where
  ppr :: Symbol v c -> SDoc
ppr (I Integer
i)   = Integer -> SDoc
integer Integer
i
  ppr (C c
c)   = c -> SDoc
forall a. Outputable a => a -> SDoc
ppr c
c
  ppr (V v
s)   = v -> SDoc
forall a. Outputable a => a -> SDoc
ppr v
s
  ppr (E SOP v c
b Product v c
e) = case (SOP v c -> SDoc
forall v c. (Outputable v, Outputable c) => SOP v c -> SDoc
pprSimple SOP v c
b, SOP v c -> SDoc
forall v c. (Outputable v, Outputable c) => SOP v c -> SDoc
pprSimple ([Product v c] -> SOP v c
forall v c. [Product v c] -> SOP v c
S [Product v c
e])) of
                  (SDoc
bS,SDoc
eS) -> SDoc
bS SDoc -> SDoc -> SDoc
<+> String -> SDoc
text String
"^" SDoc -> SDoc -> SDoc
<+> SDoc
eS
    where
      pprSimple :: SOP a c -> SDoc
pprSimple (S [P [I Integer
i]]) = Integer -> SDoc
integer Integer
i
      pprSimple (S [P [V a
v]]) = a -> SDoc
forall a. Outputable a => a -> SDoc
ppr a
v
      pprSimple SOP a c
sop           = String -> SDoc
text String
"(" SDoc -> SDoc -> SDoc
<+> SOP a c -> SDoc
forall a. Outputable a => a -> SDoc
ppr SOP a c
sop SDoc -> SDoc -> SDoc
<+> String -> SDoc
text String
")"

mergeWith :: (a -> a -> Either a a) -> [a] -> [a]
mergeWith :: (a -> a -> Either a a) -> [a] -> [a]
mergeWith a -> a -> Either a a
_ []      = []
mergeWith a -> a -> Either a a
op (a
f:[a]
fs) = case [Either a a] -> ([a], [a])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either a a] -> ([a], [a])) -> [Either a a] -> ([a], [a])
forall a b. (a -> b) -> a -> b
$ (a -> Either a a) -> [a] -> [Either a a]
forall a b. (a -> b) -> [a] -> [b]
map (a -> a -> Either a a
`op` a
f) [a]
fs of
                        ([],[a]
_)              -> a
f a -> [a] -> [a]
forall a. a -> [a] -> [a]
: (a -> a -> Either a a) -> [a] -> [a]
forall a. (a -> a -> Either a a) -> [a] -> [a]
mergeWith a -> a -> Either a a
op [a]
fs
                        ([a]
updated,[a]
untouched) -> (a -> a -> Either a a) -> [a] -> [a]
forall a. (a -> a -> Either a a) -> [a] -> [a]
mergeWith a -> a -> Either a a
op ([a]
updated [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
untouched)

-- | reduce exponentials
--
-- Performs the following rewrites:
--
-- @
-- x^0          ==>  1
-- 0^x          ==>  0
-- 2^3          ==>  8
-- (k ^ i) ^ j  ==>  k ^ (i * j)
-- @
reduceExp :: (Ord v, Ord c) => Symbol v c -> Symbol v c
reduceExp :: Symbol v c -> Symbol v c
reduceExp (E SOP v c
_                 (P [(I Integer
0)])) = Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I Integer
1        -- x^0 ==> 1
reduceExp (E (S [P [I Integer
0]])     Product v c
_          ) = Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I Integer
0        -- 0^x ==> 0
reduceExp (E (S [P [(I Integer
i)]])   (P [(I Integer
j)]))
  | Integer
j Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
0                                  = Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I (Integer
i Integer -> Integer -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
j)  -- 2^3 ==> 8

-- (k ^ i) ^ j ==> k ^ (i * j)
reduceExp (E (S [P [(E SOP v c
k Product v c
i)]]) Product v c
j) = case SOP v c -> SOP v c -> SOP v c
forall v c. (Ord v, Ord c) => SOP v c -> SOP v c -> SOP v c
normaliseExp SOP v c
k ([Product v c] -> SOP v c
forall v c. [Product v c] -> SOP v c
S [Product v c
e]) of
    (S [P [Symbol v c
s]]) -> Symbol v c
s
    SOP v c
_           -> SOP v c -> Product v c -> Symbol v c
forall v c. SOP v c -> Product v c -> Symbol v c
E SOP v c
k Product v c
e
  where
    e :: Product v c
e = [Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P ([Symbol v c] -> Product v c)
-> ([Symbol v c] -> [Symbol v c]) -> [Symbol v c] -> Product v c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Symbol v c] -> [Symbol v c]
forall a. Ord a => [a] -> [a]
sort ([Symbol v c] -> [Symbol v c])
-> ([Symbol v c] -> [Symbol v c]) -> [Symbol v c] -> [Symbol v c]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Symbol v c -> Symbol v c) -> [Symbol v c] -> [Symbol v c]
forall a b. (a -> b) -> [a] -> [b]
map Symbol v c -> Symbol v c
forall v c. (Ord v, Ord c) => Symbol v c -> Symbol v c
reduceExp ([Symbol v c] -> Product v c) -> [Symbol v c] -> Product v c
forall a b. (a -> b) -> a -> b
$ (Symbol v c -> Symbol v c -> Either (Symbol v c) (Symbol v c))
-> [Symbol v c] -> [Symbol v c]
forall a. (a -> a -> Either a a) -> [a] -> [a]
mergeWith Symbol v c -> Symbol v c -> Either (Symbol v c) (Symbol v c)
forall v c.
(Ord v, Ord c) =>
Symbol v c -> Symbol v c -> Either (Symbol v c) (Symbol v c)
mergeS (Product v c -> [Symbol v c]
forall v c. Product v c -> [Symbol v c]
unP Product v c
i [Symbol v c] -> [Symbol v c] -> [Symbol v c]
forall a. [a] -> [a] -> [a]
++ Product v c -> [Symbol v c]
forall v c. Product v c -> [Symbol v c]
unP Product v c
j)

reduceExp Symbol v c
s = Symbol v c
s

-- | Merge two symbols of a Product term
--
-- Performs the following rewrites:
--
-- @
-- 8 * 7    ==>  56
-- 1 * x    ==>  x
-- x * 1    ==>  x
-- 0 * x    ==>  0
-- x * 0    ==>  0
-- x * x^4  ==>  x^5
-- x^4 * x  ==>  x^5
-- y*y      ==>  y^2
-- @
mergeS :: (Ord v, Ord c) => Symbol v c -> Symbol v c
       -> Either (Symbol v c) (Symbol v c)
mergeS :: Symbol v c -> Symbol v c -> Either (Symbol v c) (Symbol v c)
mergeS (I Integer
i) (I Integer
j) = Symbol v c -> Either (Symbol v c) (Symbol v c)
forall a b. a -> Either a b
Left (Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
j)) -- 8 * 7 ==> 56
mergeS (I Integer
1) Symbol v c
r     = Symbol v c -> Either (Symbol v c) (Symbol v c)
forall a b. a -> Either a b
Left Symbol v c
r           -- 1 * x ==> x
mergeS Symbol v c
l     (I Integer
1) = Symbol v c -> Either (Symbol v c) (Symbol v c)
forall a b. a -> Either a b
Left Symbol v c
l           -- x * 1 ==> x
mergeS (I Integer
0) Symbol v c
_     = Symbol v c -> Either (Symbol v c) (Symbol v c)
forall a b. a -> Either a b
Left (Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I Integer
0)       -- 0 * x ==> 0
mergeS Symbol v c
_     (I Integer
0) = Symbol v c -> Either (Symbol v c) (Symbol v c)
forall a b. a -> Either a b
Left (Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I Integer
0)       -- x * 0 ==> 0

-- x * x^4 ==> x^5
mergeS Symbol v c
s (E (S [P [Symbol v c
s']]) (P [I Integer
i]))
  | Symbol v c
s Symbol v c -> Symbol v c -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol v c
s'
  = Symbol v c -> Either (Symbol v c) (Symbol v c)
forall a b. a -> Either a b
Left (SOP v c -> Product v c -> Symbol v c
forall v c. SOP v c -> Product v c -> Symbol v c
E ([Product v c] -> SOP v c
forall v c. [Product v c] -> SOP v c
S [[Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P [Symbol v c
s']]) ([Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P [Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1)]))

-- x^4 * x ==> x^5
mergeS (E (S [P [Symbol v c
s']]) (P [I Integer
i])) Symbol v c
s
  | Symbol v c
s Symbol v c -> Symbol v c -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol v c
s'
  = Symbol v c -> Either (Symbol v c) (Symbol v c)
forall a b. a -> Either a b
Left (SOP v c -> Product v c -> Symbol v c
forall v c. SOP v c -> Product v c -> Symbol v c
E ([Product v c] -> SOP v c
forall v c. [Product v c] -> SOP v c
S [[Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P [Symbol v c
s']]) ([Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P [Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1)]))

-- 4^x * 2^x ==> 8^x
mergeS (E (S [P [I Integer
i]]) Product v c
p) (E (S [P [I Integer
j]]) Product v c
p')
  | Product v c
p Product v c -> Product v c -> Bool
forall a. Eq a => a -> a -> Bool
== Product v c
p'
  = Symbol v c -> Either (Symbol v c) (Symbol v c)
forall a b. a -> Either a b
Left (SOP v c -> Product v c -> Symbol v c
forall v c. SOP v c -> Product v c -> Symbol v c
E ([Product v c] -> SOP v c
forall v c. [Product v c] -> SOP v c
S [[Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P [Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I (Integer
iInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
j)]]) Product v c
p)

-- y*y ==> y^2
mergeS Symbol v c
l Symbol v c
r
  | Symbol v c
l Symbol v c -> Symbol v c -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol v c
r
  = case SOP v c -> SOP v c -> SOP v c
forall v c. (Ord v, Ord c) => SOP v c -> SOP v c -> SOP v c
normaliseExp ([Product v c] -> SOP v c
forall v c. [Product v c] -> SOP v c
S [[Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P [Symbol v c
l]]) ([Product v c] -> SOP v c
forall v c. [Product v c] -> SOP v c
S [[Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P [Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I Integer
2]]) of
      (S [P [Symbol v c
e]]) -> Symbol v c -> Either (Symbol v c) (Symbol v c)
forall a b. a -> Either a b
Left  Symbol v c
e
      SOP v c
_           -> Symbol v c -> Either (Symbol v c) (Symbol v c)
forall a b. b -> Either a b
Right Symbol v c
l

-- x^y * x^(-y) ==> 1
mergeS (E SOP v c
s1 (P [Symbol v c]
p1)) (E SOP v c
s2 (P (I Integer
i:[Symbol v c]
p2)))
  | Integer
i Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== (-Integer
1)
  , SOP v c
s1 SOP v c -> SOP v c -> Bool
forall a. Eq a => a -> a -> Bool
== SOP v c
s2
  , [Symbol v c]
p1 [Symbol v c] -> [Symbol v c] -> Bool
forall a. Eq a => a -> a -> Bool
== [Symbol v c]
p2
  = Symbol v c -> Either (Symbol v c) (Symbol v c)
forall a b. a -> Either a b
Left (Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I Integer
1)

-- x^(-y) * x^y ==> 1
mergeS (E SOP v c
s1 (P (I Integer
i:[Symbol v c]
p1))) (E SOP v c
s2 (P [Symbol v c]
p2))
  | Integer
i Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== (-Integer
1)
  , SOP v c
s1 SOP v c -> SOP v c -> Bool
forall a. Eq a => a -> a -> Bool
== SOP v c
s2
  , [Symbol v c]
p1 [Symbol v c] -> [Symbol v c] -> Bool
forall a. Eq a => a -> a -> Bool
== [Symbol v c]
p2
  = Symbol v c -> Either (Symbol v c) (Symbol v c)
forall a b. a -> Either a b
Left (Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I Integer
1)

mergeS Symbol v c
l Symbol v c
_ = Symbol v c -> Either (Symbol v c) (Symbol v c)
forall a b. b -> Either a b
Right Symbol v c
l

-- | Merge two products of a SOP term
--
-- Performs the following rewrites:
--
-- @
-- 2xy + 3xy  ==>  5xy
-- 2xy + xy   ==>  3xy
-- xy + 2xy   ==>  3xy
-- xy + xy    ==>  2xy
-- @
mergeP :: (Eq v, Eq c) => Product v c -> Product v c
       -> Either (Product v c) (Product v c)
-- 2xy + 3xy ==> 5xy
mergeP :: Product v c -> Product v c -> Either (Product v c) (Product v c)
mergeP (P ((I Integer
i):[Symbol v c]
is)) (P ((I Integer
j):[Symbol v c]
js))
  | [Symbol v c]
is [Symbol v c] -> [Symbol v c] -> Bool
forall a. Eq a => a -> a -> Bool
== [Symbol v c]
js = Product v c -> Either (Product v c) (Product v c)
forall a b. a -> Either a b
Left (Product v c -> Either (Product v c) (Product v c))
-> ([Symbol v c] -> Product v c)
-> [Symbol v c]
-> Either (Product v c) (Product v c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P ([Symbol v c] -> Either (Product v c) (Product v c))
-> [Symbol v c] -> Either (Product v c) (Product v c)
forall a b. (a -> b) -> a -> b
$ (Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
j)) Symbol v c -> [Symbol v c] -> [Symbol v c]
forall a. a -> [a] -> [a]
: [Symbol v c]
is
-- 2xy + xy  ==> 3xy
mergeP (P ((I Integer
i):[Symbol v c]
is)) (P [Symbol v c]
js)
  | [Symbol v c]
is [Symbol v c] -> [Symbol v c] -> Bool
forall a. Eq a => a -> a -> Bool
== [Symbol v c]
js = Product v c -> Either (Product v c) (Product v c)
forall a b. a -> Either a b
Left (Product v c -> Either (Product v c) (Product v c))
-> ([Symbol v c] -> Product v c)
-> [Symbol v c]
-> Either (Product v c) (Product v c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P ([Symbol v c] -> Either (Product v c) (Product v c))
-> [Symbol v c] -> Either (Product v c) (Product v c)
forall a b. (a -> b) -> a -> b
$ (Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1)) Symbol v c -> [Symbol v c] -> [Symbol v c]
forall a. a -> [a] -> [a]
: [Symbol v c]
is
-- xy + 2xy  ==> 3xy
mergeP (P [Symbol v c]
is) (P ((I Integer
j):[Symbol v c]
js))
  | [Symbol v c]
is [Symbol v c] -> [Symbol v c] -> Bool
forall a. Eq a => a -> a -> Bool
== [Symbol v c]
js = Product v c -> Either (Product v c) (Product v c)
forall a b. a -> Either a b
Left (Product v c -> Either (Product v c) (Product v c))
-> ([Symbol v c] -> Product v c)
-> [Symbol v c]
-> Either (Product v c) (Product v c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P ([Symbol v c] -> Either (Product v c) (Product v c))
-> [Symbol v c] -> Either (Product v c) (Product v c)
forall a b. (a -> b) -> a -> b
$ (Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I (Integer
j Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1)) Symbol v c -> [Symbol v c] -> [Symbol v c]
forall a. a -> [a] -> [a]
: [Symbol v c]
is
-- xy + xy ==> 2xy
mergeP (P [Symbol v c]
is) (P [Symbol v c]
js)
  | [Symbol v c]
is [Symbol v c] -> [Symbol v c] -> Bool
forall a. Eq a => a -> a -> Bool
== [Symbol v c]
js  = Product v c -> Either (Product v c) (Product v c)
forall a b. a -> Either a b
Left (Product v c -> Either (Product v c) (Product v c))
-> ([Symbol v c] -> Product v c)
-> [Symbol v c]
-> Either (Product v c) (Product v c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P ([Symbol v c] -> Either (Product v c) (Product v c))
-> [Symbol v c] -> Either (Product v c) (Product v c)
forall a b. (a -> b) -> a -> b
$ (Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I Integer
2) Symbol v c -> [Symbol v c] -> [Symbol v c]
forall a. a -> [a] -> [a]
: [Symbol v c]
is
  | Bool
otherwise = Product v c -> Either (Product v c) (Product v c)
forall a b. b -> Either a b
Right (Product v c -> Either (Product v c) (Product v c))
-> Product v c -> Either (Product v c) (Product v c)
forall a b. (a -> b) -> a -> b
$ [Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P [Symbol v c]
is

-- | Expand or Simplify 'complex' exponentials
--
-- Performs the following rewrites:
--
-- @
-- b^1              ==>  b
-- 2^(y^2)          ==>  4^y
-- (x + 2)^2        ==>  x^2 + 4xy + 4
-- (x + 2)^(2x)     ==>  (x^2 + 4xy + 4)^x
-- (x + 2)^(y + 2)  ==>  4x(2 + x)^y + 4(2 + x)^y + (2 + x)^yx^2
-- @
normaliseExp :: (Ord v, Ord c) => SOP v c -> SOP v c -> SOP v c
-- b^1 ==> b
normaliseExp :: SOP v c -> SOP v c -> SOP v c
normaliseExp SOP v c
b (S [P [I Integer
1]]) = SOP v c
b

-- x^(2xy) ==> x^(2xy)
normaliseExp b :: SOP v c
b@(S [P [V v
_]]) (S [Product v c
e]) = [Product v c] -> SOP v c
forall v c. [Product v c] -> SOP v c
S [[Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P [SOP v c -> Product v c -> Symbol v c
forall v c. SOP v c -> Product v c -> Symbol v c
E SOP v c
b Product v c
e]]

-- 2^(y^2) ==> 4^y
normaliseExp b :: SOP v c
b@(S [P [Symbol v c
_]]) (S [e :: Product v c
e@(P [Symbol v c
_])]) = [Product v c] -> SOP v c
forall v c. [Product v c] -> SOP v c
S [[Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P [Symbol v c -> Symbol v c
forall v c. (Ord v, Ord c) => Symbol v c -> Symbol v c
reduceExp (SOP v c -> Product v c -> Symbol v c
forall v c. SOP v c -> Product v c -> Symbol v c
E SOP v c
b Product v c
e)]]

-- (x + 2)^2 ==> x^2 + 4xy + 4
normaliseExp SOP v c
b (S [P [(I Integer
i)]]) | Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
0 =
  (SOP v c -> SOP v c -> SOP v c) -> [SOP v c] -> SOP v c
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 SOP v c -> SOP v c -> SOP v c
forall v c. (Ord v, Ord c) => SOP v c -> SOP v c -> SOP v c
mergeSOPMul (Int -> SOP v c -> [SOP v c]
forall a. Int -> a -> [a]
replicate (Integer -> Int
forall a. Num a => Integer -> a
fromInteger Integer
i) SOP v c
b)

-- (x + 2)^(2x) ==> (x^2 + 4xy + 4)^x
normaliseExp SOP v c
b (S [P (e :: Symbol v c
e@(I Integer
i):[Symbol v c]
es)]) | Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
0 =
  -- Without the "| i >= 0" guard, normaliseExp can loop with itself
  -- for exponentials such as: 2^(n-k)
  SOP v c -> SOP v c -> SOP v c
forall v c. (Ord v, Ord c) => SOP v c -> SOP v c -> SOP v c
normaliseExp (SOP v c -> SOP v c -> SOP v c
forall v c. (Ord v, Ord c) => SOP v c -> SOP v c -> SOP v c
normaliseExp SOP v c
b ([Product v c] -> SOP v c
forall v c. [Product v c] -> SOP v c
S [[Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P [Symbol v c
e]])) ([Product v c] -> SOP v c
forall v c. [Product v c] -> SOP v c
S [[Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P [Symbol v c]
es])

-- (x + 2)^(xy) ==> (x+2)^(xy)
normaliseExp SOP v c
b (S [Product v c
e]) = [Product v c] -> SOP v c
forall v c. [Product v c] -> SOP v c
S [[Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P [Symbol v c -> Symbol v c
forall v c. (Ord v, Ord c) => Symbol v c -> Symbol v c
reduceExp (SOP v c -> Product v c -> Symbol v c
forall v c. SOP v c -> Product v c -> Symbol v c
E SOP v c
b Product v c
e)]]

-- (x + 2)^(y + 2) ==> 4x(2 + x)^y + 4(2 + x)^y + (2 + x)^yx^2
normaliseExp SOP v c
b (S [Product v c]
e) = (SOP v c -> SOP v c -> SOP v c) -> [SOP v c] -> SOP v c
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 SOP v c -> SOP v c -> SOP v c
forall v c. (Ord v, Ord c) => SOP v c -> SOP v c -> SOP v c
mergeSOPMul ((Product v c -> SOP v c) -> [Product v c] -> [SOP v c]
forall a b. (a -> b) -> [a] -> [b]
map (SOP v c -> SOP v c -> SOP v c
forall v c. (Ord v, Ord c) => SOP v c -> SOP v c -> SOP v c
normaliseExp SOP v c
b (SOP v c -> SOP v c)
-> (Product v c -> SOP v c) -> Product v c -> SOP v c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Product v c] -> SOP v c
forall v c. [Product v c] -> SOP v c
S ([Product v c] -> SOP v c)
-> (Product v c -> [Product v c]) -> Product v c -> SOP v c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Product v c -> [Product v c] -> [Product v c]
forall a. a -> [a] -> [a]
:[])) [Product v c]
e)

zeroP :: Product v c -> Bool
zeroP :: Product v c -> Bool
zeroP (P ((I Integer
0):[Symbol v c]
_)) = Bool
True
zeroP Product v c
_             = Bool
False

mkNonEmpty :: SOP v c -> SOP v c
mkNonEmpty :: SOP v c -> SOP v c
mkNonEmpty (S []) = [Product v c] -> SOP v c
forall v c. [Product v c] -> SOP v c
S [[Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P [(Integer -> Symbol v c
forall v c. Integer -> Symbol v c
I Integer
0)]]
mkNonEmpty SOP v c
s      = SOP v c
s

-- | Simplifies SOP terms using
--
-- * 'mergeS'
-- * 'mergeP'
-- * 'reduceExp'
simplifySOP :: (Ord v, Ord c) => SOP v c -> SOP v c
simplifySOP :: SOP v c -> SOP v c
simplifySOP = (SOP v c -> SOP v c) -> SOP v c -> SOP v c
forall t. Eq t => (t -> t) -> t -> t
repeatF SOP v c -> SOP v c
go
  where
    go :: SOP v c -> SOP v c
go = SOP v c -> SOP v c
forall v c. SOP v c -> SOP v c
mkNonEmpty
       (SOP v c -> SOP v c) -> (SOP v c -> SOP v c) -> SOP v c -> SOP v c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Product v c] -> SOP v c
forall v c. [Product v c] -> SOP v c
S
       ([Product v c] -> SOP v c)
-> (SOP v c -> [Product v c]) -> SOP v c -> SOP v c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Product v c] -> [Product v c]
forall a. Ord a => [a] -> [a]
sort ([Product v c] -> [Product v c])
-> (SOP v c -> [Product v c]) -> SOP v c -> [Product v c]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Product v c -> Bool) -> [Product v c] -> [Product v c]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Product v c -> Bool) -> Product v c -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Product v c -> Bool
forall v c. Product v c -> Bool
zeroP)
       ([Product v c] -> [Product v c])
-> (SOP v c -> [Product v c]) -> SOP v c -> [Product v c]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Product v c -> Product v c -> Either (Product v c) (Product v c))
-> [Product v c] -> [Product v c]
forall a. (a -> a -> Either a a) -> [a] -> [a]
mergeWith Product v c -> Product v c -> Either (Product v c) (Product v c)
forall v c.
(Eq v, Eq c) =>
Product v c -> Product v c -> Either (Product v c) (Product v c)
mergeP
       ([Product v c] -> [Product v c])
-> (SOP v c -> [Product v c]) -> SOP v c -> [Product v c]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Product v c -> Product v c) -> [Product v c] -> [Product v c]
forall a b. (a -> b) -> [a] -> [b]
map ([Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P ([Symbol v c] -> Product v c)
-> (Product v c -> [Symbol v c]) -> Product v c -> Product v c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Symbol v c] -> [Symbol v c]
forall a. Ord a => [a] -> [a]
sort ([Symbol v c] -> [Symbol v c])
-> (Product v c -> [Symbol v c]) -> Product v c -> [Symbol v c]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Symbol v c -> Symbol v c) -> [Symbol v c] -> [Symbol v c]
forall a b. (a -> b) -> [a] -> [b]
map Symbol v c -> Symbol v c
forall v c. (Ord v, Ord c) => Symbol v c -> Symbol v c
reduceExp ([Symbol v c] -> [Symbol v c])
-> (Product v c -> [Symbol v c]) -> Product v c -> [Symbol v c]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Symbol v c -> Symbol v c -> Either (Symbol v c) (Symbol v c))
-> [Symbol v c] -> [Symbol v c]
forall a. (a -> a -> Either a a) -> [a] -> [a]
mergeWith Symbol v c -> Symbol v c -> Either (Symbol v c) (Symbol v c)
forall v c.
(Ord v, Ord c) =>
Symbol v c -> Symbol v c -> Either (Symbol v c) (Symbol v c)
mergeS ([Symbol v c] -> [Symbol v c])
-> (Product v c -> [Symbol v c]) -> Product v c -> [Symbol v c]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Product v c -> [Symbol v c]
forall v c. Product v c -> [Symbol v c]
unP)
       ([Product v c] -> [Product v c])
-> (SOP v c -> [Product v c]) -> SOP v c -> [Product v c]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOP v c -> [Product v c]
forall v c. SOP v c -> [Product v c]
unS

    repeatF :: (t -> t) -> t -> t
repeatF t -> t
f t
x =
      let x' :: t
x' = t -> t
f t
x
      in  if t
x' t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== t
x
             then t
x
             else (t -> t) -> t -> t
repeatF t -> t
f t
x'
{-# INLINEABLE simplifySOP #-}

-- | Merge two SOP terms by additions
mergeSOPAdd :: (Ord v, Ord c) => SOP v c -> SOP v c -> SOP v c
mergeSOPAdd :: SOP v c -> SOP v c -> SOP v c
mergeSOPAdd (S [Product v c]
sop1) (S [Product v c]
sop2) = SOP v c -> SOP v c
forall v c. (Ord v, Ord c) => SOP v c -> SOP v c
simplifySOP (SOP v c -> SOP v c) -> SOP v c -> SOP v c
forall a b. (a -> b) -> a -> b
$ [Product v c] -> SOP v c
forall v c. [Product v c] -> SOP v c
S ([Product v c]
sop1 [Product v c] -> [Product v c] -> [Product v c]
forall a. [a] -> [a] -> [a]
++ [Product v c]
sop2)
{-# INLINEABLE mergeSOPAdd #-}

-- | Merge two SOP terms by multiplication
mergeSOPMul :: (Ord v, Ord c) => SOP v c -> SOP v c -> SOP v c
mergeSOPMul :: SOP v c -> SOP v c -> SOP v c
mergeSOPMul (S [Product v c]
sop1) (S [Product v c]
sop2)
  = SOP v c -> SOP v c
forall v c. (Ord v, Ord c) => SOP v c -> SOP v c
simplifySOP
  (SOP v c -> SOP v c)
-> ([Product v c] -> SOP v c) -> [Product v c] -> SOP v c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Product v c] -> SOP v c
forall v c. [Product v c] -> SOP v c
S
  ([Product v c] -> SOP v c) -> [Product v c] -> SOP v c
forall a b. (a -> b) -> a -> b
$ (Product v c -> [Product v c]) -> [Product v c] -> [Product v c]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Product v c -> Product v c -> Product v c)
-> [Product v c] -> [Product v c] -> [Product v c]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Product v c
p1 Product v c
p2 -> [Symbol v c] -> Product v c
forall v c. [Symbol v c] -> Product v c
P (Product v c -> [Symbol v c]
forall v c. Product v c -> [Symbol v c]
unP Product v c
p1 [Symbol v c] -> [Symbol v c] -> [Symbol v c]
forall a. [a] -> [a] -> [a]
++ Product v c -> [Symbol v c]
forall v c. Product v c -> [Symbol v c]
unP Product v c
p2)) [Product v c]
sop1 ([Product v c] -> [Product v c])
-> (Product v c -> [Product v c]) -> Product v c -> [Product v c]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Product v c -> [Product v c]
forall a. a -> [a]
repeat) [Product v c]
sop2
{-# INLINEABLE mergeSOPMul #-}