module Data.Array.Accelerate.LinearAlgebra.Matrix.Banded (
Symmetric(..),
flattenSymmetric,
) where
import Data.Array.Accelerate.LinearAlgebra (Matrix, matrixShape)
import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp
import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate.Utility.Lift.Exp (expr)
import Data.Array.Accelerate ((:.)((:.)), (!), (?))
newtype Symmetric ix a = Symmetric (Matrix ix a)
flattenSymmetric ::
(A.Slice ix, A.Shape ix, A.Num a) =>
Symmetric ix a -> Matrix ix a
flattenSymmetric (Symmetric m) =
case matrixShape m of
(sh :. rows :. width) ->
A.generate (A.lift $ sh :. rows :. rows) $
Exp.modify (expr:.expr:.expr) $ \(ix:.k0:.j0) ->
let k = min k0 j0
j = max k0 j0 k
in width A.> j ? (m ! A.lift(ix:.k:.j), 0)