module Sparse.Matrix
(
Mat(..)
, Key(..)
, Sparse.Matrix.fromList
, Sparse.Matrix.singleton
, transpose
, ident
, empty
, size
, null
, Eq0(..)
, addWith
, multiplyWith
, Arrayed(..)
, _Mat, keys, values
) where
import Control.Applicative hiding (empty)
import Control.Arrow
import Control.DeepSeq
import Control.Lens
import Data.Bits
import Data.Complex
import Data.Function (on)
import qualified Data.Vector as V
import qualified Data.Vector.Algorithms.Insertion as Sort
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Hybrid as H
import qualified Data.Vector.Hybrid.Internal as H
import qualified Data.Vector.Unboxed as U
import Data.Vector.Fusion.Stream (Stream)
import Data.Word
import Prelude hiding (head, last, null)
import Sparse.Matrix.Internal.Fusion as Fusion
import Sparse.Matrix.Internal.Key
import Sparse.Matrix.Internal.Array as I
import Sparse.Matrix.Internal.Heap as Heap hiding (head)
import Text.Read
class (Arrayed a, Num a) => Eq0 a where
isZero :: a -> Bool
#ifndef HLINT
default isZero :: (Num a, Eq a) => a -> Bool
isZero = (0 ==)
#endif
nonZero :: (x -> y -> a) -> x -> y -> Maybe a
nonZero f a b = case f a b of
c | isZero c -> Nothing
| otherwise -> Just c
addMats :: Mat a -> Mat a -> Mat a
addMats = addWith0 $ nonZero (+)
addHeap :: Maybe (Heap a) -> Stream (Key, a)
addHeap = Heap.streamHeapWith0 $ nonZero (+)
instance Eq0 Int
instance Eq0 Word
instance Eq0 Integer
instance Eq0 Float
instance Eq0 Double
instance (RealFloat a, Eq0 a) => Eq0 (Complex a) where
isZero (a :+ b) = isZero a && isZero b
data Mat a = Mat !Int !(U.Vector Word) !(U.Vector Word) !(I.Array a)
deriving instance (Arrayed a, Eq (I.Array a)) => Eq (Mat a)
deriving instance (Arrayed a, Ord (I.Array a)) => Ord (Mat a)
instance (Arrayed a, Show a) => Show (Mat a) where
showsPrec d m = G.showsPrec d (m^._Mat)
instance (Arrayed a, Read a) => Read (Mat a) where
readPrec = (_Mat # ) <$> G.readPrec
instance NFData (I.Array a) => NFData (Mat a) where
rnf (Mat _ xs ys vs) = rnf xs `seq` rnf ys `seq` rnf vs `seq` ()
_Mat :: Arrayed a => Iso' (Mat a) (H.Vector U.Vector (Arr a) (Key, a))
_Mat = iso (\(Mat n xs ys vs) -> H.V (V_Key n xs ys) vs)
(\(H.V (V_Key n xs ys) vs) -> Mat n xs ys vs)
keys :: Lens' (Mat a) (U.Vector Key)
keys f (Mat n xs ys vs) = f (V_Key n xs ys) <&> \ (V_Key n' xs' ys') -> Mat n' xs' ys' vs
values :: Lens (Mat a) (Mat b) (I.Array a) (I.Array b)
values f (Mat n xs ys vs) = Mat n xs ys <$> f vs
type instance IxValue (Mat a) = a
type instance Index (Mat a) = Key
eachV :: (Applicative f, G.Vector v a, G.Vector v b) => (a -> f b) -> v a -> f (v b)
eachV f v = G.fromListN (G.length v) <$> traverse f (G.toList v)
instance (Arrayed a, a ~ b) => Each (Mat a) (Mat b) a b where
each f = _Mat $ eachV $ \(k,v) -> (,) k <$> f v
instance Arrayed a => Ixed (Mat a) where
ix ij@(Key i j) f m@(Mat n xs ys vs)
| Just i' <- xs U.!? l, i == i'
, Just j' <- ys U.!? l, j == j' = f (vs G.! l) <&> \v -> Mat n xs ys (vs G.// [(l,v)])
| otherwise = pure m
where l = search (\k -> Key (xs U.! k) (ys U.! k) >= ij) 0 n
instance Arrayed a => Arrayed (Mat a) where
type Arr (Mat a) = V.Vector
instance (Arrayed a, Eq0 a) => Eq0 (Mat a) where
isZero (Mat n _ _ _) = n == 0
fromList :: Arrayed a => [(Key, a)] -> Mat a
fromList xs = _Mat # H.modify (Sort.sortBy (compare `on` fst)) (H.fromList xs)
transpose :: Arrayed a => Mat a -> Mat a
transpose xs = xs & _Mat %~ H.modify (Sort.sortBy (compare `on` fst)) . H.map (first swap)
singleton :: Arrayed a => Key -> a -> Mat a
singleton k v = _Mat # H.singleton (k,v)
ident :: (Arrayed a, Num a) => Int -> Mat a
ident w = Mat w (U.generate w fromIntegral) (U.generate w fromIntegral) (G.replicate w 1)
empty :: Arrayed a => Mat a
empty = Mat 0 U.empty U.empty G.empty
size :: Mat a -> Int
size (Mat n _ _ _) = n
null :: Mat a -> Bool
null (Mat n _ _ _) = n == 0
instance (Arrayed a, Eq0 a) => Num (Mat a) where
abs = over each abs
signum = over each signum
negate = over each negate
fromInteger 0 = empty
fromInteger _ = error "Mat: fromInteger n"
(+) = addMats
() = addWith0 $ nonZero ()
(*) = multiplyWith (*) addHeap
search :: (Int -> Bool) -> Int -> Int -> Int
search p = go where
go l h
| l == h = l
| p m = go l m
| otherwise = go (m+1) h
where m = l + div (hl) 2
split1 :: Arrayed a => Word -> Word -> Mat a -> (Mat a, Mat a)
split1 ai bi (Mat n xs ys vs) = (m0,m1)
where
!aibi = xor ai bi
!k = search (\l -> xor (xs U.! l) bi `lts` aibi) 0 n
(xs0,xs1) = U.splitAt k xs
(ys0,ys1) = U.splitAt k ys
(vs0,vs1) = G.splitAt k vs
!m0 = Mat k xs0 ys0 vs0
!m1 = Mat (nk) xs1 ys1 vs1
split2 :: Arrayed a => Word -> Word -> Mat a -> (Mat a, Mat a)
split2 aj bj (Mat n xs ys vs) = (m0,m1)
where
!ajbj = xor aj bj
!k = search (\l -> xor (ys U.! l) bj `lts` ajbj) 0 n
(xs0,xs1) = U.splitAt k xs
(ys0,ys1) = U.splitAt k ys
(vs0,vs1) = G.splitAt k vs
!m0 = Mat k xs0 ys0 vs0
!m1 = Mat (nk) xs1 ys1 vs1
addWith :: Arrayed a => (a -> a -> a) -> Mat a -> Mat a -> Mat a
addWith f xs ys = _Mat # G.unstream (mergeStreamsWith f (G.stream (xs^._Mat)) (G.stream (ys^._Mat)))
addWith0 :: Arrayed a => (a -> a -> Maybe a) -> Mat a -> Mat a -> Mat a
addWith0 f xs ys = _Mat # G.unstream (mergeStreamsWith0 f (G.stream (xs^._Mat)) (G.stream (ys^._Mat)))
multiplyWith :: Arrayed a => (a -> a -> a) -> (Maybe (Heap a) -> Stream (Key, a)) -> Mat a -> Mat a -> Mat a
multiplyWith times make x0 y0 = case compare (size x0) 1 of
LT -> empty
EQ | size y0 == 0 -> empty
| size y0 == 1 -> unhinted $ go11 (lo x0) (head x0) (lo y0) (head y0)
| otherwise -> unhinted $ go12 (lo x0) (head x0) (lo y0) y0 (hi y0)
GT -> case compare (size y0) 1 of
LT -> empty
EQ -> unhinted $ go21 (lo x0) x0 (hi x0) (lo y0) (head y0)
GT -> unhinted $ go22 (lo x0) x0 (hi x0) (lo y0) y0 (hi y0)
where
unhinted x = _Mat # G.unstream (make x)
go11 (Key i j) a (Key j' k) b
| j == j' = Just $ Heap.singleton (Key i k) (times a b)
| otherwise = Nothing
go22L0 xa x ya y yb
| size x == 1 = go12 xa (head x) ya y yb
| otherwise = go22 xa x (hi x) ya y yb
go22L1 x xb ya y yb
| size x == 1 = go12 xb (head x) ya y yb
| otherwise = go22 (lo x) x xb ya y yb
go22R0 xa x xb ya y
| size y == 1 = go21 xa x xb ya (head y)
| otherwise = go22 xa x xb ya y (hi y)
go22R1 xa x xb y yb
| size y == 1 = go21 xa x xb yb (head y)
| otherwise = go22 xa x xb (lo y) y yb
go22 xa@(Key xai xaj) x xb@(Key xbi xbj) ya@(Key yaj yak) y yb@(Key ybj ybk)
| gts (xor xaj yaj) (xiyj .|. ykxj) = Nothing
| ges xiyj ykxj
= if ges xi yj then case split1 xai xbi x of (m0,m1) -> go22L0 xa m0 ya y yb `mfby` go22L1 m1 xb ya y yb
else case split1 yaj ybj y of (m0,m1) -> go22R0 xa x xb ya m0 `madd` go22R1 xa x xb m1 yb
| ges yk xj = case split2 yak ybk y of (m0,m1) -> go22R0 xa x xb ya m0 `mfby` go22R1 xa x xb m1 yb
| otherwise = case split2 xaj xbj x of (m0,m1) -> go22L0 xa m0 ya y yb `madd` go22L1 m1 xb ya y yb
where
xi = xor xai xbi
xj = xor xaj xbj
yj = xor yaj ybj
yk = xor yak ybk
xiyj = xi .|. yj
ykxj = yk .|. xj
go21L0 xa x yb b
| size x == 1 = go11 xa (head x) yb b
| otherwise = go21 xa x (hi x) yb b
go21L1 x xb yb b
| size x == 1 = go11 xb (head x) yb b
| otherwise = go21 (lo x) x xb yb b
go21 xa@(Key xai xaj) x xb@(Key xbi xbj) yb@(Key ybj _ybk) b
| gts (xor xaj ybj) (xi.|.xj) = Nothing
| ges xi xj = case split1 xai xbi x of (m0,m1) -> go21L0 xa m0 yb b `mfby` go21L1 m1 xb yb b
| otherwise = case split2 xaj xbj x of (m0,m1) -> go21L0 xa m0 yb b `madd` go21L1 m1 xb yb b
where
xi = xor xai xbi
xj = xor xaj xbj
go12R0 xa a ya y
| size y == 1 = go11 xa a ya (head y)
| otherwise = go12 xa a ya y (hi y)
go12R1 xa a y yb
| size y == 1 = go11 xa a yb (head y)
| otherwise = go12 xa a (lo y) y yb
go12 xa@(Key _xai xaj) a ya@(Key yaj yak) y yb@(Key ybj ybk)
| gts (xor xaj yaj) (yj.|.yk) = Nothing
| ges yj yk = case split1 yaj ybj y of (m0,m1) -> go12R0 xa a ya m0 `madd` go12R1 xa a m1 yb
| otherwise = case split2 yak ybk y of (m0,m1) -> go12R0 xa a ya m0 `mfby` go12R1 xa a m1 yb
where
yj = xor yaj ybj
yk = xor yak ybk
madd Nothing xs = xs
madd xs Nothing = xs
madd (Just x) (Just y) = Just (mix x y)
mfby Nothing xs = xs
mfby xs Nothing = xs
mfby (Just x) (Just y) = Just (fby x y)
lo (Mat _ xs ys _) = Key (U.head xs) (U.head ys)
hi (Mat _ xs ys _) = Key (U.last xs) (U.last ys)
head (Mat _ _ _ vs) = G.head vs