{-# LANGUAGE FlexibleContexts #-}
module Data.Array.Accelerate.LinearAlgebra.Matrix.Banded (
   Symmetric(..),
   flattenSymmetric,
   ) where

import Data.Array.Accelerate.LinearAlgebra (Matrix, matrixShape)

import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp
import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate.Utility.Lift.Exp (expr)
import Data.Array.Accelerate ((:.)((:.)), (!), (?))


newtype Symmetric ix a = Symmetric (Matrix ix a)

flattenSymmetric ::
   (A.Slice ix, A.Shape ix, A.Num a) =>
   Symmetric ix a -> Matrix ix a
flattenSymmetric (Symmetric m) =
   case matrixShape m of
      (sh :. rows :. width) ->
         A.generate (A.lift $ sh :. rows :. rows) $
         Exp.modify (expr:.expr:.expr) $ \(ix:.k0:.j0) ->
            let k = min k0 j0
                j = max k0 j0 - k
            in  width A.> j ? (m ! A.lift(ix:.k:.j), 0)