{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{- |
Some special operations on X86 processors.
If you want to use them in algorithms
you will always have to prepare an alternative implementation
in terms of plain LLVM instructions.
You will then run them with 'Ext.run'
and this driver function then selects the most advanced of both implementations.
Functions that are written this way can be found in "LLVM.Extra.Vector".
Availability of extensions is checked with the @CPUID@ instruction.
However this does only work if you compile code for the host machine,
that is cross compilation will fail!
For cross compilation we would need access to the SubTarget detection of LLVM
that is only available in the C++ interface in version 2.6.
-}
module LLVM.Extra.Extension.X86 (
   X86.maxss, X86.minss, X86.maxps, X86.minps,
   X86.maxsd, X86.minsd, X86.maxpd, X86.minpd,
   cmpss, cmpps, cmpsd, cmppd, cmpps256, cmppd256,
   pcmpgtb,  pcmpgtw,  pcmpgtd,  pcmpgtq,
   pcmpugtb, pcmpugtw, pcmpugtd, pcmpugtq,
   pminsb, pminsw, pminsd,
   pmaxsb, pmaxsw, pmaxsd,
   pminub, pminuw, pminud,
   pmaxub, pmaxuw, pmaxud,
   pabsb, pabsw, pabsd,
   pmuludq, pmuldq,
   pmulld,
   cvtps2dq, cvtpd2dq,
   cvtdq2ps, cvtdq2pd,
   ldmxcsr, stmxcsr, withMXCSR,
   X86.haddps, X86.haddpd, X86.dpps, X86.dppd,
   roundss, X86.roundps, roundsd, X86.roundpd,
   absss, abssd, absps, abspd,
   X86.psllw128, X86.pslld128, X86.psllq128,
   X86.psrlw128, X86.psrld128, X86.psrlq128,
   X86.psraw128, X86.psrad128,
   X86.paddsb128, X86.paddsw128, X86.paddusb128, X86.paddusw128,
   X86.psubsb128, X86.psubsw128, X86.psubusb128, X86.psubusw128,
   ) where

import qualified LLVM.Extra.Extension.X86Auto as X86
import qualified LLVM.Extra.Extension as Ext
import LLVM.Extra.Extension.X86Auto (
          V2Double, V4Float,
          V2Int64, V2Word64,
          V4Int32, V4Word32,
          V8Int16, V8Word16,
          V16Int8, V16Word8,
          )
import LLVM.Extra.ExtensionCheck.X86
          (sse1, sse2, sse41, sse42, )

import qualified LLVM.Extra.ArithmeticPrivate as A
import qualified LLVM.Core as LLVM
import LLVM.Core
   (Value, Vector, valueOf, constOf, vector,
    CodeGenFunction, FPPredicate, )

import qualified Type.Data.Num.Decimal as TypeNum

import qualified Data.NonEmpty.Class as NonEmptyC
import qualified Data.Empty as Empty
import Data.NonEmpty ((!:), )

import Data.Bits (clearBit, complement, )
import Data.Word (Word8, Word32, Word64, )

import qualified Control.Monad.HT as M
import Control.Monad.HT ((<=<), )
import Control.Applicative (pure, )

import Foreign.Ptr (Ptr, )

import Prelude2010
import Prelude ()


switchFPPred ::
   (Num i, LLVM.IsConst i, LLVM.IsInteger i, LLVM.IsPrimitive i,
    LLVM.IsFirstClass v,
    TypeNum.Positive n,
    LLVM.IsSized v, LLVM.IsSized (Vector n i),
    LLVM.SizeOf v ~ LLVM.SizeOf (Vector n i)) =>
   (Value v -> Value v -> Value Word8 -> CodeGenFunction r (Value v)) ->
   FPPredicate -> Value v -> Value v -> CodeGenFunction r (Value (Vector n i))
switchFPPred g p x y =
   let f i x0 y0 = LLVM.bitcast =<< g x0 y0 (valueOf i)
   in  case p of
          LLVM.FPFalse -> return (LLVM.value LLVM.zero)
          LLVM.FPOEQ   -> f 0 x y
          LLVM.FPOGT   -> f 1 y x
          LLVM.FPOGE   -> f 2 y x
          LLVM.FPOLT   -> f 1 x y
          LLVM.FPOLE   -> f 2 x y
          LLVM.FPONE   -> M.liftJoin2 A.and (f 7 x y) (f 4 x y)
          LLVM.FPORD   -> f 7 x y
          LLVM.FPUNO   -> f 3 x y
          LLVM.FPUEQ   -> M.liftJoin2 A.or (f 3 x y) (f 0 x y)
          LLVM.FPUGT   -> f 6 x y
          LLVM.FPUGE   -> f 5 x y
          LLVM.FPULT   -> f 6 y x
          LLVM.FPULE   -> f 5 y x
          LLVM.FPUNE   -> f 4 x y
          LLVM.FPT     -> return (valueOf $ pure (-1))

cmpss :: Ext.T (FPPredicate -> V4Float -> V4Float -> CodeGenFunction r V4Int32)
cmpss = fmap switchFPPred X86.cmpss

cmpps :: Ext.T (FPPredicate -> V4Float -> V4Float -> CodeGenFunction r V4Int32)
cmpps = fmap switchFPPred X86.cmpps

cmpsd :: Ext.T (FPPredicate -> V2Double -> V2Double -> CodeGenFunction r V2Int64)
cmpsd = fmap switchFPPred X86.cmpsd

cmppd :: Ext.T (FPPredicate -> V2Double -> V2Double -> CodeGenFunction r V2Int64)
cmppd = fmap switchFPPred X86.cmppd

cmpps256 :: Ext.T (FPPredicate -> X86.V8Float -> X86.V8Float -> CodeGenFunction r X86.V8Int32)
cmpps256 = fmap switchFPPred X86.cmpps256

cmppd256 :: Ext.T (FPPredicate -> X86.V4Double -> X86.V4Double -> CodeGenFunction r X86.V4Int64)
cmppd256 = fmap switchFPPred X86.cmppd256


pcmpgtb :: Ext.T (V16Int8 -> V16Int8 -> CodeGenFunction r V16Int8)
pcmpgtb = Ext.intrinsic sse2 "pcmpgt.b"

pcmpgtw :: Ext.T (V8Int16 -> V8Int16 -> CodeGenFunction r V8Int16)
pcmpgtw = Ext.intrinsic sse2 "pcmpgt.w"

pcmpgtd :: Ext.T (V4Int32 -> V4Int32 -> CodeGenFunction r V4Int32)
pcmpgtd = Ext.intrinsic sse2 "pcmpgt.d"

pcmpgtq :: Ext.T (V2Int64 -> V2Int64 -> CodeGenFunction r V2Int64)
pcmpgtq = Ext.intrinsic sse42 "pcmpgtq"


pcmpuFromPcmp ::
   (TypeNum.Positive n,
    LLVM.IsPrimitive s,
    LLVM.IsPrimitive u, LLVM.IsArithmetic u, LLVM.IsConst u,
    Bounded u, Integral u,
    LLVM.IsSized (Vector n s), LLVM.IsSized (Vector n u),
    LLVM.SizeOf (Vector n s) ~ LLVM.SizeOf (Vector n u)) =>
   Ext.T (Value (Vector n s) -> Value (Vector n s) -> CodeGenFunction r (Value (Vector n s))) ->
   Ext.T (Value (Vector n u) -> Value (Vector n u) -> CodeGenFunction r (Value (Vector n u)))
pcmpuFromPcmp pcmp =
   Ext.with pcmp $ \cmp x y -> do
      let offset = valueOf $ pure (1 + div maxBound 2)
      xa <- LLVM.bitcast =<< A.sub x offset
      ya <- LLVM.bitcast =<< A.sub y offset
      LLVM.bitcast =<< cmp xa ya

pcmpugtb :: Ext.T (V16Word8 -> V16Word8 -> CodeGenFunction r V16Word8)
pcmpugtb = pcmpuFromPcmp pcmpgtb

pcmpugtw :: Ext.T (V8Word16 -> V8Word16 -> CodeGenFunction r V8Word16)
pcmpugtw = pcmpuFromPcmp pcmpgtw

pcmpugtd :: Ext.T (V4Word32 -> V4Word32 -> CodeGenFunction r V4Word32)
pcmpugtd = pcmpuFromPcmp pcmpgtd

pcmpugtq :: Ext.T (V2Word64 -> V2Word64 -> CodeGenFunction r V2Word64)
pcmpugtq = pcmpuFromPcmp pcmpgtq


pminsb, pmaxsb :: Ext.T (V16Int8 -> V16Int8 -> CodeGenFunction r V16Int8)
pminsb = X86.pminsb128
pmaxsb = X86.pmaxsb128

pminsw, pmaxsw :: Ext.T (V8Int16 -> V8Int16 -> CodeGenFunction r V8Int16)
pminsw = X86.pminsw128
pmaxsw = X86.pmaxsw128

pminsd, pmaxsd :: Ext.T (V4Int32 -> V4Int32 -> CodeGenFunction r V4Int32)
pminsd = X86.pminsd128
pmaxsd = X86.pmaxsd128


pminub, pmaxub :: Ext.T (V16Word8 -> V16Word8 -> CodeGenFunction r V16Word8)
pminub = X86.pminub128
pmaxub = X86.pmaxub128

pminuw, pmaxuw :: Ext.T (V8Word16 -> V8Word16 -> CodeGenFunction r V8Word16)
pminuw = X86.pminuw128
pmaxuw = X86.pmaxuw128

pminud, pmaxud :: Ext.T (V4Word32 -> V4Word32 -> CodeGenFunction r V4Word32)
pminud = X86.pminud128
pmaxud = X86.pmaxud128


pabsb :: Ext.T (V16Int8 -> CodeGenFunction r V16Int8)
pabsb = X86.pabsb128

pabsw :: Ext.T (V8Int16 -> CodeGenFunction r V8Int16)
pabsw = X86.pabsw128

pabsd :: Ext.T (V4Int32 -> CodeGenFunction r V4Int32)
pabsd = X86.pabsd128


pmuludq :: Ext.T (V4Word32 -> V4Word32 -> CodeGenFunction r V2Word64)
pmuludq = X86.pmuludq128

pmuldq :: Ext.T (V4Int32 -> V4Int32 -> CodeGenFunction r V2Int64)
pmuldq = X86.pmuldq128

pmulld :: Ext.T (V4Word32 -> V4Word32 -> CodeGenFunction r V4Word32)
pmulld = Ext.wrap sse41 LLVM.mul
-- pmulld = Ext.intrinsic sse41 "pmulld"


cvtps2dq :: Ext.T (V4Float -> CodeGenFunction r V4Int32)
cvtps2dq = X86.cvtps2dq

-- | the upper two integers are set to zero, there is no instruction that converts to Int64
cvtpd2dq :: Ext.T (V2Double -> CodeGenFunction r V4Int32)
cvtpd2dq = X86.cvtpd2dq


cvtdq2ps :: Ext.T (V4Int32 -> CodeGenFunction r V4Float)
cvtdq2ps = X86.cvtdq2ps

-- | the upper two integers are ignored, there is no instruction that converts from Int64
cvtdq2pd :: Ext.T (V4Int32 -> CodeGenFunction r V2Double)
cvtdq2pd = X86.cvtdq2pd


valueUnit :: Value () -> ()
valueUnit _ = ()

{- |
MXCSR is not really supported by LLVM-2.6.
LLVM does not know about the dependency of all floating point operations
on this status register.
-}
ldmxcsr :: Ext.T (Value (Ptr Word32) -> CodeGenFunction r ())
ldmxcsr =
   fmap (fmap valueUnit .) $ Ext.intrinsicAttr [] sse1 "ldmxcsr"

stmxcsr :: Ext.T (Value (Ptr Word32) -> CodeGenFunction r ())
stmxcsr =
   fmap (fmap valueUnit .) $ Ext.intrinsicAttr [] sse1 "stmxcsr"

withMXCSR :: Word32 -> Ext.T (CodeGenFunction r a -> CodeGenFunction r a)
withMXCSR mxcsr =
   Ext.with2 ldmxcsr stmxcsr $ \ ld st f -> do
      mxcsrOld <- LLVM.alloca
      st mxcsrOld
      mxcsrFloor <- LLVM.alloca
      LLVM.store (valueOf $ mxcsr) mxcsrFloor
{- unfortunately, createGlobal is a function CodeGenModule monad
      mxcsrFloor <-
         LLVM.createGlobal True LLVM.InternalLinkage mxcsr
-}
      ld mxcsrFloor
      r <- f
      ld mxcsrOld
      return r

{-
[maxsd, minsd, maxpd, minpd] =
   map (Ext.intrinsic sse2)
     ["max.ss", "min.ss", "max.ps", "min.ps"]
-}

roundss :: Ext.T (V4Float -> Value Word32 -> CodeGenFunction r V4Float)
roundss =
   fmap (\f -> f (LLVM.value LLVM.undef)) X86.roundss

roundsd :: Ext.T (V2Double -> Value Word32 -> CodeGenFunction r V2Double)
roundsd =
   fmap (\f -> f (LLVM.value LLVM.undef)) X86.roundsd



{-
Not an LLVM intrinsic but implementation specific:
We expect that floating point values are in IEEE format
and thus the most significant bit is the sign.
The absolute value can be computed very efficiently by clearing the sign bit.
Actually, LLVM's codegen implements neg by an XOR on the sign bit.
-}
absss :: Ext.T (V4Float -> CodeGenFunction r V4Float)
absss =
   Ext.wrap sse1 $
   LLVM.bitcast
     <=< A.and (LLVM.valueOf $ vector $
           (flip clearBit 31 $ complement 0) !: NonEmptyC.repeat (complement 0)
            :: V4Word32)
     <=< LLVM.bitcast

{-
This function works on a single Float,
but I like to do the masking in an XMM register
because usually the value is there anyway.

absss =
   flip LLVM.extractelement (valueOf 0)
     . flip asTypeOf (undefined :: V4Float)
     <=< LLVM.bitcast
--        <=< A.and (LLVM.value $ constVector [constOf 0x7FFFFFFF] :: V4Word32)
--        <=< A.and (LLVM.value $ constVector [constOf 0x7FFFFFFF, LLVM.undef, LLVM.undef, LLVM.undef] :: V4Word32)
     <=< A.and (LLVM.value $ constVector [constOf 0x7FFFFFFF, LLVM.zero, LLVM.zero, LLVM.zero] :: V4Word32)
     <=< LLVM.bitcast
     . flip asTypeOf (undefined :: V4Float)
     <=< flip (LLVM.insertelement (LLVM.value LLVM.undef)) (valueOf 0)
-}
{- This moves the value to a general purpose register and performs the bit masking there
absss =
   LLVM.bitcast
     <=< A.and (valueOf 0x7FFFFFFF :: Value Word32)
     <=< LLVM.bitcast
-}

abssd :: Ext.T (V2Double -> CodeGenFunction r V2Double)
abssd =
   Ext.wrap sse2 $
   LLVM.bitcast
     <=< A.and (LLVM.valueOf $ vector $
            (flip clearBit 63 $ complement 0) !: complement 0 !: Empty.Cons
            :: V2Word64)
     <=< LLVM.bitcast


mask ::
   (TypeNum.Positive n, LLVM.IsConst w, LLVM.IsPrimitive w, LLVM.IsInteger w) =>
   w -> Value (Vector n w) -> CodeGenFunction r (Value (Vector n w))
mask x =
   A.and (LLVM.valueOf $ pure x)

absps ::
   (TypeNum.Positive n) =>
   Ext.T (Value (Vector n Float) -> CodeGenFunction r (Value (Vector n Float)))
absps =
   Ext.wrap sse1 $
   LLVM.bitcastElements
     <=< mask (flip clearBit 31 $ complement 0 :: Word32)
     <=< LLVM.bitcastElements

abspd ::
   (TypeNum.Positive n) =>
   Ext.T (Value (Vector n Double) -> CodeGenFunction r (Value (Vector n Double)))
abspd =
   Ext.wrap sse2 $
   LLVM.bitcastElements
     <=< mask (flip clearBit 63 $ complement 0 :: Word64)
     <=< LLVM.bitcastElements

{- |
cumulative sum:
@(a,b,c,d) -> (a,a+b,a+b+c,a+b+c+d)@

I try to cleverly use horizontal add,
but the generic version in the Vector module is better.
-}
_cumulate1s :: Ext.T (V4Float -> CodeGenFunction r V4Float)
_cumulate1s = Ext.with X86.haddps $ \haddp x -> do
   y <- haddp x (LLVM.value LLVM.undef)
   z <- LLVM.shufflevector x y $
      constOf $ vector $ 0!:4!:2!:5!:Empty.Cons
   offset <- LLVM.shufflevector y (LLVM.value LLVM.zero) $
      constOf $ vector $ 4!:5!:0!:0!:Empty.Cons
   A.add z offset