module Matrix.LU (lu, lu_solve, improve, inverse, lu_det, solve, det) where
import qualified Matrix.Matrix as Matrix
import qualified Matrix.Vector as Vector
import qualified Data.List as List
import Data.Array
import Data.Ord (comparing)
lu :: Array (Int,Int) Double
-> Array (Int,Int) Double
lu a = a'
where a' = array bnds [ ((i,j), luij i j) | (i,j) <- range bnds ]
luij i j =
if i>j
then (a!(i,j) - sum [ a'!(i,k) * a'!(k,j) | k <- [1 ..(j-1)] ]) / a'!(j,j)
else a!(i,j) - sum [ a'!(i,k) * a'!(k,j) | k <- [1 ..(i-1)] ]
bnds = bounds a
lu_solve :: Array (Int,Int) Double
-> Array Int Double
-> Array Int Double
lu_solve a b = x
where x = array (1,n) ([(n,xn)] ++ [ (i, backward i) | i <- (reverse [1..(n-1)]) ])
y = array (1,n) ([(1,y1)] ++ [ (i, forward i) | i <- [2..n] ])
y1 = b!1
forward i = (b!i - sum [ a!(i,j) * y!j | j <- [1..(i-1)] ])
xn = y!n / a!(n,n)
backward i = (y!i - sum [ a!(i,j) * x!j | j <- [(i+1)..n] ]) / a!(i,i)
((_,_),(n,_)) = bounds a
improve :: Array (Int,Int) Double
-> Array (Int,Int) Double
-> Array Int Double
-> Array Int Double
-> Array Int Double
improve a a_lu b x = array (1,n) [ (i, x!i - err!i) | i <- [1..n] ]
where err = lu_solve a_lu rhs
rhs = array (1,n) [ (i, sum [ a!(i,j) * x!j | j <- [1..n] ] - b!i) | i <- [1..n] ]
((_,_),(n,_)) = bounds a
inverse :: Array (Int,Int) Double
-> Array (Int,Int) Double
inverse a0 = a'
where a' = array (bounds a0) (arrange (makecols (lu a0)) 1)
makecol i n' = array (1,n') [ (j, if i == j then 1.0 else 0.0) | j <- [1..n'] ]
makecols a = [ lu_solve a (makecol i n) | i <- [1..n] ]
((_,_),(n,_)) = bounds a0
arrange [] _ = []
arrange (m:ms) j = flatten m j ++ arrange ms (j+1)
flatten m j = map (\(i,x) -> ((i,j),x)) (assocs m)
lu_det :: Array (Int,Int) Double
-> Double
lu_det a = product [ a!(i,i) | i <- [ 1 .. n] ]
where ((_,_),(n,_)) = bounds a
solve :: Array (Int,Int) Double
-> Array Int Double
-> Array Int Double
solve a b = (lu_solve . lu) a b
_det :: Array (Int,Int) Double
-> Double
_det a = (lu_det . lu) a
det :: Array (Int,Int) Double
-> Double
det a =
if rangeSize (bounds a) == 0
then 1
else
let ((m0,n0), (m1,n1)) = bounds a
v = Matrix.getColumn n0 a
(maxi,maxv) = List.maximumBy (comparing (abs . snd)) $ assocs v
reduced =
ixmap ((m0,n0), (pred m1, pred n1))
(\(i,j) -> (if i<maxi then i else succ i, succ j)) $
Vector.sub a $ Matrix.outer v $
Vector.scale (recip maxv) $ Matrix.getRow maxi a
sign = if even (rangeSize (m0,maxi)-1) then 1 else -1
pivot = a!(maxi,n0)
in if pivot == 0 then 0 else sign * pivot * det reduced