{-# LANGUAGE TypeFamilies #-}
module LLVM.Extra.Scalar where

import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Arithmetic as A

import qualified LLVM.Util.Loop as Loop
import LLVM.Util.Loop (Phi, )

import qualified Control.Monad as Monad


{- |
The entire purpose of this datatype is to mark a type as scalar,
although it might also be interpreted as vector.
This way you can write generic operations for vectors
using the 'A.PseudoModule' class,
and specialise them to scalar types with respect to the 'A.PseudoRing' class.
From another perspective
you can consider the 'Scalar.T' type constructor a marker
where the 'A.Scalar' type function
stops reducing nested vector types to scalar types.
-}
newtype T a = Cons {decons :: a}

liftM :: (Monad m) => (a -> m b) -> T a -> m (T b)
liftM f (Cons a) = Monad.liftM Cons $ f a

liftM2 :: (Monad m) => (a -> b -> m c) -> T a -> T b -> m (T c)
liftM2 f (Cons a) (Cons b) = Monad.liftM Cons $ f a b


unliftM ::
   (Monad m) =>
   (T a -> m (T r)) ->
   a -> m r
unliftM f a =
   Monad.liftM decons $ f (Cons a)

unliftM2 ::
   (Monad m) =>
   (T a -> T b -> m (T r)) ->
   a -> b -> m r
unliftM2 f a b =
   Monad.liftM decons $ f (Cons a) (Cons b)

unliftM3 ::
   (Monad m) =>
   (T a -> T b -> T c -> m (T r)) ->
   a -> b -> c -> m r
unliftM3 f a b c =
   Monad.liftM decons $ f (Cons a) (Cons b) (Cons c)

unliftM4 ::
   (Monad m) =>
   (T a -> T b -> T c -> T d -> m (T r)) ->
   a -> b -> c -> d -> m r
unliftM4 f a b c d =
   Monad.liftM decons $ f (Cons a) (Cons b) (Cons c) (Cons d)

unliftM5 ::
   (Monad m) =>
   (T a -> T b -> T c -> T d -> T e -> m (T r)) ->
   a -> b -> c -> d -> e -> m r
unliftM5 f a b c d e =
   Monad.liftM decons $ f (Cons a) (Cons b) (Cons c) (Cons d) (Cons e)


instance (Class.Zero a) => Class.Zero (T a) where
   zeroTuple = Cons Class.zeroTuple

instance (Class.Undefined a) => Class.Undefined (T a) where
   undefTuple = Cons Class.undefTuple

instance (Phi a) => Phi (T a) where
   phis bb = fmap Cons . Loop.phis bb . decons
   addPhis bb (Cons a) (Cons b) = Loop.addPhis bb a b

instance (A.IntegerConstant a) => A.IntegerConstant (T a) where
   fromInteger' = Cons . A.fromInteger'

instance (A.RationalConstant a) => A.RationalConstant (T a) where
   fromRational' = Cons . A.fromRational'

instance (A.Additive a) => A.Additive (T a) where
   zero = Cons A.zero
   add = liftM2 A.add
   sub = liftM2 A.sub
   neg = liftM A.neg

instance (A.PseudoRing a) => A.PseudoRing (T a) where
   mul = liftM2 A.mul

instance (A.Field a) => A.Field (T a) where
   fdiv = liftM2 A.fdiv

type instance A.Scalar (T a) = T a

instance (A.PseudoRing a) => A.PseudoModule (T a) where
   scale = liftM2 A.mul


instance (A.Real a) => A.Real (T a) where
   min = liftM2 A.min
   max = liftM2 A.max
   abs = liftM A.abs
   signum = liftM A.signum

instance (A.Fraction a) => A.Fraction (T a) where
   truncate = liftM A.truncate
   fraction = liftM A.fraction

instance (A.Algebraic a) => A.Algebraic (T a) where
   sqrt = liftM A.sqrt

instance (A.Transcendental a) => A.Transcendental (T a) where
   pi = fmap Cons A.pi
   sin = liftM A.sin
   cos = liftM A.cos
   exp = liftM A.exp
   log = liftM A.log
   pow = liftM2 A.pow