module LLVM.Extra.Vector (
Simple (shuffleMatch, extract), C (insert),
Element, Size,
Canonical, Construct,
size, sizeInTuple,
replicate, iterate, assemble,
shuffle,
rotateUp, rotateDown, reverse,
shiftUp, shiftDown,
shiftUpMultiZero, shiftDownMultiZero,
shuffleMatchTraversable,
shuffleMatchAccess,
shuffleMatchPlain1,
shuffleMatchPlain2,
insertTraversable,
extractTraversable,
extractAll,
Constant, constant,
insertChunk, modify,
map, mapChunks, zipChunksWith,
chop, concat, select,
signedFraction,
cumulate1, umul32to64,
Arithmetic
(sum, sumToPair, sumInterleavedToPair,
cumulate, dotProduct, mul),
Real
(min, max, abs, signum,
truncate, floor, fraction),
) where
import qualified LLVM.Extra.Extension.X86Auto as X86A
import qualified LLVM.Extra.ExtensionCheck.X86 as X86C
import qualified LLVM.Extra.Extension.X86 as X86
import qualified LLVM.Extra.Extension as Ext
import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Monad as M
import qualified LLVM.Extra.ArithmeticPrivate as A
import qualified LLVM.Core as LLVM
import LLVM.Util.Loop (Phi(phis, addPhis), )
import LLVM.Core
(Value, ConstValue, valueOf, value, constOf, undef,
Vector, insertelement, extractelement,
IsConst, IsArithmetic, IsFloating,
IsPrimitive,
CodeGenFunction, )
import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal (D4, (:+:), )
import qualified Control.Applicative as App
import Control.Monad.HT ((<=<), )
import Control.Monad (liftM2, liftM3, foldM, )
import Control.Applicative (liftA2, )
import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold
import qualified Data.NonEmpty.Class as NonEmptyC
import qualified Data.NonEmpty as NonEmpty
import qualified Data.Empty as Empty
import qualified Data.List.HT as ListHT
import qualified Data.List as List
import Data.NonEmpty ((!:), )
import Data.Tuple.HT (uncurry3, )
import Data.Int (Int8, Int16, Int32, Int64, )
import Data.Word (Word8, Word16, Word32, Word64, )
import Prelude hiding
(Real, truncate, floor, round,
map, zipWith, iterate, replicate, reverse, concat, sum, )
class (Simple v) => C v where
insert :: Value Word32 -> Element v -> v -> CodeGenFunction r v
class
(TypeNum.Positive (Size v), Phi v, Class.Undefined v) =>
Simple v where
type Element v :: *
type Size v :: *
shuffleMatch ::
ConstValue (Vector (Size v) Word32) -> v -> CodeGenFunction r v
extract :: Value Word32 -> v -> CodeGenFunction r (Element v)
instance
(TypeNum.Positive n, LLVM.IsPrimitive a) =>
Simple (Value (Vector n a)) where
type Element (Value (Vector n a)) = Value a
type Size (Value (Vector n a)) = n
shuffleMatch is v = shuffleMatchPlain1 v is
extract k v = extractelement v k
instance
(TypeNum.Positive n, LLVM.IsPrimitive a) =>
C (Value (Vector n a)) where
insert k a v = insertelement v a k
instance
(Simple v0, Simple v1, Size v0 ~ Size v1) =>
Simple (v0, v1) where
type Element (v0, v1) = (Element v0, Element v1)
type Size (v0, v1) = Size v0
shuffleMatch is (v0,v1) =
liftM2 (,)
(shuffleMatch is v0)
(shuffleMatch is v1)
extract k (v0,v1) =
liftM2 (,)
(extract k v0)
(extract k v1)
instance
(C v0, C v1, Size v0 ~ Size v1) =>
C (v0, v1) where
insert k (a0,a1) (v0,v1) =
liftM2 (,)
(insert k a0 v0)
(insert k a1 v1)
instance
(Simple v0, Simple v1, Simple v2, Size v0 ~ Size v1, Size v1 ~ Size v2) =>
Simple (v0, v1, v2) where
type Element (v0, v1, v2) = (Element v0, Element v1, Element v2)
type Size (v0, v1, v2) = Size v0
shuffleMatch is (v0,v1,v2) =
liftM3 (,,)
(shuffleMatch is v0)
(shuffleMatch is v1)
(shuffleMatch is v2)
extract k (v0,v1,v2) =
liftM3 (,,)
(extract k v0)
(extract k v1)
(extract k v2)
instance
(C v0, C v1, C v2, Size v0 ~ Size v1, Size v1 ~ Size v2) =>
C (v0, v1, v2) where
insert k (a0,a1,a2) (v0,v1,v2) =
liftM3 (,,)
(insert k a0 v0)
(insert k a1 v1)
(insert k a2 v2)
newtype Constant n a = Constant a
constant :: (TypeNum.Positive n) => a -> Constant n a
constant = Constant
instance Functor (Constant n) where
fmap f (Constant a) = Constant (f a)
instance App.Applicative (Constant n) where
pure = Constant
Constant f <*> Constant a = Constant (f a)
instance Fold.Foldable (Constant n) where
foldMap = Trav.foldMapDefault
instance Trav.Traversable (Constant n) where
sequenceA (Constant a) = fmap Constant a
instance (Phi a) => Phi (Constant n a) where
phis = Class.phisTraversable
addPhis = Class.addPhisFoldable
instance (Class.Undefined a) => Class.Undefined (Constant n a) where
undefTuple = Class.undefTuplePointed
instance (TypeNum.Positive n, Phi a, Class.Undefined a) => Simple (Constant n a) where
type Element (Constant n a) = a
type Size (Constant n a) = n
shuffleMatch _ = return
extract _ (Constant a) = return a
class (n ~ Size (Construct n a), a ~ Element (Construct n a),
C (Construct n a)) =>
Canonical n a where
type Construct n a :: *
instance
(TypeNum.Positive n, LLVM.IsPrimitive a) =>
Canonical n (Value a) where
type Construct n (Value a) = Value (Vector n a)
instance (Canonical n a0, Canonical n a1) => Canonical n (a0, a1) where
type Construct n (a0, a1) = (Construct n a0, Construct n a1)
instance (Canonical n a0, Canonical n a1, Canonical n a2) => Canonical n (a0, a1, a2) where
type Construct n (a0, a1, a2) = (Construct n a0, Construct n a1, Construct n a2)
size ::
(TypeNum.Positive n) =>
Value (Vector n a) -> Int
size =
let sz :: (TypeNum.Positive n) => TypeNum.Singleton n -> Value (Vector n a) -> Int
sz n _ = TypeNum.integralFromSingleton n
in sz TypeNum.singleton
replicate ::
(C v) =>
Element v -> CodeGenFunction r v
replicate = replicateCore TypeNum.singleton
replicateCore ::
(C v) =>
TypeNum.Singleton (Size v) -> Element v -> CodeGenFunction r v
replicateCore n =
assemble . List.replicate (TypeNum.integralFromSingleton n)
assemble ::
(C v) =>
[Element v] -> CodeGenFunction r v
assemble =
foldM (\v (k,x) -> insert (valueOf k) x v) Class.undefTuple .
List.zip [0..]
insertChunk ::
(C c, C v, Element c ~ Element v) =>
Int -> c ->
v -> CodeGenFunction r v
insertChunk k x =
M.chain $
List.zipWith
(\i j -> \v ->
extract (valueOf i) x >>= \e ->
insert (valueOf j) e v)
(take (sizeInTuple x) [0..])
[fromIntegral k ..]
iterate ::
(C v) =>
(Element v -> CodeGenFunction r (Element v)) ->
Element v -> CodeGenFunction r v
iterate f x =
fmap snd $
iterateCore f x Class.undefTuple
iterateCore ::
(C v) =>
(Element v -> CodeGenFunction r (Element v)) ->
Element v -> v ->
CodeGenFunction r (Element v, v)
iterateCore f x0 v0 =
foldM
(\(x,v) k ->
liftM2 (,) (f x)
(insert (valueOf k) x v))
(x0,v0)
(take (sizeInTuple v0) [0..])
shuffle ::
(C v, C w, Element v ~ Element w) =>
v ->
ConstValue (Vector (Size w) Word32) ->
CodeGenFunction r w
shuffle x i =
assemble =<<
mapM
(flip extract x <=< extractelement (value i) . valueOf)
(take (size (value i)) [0..])
sizeInTuple :: Simple v => v -> Int
sizeInTuple =
let sz :: Simple v => TypeNum.Singleton (Size v) -> v -> Int
sz n _ = TypeNum.integralFromSingleton n
in sz TypeNum.singleton
constCyclicVector ::
(IsConst a, TypeNum.Positive n) =>
NonEmpty.T [] a -> ConstValue (Vector n a)
constCyclicVector =
LLVM.constCyclicVector . fmap constOf
rotateUp ::
(Simple v) =>
v -> CodeGenFunction r v
rotateUp x =
shuffleMatch
(constCyclicVector $
(fromIntegral (sizeInTuple x) 1) !: [0..]) x
rotateDown ::
(Simple v) =>
v -> CodeGenFunction r v
rotateDown x =
shuffleMatch
(constCyclicVector $
NonEmpty.snoc (List.take (sizeInTuple x 1) [1..]) 0) x
reverse ::
(Simple v) =>
v -> CodeGenFunction r v
reverse x =
shuffleMatch
(constCyclicVector $
maybe (error "vector size must be positive") NonEmpty.reverse $
NonEmpty.fetch $
List.take (sizeInTuple x) [0..])
x
shiftUp ::
(C v) =>
Element v -> v -> CodeGenFunction r (Element v, v)
shiftUp x0 x = do
y <-
shuffleMatch
(LLVM.constCyclicVector $ undef !: List.map constOf [0..]) x
liftM2 (,)
(extract (LLVM.valueOf (fromIntegral (sizeInTuple x) 1)) x)
(insert (value LLVM.zero) x0 y)
shiftDown ::
(C v) =>
Element v -> v -> CodeGenFunction r (Element v, v)
shiftDown x0 x = do
y <-
shuffleMatch
(LLVM.constCyclicVector $
NonEmpty.snoc
(List.map constOf $ List.take (sizeInTuple x 1) [1..])
undef) x
liftM2 (,)
(extract (value LLVM.zero) x)
(insert (LLVM.valueOf (fromIntegral (sizeInTuple x) 1)) x0 y)
shiftUpMultiZero ::
(C v, Class.Zero (Element v)) =>
Int -> v -> LLVM.CodeGenFunction r v
shiftUpMultiZero n v =
assemble . take (sizeInTuple v) .
(List.replicate n Class.zeroTuple ++) =<< extractAll v
shiftDownMultiZero ::
(C v, Class.Zero (Element v)) =>
Int -> v -> LLVM.CodeGenFunction r v
shiftDownMultiZero n v =
assemble . take (sizeInTuple v) .
(++ List.repeat Class.zeroTuple) . List.drop n
=<< extractAll v
shuffleMatchTraversable ::
(Simple v, Trav.Traversable f) =>
ConstValue (Vector (Size v) Word32) -> f v -> CodeGenFunction r (f v)
shuffleMatchTraversable is v =
Trav.mapM (shuffleMatch is) v
shuffleMatchAccess ::
(C v) =>
ConstValue (Vector (Size v) Word32) -> v -> CodeGenFunction r v
shuffleMatchAccess is v =
assemble =<<
mapM
(flip extract v <=<
flip extract (value is) . valueOf)
(take (size (value is)) [0..])
shuffleMatchPlain1 ::
(TypeNum.Positive n, IsPrimitive a) =>
Value (Vector n a) ->
ConstValue (Vector n Word32) ->
CodeGenFunction r (Value (Vector n a))
shuffleMatchPlain1 x =
shuffleMatchPlain2 x (value undef)
shuffleMatchPlain2 ::
(TypeNum.Positive n, IsPrimitive a) =>
Value (Vector n a) ->
Value (Vector n a) ->
ConstValue (Vector n Word32) ->
CodeGenFunction r (Value (Vector n a))
shuffleMatchPlain2 =
LLVM.shufflevector
insertTraversable ::
(C v, Trav.Traversable f, App.Applicative f) =>
Value Word32 -> f (Element v) -> f v -> CodeGenFunction r (f v)
insertTraversable n a v =
Trav.sequence (liftA2 (insert n) a v)
extractTraversable ::
(Simple v, Trav.Traversable f) =>
Value Word32 -> f v -> CodeGenFunction r (f (Element v))
extractTraversable n v =
Trav.mapM (extract n) v
extractAll ::
(Simple v) =>
v -> LLVM.CodeGenFunction r [Element v]
extractAll = sequence . extractList
extractList ::
(Simple v) =>
v -> [LLVM.CodeGenFunction r (Element v)]
extractList x =
List.map
(flip extract x . LLVM.valueOf)
(take (sizeInTuple x) [0..])
modify ::
(C v) =>
Value Word32 ->
(Element v -> CodeGenFunction r (Element v)) ->
(v -> CodeGenFunction r v)
modify k f v =
flip (insert k) v =<< f =<< extract k v
map, _mapByFold ::
(C v, C w, Size v ~ Size w) =>
(Element v -> CodeGenFunction r (Element w)) ->
(v -> CodeGenFunction r w)
map f =
assemble <=< mapM f <=< extractAll
_mapByFold f a =
foldM
(\b n ->
extract (valueOf n) a >>=
f >>=
flip (insert (valueOf n)) b)
Class.undefTuple
(take (sizeInTuple a) [0..])
mapChunks ::
(C ca, C cb, Size ca ~ Size cb,
C va, C vb, Size va ~ Size vb,
Element ca ~ Element va, Element cb ~ Element vb) =>
(ca -> CodeGenFunction r cb) ->
(va -> CodeGenFunction r vb)
mapChunks f a =
foldM
(\b (am,k) ->
am >>= \ac ->
f ac >>= \bc ->
insertChunk (k * sizeInTuple ac) bc b)
Class.undefTuple $
List.zip (chop a) [0..]
zipChunksWith ::
(C ca, C cb, C cc, Size ca ~ Size cb, Size cb ~ Size cc,
C va, C vb, C vc, Size va ~ Size vb, Size vb ~ Size vc,
Element ca ~ Element va, Element cb ~ Element vb, Element cc ~ Element vc) =>
(ca -> cb -> CodeGenFunction r cc) ->
(va -> vb -> CodeGenFunction r vc)
zipChunksWith f a b =
mapChunks (uncurry f) (a,b)
mapChunks2 ::
(C ca, C cb, Size ca ~ Size cb,
C la, C lb, Size la ~ Size lb,
C va, C vb, Size va ~ Size vb,
Element ca ~ Element va, Element la ~ Element va,
Element cb ~ Element vb, Element lb ~ Element vb) =>
(ca -> CodeGenFunction r cb) ->
(la -> CodeGenFunction r lb) ->
(va -> CodeGenFunction r vb)
mapChunks2 f g a = do
let chunkSize :: C ca => (ca -> cgf) -> TypeNum.Singleton (Size ca) -> Int
chunkSize _ = TypeNum.integralFromSingleton
xs <- extractAll a
case ListHT.viewR $
ListHT.sliceVertical (chunkSize g TypeNum.singleton) xs of
Nothing -> assemble []
Just (cs,c) -> do
ds <- mapM (extractAll <=< g <=< assemble) cs
d <-
if List.length c <= chunkSize f TypeNum.singleton
then fmap List.concat $
mapM (extractAll <=< f <=< assemble) $
ListHT.sliceVertical (chunkSize f TypeNum.singleton) c
else extractAll =<< g =<< assemble c
assemble $ List.concat ds ++ d
zipChunks2With ::
(C ca, C cb, C cc, Size ca ~ Size cb, Size cb ~ Size cc,
C la, C lb, C lc, Size la ~ Size lb, Size lb ~ Size lc,
C va, C vb, C vc, Size va ~ Size vb, Size vb ~ Size vc,
Element ca ~ Element va, Element la ~ Element va,
Element cb ~ Element vb, Element lb ~ Element vb,
Element cc ~ Element vc, Element lc ~ Element vc) =>
(ca -> cb -> CodeGenFunction r cc) ->
(la -> lb -> CodeGenFunction r lc) ->
(va -> vb -> CodeGenFunction r vc)
zipChunks2With f g a b =
mapChunks2 (uncurry f) (uncurry g) (a,b)
infixl 1 `withRound`
withRound ::
(IsPrimitive a, IsPrimitive b,
TypeNum.Positive k, TypeNum.Positive m, TypeNum.Positive n) =>
CodeGenFunction r x ->
Ext.T (Value (Vector m a) -> Value Word32 -> CodeGenFunction r (Value (Vector m b))) ->
Ext.T (Value (Vector k a) -> Value Word32 -> CodeGenFunction r (Value (Vector k b))) ->
(Value (Vector n b) -> CodeGenFunction r x) ->
Word32 ->
Value (Vector n a) -> CodeGenFunction r x
withRound generic roundSmallExt _roundLargeExt post mode x =
generic
`Ext.run`
(Ext.with roundSmallExt $ \round ->
post =<< mapChunks (flip round (valueOf mode)) x)
dotProductPartial ::
(TypeNum.Positive n, LLVM.IsPrimitive a, LLVM.IsArithmetic a) =>
Int ->
Value (Vector n a) ->
Value (Vector n a) ->
CodeGenFunction r (Value a)
dotProductPartial n x y =
sumPartial n =<< A.mul x y
sumPartial ::
(TypeNum.Positive n, LLVM.IsPrimitive a, LLVM.IsArithmetic a) =>
Int ->
Value (Vector n a) ->
CodeGenFunction r (Value a)
sumPartial n x =
foldl1
(M.liftR2 A.add)
(List.map (LLVM.extractelement x . valueOf) $ take n $ [0..])
chop ::
(C c, C v, Element c ~ Element v) =>
v -> [CodeGenFunction r c]
chop = chopCore TypeNum.singleton
chopCore ::
(C c, C v, Element c ~ Element v) =>
TypeNum.Singleton (Size c) -> v -> [CodeGenFunction r c]
chopCore m x =
List.map (assemble <=< sequence) $
ListHT.sliceVertical (TypeNum.integralFromSingleton m) $
extractList x
concat ::
(C c, C v, Element c ~ Element v) =>
[c] -> CodeGenFunction r v
concat xs =
foldM
(\v0 (js,c) ->
foldM
(\v (i,j) -> do
x <- extract (valueOf i) c
insert (valueOf j) x v)
v0 $
List.zip [0..] js)
Class.undefTuple $
List.zip
(ListHT.sliceVertical (sizeInTuple (head xs)) [0..])
xs
getLowestPair ::
(TypeNum.Positive n) =>
Value (Vector n a) ->
CodeGenFunction r (Value a, Value a)
getLowestPair x =
liftM2 (,)
(extractelement x (valueOf 0))
(extractelement x (valueOf 1))
_reduceAddInterleaved ::
(IsArithmetic a, IsPrimitive a,
TypeNum.Positive n, TypeNum.Positive m, (m :+: m) ~ n) =>
TypeNum.Singleton m ->
Value (Vector n a) ->
CodeGenFunction r (Value (Vector m a))
_reduceAddInterleaved tm v = do
let m = TypeNum.integralFromSingleton tm
x <- shuffle v (constCyclicVector $ NonEmptyC.iterate succ 0)
y <- shuffle v (constCyclicVector $ NonEmptyC.iterate succ m)
A.add x y
sumGeneric ::
(IsArithmetic a, IsPrimitive a, TypeNum.Positive n) =>
Value (Vector n a) ->
CodeGenFunction r (Value a)
sumGeneric =
flip extractelement (valueOf 0) <=<
reduceSumInterleaved 1
sumToPairGeneric ::
(Arithmetic a, TypeNum.Positive n) =>
Value (Vector n a) ->
CodeGenFunction r (Value a, Value a)
sumToPairGeneric v =
let n2 = div (size v) 2
in sumInterleavedToPair =<<
shuffleMatchPlain1 v
(maybe (error "vector size must be positive") LLVM.constCyclicVector $
NonEmpty.fetch $
List.map (constOf . fromIntegral) $
concatMap (\k -> [k, k+n2]) [0..])
reduceSumInterleaved ::
(IsArithmetic a, IsPrimitive a, TypeNum.Positive n) =>
Int ->
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
reduceSumInterleaved m x0 =
let go ::
(IsArithmetic a, IsPrimitive a, TypeNum.Positive n) =>
Int ->
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
go n x =
if m==n
then return x
else
let n2 = div n 2
in go n2
=<< A.add x
=<< shuffleMatchPlain1 x
(LLVM.constCyclicVector $
NonEmpty.appendLeft
(List.map constOf $
take n2 [fromIntegral n2 ..])
(NonEmptyC.repeat undef))
in go (size x0) x0
cumulateGeneric, _cumulateSimple ::
(IsArithmetic a, IsPrimitive a, TypeNum.Positive n) =>
Value a -> Value (Vector n a) ->
CodeGenFunction r (Value a, Value (Vector n a))
_cumulateSimple a x =
foldM
(\(a0,y0) k -> do
a1 <- A.add a0 =<< extract (valueOf k) x
y1 <- insert (valueOf k) a0 y0
return (a1,y1))
(a, Class.undefTuple)
(take (sizeInTuple x) $ [0..])
cumulateGeneric =
cumulateFrom1 cumulate1
cumulateFrom1 ::
(IsArithmetic a, IsPrimitive a, TypeNum.Positive n) =>
(Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))) ->
Value a -> Value (Vector n a) ->
CodeGenFunction r (Value a, Value (Vector n a))
cumulateFrom1 cum a x0 = do
(b,x1) <- shiftUp a x0
y <- cum x1
z <- A.add b =<< extract (valueOf (fromIntegral (sizeInTuple x0) 1)) y
return (z,y)
cumulate1 ::
(IsArithmetic a, IsPrimitive a, TypeNum.Positive n) =>
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
cumulate1 x =
foldM
(\y k -> A.add y =<< shiftUpMultiZero k y)
x
(takeWhile (<sizeInTuple x) $ List.iterate (2*) 1)
inttofp ::
(TypeNum.Positive n,
IsPrimitive a, IsPrimitive b,
LLVM.IsInteger a, IsFloating b) =>
Value (Vector n a) -> CodeGenFunction r (Value (Vector n b))
inttofp = LLVM.inttofp
signumLogical ::
(TypeNum.Positive n,
IsPrimitive a, IsPrimitive b, IsArithmetic b) =>
(Value (Vector n a) ->
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n b))) ->
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n b))
signumLogical gt x = do
let zero = LLVM.value LLVM.zero
negative <- gt zero x
positive <- gt x zero
A.sub negative positive
signumIntGeneric ::
(TypeNum.Positive n,
IsPrimitive a, LLVM.IsInteger a,
LLVM.CmpRet a, LLVM.CmpResult a ~ b,
IsPrimitive b, LLVM.IsInteger b) =>
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
signumIntGeneric x = do
let zero = LLVM.value LLVM.zero
negative <- LLVM.sadapt =<< A.cmp LLVM.CmpLT x zero
positive <- LLVM.sadapt =<< A.cmp LLVM.CmpGT x zero
A.sub positive negative
signumWordGeneric ::
(TypeNum.Positive n,
IsPrimitive a, LLVM.IsInteger a,
LLVM.CmpRet a, LLVM.CmpResult a ~ b,
IsPrimitive b, LLVM.IsInteger b) =>
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
signumWordGeneric x =
LLVM.zadapt =<< A.cmp LLVM.CmpGT x (LLVM.value LLVM.zero)
signumFloatGeneric ::
(TypeNum.Positive n,
IsPrimitive a, IsArithmetic a, IsFloating a,
LLVM.CmpRet a, LLVM.CmpResult a ~ b,
IsPrimitive b, LLVM.IsInteger b) =>
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
signumFloatGeneric x = do
let zero = LLVM.value LLVM.zero
negative <- LLVM.sitofp =<< A.cmp LLVM.CmpLT x zero
positive <- LLVM.sitofp =<< A.cmp LLVM.CmpGT x zero
A.sub negative positive
signedFraction ::
(IsFloating a, IsConst a, Real a, TypeNum.Positive n) =>
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
signedFraction x =
A.sub x =<< truncate x
floorGeneric ::
(IsFloating a, IsConst a, Real a, TypeNum.Positive n) =>
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
floorGeneric = floorLogical A.fcmp
fractionGeneric ::
(IsFloating a, IsConst a, Real a, TypeNum.Positive n) =>
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
fractionGeneric = fractionLogical A.fcmp
class (LLVM.IsSized a, LLVM.IsSized (Mask a),
LLVM.SizeOf a ~ LLVM.SizeOf (Mask a),
LLVM.IsPrimitive a, LLVM.IsPrimitive (Mask a),
LLVM.IsInteger (Mask a)) =>
Maskable a where
type Mask a :: *
instance Maskable Int8 where type Mask Int8 = Int8
instance Maskable Int16 where type Mask Int16 = Int16
instance Maskable Int32 where type Mask Int32 = Int32
instance Maskable Int64 where type Mask Int64 = Int64
instance Maskable Word8 where type Mask Word8 = Int8
instance Maskable Word16 where type Mask Word16 = Int16
instance Maskable Word32 where type Mask Word32 = Int32
instance Maskable Word64 where type Mask Word64 = Int64
instance Maskable Float where type Mask Float = Int32
instance Maskable Double where type Mask Double = Int64
makeMask ::
(Maskable a, TypeNum.Positive n) =>
Value (Vector n a) ->
Value (Vector n Bool) ->
CodeGenFunction r (Value (Vector n (Mask a)))
makeMask _ = LLVM.sadapt
minGeneric, maxGeneric ::
(IsConst a, Real a, Maskable a, TypeNum.Positive n) =>
Value (Vector n a) ->
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
minGeneric x y = do
b <- makeMask x =<< A.cmp LLVM.CmpLT x y
selectLogical b x y
maxGeneric x y = do
b <- makeMask x =<< A.cmp LLVM.CmpGT x y
selectLogical b x y
absGeneric ::
(IsConst a, Real a, Maskable a, TypeNum.Positive n) =>
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
absGeneric x =
maxGeneric x =<< LLVM.neg x
absAuto ::
(TypeNum.Positive n, TypeNum.Positive m, TypeNum.Positive k,
IsConst a, Real a, Maskable a) =>
Ext.T (Value (Vector m a) -> CodeGenFunction r (Value (Vector m a))) ->
Ext.T (Value (Vector k a) -> CodeGenFunction r (Value (Vector k a))) ->
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
absAuto byChunk byLargeChunk x =
absGeneric x
`Ext.run`
(Ext.with byChunk $ \f -> mapChunks f x)
`Ext.run`
(Ext.with2 byChunk byLargeChunk $
\ f g -> mapChunks2 f g x)
select ::
(LLVM.IsFirstClass a, IsPrimitive a, TypeNum.Positive n,
LLVM.CmpRet a, LLVM.CmpResult a ~ Bool) =>
Value (Vector n Bool) ->
Value (Vector n a) ->
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
select b x y =
map (uncurry3 LLVM.select) (b, x, y)
_floorSelect ::
(Num a, IsFloating a, IsConst a, Real a, TypeNum.Positive n) =>
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
_floorSelect x =
do xr <- truncate x
b <- A.fcmp LLVM.FPOLE xr x
select b xr =<< A.sub xr =<< replicate (valueOf 1)
_fractionSelect ::
(Num a, IsFloating a, IsConst a, Real a, TypeNum.Positive n) =>
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
_fractionSelect x =
do xf <- signedFraction x
b <- A.fcmp LLVM.FPOGE xf (value LLVM.zero)
select b xf =<< A.add xf =<< replicate (valueOf 1)
selectLogical ::
(LLVM.IsFirstClass a, IsPrimitive a,
LLVM.IsInteger i, IsPrimitive i,
LLVM.IsSized a, LLVM.IsSized i,
LLVM.SizeOf a ~ LLVM.SizeOf i,
TypeNum.Positive n) =>
Value (Vector n i) ->
Value (Vector n a) ->
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
selectLogical b x y = do
bneg <- LLVM.inv b
xm <- A.and b =<< LLVM.bitcastElements x
ym <- A.and bneg =<< LLVM.bitcastElements y
LLVM.bitcastElements =<< A.or xm ym
floorLogical ::
(IsFloating a, IsConst a, Real a,
IsPrimitive i, LLVM.IsInteger i, TypeNum.Positive n) =>
(LLVM.FPPredicate ->
Value (Vector n a) ->
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n i))) ->
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
floorLogical cmp x =
do xr <- truncate x
b <- cmp LLVM.FPOGT xr x
A.add xr =<< LLVM.sitofp b
fractionLogical ::
(IsFloating a, IsConst a, Real a,
IsPrimitive i, LLVM.IsInteger i, TypeNum.Positive n) =>
(LLVM.FPPredicate ->
Value (Vector n a) ->
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n i))) ->
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
fractionLogical cmp x =
do xf <- signedFraction x
b <- cmp LLVM.FPOLT xf (value LLVM.zero)
A.sub xf =<< LLVM.sitofp b
order ::
(TypeNum.Positive n, TypeNum.Positive m, TypeNum.Positive k,
LLVM.IsFirstClass a, IsPrimitive a) =>
(Value (Vector n a) -> Value (Vector n a) -> CodeGenFunction r (Value (Vector n a))) ->
Ext.T (Value (Vector m a) -> Value (Vector m a) -> CodeGenFunction r (Value (Vector m a))) ->
Ext.T (Value (Vector k a) -> Value (Vector k a) -> CodeGenFunction r (Value (Vector k a))) ->
(Value (Vector n a) -> Value (Vector n a) -> CodeGenFunction r (Value (Vector n a)))
order f byChunk byLargeChunk x y =
f x y
`Ext.run`
(Ext.with byChunk $ \psel -> zipChunksWith psel x y)
`Ext.run`
(Ext.with2 byChunk byLargeChunk $
\ psel plsel -> zipChunks2With psel plsel x y)
class (IsArithmetic a, IsPrimitive a) => Arithmetic a where
sum ::
(TypeNum.Positive n) =>
Value (Vector n a) ->
CodeGenFunction r (Value a)
sum = sumGeneric
sumToPair ::
(TypeNum.Positive n) =>
Value (Vector n a) ->
CodeGenFunction r (Value a, Value a)
sumToPair = sumToPairGeneric
sumInterleavedToPair ::
(TypeNum.Positive n) =>
Value (Vector n a) ->
CodeGenFunction r (Value a, Value a)
sumInterleavedToPair v =
getLowestPair =<< reduceSumInterleaved 2 v
cumulate ::
(TypeNum.Positive n) =>
Value a -> Value (Vector n a) ->
CodeGenFunction r (Value a, Value (Vector n a))
cumulate = cumulateGeneric
dotProduct ::
(TypeNum.Positive n) =>
Value (Vector n a) ->
Value (Vector n a) ->
CodeGenFunction r (Value a)
dotProduct x y =
dotProductPartial (size x) x y
mul ::
(TypeNum.Positive n) =>
Value (Vector n a) ->
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
mul = A.mul
instance Arithmetic Float where
sum x =
Ext.runWhen (size x >= 4) (sumGeneric x) $
Ext.with X86A.haddps $ \haddp ->
do chunkSum <-
foldl1 (M.liftR2 A.add) $ chop x
y <- haddp chunkSum (value undef)
z <- haddp y (value undef)
extractelement z (valueOf 0)
sumToPair x =
Ext.runWhen (size x >= 4) (getLowestPair x) $
Ext.with X86A.haddps $ \haddp ->
let
reduce [] = []
reduce [_] = error "vector must have size power of two"
reduce (x0:x1:xs) =
M.liftR2 haddp x0 x1 : reduce xs
go [] = error "vector must not be empty"
go [c] =
getLowestPair
=<< flip haddp (value undef)
=<< c
go cs = go (reduce cs)
in go $ chop x
dotProduct x y =
Ext.run (sum =<< A.mul x y) $
Ext.with X86A.dpps $ \dpp ->
foldl1 (M.liftR2 A.add) $
List.zipWith
(\mx my -> do
cx <- mx
cy <- my
flip extractelement (valueOf 0)
=<< dpp cx cy (valueOf 0xF1))
(chop x)
(chop y)
instance Arithmetic Double where
instance Arithmetic Int8 where
instance Arithmetic Int16 where
instance Arithmetic Int32 where
instance Arithmetic Int64 where
instance Arithmetic Word8 where
instance Arithmetic Word16 where
instance Arithmetic Word64 where
instance Arithmetic Word32 where
mul x y =
A.mul x y
`Ext.run`
(Ext.with X86A.pmuludq128 $ \pmul ->
zipChunksWith
(\cx cy -> do
evenX <- shuffleMatchPlain1 cx
(constVector4 (constOf 0, undef, constOf 2, undef))
evenY <- shuffleMatchPlain1 cy
(constVector4 (constOf 0, undef, constOf 2, undef))
evenZ64 <- pmul evenX evenY
evenZ <- LLVM.bitcast evenZ64
oddX <- shuffleMatchPlain1 cx
(constVector4 (constOf 1, undef, constOf 3, undef))
oddY <- shuffleMatchPlain1 cy
(constVector4 (constOf 1, undef, constOf 3, undef))
oddZ64 <- pmul oddX oddY
oddZ <- LLVM.bitcast oddZ64
shuffleMatchPlain2 evenZ oddZ
(constVector4 (constOf 0, constOf 4, constOf 2, constOf 6)))
x y)
`Ext.run`
Ext.wrap X86C.sse41 (A.mul x y)
umul32to64 ::
(TypeNum.Positive n) =>
Value (Vector n Word32) ->
Value (Vector n Word32) ->
CodeGenFunction r (Value (Vector n Word64))
umul32to64 x y =
(do x64 <- map LLVM.zext x
y64 <- map LLVM.zext y
A.mul x64 y64)
`Ext.run`
(Ext.with X86A.pmuludq128 $ \pmul ->
zipChunksWith
(\cx cy -> do
evenX <- shuffleMatchPlain1 cx
(constVector4 (constOf 0, undef, constOf 2, undef))
evenY <- shuffleMatchPlain1 cy
(constVector4 (constOf 0, undef, constOf 2, undef))
evenZ <- pmul evenX evenY
oddX <- shuffleMatchPlain1 cx
(constVector4 (constOf 1, undef, constOf 3, undef))
oddY <- shuffleMatchPlain1 cy
(constVector4 (constOf 1, undef, constOf 3, undef))
oddZ <- pmul oddX oddY
assemble =<< (sequence $
extract (valueOf 0) evenZ :
extract (valueOf 0) oddZ :
extract (valueOf 1) evenZ :
extract (valueOf 1) oddZ :
[]) :: CodeGenFunction r (Value (Vector D4 Word64)))
x y)
constVector4 ::
(IsConst a) =>
(ConstValue a, ConstValue a, ConstValue a, ConstValue a) ->
ConstValue (Vector D4 a)
constVector4 (a,b,c,d) =
LLVM.constVector $ a!:b!:c!:d!:Empty.Cons
class (Arithmetic a, LLVM.CmpRet a, LLVM.CmpResult a ~ Bool, IsConst a) =>
Real a where
min, max ::
(TypeNum.Positive n) =>
Value (Vector n a) ->
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
abs ::
(TypeNum.Positive n) =>
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
signum ::
(TypeNum.Positive n) =>
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
truncate, floor, fraction ::
(TypeNum.Positive n) =>
Value (Vector n a) ->
CodeGenFunction r (Value (Vector n a))
instance Real Float where
min = order minGeneric X86A.minps X86A.minps256
max = order maxGeneric X86A.maxps X86A.maxps256
abs x = Ext.run (absGeneric x) (Ext.with X86.absps ($x))
signum x =
signumFloatGeneric x
`Ext.run`
(Ext.with X86.cmpps $ \cmp ->
inttofp =<< mapChunks (signumLogical (cmp LLVM.FPOGT)) x)
truncate x =
withRound
((LLVM.inttofp .
(id :: Value (Vector n Int32) -> Value (Vector n Int32))
<=< LLVM.fptoint) x)
X86A.roundps X86A.roundps256 return 3 x
floor x =
withRound
(floorGeneric x
`Ext.run`
(Ext.with X86.cmpps $ \cmp ->
mapChunks (floorLogical cmp) x)
)
X86A.roundps X86A.roundps256 return 1 x
fraction x =
withRound
(fractionGeneric x
`Ext.run`
(Ext.with X86.cmpps $ \cmp ->
mapChunks (fractionLogical cmp) x)
)
X86A.roundps X86A.roundps256 (A.sub x) 1 x
instance Real Double where
min = order minGeneric X86A.minpd X86A.minpd256
max = order maxGeneric X86A.maxpd X86A.maxpd256
abs x = Ext.run (absGeneric x) (Ext.with X86.abspd ($x))
signum x =
signumFloatGeneric x
`Ext.run`
(Ext.with2 X86.cmppd X86A.cvtdq2pd $ \cmp tofp ->
mapChunks (signumLogical
(\a b -> do
c <- LLVM.bitcast =<< cmp LLVM.FPOGT a b
c0 <- extract (valueOf 0) (c :: Value (Vector D4 Int32))
c1 <- extract (valueOf 2) c
tofp =<< assemble [c0,c1])) x)
truncate x =
withRound
((LLVM.inttofp .
(id :: Value (Vector n Int64) -> Value (Vector n Int64))
<=< LLVM.fptoint) x)
X86A.roundpd X86A.roundpd256 return 3 x
floor x =
withRound
(floorGeneric x
`Ext.run`
(Ext.with X86.cmppd $ \cmp ->
mapChunks (floorLogical cmp) x))
X86A.roundpd X86A.roundpd256 return 1 x
fraction x =
withRound
(fractionGeneric x
`Ext.run`
(Ext.with X86.cmppd $ \cmp ->
mapChunks (fractionLogical cmp) x))
X86A.roundpd X86A.roundpd256 (A.sub x) 1 x
instance Real Int8 where
min = order minGeneric X86A.pminsb128 X86A.pminsb256
max = order maxGeneric X86A.pmaxsb128 X86A.pmaxsb256
abs = absAuto X86A.pabsb128 X86A.pabsb256
signum = signumIntGeneric
truncate = return
floor = return
fraction = const $ return (value LLVM.zero)
instance Real Int16 where
min = order minGeneric X86A.pminsw128 X86A.pminsw256
max = order maxGeneric X86A.pmaxsw128 X86A.pmaxsw256
abs = absAuto X86A.pabsw128 X86A.pabsw256
signum = signumIntGeneric
truncate = return
floor = return
fraction = const $ return (value LLVM.zero)
instance Real Int32 where
min = order minGeneric X86A.pminsd128 X86A.pminsd256
max = order maxGeneric X86A.pmaxsd128 X86A.pmaxsd256
abs = absAuto X86A.pabsd128 X86A.pabsd256
signum = signumIntGeneric
truncate = return
floor = return
fraction = const $ return (value LLVM.zero)
instance Real Int64 where
min = minGeneric
max = maxGeneric
abs = absGeneric
signum = signumIntGeneric
truncate = return
floor = return
fraction = const $ return (value LLVM.zero)
instance Real Word8 where
min = order minGeneric X86A.pminub128 X86A.pminub256
max = order maxGeneric X86A.pmaxub128 X86A.pmaxub256
abs = return
signum = signumWordGeneric
truncate = return
floor = return
fraction = const $ return (value LLVM.zero)
instance Real Word16 where
min = order minGeneric X86A.pminuw128 X86A.pminuw256
max = order maxGeneric X86A.pmaxuw128 X86A.pmaxuw256
abs = return
signum = signumWordGeneric
truncate = return
floor = return
fraction = const $ return (value LLVM.zero)
instance Real Word32 where
min = order minGeneric X86A.pminud128 X86A.pminud256
max = order maxGeneric X86A.pmaxud128 X86A.pmaxud256
abs = return
signum = signumWordGeneric
truncate = return
floor = return
fraction = const $ return (value LLVM.zero)
instance Real Word64 where
min = minGeneric
max = maxGeneric
abs = return
signum = signumWordGeneric
truncate = return
floor = return
fraction = const $ return (value LLVM.zero)