module LLVM.Extra.Arithmetic (
Additive (zero, add, sub, neg), one, inc, dec,
PseudoRing (mul), square,
Scalar,
PseudoModule (scale),
Field (fdiv),
IntegerConstant(fromInteger'),
RationalConstant(fromRational'),
idiv, irem,
FloatingComparison(fcmp), Comparison(cmp),
CmpResult, LLVM.CmpPredicate(..),
Logic (and, or, xor, inv),
Real (min, max, abs, signum),
Fraction (truncate, fraction),
signedFraction, addToPhase, incPhase,
advanceArrayElementPtr,
decreaseArrayElementPtr,
Algebraic (sqrt),
Transcendental (pi, sin, cos, exp, log, pow),
) where
import LLVM.Extra.ArithmeticPrivate
(inc, dec, advanceArrayElementPtr, decreaseArrayElementPtr, )
import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Util.Proxy as LP
import qualified LLVM.Core as LLVM
import LLVM.Core
(CodeGenFunction, value, Value, ConstValue,
IsType, IsInteger, IsFloating, IsArithmetic, IsFirstClass, )
import Control.Monad (liftM2, liftM3, )
import Prelude hiding
(Real, and, or, sqrt, sin, cos, exp, log, abs, min, max, truncate, )
class (Class.Zero a) => Additive a where
zero :: a
add :: a -> a -> CodeGenFunction r a
sub :: a -> a -> CodeGenFunction r a
neg :: a -> CodeGenFunction r a
instance (IsArithmetic a) => Additive (Value a) where
zero = LLVM.value LLVM.zero
add = LLVM.add
sub = LLVM.sub
neg = LLVM.neg
instance (IsArithmetic a) => Additive (ConstValue a) where
zero = LLVM.zero
add = LLVM.add
sub = LLVM.sub
neg = sub LLVM.zero
instance (Additive a, Additive b) => Additive (a,b) where
zero = (zero, zero)
add (x0,x1) (y0,y1) =
liftM2 (,) (add x0 y0) (add x1 y1)
sub (x0,x1) (y0,y1) =
liftM2 (,) (sub x0 y0) (sub x1 y1)
neg (x0,x1) =
liftM2 (,) (neg x0) (neg x1)
instance (Additive a, Additive b, Additive c) => Additive (a,b,c) where
zero = (zero, zero, zero)
add (x0,x1,x2) (y0,y1,y2) =
liftM3 (,,) (add x0 y0) (add x1 y1) (add x2 y2)
sub (x0,x1,x2) (y0,y1,y2) =
liftM3 (,,) (sub x0 y0) (sub x1 y1) (sub x2 y2)
neg (x0,x1,x2) =
liftM3 (,,) (neg x0) (neg x1) (neg x2)
class (Additive a) => PseudoRing a where
mul :: a -> a -> CodeGenFunction r a
instance (IsArithmetic v) => PseudoRing (Value v) where
mul = LLVM.mul
instance (IsArithmetic v) => PseudoRing (ConstValue v) where
mul = LLVM.mul
type family Scalar vector :: *
type instance Scalar (Value a) = Value (SoV.Scalar a)
type instance Scalar (ConstValue a) = ConstValue (SoV.Scalar a)
class (PseudoRing (Scalar v), Additive v) => PseudoModule v where
scale :: Scalar v -> v -> CodeGenFunction r v
instance (SoV.PseudoModule v) => PseudoModule (Value v) where
scale = SoV.scale
instance (SoV.PseudoModule v) => PseudoModule (ConstValue v) where
scale = SoV.scaleConst
class IntegerConstant a where
fromInteger' :: Integer -> a
instance SoV.IntegerConstant a => IntegerConstant (ConstValue a) where
fromInteger' = SoV.constFromInteger
instance SoV.IntegerConstant a => IntegerConstant (Value a) where
fromInteger' = value . SoV.constFromInteger
one :: (IntegerConstant a) => a
one = fromInteger' 1
_inc ::
(PseudoRing a, IntegerConstant a) =>
a -> CodeGenFunction r a
_inc x = add x one
_dec ::
(PseudoRing a, IntegerConstant a) =>
a -> CodeGenFunction r a
_dec x = sub x one
square ::
(PseudoRing a) =>
a -> CodeGenFunction r a
square x = mul x x
class (PseudoRing a) => Field a where
fdiv :: a -> a -> CodeGenFunction r a
instance (LLVM.IsFloating v) => Field (Value v) where
fdiv = LLVM.fdiv
instance (LLVM.IsFloating v) => Field (ConstValue v) where
fdiv = LLVM.fdiv
class (IntegerConstant a) => RationalConstant a where
fromRational' :: Rational -> a
instance SoV.RationalConstant a => RationalConstant (ConstValue a) where
fromRational' = SoV.constFromRational
instance SoV.RationalConstant a => RationalConstant (Value a) where
fromRational' = value . SoV.constFromRational
idiv ::
(IsInteger a) =>
Value a -> Value a -> CodeGenFunction r (Value a)
idiv = LLVM.idiv
irem ::
(IsInteger a) =>
Value a -> Value a -> CodeGenFunction r (Value a)
irem = LLVM.irem
class (Additive a) => Real a where
min :: a -> a -> CodeGenFunction r a
max :: a -> a -> CodeGenFunction r a
abs :: a -> CodeGenFunction r a
signum :: a -> CodeGenFunction r a
instance (SoV.Real a) => Real (Value a) where
min = SoV.min
max = SoV.max
abs = SoV.abs
signum = SoV.signum
class (Real a) => Fraction a where
truncate :: a -> CodeGenFunction r a
fraction :: a -> CodeGenFunction r a
instance (SoV.Fraction a) => Fraction (Value a) where
truncate = SoV.truncate
fraction = SoV.fraction
signedFraction ::
(Fraction a) =>
a -> CodeGenFunction r a
signedFraction x =
sub x =<< truncate x
addToPhase ::
(Fraction a) =>
a -> a -> CodeGenFunction r a
addToPhase d p =
fraction =<< add d p
incPhase ::
(Fraction a) =>
a -> a -> CodeGenFunction r a
incPhase d p =
signedFraction =<< add d p
class Comparison a where
type CmpResult a :: *
cmp :: LLVM.CmpPredicate -> a -> a -> CodeGenFunction r (CmpResult a)
instance (LLVM.CmpRet a) => Comparison (Value a) where
type CmpResult (Value a) = Value (LLVM.CmpResult a)
cmp = LLVM.cmp
instance (LLVM.CmpRet a) => Comparison (ConstValue a) where
type CmpResult (ConstValue a) = ConstValue (LLVM.CmpResult a)
cmp = LLVM.cmp
class (Comparison a) => FloatingComparison a where
fcmp :: LLVM.FPPredicate -> a -> a -> CodeGenFunction r (CmpResult a)
instance (IsFloating a, LLVM.CmpRet a) => FloatingComparison (Value a) where
fcmp = LLVM.fcmp
instance (IsFloating a, LLVM.CmpRet a) => FloatingComparison (ConstValue a) where
fcmp = LLVM.fcmp
class Logic a where
and :: a -> a -> CodeGenFunction r a
or :: a -> a -> CodeGenFunction r a
xor :: a -> a -> CodeGenFunction r a
inv :: a -> CodeGenFunction r a
instance (LLVM.IsInteger a) => Logic (Value a) where
and = LLVM.and
or = LLVM.or
xor = LLVM.xor
inv = LLVM.inv
instance (LLVM.IsInteger a) => Logic (ConstValue a) where
and = LLVM.and
or = LLVM.or
xor = LLVM.xor
inv = LLVM.inv
valueTypeName ::
(IsType a) =>
Value a -> String
valueTypeName =
LLVM.intrinsicTypeName . ((\_ -> LP.Proxy) :: Value a -> LP.Proxy a)
callIntrinsic1 ::
(IsFirstClass a) =>
String -> Value a -> CodeGenFunction r (Value a)
callIntrinsic1 fn x = do
op <- LLVM.externFunction ("llvm." ++ fn ++ "." ++ valueTypeName x)
LLVM.call op x >>= addReadNone
callIntrinsic2 ::
(IsFirstClass a) =>
String -> Value a -> Value a -> CodeGenFunction r (Value a)
callIntrinsic2 fn x y = do
op <- LLVM.externFunction ("llvm." ++ fn ++ "." ++ valueTypeName x)
LLVM.call op x y >>= addReadNone
addReadNone :: Value a -> CodeGenFunction r (Value a)
addReadNone x = do
return x
class Field a => Algebraic a where
sqrt :: a -> CodeGenFunction r a
instance (IsFloating a) => Algebraic (Value a) where
sqrt = callIntrinsic1 "sqrt"
class Algebraic a => Transcendental a where
pi :: CodeGenFunction r a
sin, cos, exp, log :: a -> CodeGenFunction r a
pow :: a -> a -> CodeGenFunction r a
instance (IsFloating a, SoV.TranscendentalConstant a) => Transcendental (Value a) where
pi = return $ value SoV.constPi
sin = callIntrinsic1 "sin"
cos = callIntrinsic1 "cos"
exp = callIntrinsic1 "exp"
log = callIntrinsic1 "log"
pow = callIntrinsic2 "pow"