module LLVM.Extra.FastMath where
import qualified LLVM.Extra.Multi.Value.Private as MV
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Class as Class
import qualified LLVM.Core as LLVM
import LLVM.Util.Proxy (Proxy(Proxy))
import Foreign.Storable (Storable)
import qualified Control.Monad.HT as Monad
import Control.Applicative ((<$>))
data NoNaNs = NoNaNs deriving (Show, Eq)
data NoInfs = NoInfs deriving (Show, Eq)
data NoSignedZeros = NoSignedZeros deriving (Show, Eq)
data AllowReciprocal = AllowReciprocal deriving (Show, Eq)
data Fast = Fast deriving (Show, Eq)
class Flags flags where
setFlags ::
(LLVM.IsFloating a) =>
Proxy flags -> Bool -> LLVM.Value a -> LLVM.CodeGenFunction r ()
instance Flags NoNaNs where setFlags Proxy = LLVM.setHasNoNaNs
instance Flags NoInfs where setFlags Proxy = LLVM.setHasNoInfs
instance Flags NoSignedZeros where setFlags Proxy = LLVM.setHasNoSignedZeros
instance Flags AllowReciprocal where setFlags Proxy = LLVM.setHasAllowReciprocal
instance Flags Fast where setFlags Proxy = LLVM.setFastMath
instance (Flags f0, Flags f1) => Flags (f0,f1) where
setFlags p b v = setFlags (fst<$>p) b v >> setFlags (snd<$>p) b v
instance (Flags f0, Flags f1, Flags f2) => Flags (f0,f1,f2) where
setFlags = setSplitFlags $ \(f0,f1,f2) -> (f0,(f1,f2))
instance (Flags f0, Flags f1, Flags f2, Flags f3) => Flags (f0,f1,f2,f3) where
setFlags = setSplitFlags $ \(f0,f1,f2,f3) -> (f0,(f1,f2,f3))
instance
(Flags f0, Flags f1, Flags f2, Flags f3, Flags f4) =>
Flags (f0,f1,f2,f3,f4) where
setFlags = setSplitFlags $ \(f0,f1,f2,f3,f4) -> (f0,(f1,f2,f3,f4))
setSplitFlags ::
(Flags split, LLVM.IsFloating a) =>
(flags -> split) ->
Proxy flags -> Bool -> LLVM.Value a -> LLVM.CodeGenFunction r ()
setSplitFlags split p = setFlags (fmap split p)
newtype Number flags a = Number {deconsNumber :: a}
deriving (Eq, Ord, Show, Num, Fractional, Floating, Storable)
getNumber :: flags -> Number flags a -> a
getNumber _ (Number a) = a
instance MultiValue a => MV.C (Number flags a) where
type Repr f (Number flags a) = MV.Repr f a
cons = mvNumber . MV.cons . deconsNumber
undef = mvNumber MV.undef
zero = mvNumber MV.zero
phis bb = fmap mvNumber . MV.phis bb . mvDenumber
addPhis bb a b = MV.addPhis bb (mvDenumber a) (mvDenumber b)
mvNumber :: MV.T a -> MV.T (Number flags a)
mvNumber (MV.Cons a) = MV.Cons a
mvDenumber :: MV.T (Number flags a) -> MV.T a
mvDenumber (MV.Cons a) = MV.Cons a
class MV.C a => MultiValue a where
setMultiValueFlags ::
(Flags flags) =>
Proxy flags -> Bool -> MV.T (Number flags a) -> LLVM.CodeGenFunction r ()
instance MultiValue Float where
setMultiValueFlags p b (MV.Cons a) = setFlags p b a
instance MultiValue Double where
setMultiValueFlags p b (MV.Cons a) = setFlags p b a
type Id a = a -> a
attachMultiValueFlags ::
(Flags flags, MultiValue a) =>
Id (LLVM.CodeGenFunction r (MV.T (Number flags a)))
attachMultiValueFlags act = do
mv <- act
setMultiValueFlags Proxy True mv
return mv
liftNumberM ::
(m ~ LLVM.CodeGenFunction r, Flags flags, MultiValue b) =>
(MV.T a -> m (MV.T b)) ->
MV.T (Number flags a) -> m (MV.T (Number flags b))
liftNumberM f =
attachMultiValueFlags . Monad.lift mvNumber . f . mvDenumber
liftNumberM2 ::
(m ~ LLVM.CodeGenFunction r, Flags flags, MultiValue c) =>
(MV.T a -> MV.T b -> m (MV.T c)) ->
MV.T (Number flags a) -> MV.T (Number flags b) -> m (MV.T (Number flags c))
liftNumberM2 f a b =
attachMultiValueFlags $ Monad.lift mvNumber $ f (mvDenumber a) (mvDenumber b)
instance (Flags flags, MV.Compose a) => MV.Compose (Number flags a) where
type Composed (Number flags a) = Number flags (MV.Composed a)
compose = mvNumber . MV.compose . deconsNumber
instance (Flags flags, MV.Decompose pa) => MV.Decompose (Number flags pa) where
decompose (Number p) = Number . MV.decompose p . mvDenumber
type instance
MV.Decomposed f (Number flags pa) = Number flags (MV.Decomposed f pa)
type instance
MV.PatternTuple (Number flags pa) = Number flags (MV.PatternTuple pa)
instance
(Flags flags, MultiValue a, MV.IntegerConstant a) =>
MV.IntegerConstant (Number flags a) where
fromInteger' = mvNumber . MV.fromInteger'
instance
(Flags flags, MultiValue a, MV.RationalConstant a) =>
MV.RationalConstant (Number flags a) where
fromRational' = mvNumber . MV.fromRational'
instance
(Flags flags, MultiValue a, MV.Additive a) =>
MV.Additive (Number flags a) where
add = liftNumberM2 MV.add
sub = liftNumberM2 MV.sub
neg = liftNumberM MV.neg
instance
(Flags flags, MultiValue a, MV.PseudoRing a) =>
MV.PseudoRing (Number flags a) where
mul = liftNumberM2 MV.mul
instance
(Flags flags, MultiValue a, MV.Field a) =>
MV.Field (Number flags a) where
fdiv = liftNumberM2 MV.fdiv
type instance MV.Scalar (Number flags a) = Number flags (MV.Scalar a)
instance
(Flags flags, MultiValue a, a ~ MV.Scalar v,
MultiValue v, MV.PseudoModule v) =>
MV.PseudoModule (Number flags v) where
scale = liftNumberM2 MV.scale
instance
(Flags flags, MultiValue a, MV.Real a) =>
MV.Real (Number flags a) where
min = liftNumberM2 MV.min
max = liftNumberM2 MV.max
abs = liftNumberM MV.abs
signum = liftNumberM MV.signum
instance
(Flags flags, MultiValue a, MV.Fraction a) =>
MV.Fraction (Number flags a) where
truncate = liftNumberM MV.truncate
fraction = liftNumberM MV.fraction
instance
(Flags flags, MultiValue a, MV.Algebraic a) =>
MV.Algebraic (Number flags a) where
sqrt = liftNumberM MV.sqrt
instance
(Flags flags, MultiValue a, MV.Transcendental a) =>
MV.Transcendental (Number flags a) where
pi = fmap mvNumber MV.pi
sin = liftNumberM MV.sin
cos = liftNumberM MV.cos
exp = liftNumberM MV.exp
log = liftNumberM MV.log
pow = liftNumberM2 MV.pow
instance
(Flags flags, MultiValue a, MV.Select a) =>
MV.Select (Number flags a) where
select = liftNumberM2 . MV.select
instance
(Flags flags, MultiValue a, MV.Comparison a) =>
MV.Comparison (Number flags a) where
cmp p a b = MV.cmp p (mvDenumber a) (mvDenumber b)
instance
(Flags flags, MultiValue a, MV.FloatingComparison a) =>
MV.FloatingComparison (Number flags a) where
fcmp p a b = MV.fcmp p (mvDenumber a) (mvDenumber b)
class Tuple a where
setTupleFlags ::
(Flags flags) => Proxy flags -> Bool -> a -> LLVM.CodeGenFunction r ()
instance (LLVM.IsFloating a) => Tuple (LLVM.Value a) where
setTupleFlags = setFlags
newtype Context flags a = Context a
proxyFromContext :: Context flags a -> Proxy flags
proxyFromContext (Context _) = Proxy
instance
(Flags flags, Class.Zero a, Tuple a) =>
Class.Zero (Context flags a) where
zeroTuple = Context Class.zeroTuple
instance
(Flags flags, Tuple a, A.Additive a) =>
A.Additive (Context flags a) where
zero = Context A.zero
add = liftContext2 A.add
sub = liftContext2 A.sub
neg = liftContext A.neg
instance
(Flags flags, A.PseudoRing a, Tuple a) =>
A.PseudoRing (Context flags a) where
mul = liftContext2 A.mul
type instance A.Scalar (Context flags a) = Context flags (A.Scalar a)
instance
(Flags flags, A.PseudoModule v, Tuple v, A.Scalar v ~ a, Tuple a) =>
A.PseudoModule (Context flags v) where
scale = liftContext2 A.scale
instance
(Flags flags, Tuple a, A.IntegerConstant a) =>
A.IntegerConstant (Context flags a) where
fromInteger' = Context . A.fromInteger'
instance
(Flags flags, Tuple v, A.Field v) =>
A.Field (Context flags v) where
fdiv = liftContext2 A.fdiv
instance
(Flags flags, Tuple a, A.RationalConstant a) =>
A.RationalConstant (Context flags a) where
fromRational' = Context . A.fromRational'
instance (Flags flags, Tuple a, A.Real a) => A.Real (Context flags a) where
min = liftContext2 A.min
max = liftContext2 A.max
abs = liftContext A.abs
signum = liftContext A.signum
instance
(Flags flags, Tuple a, A.Fraction a) =>
A.Fraction (Context flags a) where
truncate = liftContext A.truncate
fraction = liftContext A.fraction
instance
(Flags flags, Tuple a, A.Comparison a) =>
A.Comparison (Context flags a) where
type CmpResult (Context flags a) = A.CmpResult a
cmp p (Context x) (Context y) = A.cmp p x y
instance
(Flags flags, Tuple a, A.FloatingComparison a) =>
A.FloatingComparison (Context flags a) where
fcmp p (Context x) (Context y) = A.fcmp p x y
instance
(Flags flags, Tuple a, A.Algebraic a) =>
A.Algebraic (Context flags a) where
sqrt = liftContext A.sqrt
instance
(Flags flags, Tuple a, A.Transcendental a) =>
A.Transcendental (Context flags a) where
pi = attachTupleFlags A.pi
sin = liftContext A.sin
cos = liftContext A.cos
exp = liftContext A.exp
log = liftContext A.log
pow = liftContext2 A.pow
attachTupleFlags ::
(Flags flags, Tuple a) =>
Id (LLVM.CodeGenFunction r (Context flags a))
attachTupleFlags act = do
c@(Context x) <- act
setTupleFlags (proxyFromContext c) True x
return c
liftContext :: (Flags flags, Tuple b) =>
(a -> LLVM.CodeGenFunction r b) ->
Context flags a -> LLVM.CodeGenFunction r (Context flags b)
liftContext f (Context x) = attachTupleFlags (Context <$> f x)
liftContext2 :: (Flags flags, Tuple c) =>
(a -> b -> LLVM.CodeGenFunction r c) ->
Context flags a -> Context flags b ->
LLVM.CodeGenFunction r (Context flags c)
liftContext2 f (Context x) = liftContext $ f x