module Math.MFSolve
(SimpleExpr(..), Expr, LinExpr(..), UnaryOp(..), BinaryOp(..),
Dependencies, DepError(..), SimpleVar(..),
getKnown, knownVars, varDefined, nonlinearEqs, dependendVars,
simpleExpr, emptyDeps, makeVariable, makeConstant,
(===), (=&=), solveEqs, showVars)
where
import qualified Data.HashMap.Strict as M
import qualified Data.HashSet as H
import GHC.Generics
import Data.Hashable
import Data.Maybe
import Data.List
import Data.Function(on)
import Control.Monad
infixr 1 === , =&=
data UnaryFun n = UnaryFun UnaryOp (n -> n)
data BinaryOp = Add | Mul
deriving Eq
data UnaryOp =
Sin | Abs | Recip | Signum |
Exp | Log | Cos | Cosh | Atanh |
Tan | Sinh | Asin | Acos | Asinh | Acosh | Atan
deriving (Eq, Generic)
data SimpleExpr v n =
SEBin BinaryOp (SimpleExpr v n) (SimpleExpr v n) |
SEUn UnaryOp (SimpleExpr v n) |
Var v |
Const n
newtype SimpleVar = SimpleVar String
deriving (Eq, Ord, Generic)
data Expr v n = Expr (LinExpr v n) [TrigTerm v n] [NonLinExpr v n]
deriving Generic
data LinExpr v n = LinExpr n [(v, n)]
deriving (Generic, Eq, Show)
type Period v n = [(v, n)]
type Phase n = n
type Amplitude v n = LinExpr v n
type TrigTerm v n = (Period v n, [(Phase n, Amplitude v n)])
data NonLinExpr v n =
UnaryApp (UnaryFun n) (Expr v n) |
MulExp (Expr v n) (Expr v n) |
SinExp (Expr v n)
deriving Generic
type LinearMap v n = M.HashMap v (LinExpr v n)
type TrigEq v n = (Period v n, Amplitude v n, Phase n, n)
type TrigEq2 v n = M.HashMap (Period v n)
(M.HashMap v (Expr v n))
instance (Hashable v, Hashable n) => Hashable (LinExpr v n)
instance (Hashable v, Hashable n) => Hashable (NonLinExpr v n)
instance (Hashable n) => Hashable (UnaryFun n) where
hashWithSalt s (UnaryFun o _) = hashWithSalt s o
instance Hashable UnaryOp
instance (Hashable v, Hashable n) => Hashable (Expr v n)
instance Hashable SimpleVar
instance Show SimpleVar where
show (SimpleVar s) = s
data Dependencies v n = Dependencies
(M.HashMap v (H.HashSet v))
(LinearMap v n)
[TrigEq v n]
(TrigEq2 v n)
[Expr v n]
data DepError n =
InconsistentEq n |
RedundantEq
instance (Ord n, Num n, Eq n, Show v, Show n) => Show (Expr v n) where
show e = show (simpleExpr e)
withParens :: (Show t1, Show t, Ord t1, Num t1, Eq t1) => SimpleExpr t t1 -> [BinaryOp] -> String
withParens e@(SEBin op _ _) ops
| op `elem` ops = "(" ++ show e ++ ")"
withParens e _ = show e
instance (Show v, Ord n, Show n, Num n, Eq n) => Show (SimpleExpr v n) where
show (Var v) = show v
show (Const n) = show n
show (SEBin Add e1 (SEBin Mul (Const e2) e3))
| e2 < 0 =
show e1 ++ " - " ++ show (SEBin Mul (Const (negate e2)) e3)
show (SEBin Add e1 e2) =
show e1 ++ " + " ++ show e2
show (SEBin Mul (Const 1) e) = show e
show (SEBin Mul e (Const 1)) = show e
show (SEBin Mul e1 (SEUn Recip e2)) =
withParens e1 [Add] ++ "/" ++ withParens e2 [Add, Mul]
show (SEBin Mul e1 e2) =
withParens e1 [Add] ++ "*" ++ withParens e2 [Add]
show (SEUn Exp (SEBin Mul (SEUn Log e1) e2)) =
withParens e1 [Add, Mul] ++ "**" ++ withParens e2 [Add, Mul]
show (SEUn op e) = show op ++ "(" ++ show e ++ ")"
instance Show BinaryOp where
show Add = "+"
show Mul = "*"
instance Show UnaryOp where
show Sin = "sin"
show Abs = "abs"
show Recip = "1/"
show Signum = "sign"
show Exp = "exp"
show Log = "log"
show Cos = "cos"
show Cosh = "cosh"
show Atanh = "atanh"
show Tan = "tan"
show Sinh = "sinh"
show Asin = "asin"
show Acos = "acos"
show Asinh = "asinh"
show Acosh = "acosh"
show Atan = "atan"
instance (Floating n, Ord n, Ord v) => Num (Expr v n) where
(+) = addExpr
(*) = mulExpr
negate = mulExpr (makeConstant (1))
abs = unExpr (UnaryFun Abs abs)
signum = unExpr (UnaryFun Signum signum)
fromInteger = makeConstant . fromInteger
instance (Floating n, Ord n, Ord v) => Fractional (Expr v n) where
recip = unExpr (UnaryFun Recip (1.0/))
fromRational = makeConstant . fromRational
instance (Floating n, Ord n, Ord v) => Floating (Expr v n) where
pi = makeConstant pi
exp = unExpr (UnaryFun Exp exp)
log = unExpr (UnaryFun Log log)
sin = sinExpr
cos a = sinExpr (a + makeConstant (pi/2))
cosh = unExpr (UnaryFun Cosh cosh)
atanh = unExpr (UnaryFun Atanh atanh)
tan = unExpr (UnaryFun Tan tan)
sinh = unExpr (UnaryFun Sinh sinh)
asin = unExpr (UnaryFun Asin asin)
acos = unExpr (UnaryFun Acos acos)
asinh = unExpr (UnaryFun Asinh asinh)
acosh = unExpr (UnaryFun Acosh acosh)
atan = unExpr (UnaryFun Atan atan)
instance (Show n, Floating n, Ord n, Ord v, Show v) =>Show (Dependencies v n) where
show dep@(Dependencies _ lin _ _ _) =
unlines (map showLin (M.toList lin) ++
map showNl (nonlinearEqs dep))
where showLin (v, e) = show v ++ " = " ++ show (linExpr e)
showNl e = show e ++ " = 0"
instance (Show n) => Show (DepError n) where
show (InconsistentEq a) =
"Inconsistent equations, off by " ++ show a
show RedundantEq =
"Redundant Equation."
addSimple :: (Num t1, Eq t1) => SimpleExpr t t1 -> SimpleExpr t t1 -> SimpleExpr t t1
addSimple (Const 0) e = e
addSimple e (Const 0) = e
addSimple e1 e2 = SEBin Add e1 e2
linToSimple :: (Num t1, Eq t1) => LinExpr t t1 -> SimpleExpr t t1
linToSimple (LinExpr v t) =
Const v `addSimple`
foldr (addSimple.mul) (Const 0) t
where
mul (v2, 1) = Var v2
mul (v2, c) = SEBin Mul (Const c) (Var v2)
trigToSimple :: (Num n, Eq n) => TrigTerm v n -> SimpleExpr v n
trigToSimple (theta, t) =
foldr (addSimple.makeSin) (Const 0) t
where
makeSin (alpha, n) =
SEBin Mul (linToSimple n)
(SEUn Sin angle) where
angle | alpha == 0 =
linToSimple (LinExpr 0 theta)
| otherwise =
SEBin Add (linToSimple (LinExpr 0 theta))
(Const alpha)
nonlinToSimple :: (Num n, Eq n) => NonLinExpr v n -> SimpleExpr v n
nonlinToSimple (UnaryApp (UnaryFun o _) e) =
SEUn o (simpleExpr e)
nonlinToSimple (MulExp e1 e2) =
SEBin Mul (simpleExpr e1) (simpleExpr e2)
nonlinToSimple (SinExp e) =
SEUn Sin (simpleExpr e)
simpleExpr :: (Num n, Eq n) => Expr v n -> SimpleExpr v n
simpleExpr (Expr lin trig nonlin) =
linToSimple lin `addSimple`
foldr (addSimple.trigToSimple)
(Const 0) trig `addSimple`
foldr (addSimple.nonlinToSimple)
(Const 0) nonlin
zeroTerm :: (Num n) => LinExpr v n
zeroTerm = LinExpr 0 []
linExpr :: LinExpr v n -> Expr v n
linExpr lt = Expr lt [] []
makeConstant :: n -> Expr v n
makeConstant c = linExpr (LinExpr c [])
makeVariable :: Num n => v -> Expr v n
makeVariable v = linExpr (LinExpr 0 [(v, 1)])
trigExpr :: (Num n) => [TrigTerm v n] -> Expr v n
trigExpr t = Expr zeroTerm t []
nonlinExpr :: Num n => [NonLinExpr v n] -> Expr v n
nonlinExpr = Expr zeroTerm []
getConst :: LinExpr t a -> Maybe a
getConst (LinExpr a []) = Just a
getConst _ = Nothing
getLin :: Expr t n -> Maybe (LinExpr t n)
getLin (Expr lt [] []) = Just lt
getLin _ = Nothing
getConstExpr :: Expr t b -> Maybe b
getConstExpr e = getLin e >>= getConst
addLin :: (Ord v, Num n, Eq n) => LinExpr v n -> LinExpr v n -> LinExpr v n
addLin (LinExpr c1 terms1) (LinExpr c2 terms2) =
LinExpr (c1+c2) terms3 where
terms3 = filter ((/= 0) . snd) $
merge terms1 terms2 (+)
addExpr :: (Ord n, Ord v, Floating n) => Expr v n -> Expr v n -> Expr v n
addExpr (Expr lt1 trig1 nl1) (Expr lt2 trig2 nl2) =
Expr (addLin lt1 lt2) trig3 (nl1++nl2)
where
trig3 = merge trig1 trig2 addTrigTerms
merge :: Ord k => [(k, v)] -> [(k, v)] -> (v -> v -> v) -> [(k, v)]
merge [] l _ = l
merge l [] _ = l
merge (a@(k,v):as) (b@(k2,v2):bs) f = case compare k k2 of
LT -> a: merge as (b:bs) f
EQ -> (k, f v v2): merge as bs f
GT -> b: merge (a:as) bs f
addTrigTerms :: (Ord a, Ord t, Floating a) => [(a, LinExpr t a)] -> [(a, LinExpr t a)] -> [(a, LinExpr t a)]
addTrigTerms [] p = p
addTrigTerms terms terms2 =
foldr mergeTerms terms terms2
where
mergeTerms (alpha, n) ((beta, m):rest) =
case addTrigTerm alpha n beta m of
Just (_, LinExpr 0 []) -> rest
Just (gamma, o) ->
mergeTerms (gamma, o) rest
Nothing -> (beta, m) : mergeTerms (alpha, n) rest
mergeTerms a [] = [a]
addTrigTerm :: (Ord a, Ord t, Floating a) => a -> LinExpr t a -> a -> LinExpr t a -> Maybe (a, LinExpr t a)
addTrigTerm alpha n beta m
| alpha == beta =
Just (alpha, addLin n m)
| Just r <- termIsMultiple n m =
let gamma = atan (divident/divisor) +
(if divisor < 0 then pi else 0)
divident = r*sin alpha + sin beta
divisor = r*cos alpha + cos beta
o = sqrt(divident*divident + divisor*divisor)
in Just (gamma, mulLinExpr o m)
| otherwise = Nothing
termIsMultiple :: (Ord a, Fractional a, Eq t) => LinExpr t a -> LinExpr t a -> Maybe a
termIsMultiple (LinExpr _ _) (LinExpr 0 []) = Nothing
termIsMultiple (LinExpr 0 []) (LinExpr _ _) = Nothing
termIsMultiple (LinExpr 0 r1@((_, d1):_)) (LinExpr 0 r2@((_, d2):_))
| compareBy r1 r2 (compareTerm (d1/d2)) =
Just (d1/d2)
termIsMultiple (LinExpr c1 r1) (LinExpr c2 r2)
| compareBy r1 r2 (compareTerm (c1/c2)) =
Just (c1/c2)
| otherwise = Nothing
compareTerm :: (Ord a1, Fractional a1, Eq a) => a1 -> (a, a1) -> (a, a1) -> Bool
compareTerm ratio (v3,c3) (v4, c4) =
v3 == v4 && (abs(c3 (c4 * ratio)) <= abs c3*2e-50)
compareBy :: [a] -> [b] -> (a -> b -> Bool) -> Bool
compareBy [] [] _ = True
compareBy (e:l) (e2:l2) f =
f e e2 && compareBy l l2 f
compareBy _ _ _ = False
mulLinExpr :: Num n => n -> LinExpr v n -> LinExpr v n
mulLinExpr x (LinExpr e terms) =
LinExpr (e*x) $ map (fmap (*x)) terms
mulConstTrig :: (Ord n, Num n) => n -> TrigTerm v n -> TrigTerm v n
mulConstTrig c (theta, terms) = (theta, tt) where
tt = map (fmap (mulLinExpr c)) terms
mulLinTrig :: (Ord n, Ord v, Floating n) => LinExpr v n -> TrigTerm v n -> Expr v n
mulLinTrig lt (theta, terms) =
foldr ((+).mul1) 0 terms
where
mul1 (alpha, LinExpr c []) =
trigExpr [(theta, [(alpha, mulLinExpr c lt)])]
mul1 t =
nonlinExpr [MulExp (trigExpr [(theta, [t])])
(Expr lt [] [])]
mulExpr :: (Ord a, Ord t, Floating a) => Expr t a -> Expr t a -> Expr t a
mulExpr (getConstExpr -> Just c) (Expr lt2 trig []) =
Expr (mulLinExpr c lt2)
(map (mulConstTrig c) trig) []
mulExpr (Expr lt2 trig []) (getConstExpr -> Just c) =
Expr (mulLinExpr c lt2)
(map (mulConstTrig c) trig) []
mulExpr (getLin -> Just lt) (Expr (getConst -> Just c) trig []) =
linExpr (mulLinExpr c lt) +
foldr ((+).mulLinTrig lt) 0 trig
mulExpr (Expr (getConst -> Just c) trig []) (getLin -> Just lt) =
linExpr (mulLinExpr c lt) +
foldr ((+).mulLinTrig lt) 0 trig
mulExpr e1 e2 = nonlinExpr [MulExp e1 e2]
sinExpr :: Floating n => Expr v n -> Expr v n
sinExpr (Expr (LinExpr c t) [] [])
| null t = makeConstant (sin c)
| otherwise = trigExpr [(t, [(c, LinExpr 1 [])])]
sinExpr e = nonlinExpr [SinExp e]
unExpr :: Num n => UnaryFun n -> Expr v n -> Expr v n
unExpr (UnaryFun _ f) e
| Just c <- getConstExpr e = makeConstant (f c)
unExpr f e = nonlinExpr [UnaryApp f e]
substVarLin :: (Ord v, Num n, Eq n) => (v -> Maybe (LinExpr v n)) -> LinExpr v n -> LinExpr v n
substVarLin s (LinExpr a terms) =
let substOne (v, c) =
maybe (LinExpr 0 [(v, c)]) (mulLinExpr c) (s v)
in foldr (addLin.substOne) (LinExpr a []) terms
substVarNonLin :: (Ord n, Ord v, Floating n) => (v -> Maybe (LinExpr v n)) -> NonLinExpr v n -> Expr v n
substVarNonLin s (UnaryApp f e1) =
unExpr f (subst s e1)
substVarNonLin s (MulExp e1 e2) =
subst s e1 * subst s e2
substVarNonLin s (SinExp e1) =
sin (subst s e1)
substVarTrig :: (Ord v, Ord n, Floating n) => (v -> Maybe (LinExpr v n)) -> ([(v, n)], [(n, LinExpr v n)]) -> Expr v n
substVarTrig s (period, terms) =
let period2 = linExpr $ substVarLin s (LinExpr 0 period)
terms2 = map (fmap $ linExpr.substVarLin s) terms
in foldr (\(p,a) -> (+ (a * sin (makeConstant p + period2))))
0 terms2
subst :: (Ord n, Ord v, Floating n) => (v -> Maybe (LinExpr v n)) -> Expr v n -> Expr v n
subst s (Expr lt trig nl) =
linExpr (substVarLin s lt) +
foldr ((+).substVarTrig s) 0 trig +
foldr ((+).substVarNonLin s) 0 nl
emptyDeps :: Dependencies v n
emptyDeps = Dependencies M.empty M.empty [] M.empty []
simpleSubst :: Eq a => a -> b -> a -> Maybe b
simpleSubst x y z
| x == z = Just y
| otherwise = Nothing
(===) :: (Hashable n, Hashable v, RealFrac (Phase n), Ord v,
Floating n) => Expr v n -> Expr v n
-> Dependencies v n
-> Either (DepError n) (Dependencies v n)
(===) lhs rhs deps = addEq deps (lhs rhs)
addEqs :: (Hashable v, Hashable n, RealFrac (Phase n), Ord v, Floating n) => Dependencies v n -> [Expr v n] -> Either (DepError n) (Dependencies v n)
addEqs = foldM addEq
addEq :: (Hashable n, Hashable v, RealFrac (Phase n), Ord v,
Floating n) =>
Dependencies v n
-> Expr v n -> Either (DepError n) (Dependencies v n)
addEq deps@(Dependencies _ lin _ _ _) expr =
addEq0 deps $
subst (flip M.lookup lin) expr
select :: [a] -> [(a, [a])]
select [] = []
select (x:xs) =
(x,xs) : [(y,x:ys) | (y,ys) <- select xs]
addEq0 :: (Hashable v, Hashable n, RealFrac (Phase n), Ord v, Floating n) => Dependencies v n -> Expr v n -> Either (DepError n) (Dependencies v n)
addEq0 _ (getConstExpr -> Just c) =
if c == 0 then Left RedundantEq
else Left (InconsistentEq c)
addEq0 (Dependencies vdep lin trig trig2 nonlin) (Expr lt [] []) =
let (v, _, lt2) = splitMax lt
depVars = fromMaybe H.empty (M.lookup v vdep)
lin' = M.insert v lt2 $
H.foldl' (flip $ M.adjust $ substVarLin $
simpleSubst v lt2) lin depVars
ltVars = case lt2 of
LinExpr _ vars -> map fst vars
depVars2 = H.insert v depVars
vdep' = H.foldl'
(\mp k -> M.insertWith H.union k depVars2 mp)
(M.delete v vdep) (H.fromList ltVars)
nonlin' = map (subst (simpleSubst v lt2)) nonlin
trigSubst (p, a, ph, c) =
subst (simpleSubst v lt2) $
sin (linExpr $ LinExpr ph p) *
linExpr a + makeConstant c
newTrig = map trigSubst trig
trigSubst2 (v2, ex) =
subst (simpleSubst v lt2) $
makeVariable v2 ex
newTrig2 =
map trigSubst2 $
concatMap M.toList $
M.elems trig2
in addEqs (Dependencies vdep' lin' [] M.empty []) (newTrig++newTrig2++nonlin')
addEq0 deps@(Dependencies vdep lin trig trig2 nl)
(Expr (LinExpr c lt) [(theta, [(alpha, getConst -> Just n)])] []) =
if null lt then
addEq0 deps (linExpr $ LinExpr (alpha asin (c/n)) theta)
else
case M.lookup theta trig2 of
Nothing -> addSin (LinExpr c lt) alpha n
Just map2 ->
case foldr ((+).doSubst)
(makeConstant c +
makeConstant n *
sin (linExpr $ LinExpr alpha theta))
lt of
Expr lt2 [(_, [(alpha2, getConst -> Just n2)])] []
| isNothing(getConst lt2)
-> addSin lt2 alpha2 n2
e2 -> addEq0 deps e2
where
doSubst (v,c2) = case M.lookup v map2 of
Nothing -> makeVariable v * makeConstant c2
Just e2 -> e2 * makeConstant c2
where
addSin l' a' n' =
let (v, c', r) = splitMax l'
trig2' = M.insertWith M.union theta
(M.singleton v $
Expr r [(theta, [(a', LinExpr (n'/negate c') [])])] [])
trig2
in Right $ Dependencies vdep lin trig trig2' nl
addEq0 (Dependencies d lin [] trig2 nl) (Expr (getConst -> Just c) [(theta, [(alpha, n)])] []) =
Right $ Dependencies d lin [(theta, n, alpha, c)] trig2 nl
addEq0 (Dependencies deps lin trig trig2 nl)
(Expr (getConst -> Just x) [(theta, [(a, n)])] []) =
case mapMaybe similarTrig $ select trig of
[] -> Right $ Dependencies deps lin ((theta, n, a, x):trig) trig2 nl
l -> addEqs (Dependencies deps lin rest trig2 nl) [lin1, lin2]
where
((b,y), rest) = maximumBy (maxTrig `on` fst) l
maxTrig (t1,_) (t2,_) =
compare ((t1a)`dmod`pi) ((t2a)`dmod`pi)
d = sin(ab)
e = y*cos(ab)x
theta2 = atan (y*d/e)b +
(if (d*e) < 0 then pi else 0)
n2 = sqrt(y*y + e*e/(d*d))
lin1 = linExpr $ LinExpr (theta2) theta
lin2 = linExpr n makeConstant n2
where
similarTrig ((t,m,b,y),rest)
| Just r <- termIsMultiple m n,
t == theta &&
(ba) `dmod` pi > pi/8 =
Just ((b,y/r),rest)
| otherwise = Nothing
addEq0 (Dependencies d lin trig trig2 nonlin) e =
Right $ Dependencies d lin trig trig2 (e:nonlin)
dmod :: RealFrac a => a -> a -> a
dmod a b = abs((a/b) fromInteger (round (a/b)) * b)
splitMax :: (Ord b, Fractional b, Eq v) => LinExpr v b -> (v, b, LinExpr v b)
splitMax (LinExpr c t) =
let (v,c2) = maximumBy (compare `on` (abs.snd)) t
in (v, c2,
LinExpr (c/c2) $
map (fmap (negate.(/c2))) $
filter ((/= v).fst) t)
varDefined :: (Eq v, Hashable v) => Dependencies v n -> v -> Bool
varDefined (Dependencies _ dep _ _ _) v =
case M.lookup v dep of
Nothing -> False
_ -> True
dependendVars :: (Eq n) => Dependencies v n -> [(v, LinExpr v n)]
dependendVars (Dependencies _ lin _ _ _) =
filter (notConst.snd) (M.toList lin)
where
notConst (LinExpr _ []) = False
notConst _ = True
knownVars :: Dependencies v n -> [(v, n)]
knownVars (Dependencies _ lin _ _ _) =
mapMaybe knownVar $ M.toList lin
where
knownVar (v, LinExpr n []) = Just (v, n)
knownVar _ = Nothing
getKnown :: (Eq v, Hashable v) => Dependencies v n -> v -> Either [v] n
getKnown (Dependencies _ lin _ _ _) var =
case M.lookup var lin of
Nothing -> Left []
Just (LinExpr a []) ->
Right a
Just (LinExpr _ v) ->
Left $ map fst v
trigToExpr :: (Ord n, Ord v, Floating n) => TrigEq v n -> Expr v n
trigToExpr (p, a, ph, c) =
linExpr a * sin(linExpr $ LinExpr ph p) +
makeConstant c
nonlinearEqs :: (Ord n, Ord v, Floating n) => Dependencies v n -> [Expr v n]
nonlinearEqs (Dependencies _ _ trig trig2 nl) =
map trigToExpr trig ++
map (\(v, e) -> makeVariable v e)
(concatMap M.toList (M.elems trig2)) ++
nl
(=&=) :: (Hashable n, Hashable v, RealFrac (Phase n), Ord v, Floating n) => (Expr v n, Expr v n) -> (Expr v n, Expr v n) -> Dependencies v n -> Either (DepError n) (Dependencies v n)
(=&=) (a, b) (c, d) dep =
case (a === c) dep of
Left RedundantEq ->
(b === d) dep
Right res ->
case (b === d) res of
Left RedundantEq -> Right res
Right res2 -> Right res2
err -> err
err -> err
solveEqs :: Dependencies v n -> [Dependencies v n -> Either (DepError n) (Dependencies v n)] -> Either (DepError n) (Dependencies v n)
solveEqs = foldM $ flip ($)
showVars :: (Show n, Show v, Show a, Ord n, Ord v, Floating n) => Either (DepError a) (Dependencies v n) -> IO ()
showVars (Left e) = print e
showVars (Right dep) = print dep