#include "fusion-phases.h"
module Data.Array.Parallel.Unlifted.Distributed.Data.Tuple
(
zipD, unzipD, fstD, sndD
, zip3D, unzip3D)
where
import Data.Array.Parallel.Unlifted.Distributed.Primitive.DT
import Data.Array.Parallel.Base
import Data.Array.Parallel.Pretty
import Control.Monad
here :: String -> String
here s = "Data.Array.Parallel.Unlifted.Distributed.Types.Tuple." ++ s
instance (DT a, DT b) => DT (a,b) where
data Dist (a,b) = DProd !(Dist a) !(Dist b)
data MDist (a,b) s = MDProd !(MDist a s) !(MDist b s)
indexD str d i
= ( indexD (str ++ "/indexD[Tuple2]") (fstD d) i
, indexD (str ++ "/indexD[Tuple2]") (sndD d) i)
newMD g
= liftM2 MDProd (newMD g) (newMD g)
readMD (MDProd xs ys) i
= liftM2 (,) (readMD xs i) (readMD ys i)
writeMD (MDProd xs ys) i (x,y)
= do writeMD xs i x
writeMD ys i y
unsafeFreezeMD (MDProd xs ys)
= liftM2 DProd (unsafeFreezeMD xs)
(unsafeFreezeMD ys)
deepSeqD (x, y) z
= deepSeqD x (deepSeqD y z)
sizeD (DProd x _)
= sizeD x
sizeMD (MDProd x _)
= sizeMD x
measureD (x, y)
= "Pair " ++ "(" ++ measureD x ++ ") (" ++ measureD y ++ ")"
instance (PprPhysical (Dist a), PprPhysical (Dist b))
=> PprPhysical (Dist (a, b)) where
pprp (DProd xs ys)
= text "DProd"
$$ (nest 8 $ vcat
[ pprp xs
, pprp ys ])
zipD :: (DT a, DT b) => Dist a -> Dist b -> Dist (a,b)
zipD !x !y
= checkEq (here "zipDT") "Size mismatch" (sizeD x) (sizeD y)
$ DProd x y
unzipD :: (DT a, DT b) => Dist (a,b) -> (Dist a, Dist b)
unzipD (DProd dx dy) = (dx,dy)
fstD :: (DT a, DT b) => Dist (a,b) -> Dist a
fstD = fst . unzipD
sndD :: (DT a, DT b) => Dist (a,b) -> Dist b
sndD = snd . unzipD
instance (DT a, DT b, DT c) => DT (a,b,c) where
data Dist (a,b,c) = DProd3 !(Dist a) !(Dist b) !(Dist c)
data MDist (a,b,c) s = MDProd3 !(MDist a s) !(MDist b s) !(MDist c s)
indexD str (DProd3 xs ys zs) i
= ( indexD (here $ "indexD[Tuple3]/" ++ str) xs i
, indexD (here $ "indexD[Tuple3]/" ++ str) ys i
, indexD (here $ "indexD[Tuple3]/" ++ str) zs i)
newMD g
= liftM3 MDProd3 (newMD g) (newMD g) (newMD g)
readMD (MDProd3 xs ys zs) i
= liftM3 (,,) (readMD xs i) (readMD ys i) (readMD zs i)
writeMD (MDProd3 xs ys zs) i (x,y,z)
= do writeMD xs i x
writeMD ys i y
writeMD zs i z
unsafeFreezeMD (MDProd3 xs ys zs)
= liftM3 DProd3 (unsafeFreezeMD xs)
(unsafeFreezeMD ys)
(unsafeFreezeMD zs)
deepSeqD (x,y,z) k
= deepSeqD x (deepSeqD y (deepSeqD z k))
sizeD (DProd3 x _ _)
= sizeD x
sizeMD (MDProd3 x _ _)
= sizeMD x
measureD (x,y,z)
= "Triple "
++ "(" ++ measureD x ++ ") "
++ "(" ++ measureD y ++ ") "
++ "(" ++ measureD z ++ ")"
zip3D :: (DT a, DT b, DT c) => Dist a -> Dist b -> Dist c -> Dist (a,b,c)
zip3D !x !y !z
= checkEq (here "zip3DT") "Size mismatch" (sizeD x) (sizeD y)
$ checkEq (here "zip3DT") "Size mismatch" (sizeD x) (sizeD z)
$ DProd3 x y z
unzip3D :: (DT a, DT b, DT c) => Dist (a,b,c) -> (Dist a, Dist b, Dist c)
unzip3D (DProd3 dx dy dz)
= (dx,dy,dz)