{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702
{-# LANGUAGE Trustworthy #-}
#endif
module Linear.Matrix
( (!*!), (!+!), (!-!), (!*), (*!), (!!*), (*!!), (!!/)
, column
, adjoint
, M22, M23, M24, M32, M33, M34, M42, M43, M44
, m33_to_m44, m43_to_m44
, det22, det33, det44, inv22, inv33, inv44
, identity
, Trace(..)
, translation
, transpose
, fromQuaternion
, mkTransformation
, mkTransformationMat
, _m22, _m23, _m24
, _m32, _m33, _m34
, _m42, _m43, _m44
#if MIN_VERSION_base(4,8,0)
, lu
, luFinite
, forwardSub
, forwardSubFinite
, backwardSub
, backwardSubFinite
, luSolve
, luSolveFinite
, luInv
, luInvFinite
, luDet
, luDetFinite
#endif
) where
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative
#endif
import Control.Lens hiding (index)
import Control.Lens.Internal.Context
import Data.Distributive
import Data.Foldable as Foldable
import Data.Functor.Rep
import Linear.Quaternion
import Linear.V2
import Linear.V3
import Linear.V4
import Linear.Vector
import Linear.Conjugate
import Linear.Trace
#if MIN_VERSION_base(4,8,0)
import GHC.TypeLits
import Linear.V
#endif
#ifdef HLINT
{-# ANN module "HLint: ignore Reduce duplication" #-}
#endif
column :: Representable f => LensLike (Context a b) s t a b -> Lens (f s) (f t) (f a) (f b)
column l f es = o <$> f i where
go = l (Context id)
i = tabulate $ \ e -> ipos $ go (index es e)
o eb = tabulate $ \ e -> ipeek (index eb e) (go (index es e))
infixl 7 !*!
(!*!) :: (Functor m, Foldable t, Additive t, Additive n, Num a) => m (t a) -> t (n a) -> m (n a)
f !*! g = fmap (\ f' -> Foldable.foldl' (^+^) zero $ liftI2 (*^) f' g) f
infixl 6 !+!
(!+!) :: (Additive m, Additive n, Num a) => m (n a) -> m (n a) -> m (n a)
as !+! bs = liftU2 (^+^) as bs
infixl 6 !-!
(!-!) :: (Additive m, Additive n, Num a) => m (n a) -> m (n a) -> m (n a)
as !-! bs = liftU2 (^-^) as bs
infixl 7 !*
(!*) :: (Functor m, Foldable r, Additive r, Num a) => m (r a) -> r a -> m a
m !* v = fmap (\r -> Foldable.sum $ liftI2 (*) r v) m
infixl 7 *!
(*!) :: (Num a, Foldable t, Additive f, Additive t) => t a -> t (f a) -> f a
f *! g = sumV $ liftI2 (*^) f g
infixl 7 *!!
(*!!) :: (Functor m, Functor r, Num a) => a -> m (r a) -> m (r a)
s *!! m = fmap (s *^) m
{-# INLINE (*!!) #-}
infixl 7 !!*
(!!*) :: (Functor m, Functor r, Num a) => m (r a) -> a -> m (r a)
(!!*) = flip (*!!)
{-# INLINE (!!*) #-}
infixl 7 !!/
(!!/) :: (Functor m, Functor r, Fractional a) => m (r a) -> a -> m (r a)
m !!/ s = fmap (^/ s) m
{-# INLINE (!!/) #-}
adjoint :: (Functor m, Distributive n, Conjugate a) => m (n a) -> n (m a)
adjoint = collect (fmap conjugate)
{-# INLINE adjoint #-}
type M22 a = V2 (V2 a)
type M23 a = V2 (V3 a)
type M24 a = V2 (V4 a)
type M32 a = V3 (V2 a)
type M33 a = V3 (V3 a)
type M34 a = V3 (V4 a)
type M42 a = V4 (V2 a)
type M43 a = V4 (V3 a)
type M44 a = V4 (V4 a)
fromQuaternion :: Num a => Quaternion a -> M33 a
fromQuaternion (Quaternion w (V3 x y z)) =
V3 (V3 (1-2*(y2+z2)) (2*(xy-zw)) (2*(xz+yw)))
(V3 (2*(xy+zw)) (1-2*(x2+z2)) (2*(yz-xw)))
(V3 (2*(xz-yw)) (2*(yz+xw)) (1-2*(x2+y2)))
where x2 = x*x
y2 = y*y
z2 = z*z
xy = x*y
xz = x*z
xw = x*w
yz = y*z
yw = y*w
zw = z*w
{-# INLINE fromQuaternion #-}
mkTransformationMat :: Num a => M33 a -> V3 a -> M44 a
mkTransformationMat (V3 r1 r2 r3) (V3 tx ty tz) =
V4 (snoc3 r1 tx) (snoc3 r2 ty) (snoc3 r3 tz) (V4 0 0 0 1)
where snoc3 (V3 x y z) = V4 x y z
{-# INLINE mkTransformationMat #-}
mkTransformation :: Num a => Quaternion a -> V3 a -> M44 a
mkTransformation = mkTransformationMat . fromQuaternion
{-# INLINE mkTransformation #-}
m43_to_m44 :: Num a => M43 a -> M44 a
m43_to_m44
(V4 (V3 a b c)
(V3 d e f)
(V3 g h i)
(V3 j k l)) =
V4 (V4 a b c 0)
(V4 d e f 0)
(V4 g h i 0)
(V4 j k l 1)
{-# ANN m43_to_m44 "HLint: ignore Use camelCase" #-}
m33_to_m44 :: Num a => M33 a -> M44 a
m33_to_m44 (V3 r1 r2 r3) = V4 (vector r1) (vector r2) (vector r3) (point 0)
{-# ANN m33_to_m44 "HLint: ignore Use camelCase" #-}
identity :: (Num a, Traversable t, Applicative t) => t (t a)
identity = scaled (pure 1)
translation :: (Representable t, R3 t, R4 v) => Lens' (t (v a)) (V3 a)
translation = column _w._xyz
_m22 :: (Representable t, R2 t, R2 v) => Lens' (t (v a)) (M22 a)
_m22 = column _xy._xy
_m23 :: (Representable t, R2 t, R3 v) => Lens' (t (v a)) (M23 a)
_m23 = column _xyz._xy
_m24 :: (Representable t, R2 t, R4 v) => Lens' (t (v a)) (M24 a)
_m24 = column _xyzw._xy
_m32 :: (Representable t, R3 t, R2 v) => Lens' (t (v a)) (M32 a)
_m32 = column _xy._xyz
_m33 :: (Representable t, R3 t, R3 v) => Lens' (t (v a)) (M33 a)
_m33 = column _xyz._xyz
_m34 :: (Representable t, R3 t, R4 v) => Lens' (t (v a)) (M34 a)
_m34 = column _xyzw._xyz
_m42 :: (Representable t, R4 t, R2 v) => Lens' (t (v a)) (M42 a)
_m42 = column _xy._xyzw
_m43 :: (Representable t, R4 t, R3 v) => Lens' (t (v a)) (M43 a)
_m43 = column _xyz._xyzw
_m44 :: (Representable t, R4 t, R4 v) => Lens' (t (v a)) (M44 a)
_m44 = column _xyzw._xyzw
det22 :: Num a => M22 a -> a
det22 (V2 (V2 a b) (V2 c d)) = a * d - b * c
{-# INLINE det22 #-}
det33 :: Num a => M33 a -> a
det33 (V3 (V3 a b c)
(V3 d e f)
(V3 g h i)) = a * (e*i-f*h) - d * (b*i-c*h) + g * (b*f-c*e)
{-# INLINE det33 #-}
det44 :: Num a => M44 a -> a
det44 (V4 (V4 i00 i01 i02 i03)
(V4 i10 i11 i12 i13)
(V4 i20 i21 i22 i23)
(V4 i30 i31 i32 i33)) =
let
s0 = i00 * i11 - i10 * i01
s1 = i00 * i12 - i10 * i02
s2 = i00 * i13 - i10 * i03
s3 = i01 * i12 - i11 * i02
s4 = i01 * i13 - i11 * i03
s5 = i02 * i13 - i12 * i03
c5 = i22 * i33 - i32 * i23
c4 = i21 * i33 - i31 * i23
c3 = i21 * i32 - i31 * i22
c2 = i20 * i33 - i30 * i23
c1 = i20 * i32 - i30 * i22
c0 = i20 * i31 - i30 * i21
in s0 * c5 - s1 * c4 + s2 * c3 + s3 * c2 - s4 * c1 + s5 * c0
{-# INLINE det44 #-}
inv22 :: Fractional a => M22 a -> M22 a
inv22 m@(V2 (V2 a b) (V2 c d)) = (1 / det) *!! V2 (V2 d (-b)) (V2 (-c) a)
where det = det22 m
{-# INLINE inv22 #-}
inv33 :: Fractional a => M33 a -> M33 a
inv33 m@(V3 (V3 a b c)
(V3 d e f)
(V3 g h i))
= (1 / det) *!! V3 (V3 a' b' c')
(V3 d' e' f')
(V3 g' h' i')
where a' = cofactor (e,f,h,i)
b' = cofactor (c,b,i,h)
c' = cofactor (b,c,e,f)
d' = cofactor (f,d,i,g)
e' = cofactor (a,c,g,i)
f' = cofactor (c,a,f,d)
g' = cofactor (d,e,g,h)
h' = cofactor (b,a,h,g)
i' = cofactor (a,b,d,e)
cofactor (q,r,s,t) = det22 (V2 (V2 q r) (V2 s t))
det = det33 m
{-# INLINE inv33 #-}
transpose :: (Distributive g, Functor f) => f (g a) -> g (f a)
transpose = distribute
{-# INLINE transpose #-}
inv44 :: Fractional a => M44 a -> M44 a
inv44 (V4 (V4 i00 i01 i02 i03)
(V4 i10 i11 i12 i13)
(V4 i20 i21 i22 i23)
(V4 i30 i31 i32 i33)) =
let s0 = i00 * i11 - i10 * i01
s1 = i00 * i12 - i10 * i02
s2 = i00 * i13 - i10 * i03
s3 = i01 * i12 - i11 * i02
s4 = i01 * i13 - i11 * i03
s5 = i02 * i13 - i12 * i03
c5 = i22 * i33 - i32 * i23
c4 = i21 * i33 - i31 * i23
c3 = i21 * i32 - i31 * i22
c2 = i20 * i33 - i30 * i23
c1 = i20 * i32 - i30 * i22
c0 = i20 * i31 - i30 * i21
det = s0 * c5 - s1 * c4 + s2 * c3 + s3 * c2 - s4 * c1 + s5 * c0
invDet = recip det
in invDet *!! V4 (V4 (i11 * c5 - i12 * c4 + i13 * c3)
(-i01 * c5 + i02 * c4 - i03 * c3)
(i31 * s5 - i32 * s4 + i33 * s3)
(-i21 * s5 + i22 * s4 - i23 * s3))
(V4 (-i10 * c5 + i12 * c2 - i13 * c1)
(i00 * c5 - i02 * c2 + i03 * c1)
(-i30 * s5 + i32 * s2 - i33 * s1)
(i20 * s5 - i22 * s2 + i23 * s1))
(V4 (i10 * c4 - i11 * c2 + i13 * c0)
(-i00 * c4 + i01 * c2 - i03 * c0)
(i30 * s4 - i31 * s2 + i33 * s0)
(-i20 * s4 + i21 * s2 - i23 * s0))
(V4 (-i10 * c3 + i11 * c1 - i12 * c0)
(i00 * c3 - i01 * c1 + i02 * c0)
(-i30 * s3 + i31 * s1 - i32 * s0)
(i20 * s3 - i21 * s1 + i22 * s0))
{-# INLINE inv44 #-}
#if MIN_VERSION_base(4,8,0)
lu :: ( Num a
, Fractional a
, Foldable m
, Traversable m
, Applicative m
, Additive m
, Ixed (m a)
, Ixed (m (m a))
, i ~ Index (m a)
, i ~ Index (m (m a))
, Eq i
, Integral i
, a ~ IxValue (m a)
, m a ~ IxValue (m (m a))
, Num (m a)
)
=> m (m a)
-> (m (m a), m (m a))
lu a =
let n = fromIntegral (length a)
initU = identity
initL = zero
buildLVal !i !j !l !u =
let go !k !s
| k == j = s
| otherwise = go (k+1)
( s
+ ( (l ^?! ix i ^?! ix k)
* (u ^?! ix k ^?! ix j)
)
)
s' = go 0 0
in l & (ix i . ix j) .~ ((a ^?! ix i ^?! ix j) - s')
buildL !i !j !l !u
| i == n = l
| otherwise = buildL (i+1) j (buildLVal i j l u) u
buildUVal !i !j !l !u =
let go !k !s
| k == j = s
| otherwise = go (k+1)
( s
+ ( (l ^?! ix j ^?! ix k)
* (u ^?! ix k ^?! ix i)
)
)
s' = go 0 0
in u & (ix j . ix i) .~ ( ((a ^?! ix j ^?! ix i) - s')
/ (l ^?! ix j ^?! ix j)
)
buildU !i !j !l !u
| i == n = u
| otherwise = buildU (i+1) j l (buildUVal i j l u)
buildLU !j !l !u
| j == n = (l, u)
| otherwise =
let l' = buildL j j l u
u' = buildU j j l' u
in buildLU (j+1) l' u'
in buildLU 0 initL initU
luFinite :: ( Num a
, Fractional a
, Functor m
, Finite m
, n ~ Size m
, KnownNat n
, Num (m a)
)
=> m (m a)
-> (m (m a), m (m a))
luFinite a =
bimap (fmap fromV . fromV)
(fmap fromV . fromV)
(lu (fmap toV (toV a)))
forwardSub :: ( Num a
, Fractional a
, Foldable m
, Additive m
, Ixed (m a)
, Ixed (m (m a))
, i ~ Index (m a)
, i ~ Index (m (m a))
, Eq i
, Ord i
, Integral i
, a ~ IxValue (m a)
, m a ~ IxValue (m (m a))
)
=> m (m a)
-> m a
-> m a
forwardSub a b =
let n = fromIntegral (length b)
initX = zero
coeff !i !j !s !x
| j == i = s
| otherwise = coeff i (j+1) (s + ((a ^?! ix i ^?! ix j) * (x ^?! ix j))) x
go !i !x
| i == n = x
| otherwise = go (i + 1) (x & ix i .~ ( ((b ^?! ix i) - coeff i 0 0 x)
/ (a ^?! ix i ^?! ix i)
))
in go 0 initX
forwardSubFinite :: ( Num a
, Fractional a
, Foldable m
, n ~ Size m
, KnownNat n
, Additive m
, Finite m
)
=> m (m a)
-> m a
-> m a
forwardSubFinite a b = fromV (forwardSub (fmap toV (toV a)) (toV b))
backwardSub :: ( Num a
, Fractional a
, Foldable m
, Additive m
, Ixed (m a)
, Ixed (m (m a))
, i ~ Index (m a)
, i ~ Index (m (m a))
, Eq i
, Ord i
, Integral i
, a ~ IxValue (m a)
, m a ~ IxValue (m (m a))
)
=> m (m a)
-> m a
-> m a
backwardSub a b =
let n = fromIntegral (length b)
initX = zero
coeff !i !j !s !x
| j == n = s
| otherwise = coeff i
(j+1)
(s + ((a ^?! ix i ^?! ix j) * (x ^?! ix j)))
x
go !i !x
| i < 0 = x
| otherwise = go (i-1)
(x & ix i .~ ( ((b ^?! ix i) - coeff i (i+1) 0 x)
/ (a ^?! ix i ^?! ix i)
))
in go (n-1) initX
backwardSubFinite :: ( Num a
, Fractional a
, Foldable m
, n ~ Size m
, KnownNat n
, Additive m
, Finite m
)
=> m (m a)
-> m a
-> m a
backwardSubFinite a b = fromV (backwardSub (fmap toV (toV a)) (toV b))
luSolve :: ( Num a
, Fractional a
, Foldable m
, Traversable m
, Applicative m
, Additive m
, Ixed (m a)
, Ixed (m (m a))
, i ~ Index (m a)
, i ~ Index (m (m a))
, Eq i
, Integral i
, a ~ IxValue (m a)
, m a ~ IxValue (m (m a))
, Num (m a)
)
=> m (m a)
-> m a
-> m a
luSolve a b =
let (l, u) = lu a
in backwardSub u (forwardSub l b)
luSolveFinite :: ( Num a
, Fractional a
, Functor m
, Finite m
, n ~ Size m
, KnownNat n
, Num (m a)
)
=> m (m a)
-> m a
-> m a
luSolveFinite a b = fromV (luSolve (fmap toV (toV a)) (toV b))
luInv :: ( Num a
, Fractional a
, Foldable m
, Traversable m
, Applicative m
, Additive m
, Distributive m
, Ixed (m a)
, Ixed (m (m a))
, i ~ Index (m a)
, i ~ Index (m (m a))
, Eq i
, Integral i
, a ~ IxValue (m a)
, m a ~ IxValue (m (m a))
, Num (m a)
)
=> m (m a)
-> m (m a)
luInv a =
let n = fromIntegral (length a)
initA' = zero
(l, u) = lu a
go !i !a'
| i == n = a'
| otherwise = let e = zero & ix i .~ 1
a'r = backwardSub u (forwardSub l e)
in go (i+1) (a' & ix i .~ a'r)
in transpose (go 0 initA')
luInvFinite :: ( Num a
, Fractional a
, Functor m
, Finite m
, n ~ Size m
, KnownNat n
, Num (m a)
)
=> m (m a)
-> m (m a)
luInvFinite a = fmap fromV (fromV (luInv (fmap toV (toV a))))
luDet :: ( Num a
, Fractional a
, Foldable m
, Traversable m
, Applicative m
, Additive m
, Trace m
, Ixed (m a)
, Ixed (m (m a))
, i ~ Index (m a)
, i ~ Index (m (m a))
, Eq i
, Integral i
, a ~ IxValue (m a)
, m a ~ IxValue (m (m a))
, Num (m a)
)
=> m (m a)
-> a
luDet a =
let (l, u) = lu a
p = Foldable.foldl (*) 1
in (p (diagonal l)) * (p (diagonal u))
luDetFinite :: ( Num a
, Fractional a
, Functor m
, Finite m
, n ~ Size m
, KnownNat n
, Num (m a)
)
=> m (m a)
-> a
luDetFinite = luDet . fmap toV . toV
#endif