{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ParallelListComp #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.Pattern (
pattern Pattern,
pattern T2, pattern T3, pattern T4, pattern T5, pattern T6,
pattern T7, pattern T8, pattern T9, pattern T10, pattern T11,
pattern T12, pattern T13, pattern T14, pattern T15, pattern T16,
pattern Z_, pattern Ix, pattern (::.),
pattern I0, pattern I1, pattern I2, pattern I3, pattern I4,
pattern I5, pattern I6, pattern I7, pattern I8, pattern I9,
pattern V2, pattern V3, pattern V4, pattern V8, pattern V16,
) where
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.Representation.Tag
import Data.Array.Accelerate.Representation.Vec
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Sugar.Array
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Shape
import Data.Array.Accelerate.Sugar.Vec
import Data.Array.Accelerate.Type
import Data.Primitive.Vec
import Language.Haskell.TH hiding ( Exp, Match, tupP, tupE )
import Language.Haskell.TH.Extra
pattern Pattern :: forall b a context. IsPattern context a b => b -> context a
pattern Pattern vars <- (destruct @context -> vars)
where Pattern = construct @context
class IsPattern con a b where
construct :: b -> con a
destruct :: con a -> b
pattern Vector :: forall b a context. IsVector context a b => b -> context a
pattern Vector vars <- (vunpack @context -> vars)
where Vector = vpack @context
class IsVector context a b where
vpack :: b -> context a
vunpack :: context a -> b
pattern Z_ :: Exp DIM0
pattern Z_ = Pattern Z
{-# COMPLETE Z_ #-}
infixl 3 ::.
pattern (::.) :: (Elt a, Elt b) => Exp a -> Exp b -> Exp (a :. b)
pattern a ::. b = Pattern (a :. b)
{-# COMPLETE (::.) #-}
infixl 3 `Ix`
pattern Ix :: (Elt a, Elt b) => Exp a -> Exp b -> Exp (a :. b)
pattern a `Ix` b = a ::. b
{-# COMPLETE Ix #-}
instance IsPattern Exp Z Z where
construct _ = constant Z
destruct _ = Z
instance (Elt a, Elt b) => IsPattern Exp (a :. b) (Exp a :. Exp b) where
construct (Exp a :. Exp b) = Exp $ SmartExp $ Pair a b
destruct (Exp t) = Exp (SmartExp $ Prj PairIdxLeft t) :. Exp (SmartExp $ Prj PairIdxRight t)
runQ $ do
let
mkAccPattern :: Int -> Q [Dec]
mkAccPattern n = do
a <- newName "a"
let
xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ]
b = tupT (map (\t -> [t| Acc $(varT t)|]) xs)
snoc = foldl (\sn t -> [t| ($sn, ArraysR $(varT t)) |]) [t| () |] xs
context = tupT
$ [t| Arrays $(varT a) |]
: [t| ArraysR $(varT a) ~ $snoc |]
: map (\t -> [t| Arrays $(varT t)|]) xs
get x 0 = [| Acc (SmartAcc (Aprj PairIdxRight $x)) |]
get x i = get [| SmartAcc (Aprj PairIdxLeft $x) |] (i-1)
_x <- newName "_x"
[d| instance $context => IsPattern Acc $(varT a) $b where
construct $(tupP (map (\x -> [p| Acc $(varP x)|]) xs)) =
Acc $(foldl (\vs v -> [| SmartAcc ($vs `Apair` $(varE v)) |]) [| SmartAcc Anil |] xs)
destruct (Acc $(varP _x)) =
$(tupE (map (get (varE _x)) [(n-1), (n-2) .. 0]))
|]
mkExpPattern :: Int -> Q [Dec]
mkExpPattern n = do
a <- newName "a"
let
xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ]
ms = [ mkName ('m' : show i) | i <- [0 .. n-1] ]
tags = foldl (\ts t -> [p| $ts `TagRpair` $(varP t) |]) [p| TagRunit |] ms
b = tupT (map (\t -> [t| Exp $(varT t)|]) xs)
snoc = foldl (\sn t -> [t| ($sn, EltR $(varT t)) |]) [t| () |] xs
context = tupT
$ [t| Elt $(varT a) |]
: [t| EltR $(varT a) ~ $snoc |]
: map (\t -> [t| Elt $(varT t)|]) xs
get x 0 = [| SmartExp (Prj PairIdxRight $x) |]
get x i = get [| SmartExp (Prj PairIdxLeft $x) |] (i-1)
_x <- newName "_x"
_y <- newName "_y"
[d| instance $context => IsPattern Exp $(varT a) $b where
construct $(tupP (map (\x -> [p| Exp $(varP x)|]) xs)) =
let _unmatch :: SmartExp a -> SmartExp a
_unmatch (SmartExp (Match _ $(varP _y))) = $(varE _y)
_unmatch x = x
in
Exp $(foldl (\vs v -> [| SmartExp ($vs `Pair` _unmatch $(varE v)) |]) [| SmartExp Nil |] xs)
destruct (Exp $(varP _x)) =
case $(varE _x) of
SmartExp (Match $tags $(varP _y))
-> $(tupE [[| Exp (SmartExp (Match $(varE m) $(get (varE _x) i))) |] | m <- ms | i <- [(n-1), (n-2) .. 0]])
_ -> $(tupE [[| Exp $(get (varE _x) i) |] | i <- [(n-1), (n-2) .. 0]])
|]
mkVecPattern :: Int -> Q [Dec]
mkVecPattern n = do
a <- newName "a"
v <- newName "v"
let
tup = tupT (replicate n ([t| Exp $(varT a)|]))
vec = [t| Vec $(litT (numTyLit (fromIntegral n))) $(varT a) |]
context = [t| (Elt $(varT v), VecElt $(varT a), EltR $(varT v) ~ $vec) |]
vecR = foldr appE [| VecRnil (singleType @ $(varT a)) |] (replicate n [| VecRsucc |])
tR = tupT (replicate n (varT a))
[d| instance $context => IsVector Exp $(varT v) $tup where
vpack x = case construct x :: Exp $tR of
Exp x' -> Exp (SmartExp (VecPack $vecR x'))
vunpack (Exp x) = destruct (Exp (SmartExp (VecUnpack $vecR x)) :: Exp $tR)
|]
es <- mapM mkExpPattern [0..16]
as <- mapM mkAccPattern [0..16]
vs <- mapM mkVecPattern [2,3,4,8,16]
return $ concat (es ++ as ++ vs)
runQ $ do
let
mkT :: Int -> Q [Dec]
mkT n =
let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ]
ts = map varT xs
name = mkName ('T':show n)
con = varT (mkName "con")
ty1 = tupT ts
ty2 = tupT (map (con `appT`) ts)
sig = foldr (\t r -> [t| $con $t -> $r |]) (appT con ty1) ts
in
sequence
[ patSynSigD name [t| IsPattern $con $ty1 $ty2 => $sig |]
, patSynD name (prefixPatSyn xs) implBidir [p| Pattern $(tupP (map varP xs)) |]
, pragCompleteD [name] (Just ''Acc)
, pragCompleteD [name] (Just ''Exp)
]
mkI :: Int -> Q [Dec]
mkI n =
let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ]
ts = map varT xs
name = mkName ('I':show n)
ix = mkName "Ix"
cst = tupT (map (\t -> [t| Elt $t |]) ts)
dim = foldl (\h t -> [t| $h :. $t |]) [t| Z |] ts
sig = foldr (\t r -> [t| Exp $t -> $r |]) [t| Exp $dim |] ts
in
sequence
[ patSynSigD name [t| $cst => $sig |]
, patSynD name (prefixPatSyn xs) implBidir (foldl (\ps p -> infixP ps ix (varP p)) [p| Z_ |] xs)
, pragCompleteD [name] Nothing
]
mkV :: Int -> Q [Dec]
mkV n =
let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ]
ts = map varT xs
name = mkName ('V':show n)
con = varT (mkName "con")
ty1 = varT (mkName "vec")
ty2 = tupT (map (con `appT`) ts)
sig = foldr (\t r -> [t| $con $t -> $r |]) (appT con ty1) ts
in
sequence
[ patSynSigD name [t| IsVector $con $ty1 $ty2 => $sig |]
, patSynD name (prefixPatSyn xs) implBidir [p| Vector $(tupP (map varP xs)) |]
, pragCompleteD [name] (Just ''Exp)
]
ts <- mapM mkT [2..16]
is <- mapM mkI [0..9]
vs <- mapM mkV [2,3,4,8,16]
return $ concat (ts ++ is ++ vs)