module Data.Array.Accelerate.Convolution.Preprocessed (
Transform2,
karatsuba,
) where
import Data.Array.Accelerate.Convolution.Private (Transform2, indexPad, )
import qualified Data.Array.Accelerate.Utility.Sliced as Sliced
import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp
import Data.Array.Accelerate.Utility.Lift.Exp (expr)
import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate ((:.)((:.)), Any(Any), All(All), Slice, Shape, )
karatsuba ::
(Shape sh, Slice sh, A.Num a) =>
Int -> Transform2 (sh :. Int) a
karatsuba len x y =
if len <= 1
then A.zipWith (*) x y
else
let len2 = div (len) 2
elen2 = A.constant len2
xl = Sliced.take elen2 x
yl = Sliced.take elen2 y
xr = Sliced.pad 0 elen2 $ Sliced.drop elen2 x
yr = Sliced.pad 0 elen2 $ Sliced.drop elen2 y
zmerged =
karatsuba len2
(Sliced.stack3 xl (A.zipWith (+) xl xr) xr)
(Sliced.stack3 yl (A.zipWith (+) yl yr) yr)
zl = A.slice zmerged $ A.lift $ Any :. (0::Int) :. All
zm = A.slice zmerged $ A.lift $ Any :. (1::Int) :. All
zr = A.slice zmerged $ A.lift $ Any :. (2::Int) :. All
zc = A.zipWith () zm $ A.zipWith (+) zl zr
sh = A.indexTail $ A.shape zc
in A.generate (A.lift $ sh :. 2*len1) $
Exp.modify (expr:.expr) $
\(ix:.k) ->
indexPad (ix:.k) zl +
indexPad (ix:.kelen2) zc +
indexPad (ix:.kelen2*2) zr