module LLVM.Extra.ScalarOrVector (
Fraction (truncate, fraction),
signedFraction,
addToPhase,
incPhase,
truncateToInt,
floorToInt,
ceilingToInt,
roundToIntFast,
splitFractionToInt,
Scalar,
Replicate (replicate, replicateConst),
replicateOf,
Real (min, max, abs, signum),
PseudoModule (scale, scaleConst),
IntegerConstant(constFromInteger),
RationalConstant(constFromRational),
TranscendentalConstant(constPi),
) where
import LLVM.Extra.Vector (Element, Size, )
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Extension.X86 as X86
import qualified LLVM.Extra.Extension as Ext
import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.ArithmeticPrivate as A
import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal (D1, )
import qualified LLVM.Core as LLVM
import LLVM.Core
(Value, ConstValue, valueOf, constOf,
CmpRet, CmpResult, NumberOfElements,
Vector, WordN(WordN), IntN(IntN), FP128,
IsConst, IsInteger, IsFloating,
CodeGenFunction, )
import Control.Monad.HT ((<=<), )
import qualified Data.NonEmpty as NonEmpty
import Data.Word (Word8, Word16, Word32, Word64, )
import Data.Int (Int8, Int16, Int32, Int64, )
import Prelude hiding (Real, replicate, min, max, abs, truncate, floor, round, )
class (Real a, IsFloating a) => Fraction a where
truncate :: Value a -> CodeGenFunction r (Value a)
fraction :: Value a -> CodeGenFunction r (Value a)
instance Fraction Float where
truncate =
mapAuto
(LLVM.inttofp . flip asTypeOf (undefined :: Value Int32) <=< LLVM.fptoint)
(Ext.with X86.roundss $ \round x -> round x (valueOf 3))
fraction =
(\x ->
fractionGen x
`Ext.run`
(Ext.with X86.cmpss $ \cmp ->
fractionLogical (\modus -> curry (runScalar (uncurry (cmp modus)))) x))
`mapAuto`
(Ext.with X86.roundss $ \round x ->
A.sub x =<< round x (valueOf 1))
instance Fraction Double where
truncate =
mapAuto
(LLVM.inttofp . flip asTypeOf (undefined :: Value Int32) <=< LLVM.fptoint)
(Ext.with X86.roundsd $ \round x -> round x (valueOf 3))
fraction =
(\x ->
fractionGen x
`Ext.run`
(Ext.with X86.cmpsd $ \cmp ->
fractionLogical (\modus -> curry (runScalar (uncurry (cmp modus)))) x))
`mapAuto`
(Ext.with X86.roundsd $ \round x ->
A.sub x =<< round x (valueOf 1))
instance (TypeNum.Positive n, Vector.Real a, IsFloating a, IsConst a) =>
Fraction (Vector n a) where
truncate = Vector.truncate
fraction = Vector.fraction
signedFraction ::
(Fraction a) =>
Value a -> CodeGenFunction r (Value a)
signedFraction x =
A.sub x =<< truncate x
fractionGen ::
(IntegerConstant v, Fraction v, CmpRet v) =>
Value v -> CodeGenFunction r (Value v)
fractionGen x =
do xf <- signedFraction x
b <- A.fcmp LLVM.FPOGE xf zero
LLVM.select b xf =<< A.add xf (LLVM.value $ constFromInteger 1)
fractionLogical ::
(Fraction a, LLVM.IsScalarOrVector a, NumberOfElements a ~ D1,
IsInteger b, LLVM.IsScalarOrVector b, NumberOfElements b ~ D1) =>
(LLVM.FPPredicate ->
Value a -> Value a -> CodeGenFunction r (Value b)) ->
Value a -> CodeGenFunction r (Value a)
fractionLogical cmp x =
do xf <- signedFraction x
b <- cmp LLVM.FPOLT xf zero
A.sub xf =<< LLVM.inttofp b
addToPhase ::
(Fraction a) =>
Value a -> Value a -> CodeGenFunction r (Value a)
addToPhase d p =
fraction =<< A.add d p
incPhase ::
(Fraction a) =>
Value a -> Value a -> CodeGenFunction r (Value a)
incPhase d p =
signedFraction =<< A.add d p
truncateToInt ::
(IsFloating a, IsInteger i,
NumberOfElements a ~ NumberOfElements i) =>
Value a -> CodeGenFunction r (Value i)
truncateToInt = LLVM.fptoint
roundToIntFast ::
(IsFloating a, RationalConstant a, CmpRet a,
IsInteger i, IntegerConstant i, CmpRet i,
CmpResult a ~ CmpResult i,
NumberOfElements a ~ NumberOfElements i) =>
Value a -> CodeGenFunction r (Value i)
roundToIntFast x = do
pos <- A.cmp LLVM.CmpGT x zero
truncateToInt =<< A.add x =<<
LLVM.select pos (ratio 0.5) (ratio (0.5))
floorToInt ::
(IsFloating a, CmpRet a,
IsInteger i, IntegerConstant i, CmpRet i,
CmpResult a ~ CmpResult i,
NumberOfElements a ~ NumberOfElements i) =>
Value a -> CodeGenFunction r (Value i)
floorToInt x = do
i <- truncateToInt x
lt <- A.cmp LLVM.CmpLT x =<< LLVM.inttofp i
A.sub i =<< LLVM.select lt (int 1) (int 0)
splitFractionToInt ::
(IsFloating a, CmpRet a,
IsInteger i, IntegerConstant i, CmpRet i,
CmpResult a ~ CmpResult i,
NumberOfElements a ~ NumberOfElements i) =>
Value a -> CodeGenFunction r (Value i, Value a)
splitFractionToInt x = do
i <- floorToInt x
frac <- A.sub x =<< LLVM.inttofp i
return (i, frac)
ceilingToInt ::
(IsFloating a, CmpRet a,
IsInteger i, IntegerConstant i, CmpRet i,
CmpResult a ~ CmpResult i,
NumberOfElements a ~ NumberOfElements i) =>
Value a -> CodeGenFunction r (Value i)
ceilingToInt x = do
i <- truncateToInt x
gt <- A.cmp LLVM.CmpGT x =<< LLVM.inttofp i
A.add i =<< LLVM.select gt (int 1) (int 0)
zero :: (LLVM.IsType a) => Value a
zero = LLVM.value LLVM.zero
int :: (IntegerConstant a) => Integer -> Value a
int = LLVM.value . constFromInteger
ratio :: (RationalConstant a) => Rational -> Value a
ratio = LLVM.value . constFromRational
type family Scalar vector :: *
type instance Scalar Float = Float
type instance Scalar Double = Double
type instance Scalar FP128 = FP128
type instance Scalar Bool = Bool
type instance Scalar Int8 = Int8
type instance Scalar Int16 = Int16
type instance Scalar Int32 = Int32
type instance Scalar Int64 = Int64
type instance Scalar Word8 = Word8
type instance Scalar Word16 = Word16
type instance Scalar Word32 = Word32
type instance Scalar Word64 = Word64
type instance Scalar (Vector n a) = a
class Replicate vector where
replicate :: Value (Scalar vector) -> CodeGenFunction r (Value vector)
replicateConst :: ConstValue (Scalar vector) -> ConstValue vector
instance Replicate Float where replicate = return; replicateConst = id;
instance Replicate Double where replicate = return; replicateConst = id;
instance Replicate FP128 where replicate = return; replicateConst = id;
instance Replicate Bool where replicate = return; replicateConst = id;
instance Replicate Int8 where replicate = return; replicateConst = id;
instance Replicate Int16 where replicate = return; replicateConst = id;
instance Replicate Int32 where replicate = return; replicateConst = id;
instance Replicate Int64 where replicate = return; replicateConst = id;
instance Replicate Word8 where replicate = return; replicateConst = id;
instance Replicate Word16 where replicate = return; replicateConst = id;
instance Replicate Word32 where replicate = return; replicateConst = id;
instance Replicate Word64 where replicate = return; replicateConst = id;
instance (TypeNum.Positive n, LLVM.IsPrimitive a) => Replicate (Vector n a) where
replicate x = do
v <- singleton x
LLVM.shufflevector v (LLVM.value LLVM.undef) LLVM.zero
replicateConst x = LLVM.constCyclicVector $ NonEmpty.Cons x []
singleton ::
(LLVM.IsPrimitive a) =>
Value a -> CodeGenFunction r (Value (Vector D1 a))
singleton x =
LLVM.insertelement (LLVM.value LLVM.undef) x (valueOf 0)
replicateOf ::
(IsConst (Scalar v), Replicate v) =>
Scalar v -> Value v
replicateOf =
LLVM.value . replicateConst . LLVM.constOf
class (LLVM.IsArithmetic a) => Real a where
min :: Value a -> Value a -> CodeGenFunction r (Value a)
max :: Value a -> Value a -> CodeGenFunction r (Value a)
abs :: Value a -> CodeGenFunction r (Value a)
signum :: Value a -> CodeGenFunction r (Value a)
instance Real Float where
min = zipAutoWith A.min X86.minss
max = zipAutoWith A.max X86.maxss
abs = mapAuto A.abs X86.absss
signum = A.signum
instance Real Double where
min = zipAutoWith A.min X86.minsd
max = zipAutoWith A.max X86.maxsd
abs = mapAuto A.abs X86.abssd
signum = A.signum
instance Real FP128 where
min = A.min
max = A.max
abs = A.abs
signum x = do
minusOne <- LLVM.inttofp $ LLVM.valueOf (1 :: Int8)
one <- LLVM.inttofp $ LLVM.valueOf ( 1 :: Int8)
A.signumGen minusOne one x
infixl 1 `mapAuto`
runScalar ::
(Vector.C v, Vector.C w, Size v ~ Size w) =>
(v -> CodeGenFunction r w) ->
(Element v -> CodeGenFunction r (Element w))
runScalar op a =
Vector.extract (valueOf 0)
=<< op
=<< Vector.insert (valueOf 0) a Class.undefTuple
mapAuto ::
(Vector.C v, Vector.C w, Size v ~ Size w) =>
(Element v -> CodeGenFunction r (Element w)) ->
Ext.T (v -> CodeGenFunction r w) ->
(Element v -> CodeGenFunction r (Element w))
mapAuto f g a =
Ext.run (f a) $
Ext.with g $ \op -> runScalar op a
zipAutoWith ::
(Vector.C u, Vector.C v, Vector.C w,
Size u ~ Size v, Size v ~ Size w) =>
(Element u -> Element v -> CodeGenFunction r (Element w)) ->
Ext.T (u -> v -> CodeGenFunction r w) ->
(Element u -> Element v -> CodeGenFunction r (Element w))
zipAutoWith f g =
curry $ mapAuto (uncurry f) (fmap uncurry g)
instance Real Int8 where min = A.min; max = A.max; signum = A.signum; abs = A.abs;
instance Real Int16 where min = A.min; max = A.max; signum = A.signum; abs = A.abs;
instance Real Int32 where min = A.min; max = A.max; signum = A.signum; abs = A.abs;
instance Real Int64 where min = A.min; max = A.max; signum = A.signum; abs = A.abs;
instance Real Word8 where min = A.min; max = A.max; signum = A.signum; abs = return;
instance Real Word16 where min = A.min; max = A.max; signum = A.signum; abs = return;
instance Real Word32 where min = A.min; max = A.max; signum = A.signum; abs = return;
instance Real Word64 where min = A.min; max = A.max; signum = A.signum; abs = return;
instance (TypeNum.Positive n) => Real (IntN n) where
min = A.min; max = A.max; abs = A.abs
signum = A.signumGen (LLVM.valueOf $ IntN (1)) (LLVM.valueOf $ IntN 1)
instance (TypeNum.Positive n) => Real (WordN n) where
min = A.min; max = A.max; abs = return
signum = A.signumGen (LLVM.value LLVM.undef) (LLVM.valueOf $ WordN 1)
instance (TypeNum.Positive n, Vector.Real a) => Real (Vector n a) where
min = Vector.min
max = Vector.max
abs = Vector.abs
signum = Vector.signum
class
(LLVM.IsArithmetic (Scalar v), LLVM.IsArithmetic v) =>
PseudoModule v where
scale :: (a ~ Scalar v) => Value a -> Value v -> CodeGenFunction r (Value v)
scaleConst :: (a ~ Scalar v) => ConstValue a -> ConstValue v -> CodeGenFunction r (ConstValue v)
instance PseudoModule Word8 where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Word16 where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Word32 where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Word64 where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Int8 where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Int16 where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Int32 where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Int64 where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Float where scale = LLVM.mul; scaleConst = LLVM.mul
instance PseudoModule Double where scale = LLVM.mul; scaleConst = LLVM.mul
instance (LLVM.IsArithmetic a, LLVM.IsPrimitive a, TypeNum.Positive n) =>
PseudoModule (Vector n a) where
scale a v = flip A.mul v =<< replicate a
scaleConst a v = LLVM.mul (replicateConst a `asTypeOf` v) v
class (LLVM.IsConst a) => IntegerConstant a where
constFromInteger :: Integer -> ConstValue a
instance IntegerConstant Word8 where constFromInteger = constOf . fromInteger
instance IntegerConstant Word16 where constFromInteger = constOf . fromInteger
instance IntegerConstant Word32 where constFromInteger = constOf . fromInteger
instance IntegerConstant Word64 where constFromInteger = constOf . fromInteger
instance IntegerConstant Int8 where constFromInteger = constOf . fromInteger
instance IntegerConstant Int16 where constFromInteger = constOf . fromInteger
instance IntegerConstant Int32 where constFromInteger = constOf . fromInteger
instance IntegerConstant Int64 where constFromInteger = constOf . fromInteger
instance IntegerConstant Float where constFromInteger = constOf . fromInteger
instance IntegerConstant Double where constFromInteger = constOf . fromInteger
instance (TypeNum.Positive n) => IntegerConstant (WordN n) where
constFromInteger = constOf . WordN
instance (TypeNum.Positive n) => IntegerConstant (IntN n) where
constFromInteger = constOf . IntN
instance (IntegerConstant a, LLVM.IsPrimitive a, TypeNum.Positive n) =>
IntegerConstant (Vector n a) where
constFromInteger = replicateConst . constFromInteger
class (IntegerConstant a) => RationalConstant a where
constFromRational :: Rational -> ConstValue a
instance RationalConstant Float where constFromRational = constOf . fromRational
instance RationalConstant Double where constFromRational = constOf . fromRational
instance (RationalConstant a, LLVM.IsPrimitive a, TypeNum.Positive n) =>
RationalConstant (Vector n a) where
constFromRational = replicateConst . constFromRational
class (RationalConstant a) => TranscendentalConstant a where
constPi :: ConstValue a
instance TranscendentalConstant Float where constPi = constOf pi
instance TranscendentalConstant Double where constPi = constOf pi
instance (TranscendentalConstant a, LLVM.IsPrimitive a, TypeNum.Positive n) =>
TranscendentalConstant (Vector n a) where
constPi = replicateConst constPi