{-# LANGUAGE BangPatterns, MagicHash #-}
module Data.Array.Repa.Eval.Reduction
( foldS, foldP
, foldAllS, foldAllP)
where
import Data.Array.Repa.Eval.Gang
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as M
import GHC.Base ( quotInt, divInt )
import GHC.Exts
foldS :: V.Unbox a
=> M.IOVector a
-> (Int# -> a)
-> (a -> a -> a)
-> a
-> Int#
-> IO ()
{-# INLINE [1] foldS #-}
foldS :: IOVector a -> (Int# -> a) -> (a -> a -> a) -> a -> Int# -> IO ()
foldS !IOVector a
vec Int# -> a
get a -> a -> a
c !a
r !Int#
n
= Int# -> Int# -> IO ()
iter Int#
0# Int#
0#
where
!(I# Int#
end) = IOVector a -> Int
forall a s. Unbox a => MVector s a -> Int
M.length IOVector a
vec
{-# INLINE iter #-}
iter :: Int# -> Int# -> IO ()
iter !Int#
sh !Int#
sz
| Int#
1# <- Int#
sh Int# -> Int# -> Int#
>=# Int#
end
= () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
| Bool
otherwise
= do let !next :: Int#
next = Int#
sz Int# -> Int# -> Int#
+# Int#
n
MVector (PrimState IO) a -> Int -> a -> IO ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite IOVector a
MVector (PrimState IO) a
vec (Int# -> Int
I# Int#
sh) ((Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
forall a. (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
reduceAny Int# -> a
get a -> a -> a
c a
r Int#
sz Int#
next)
Int# -> Int# -> IO ()
iter (Int#
sh Int# -> Int# -> Int#
+# Int#
1#) Int#
next
foldP :: V.Unbox a
=> M.IOVector a
-> (Int -> a)
-> (a -> a -> a)
-> a
-> Int
-> IO ()
{-# INLINE [1] foldP #-}
foldP :: IOVector a -> (Int -> a) -> (a -> a -> a) -> a -> Int -> IO ()
foldP IOVector a
vec Int -> a
f a -> a -> a
c !a
r (I# Int#
n)
= Gang -> (Int -> IO ()) -> IO ()
gangIO Gang
theGang
((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(I# Int#
tid) -> Int# -> Int# -> IO ()
fill (Int# -> Int#
split Int#
tid) (Int# -> Int#
split (Int#
tid Int# -> Int# -> Int#
+# Int#
1#))
where
!(I# Int#
threads) = Gang -> Int
gangSize Gang
theGang
!(I# Int#
len) = IOVector a -> Int
forall a s. Unbox a => MVector s a -> Int
M.length IOVector a
vec
!step :: Int#
step = (Int#
len Int# -> Int# -> Int#
+# Int#
threads Int# -> Int# -> Int#
-# Int#
1#) Int# -> Int# -> Int#
`quotInt#` Int#
threads
{-# INLINE split #-}
split :: Int# -> Int#
split !Int#
ix
= let !ix' :: Int#
ix' = Int#
ix Int# -> Int# -> Int#
*# Int#
step
in case Int#
len Int# -> Int# -> Int#
<# Int#
ix' of
Int#
0# -> Int#
ix'
Int#
_ -> Int#
len
{-# INLINE fill #-}
fill :: Int# -> Int# -> IO ()
fill !Int#
start !Int#
end
= Int# -> Int# -> IO ()
iter Int#
start (Int#
start Int# -> Int# -> Int#
*# Int#
n)
where
{-# INLINE iter #-}
iter :: Int# -> Int# -> IO ()
iter !Int#
sh !Int#
sz
| Int#
1# <- Int#
sh Int# -> Int# -> Int#
>=# Int#
end
= () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
| Bool
otherwise
= do let !next :: Int#
next = Int#
sz Int# -> Int# -> Int#
+# Int#
n
MVector (PrimState IO) a -> Int -> a -> IO ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite IOVector a
MVector (PrimState IO) a
vec (Int# -> Int
I# Int#
sh) ((Int -> a) -> (a -> a -> a) -> a -> Int -> Int -> a
forall a. (Int -> a) -> (a -> a -> a) -> a -> Int -> Int -> a
reduce Int -> a
f a -> a -> a
c a
r (Int# -> Int
I# Int#
sz) (Int# -> Int
I# Int#
next))
Int# -> Int# -> IO ()
iter (Int#
sh Int# -> Int# -> Int#
+# Int#
1#) Int#
next
foldAllS :: (Int# -> a)
-> (a -> a -> a)
-> a
-> Int#
-> a
{-# INLINE [1] foldAllS #-}
foldAllS :: (Int# -> a) -> (a -> a -> a) -> a -> Int# -> a
foldAllS Int# -> a
f a -> a -> a
c !a
r !Int#
len
= (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
forall a. (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
reduceAny (\Int#
i -> Int# -> a
f Int#
i) a -> a -> a
c a
r Int#
0# Int#
len
foldAllP :: V.Unbox a
=> (Int -> a)
-> (a -> a -> a)
-> a
-> Int
-> IO a
{-# INLINE [1] foldAllP #-}
foldAllP :: (Int -> a) -> (a -> a -> a) -> a -> Int -> IO a
foldAllP Int -> a
f a -> a -> a
c !a
r !Int
len
| Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r
| Bool
otherwise = do
MVector RealWorld a
mvec <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
M.unsafeNew Int
chunks
Gang -> (Int -> IO ()) -> IO ()
gangIO Gang
theGang ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int
tid -> MVector RealWorld a -> Int -> Int -> Int -> IO ()
fill MVector RealWorld a
mvec Int
tid (Int -> Int
split Int
tid) (Int -> Int
split (Int
tidInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))
Vector a
vec <- MVector (PrimState IO) a -> IO (Vector a)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze MVector RealWorld a
MVector (PrimState IO) a
mvec
a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> IO a) -> a -> IO a
forall a b. (a -> b) -> a -> b
$! (a -> a -> a) -> a -> Vector a -> a
forall b a. Unbox b => (a -> b -> a) -> a -> Vector b -> a
V.foldl' a -> a -> a
c a
r Vector a
vec
where
!threads :: Int
threads = Gang -> Int
gangSize Gang
theGang
!step :: Int
step = (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
threads Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
`quotInt` Int
threads
chunks :: Int
chunks = ((Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
step Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
`divInt` Int
step) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
threads
{-# INLINE split #-}
split :: Int -> Int
split !Int
ix = Int
len Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
step)
{-# INLINE fill #-}
fill :: MVector RealWorld a -> Int -> Int -> Int -> IO ()
fill !MVector RealWorld a
mvec !Int
tid !Int
start !Int
end
| Int
start Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
end = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
| Bool
otherwise = MVector (PrimState IO) a -> Int -> a -> IO ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector RealWorld a
MVector (PrimState IO) a
mvec Int
tid ((Int -> a) -> (a -> a -> a) -> a -> Int -> Int -> a
forall a. (Int -> a) -> (a -> a -> a) -> a -> Int -> Int -> a
reduce Int -> a
f a -> a -> a
c (Int -> a
f Int
start) (Int
startInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
end)
{-# INLINE [0] reduce #-}
reduce :: (Int -> a)
-> (a -> a -> a)
-> a
-> Int
-> Int
-> a
reduce :: (Int -> a) -> (a -> a -> a) -> a -> Int -> Int -> a
reduce Int -> a
f a -> a -> a
c !a
r (I# Int#
start) (I# Int#
end)
= (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
forall a. (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
reduceAny (\Int#
i -> Int -> a
f (Int# -> Int
I# Int#
i)) a -> a -> a
c a
r Int#
start Int#
end
{-# INLINE [0] reduceAny #-}
reduceAny :: (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
reduceAny :: (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
reduceAny Int# -> a
f a -> a -> a
c !a
r !Int#
start !Int#
end
= Int# -> a -> a
iter Int#
start a
r
where
{-# INLINE iter #-}
iter :: Int# -> a -> a
iter !Int#
i !a
z
| Int#
1# <- Int#
i Int# -> Int# -> Int#
>=# Int#
end = a
z
| Bool
otherwise = Int# -> a -> a
iter (Int#
i Int# -> Int# -> Int#
+# Int#
1#) (a
z a -> a -> a
`c` Int# -> a
f Int#
i)
{-# INLINE [0] reduceInt #-}
reduceInt
:: (Int# -> Int#)
-> (Int# -> Int# -> Int#)
-> Int#
-> Int# -> Int#
-> Int#
reduceInt :: (Int# -> Int#)
-> (Int# -> Int# -> Int#) -> Int# -> Int# -> Int# -> Int#
reduceInt Int# -> Int#
f Int# -> Int# -> Int#
c !Int#
r !Int#
start !Int#
end
= Int# -> Int# -> Int#
iter Int#
start Int#
r
where
{-# INLINE iter #-}
iter :: Int# -> Int# -> Int#
iter !Int#
i !Int#
z
| Int#
1# <- Int#
i Int# -> Int# -> Int#
>=# Int#
end = Int#
z
| Bool
otherwise = Int# -> Int# -> Int#
iter (Int#
i Int# -> Int# -> Int#
+# Int#
1#) (Int#
z Int# -> Int# -> Int#
`c` Int# -> Int#
f Int#
i)
{-# INLINE [0] reduceFloat #-}
reduceFloat
:: (Int# -> Float#)
-> (Float# -> Float# -> Float#)
-> Float#
-> Int# -> Int#
-> Float#
reduceFloat :: (Int# -> Float#)
-> (Float# -> Float# -> Float#) -> Float# -> Int# -> Int# -> Float#
reduceFloat Int# -> Float#
f Float# -> Float# -> Float#
c !Float#
r !Int#
start !Int#
end
= Int# -> Float# -> Float#
iter Int#
start Float#
r
where
{-# INLINE iter #-}
iter :: Int# -> Float# -> Float#
iter !Int#
i !Float#
z
| Int#
1# <- Int#
i Int# -> Int# -> Int#
>=# Int#
end = Float#
z
| Bool
otherwise = Int# -> Float# -> Float#
iter (Int#
i Int# -> Int# -> Int#
+# Int#
1#) (Float#
z Float# -> Float# -> Float#
`c` Int# -> Float#
f Int#
i)
{-# INLINE [0] reduceDouble #-}
reduceDouble
:: (Int# -> Double#)
-> (Double# -> Double# -> Double#)
-> Double#
-> Int# -> Int#
-> Double#
reduceDouble :: (Int# -> Double#)
-> (Double# -> Double# -> Double#)
-> Double#
-> Int#
-> Int#
-> Double#
reduceDouble Int# -> Double#
f Double# -> Double# -> Double#
c !Double#
r !Int#
start !Int#
end
= Int# -> Double# -> Double#
iter Int#
start Double#
r
where
{-# INLINE iter #-}
iter :: Int# -> Double# -> Double#
iter !Int#
i !Double#
z
| Int#
1# <- Int#
i Int# -> Int# -> Int#
>=# Int#
end = Double#
z
| Bool
otherwise = Int# -> Double# -> Double#
iter (Int#
i Int# -> Int# -> Int#
+# Int#
1#) (Double#
z Double# -> Double# -> Double#
`c` Int# -> Double#
f Int#
i)
{-# INLINE unboxInt #-}
unboxInt :: Int -> Int#
unboxInt :: Int -> Int#
unboxInt (I# Int#
i) = Int#
i
{-# INLINE unboxFloat #-}
unboxFloat :: Float -> Float#
unboxFloat :: Float -> Float#
unboxFloat (F# Float#
f) = Float#
f
{-# INLINE unboxDouble #-}
unboxDouble :: Double -> Double#
unboxDouble :: Double -> Double#
unboxDouble (D# Double#
d) = Double#
d
{-# RULES "reduceInt"
forall (get :: Int# -> Int) f r start end
. reduceAny get f r start end
= I# (reduceInt
(\i -> unboxInt (get i))
(\d1 d2 -> unboxInt (f (I# d1) (I# d2)))
(unboxInt r)
start
end)
#-}
{-# RULES "reduceFloat"
forall (get :: Int# -> Float) f r start end
. reduceAny get f r start end
= F# (reduceFloat
(\i -> unboxFloat (get i))
(\d1 d2 -> unboxFloat (f (F# d1) (F# d2)))
(unboxFloat r)
start
end)
#-}
{-# RULES "reduceDouble"
forall (get :: Int# -> Double) f r start end
. reduceAny get f r start end
= D# (reduceDouble
(\i -> unboxDouble (get i))
(\d1 d2 -> unboxDouble (f (D# d1) (D# d2)))
(unboxDouble r)
start
end)
#-}