module Matrix.QR.Householder (
leastSquares,
decompose, solve, det,
Reflection, reflectMatrix, reflectVector,
Upper, matrixFromUpper, solveUpper, detUpper,
) where
import Matrix.Matrix (mv_mult, m_trans, getRow, getColumn, inner, outer)
import Matrix.Vector (sub, scale, norm)
import DSP.Basic (toMaybe)
import qualified Data.List as List
import Data.Array
(Array, Ix, bounds, elems, range, rangeSize,
accum, accumArray, assocs, ixmap, listArray, (!), (//), )
decompose :: (Ix i, Enum i, Ix j, Enum j, RealFloat a) =>
Array (i,j) a
-> ([Reflection i a], Upper i j a)
decompose a =
(\(qs,rows) -> (qs, Upper (bounds a) rows)) .
unzip .
List.unfoldr
(\a0 ->
let bnds@((m0,_), _) = bounds a0
in toMaybe (not $ emptyRange bnds) $
let (q,a1) = step a0
in ((q, getRow m0 a1), submatrix a1))
$ a
emptyRange :: (Ix i) => (i,i) -> Bool
emptyRange = null . range
step ::
(Ix i, Ix j, RealFloat a) =>
Array (i,j) a -> (Reflection i a, Array (i,j) a)
step a =
let (m0,n0) = fst $ bounds a
z = getColumn n0 a
sign x = if x<0 then -1 else 1
q = reflection $ accum (+) z [(m0, sign(z!m0) * norm z)]
in (q, reflectMatrix q a)
submatrix :: (Ix i, Enum i, Ix j, Enum j) => Array (i,j) e -> Array (i,j) e
submatrix a =
let ((m0,n0), (m1,n1)) = bounds a
in ixmap ((succ m0, succ n0), (m1,n1)) id a
data Upper i j a = Upper ((i,j), (i,j)) [Array j a]
matrixFromUpper :: (Ix i, Ix j, Num a) => Upper i j a -> Array (i,j) a
matrixFromUpper (Upper bnds@((m0,_n0), (m1,_n1)) rows) =
accumArray (const id) 0 bnds $ concat $
zipWith (\k -> map (\(j,a) -> ((k,j),a)) . assocs) (range (m0,m1)) rows
newtype Reflection i a = Reflection (Array i a)
reflection :: (Ix i, Floating a) => Array i a -> Reflection i a
reflection v =
let normv = norm v
in Reflection $ fmap (/ ((1-signum normv) + normv)) v
reflectMatrixFull ::
(Ix i, Ix j, Num a) => Reflection i a -> Array (i,j) a -> Array (i,j) a
reflectMatrixFull (Reflection v) a =
sub a $ scale 2 $ outer v $ mv_mult (m_trans a) v
reflectMatrix ::
(Ix i, Ix j, Num a) => Reflection i a -> Array (i,j) a -> Array (i,j) a
reflectMatrix q@(Reflection v) a =
let (k0,k1) = bounds v
((m0,n0), (m1,n1)) = bounds a
bnds = ((k0,n0),(k1,n1))
in case (compare k0 m0, compare k1 m1) of
(EQ,EQ) -> reflectMatrixFull q a
(LT,_) -> error "reflectMatrix: lower reflection dimension too small"
(_,GT) -> error "reflectMatrix: upper reflection dimension too big"
_ -> replaceSubArray a $ reflectMatrixFull q $ subArray bnds a
reflectVectorFull :: (Ix i, Num a) => Reflection i a -> Array i a -> Array i a
reflectVectorFull (Reflection v) a = sub a $ scale (2 * inner v a) v
reflectVector :: (Ix i, Num a) => Reflection i a -> Array i a -> Array i a
reflectVector q@(Reflection v) a =
let bnds@(k0,k1) = bounds v
(m0,m1) = bounds a
in case (compare k0 m0, compare k1 m1) of
(EQ,EQ) -> reflectVectorFull q a
(LT,_) -> error "reflectVector: lower reflection dimension too small"
(_,GT) -> error "reflectVector: upper reflection dimension too big"
_ -> replaceSubArray a $ reflectVectorFull q $ subArray bnds a
subArray :: (Ix i) => (i,i) -> Array i a -> Array i a
subArray bnds = ixmap bnds id
replaceSubArray :: (Ix i) => Array i a -> Array i a -> Array i a
replaceSubArray x y = x // assocs y
solveUpper ::
(Ix i, Ix j, Fractional a) => Upper i j a -> Array i a -> Array j a
solveUpper (Upper ((m0,n0), (m1,n1)) rs0) b =
if bounds b == (m0,m1)
then
listArray (n0,n1) $
foldr
(\(r,bi) xs ->
let (a:as) = elems r
in (bi - sum (zipWith (*) as xs)) / a : xs)
[]
(zip rs0 (elems b))
else error "solveUpper: vertical bounds mismatch"
solve ::
(Ix i, Ix j, Fractional a) =>
([Reflection i a], Upper i j a) -> Array i a -> Array j a
solve (qs, u) b = solveUpper u $ foldl (flip reflectVector) b qs
leastSquares ::
(Ix i, Enum i, Ix j, Enum j, RealFloat a) =>
Array (i,j) a -> Array i a -> Array j a
leastSquares = solve . decompose
detUpper :: (Ix i, Ix j, Fractional a) => Upper i j a -> a
detUpper (Upper ((_m0,n0), (_m1,n1)) rs) =
if rangeSize (n0,n1) == length rs
then product $ map (head . elems) rs
else 0
det :: (Ix i, Enum i, Ix j, Enum j, RealFloat a) => Array (i,j) a -> a
det a =
let (qs,u) = decompose a
in (if even (length qs) then 1 else -1) * detUpper u