module LLVM.Util.Arithmetic(
TValue,
(%==), (%/=), (%<), (%<=), (%>), (%>=),
(%&&), (%||),
(?), (??),
retrn, set,
ArithFunction, arithFunction,
ToArithFunction, toArithFunction, recursiveFunction,
CallIntrinsic,
) where
import qualified LLVM.Core as LLVM
import LLVM.Util.Loop (mapVector, mapVector2)
import LLVM.Util.Proxy (Proxy(Proxy))
import LLVM.Core
import qualified Type.Data.Num.Decimal.Number as Dec
import Control.Monad (liftM2)
type TValue r a = CodeGenFunction r (Value a)
infix 4 %==, %/=, %<, %<=, %>=, %>
(%==), (%/=), (%<), (%<=), (%>), (%>=) :: (CmpRet a) => TValue r a -> TValue r a -> TValue r (CmpResult a)
(%==) = binop $ LLVM.cmp CmpEQ
(%/=) = binop $ LLVM.cmp CmpNE
(%>) = binop $ LLVM.cmp CmpGT
(%>=) = binop $ LLVM.cmp CmpGE
(%<) = binop $ LLVM.cmp CmpLT
(%<=) = binop $ LLVM.cmp CmpLE
infixr 3 %&&
infixr 2 %||
(%&&) :: TValue r Bool -> TValue r Bool -> TValue r Bool
a %&& b = a ? (b, return (valueOf False))
(%||) :: TValue r Bool -> TValue r Bool -> TValue r Bool
a %|| b = a ? (return (valueOf True), b)
infix 0 ?
(?) :: (IsFirstClass a) => TValue r Bool -> (TValue r a, TValue r a) -> TValue r a
c ? (t, f) = do
lt <- newBasicBlock
lf <- newBasicBlock
lj <- newBasicBlock
c' <- c
condBr c' lt lf
defineBasicBlock lt
rt <- t
lt' <- getCurrentBasicBlock
br lj
defineBasicBlock lf
rf <- f
lf' <- getCurrentBasicBlock
br lj
defineBasicBlock lj
phi [(rt, lt'), (rf, lf')]
infix 0 ??
(??) :: (IsFirstClass a, CmpRet a) => TValue r (CmpResult a) -> (TValue r a, TValue r a) -> TValue r a
c ?? (t, f) = do
c' <- c
t' <- t
f' <- f
select c' t' f'
retrn :: (Ret (Value a) r) => TValue r a -> CodeGenFunction r ()
retrn x = x >>= ret
set :: TValue r a -> (CodeGenFunction r (TValue r a))
set x = do x' <- x; return (return x')
instance Eq (TValue r a)
instance Ord (TValue r a)
instance (IsArithmetic a, CmpRet a, Num a, IsConst a) => Num (TValue r a) where
(+) = binop add
() = binop sub
(*) = binop mul
negate = (>>= neg)
abs x = x %< 0 ?? (x, x)
signum x = x %< 0 ?? (1, x %> 0 ?? (1, 0))
fromInteger = return . valueOf . fromInteger
instance (IsArithmetic a, CmpRet a, Num a, IsConst a) => Enum (TValue r a) where
succ x = x + 1
pred x = x 1
fromEnum _ = error "CodeGenFunction Value: fromEnum"
toEnum = fromIntegral
instance (IsArithmetic a, CmpRet a, Num a, IsConst a) => Real (TValue r a) where
toRational _ = error "CodeGenFunction Value: toRational"
instance (CmpRet a, Num a, IsConst a, IsInteger a) => Integral (TValue r a) where
quot = binop idiv
rem = binop irem
quotRem x y = (quot x y, rem x y)
toInteger _ = error "CodeGenFunction Value: toInteger"
instance (CmpRet a, Fractional a, IsConst a, IsFloating a) => Fractional (TValue r a) where
(/) = binop fdiv
fromRational = return . valueOf . fromRational
instance (CmpRet a, Fractional a, IsConst a, IsFloating a) => RealFrac (TValue r a) where
properFraction _ = error "CodeGenFunction Value: properFraction"
instance (CmpRet a, CallIntrinsic a, Floating a, IsConst a, IsFloating a) => Floating (TValue r a) where
pi = return $ valueOf pi
sqrt = callIntrinsic1 "sqrt"
sin = callIntrinsic1 "sin"
cos = callIntrinsic1 "cos"
(**) = callIntrinsic2 "pow"
exp = callIntrinsic1 "exp"
log = callIntrinsic1 "log"
asin _ = error "LLVM missing intrinsic: asin"
acos _ = error "LLVM missing intrinsic: acos"
atan _ = error "LLVM missing intrinsic: atan"
sinh x = (exp x exp (x)) / 2
cosh x = (exp x + exp (x)) / 2
asinh x = log (x + sqrt (x*x + 1))
acosh x = log (x + sqrt (x*x 1))
atanh x = (log (1 + x) log (1 x)) / 2
instance (CmpRet a, CallIntrinsic a, RealFloat a, IsConst a, IsFloating a) => RealFloat (TValue r a) where
floatRadix _ = floatRadix (undefined :: a)
floatDigits _ = floatDigits (undefined :: a)
floatRange _ = floatRange (undefined :: a)
decodeFloat _ = error "CodeGenFunction Value: decodeFloat"
encodeFloat _ _ = error "CodeGenFunction Value: encodeFloat"
exponent _ = 0
scaleFloat 0 x = x
scaleFloat _ _ = error "CodeGenFunction Value: scaleFloat"
isNaN _ = error "CodeGenFunction Value: isNaN"
isInfinite _ = error "CodeGenFunction Value: isInfinite"
isDenormalized _ = error "CodeGenFunction Value: isDenormalized"
isNegativeZero _ = error "CodeGenFunction Value: isNegativeZero"
isIEEE _ = isIEEE (undefined :: a)
binop :: (Value a -> Value b -> TValue r c) ->
TValue r a -> TValue r b -> TValue r c
binop op x y = do
x' <- x
y' <- y
op x' y'
addReadNone :: Value a -> CodeGenFunction r (Value a)
addReadNone x = do
return x
callIntrinsicP1 :: forall a b r . (IsFirstClass a, IsFirstClass b, IsPrimitive a) =>
String -> Value a -> TValue r b
callIntrinsicP1 fn x = do
op <- externFunction ("llvm." ++ fn ++ "." ++ intrinsicTypeName (Proxy :: Proxy a))
runCall (callFromFunction op `applyCall` x) >>= addReadNone
callIntrinsicP2 :: forall a b c r . (IsFirstClass a, IsFirstClass b, IsFirstClass c, IsPrimitive a) =>
String -> Value a -> Value b -> TValue r c
callIntrinsicP2 fn x y = do
op <- externFunction ("llvm." ++ fn ++ "." ++ intrinsicTypeName (Proxy :: Proxy a))
runCall (callFromFunction op `applyCall` x `applyCall` y) >>= addReadNone
class ArithFunction r z a b | a -> b r z, b r z -> a where
arithFunction' :: a -> b
instance
(Ret a r) =>
ArithFunction r a (CodeGenFunction r a) (CodeGenFunction r ()) where
arithFunction' x = x >>= ret
instance
(ArithFunction r z b0 b1) =>
ArithFunction r z (CodeGenFunction r a -> b0) (a -> b1) where
arithFunction' f = arithFunction' . f . return
arithFunction :: ArithFunction r z a b => a -> b
arithFunction = arithFunction'
class ToArithFunction r a b | a r -> b, b -> a r where
toArithFunction' :: CodeGenFunction r (Call a) -> b
instance ToArithFunction r (IO b) (CodeGenFunction r (Value b)) where
toArithFunction' cl = cl >>= runCall
instance
ToArithFunction r b0 b1 =>
ToArithFunction r (a -> b0) (CodeGenFunction r (Value a) -> b1) where
toArithFunction' cl x =
toArithFunction' (liftM2 applyCall cl x)
_toArithFunction2 ::
Function (a -> b -> IO c) -> TValue r a -> TValue r b -> TValue r c
_toArithFunction2 f tx ty = do
x <- tx
y <- ty
runCall $ callFromFunction f `applyCall` x `applyCall` y
toArithFunction ::
(ToArithFunction r f g) =>
Function f -> g
toArithFunction f =
toArithFunction' $ return $ callFromFunction f
recursiveFunction ::
(IsFunction f, FunctionArgs f, code ~ FunctionCodeGen f,
ArithFunction r1 z arith code,
ToArithFunction r0 f g) =>
(g -> arith) -> CodeGenModule (Function f)
recursiveFunction af = do
f <- newFunction ExternalLinkage
defineFunction f $ arithFunction $ af $ toArithFunction f
return f
class CallIntrinsic a where
callIntrinsic1' :: String -> Value a -> TValue r a
callIntrinsic2' :: String -> Value a -> Value a -> TValue r a
instance CallIntrinsic Float where
callIntrinsic1' = callIntrinsicP1
callIntrinsic2' = callIntrinsicP2
instance CallIntrinsic Double where
callIntrinsic1' = callIntrinsicP1
callIntrinsic2' = callIntrinsicP2
macOS :: Bool
#if defined(__MACOS__)
macOS = True
#else
macOS = False
#endif
instance (Dec.Positive n, IsPrimitive a, CallIntrinsic a) => CallIntrinsic (Vector n a) where
callIntrinsic1' s x =
if macOS && Dec.integerFromSingleton (Dec.singleton :: Dec.Singleton n) == 4 &&
elem s ["sqrt", "log", "exp", "sin", "cos", "tan"]
then do
op <- externFunction ("v" ++ s ++ "f")
call op x >>= addReadNone
else mapVector (callIntrinsic1' s) x
callIntrinsic2' s = mapVector2 (callIntrinsic2' s)
callIntrinsic1 :: (CallIntrinsic a) => String -> TValue r a -> TValue r a
callIntrinsic1 s x = do x' <- x; callIntrinsic1' s x'
callIntrinsic2 :: (CallIntrinsic a) => String -> TValue r a -> TValue r a -> TValue r a
callIntrinsic2 s = binop (callIntrinsic2' s)