{-# LANGUAGE CPP               #-}
{-# LANGUAGE FlexibleContexts  #-}
{-# LANGUAGE OverloadedStrings #-}

-- | Marshal a limited subset of J arrays into Repa arrays.
--
-- = Tutorial
--
-- Suppose we wish to perform linear regression. In J we could do:
--
-- @
-- xs := 1 2 3
-- ys := 2 4 6
--
-- reg_result =: ys %. xs ^/ i.2
-- @
--
-- To do this with Haskell data:
--
-- @
-- do
--    jenv <- 'jinit' 'libLinux'
--
--    let hsArr0 = R.fromListUnboxed (R.ix1 3) [1.0,2.0,3.0]
--        hsArr1 = R.fromListUnboxed (R.ix1 3) [2.0,4.0,6.0]
--        jArr0 = 'JDoubleArr' $ R.copyS $ R.map (realToFrac :: Double -> CDouble) hsArr0
--        jArr1 = 'JDoubleArr' $ R.copyS $ R.map (realToFrac :: Double -> CDouble) hsArr1
--
--    'setJData' jenv "xs" jArr0
--    'setJData' jenv "ys" jArr1
--
--    'bsDispatch' jenv "reg_result =: ys %. xs ^/ i.2"
--
--    'JDoubleArr' res <- 'getJData' jenv "reg_result"
--    R.toList res
-- @
--
-- There are three steps to do the calculation, plus one to get a J environment.
--
--     (1) Use 'jinit' with the appropriate file path for your platform
--
--     2. Marshal Haskell values and send them to the J environment. To do so, we
--     use 'setJData', which takes a 'JData' containing a repa array or
--     a string.
--
--     3. Perform calculations within the J environment. Here, we use
--     'bsDispatch' to compute some results and assign them within J
--
--     4. Marshal J values back to Haskell. We use 'getJData'.
--
--
--  Since marshaling data between J and Haskell is expensive, it's best to do as
--  much computation as possible in J.
--
--  == Loading Profile
--
--  If you would like to use user libraries, you need to use 'jLoad' on the
--  'JEnv'. As an example:
--
--  @
--  do
--      jenv <- 'jinit' 'libLinux'
--      'jLoad' jenv ('linuxProfile' "9.01")
--      'bsDispatch' 'jenv' "load'tables/csv'"
--  @
--
--  This will load the CSV addon, assuming it is installed.
--
--  = FFI
--
--  If you want to marshal data yourself, say to use a @Vector@, look at 'JEnv'.
module Language.J ( -- * Environment
                    JEnv (..)
                  , jinit
                  , jLoad
                  , Profile (..)
                  , linuxProfile
                  , macProfile
                  , windowsProfile
#ifndef mingw32_HOST_OS
                  , libLinux
                  , libMac
                  , profLinux
#else
                  , libWindows
#endif
                  , bsDispatch
                  , bsOut
                  , JVersion
                  -- * Repa
                  , JData (..)
                  , getJData
                  , setJData
                  -- * FFI
                  , J
                  , JDoType
                  , JGetMType
                  , JGetRType
                  , JSetAType
                  ) where

import           Control.Applicative             (pure, (<$>), (<*>))
import qualified Data.Array.Repa                 as R
import qualified Data.Array.Repa.Repr.ForeignPtr as RF
import qualified Data.ByteString                 as BS
import qualified Data.ByteString.Char8           as ASCII
import qualified Data.ByteString.Internal        as BS
import           Data.Complex                    (Complex (..))
import           Data.Functor                    (void)
import           Data.Semigroup                  ((<>))
import           Foreign.C.String                (CString)
import           Foreign.C.Types                 (CChar, CDouble, CInt (..), CLLong (..))
import           Foreign.ForeignPtr              (ForeignPtr, castForeignPtr, mallocForeignPtrBytes, withForeignPtr)
import           Foreign.Marshal                 (alloca, copyArray, mallocBytes, peekArray, pokeArray)
import           Foreign.Ptr                     (FunPtr, Ptr, castPtrToFunPtr, plusPtr)
import           Foreign.Storable                (Storable, peek, pokeByteOff, sizeOf)
import           System.Info                     (arch)
#ifndef mingw32_HOST_OS
import           System.Posix.ByteString         (RTLDFlags (RTLD_LAZY), RawFilePath, dlopen, dlsym)
#else
import           System.Win32.DLL                (getProcAddress, loadLibrary)
#endif

-- Upstream reference
-- https://github.com/jsoftware/stats_jserver4r/blob/4c94fc6df351fab34791aa9d78d158eaefd33b17/source/lib/j2r.c
-- https://github.com/jsoftware/stats_jserver4r/blob/4c94fc6df351fab34791aa9d78d158eaefd33b17/source/lib/r2j.c

-- | Abstract context
data J

data JEnv = JEnv { JEnv -> Ptr J
context   :: Ptr J
                 , JEnv -> JDoType
evaluator :: JDoType
                 , JEnv -> JGetMType
reader    :: JGetMType
                 , JEnv -> JGetRType
out       :: JGetRType
                 , JEnv -> JSetAType
setter    :: JSetAType
                 }

type JDoType = Ptr J -> CString -> IO CInt
type JGetMType = Ptr J -> CString -> Ptr CLLong -> Ptr CLLong -> Ptr (Ptr CLLong) -> Ptr (Ptr CChar) -> IO CInt
type JGetRType = Ptr J -> IO CString
type JSetAType = Ptr J -> CLLong -> CString -> CLLong -> Ptr () -> IO CInt

foreign import ccall "dynamic" mkJDo :: FunPtr JDoType -> JDoType
foreign import ccall "dynamic" mkJInit :: FunPtr (IO (Ptr J)) -> IO (Ptr J)
foreign import ccall "dynamic" mkJGetM :: FunPtr JGetMType -> JGetMType
foreign import ccall "dynamic" mkJGetR :: FunPtr JGetRType -> JGetRType
foreign import ccall "dynamic" mkJSetA :: FunPtr JSetAType -> JSetAType

type JVersion = [Int]

squashVersion :: JVersion -> String
squashVersion :: JVersion -> String
squashVersion = (Int -> String) -> JVersion -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Int -> String
forall a. Show a => a -> String
show

squashVersionBS :: JVersion -> BS.ByteString
squashVersionBS :: JVersion -> ByteString
squashVersionBS = String -> ByteString
ASCII.pack (String -> ByteString)
-> (JVersion -> String) -> JVersion -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JVersion -> String
squashVersion

#ifndef mingw32_HOST_OS
-- | Expected 'RawFilePath' to the library on a Linux machine.
libLinux :: RawFilePath
libLinux :: ByteString
libLinux = ByteString
"/usr/lib/" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> String -> ByteString
ASCII.pack String
arch ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"-linux-gnu/libj.so"

-- | Expected 'RawFilePath' to the library on Mac.
libMac :: JVersion -> RawFilePath
libMac :: JVersion -> ByteString
libMac JVersion
v = ByteString
"/Applications/j64-" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> JVersion -> ByteString
squashVersionBS JVersion
v ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"/bin/libj.dylib"
#else
-- | @since 0.1.1.0
libWindows :: JVersion -> FilePath
libWindows v = "C:\\Program Files\\J" <> squashVersion v <> "\\bin\\j.dll"
#endif

profLinux :: BS.ByteString -> BS.ByteString
profLinux :: ByteString -> ByteString
profLinux ByteString
v = ByteString
"/etc/j/" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
v ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"/profile.ijs"

binpathLinux :: BS.ByteString
binpathLinux :: ByteString
binpathLinux = ByteString
"/usr/bin"

dllLinux :: BS.ByteString -> BS.ByteString
dllLinux :: ByteString -> ByteString
dllLinux ByteString
v = ByteString
"libj.so." ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
v

-- | @since 0.1.2.0
linuxProfile :: BS.ByteString -- ^ J version, e.g. @"9.01"@
             -> Profile
linuxProfile :: ByteString -> Profile
linuxProfile ByteString
ver = ByteString -> ByteString -> ByteString -> Profile
Profile (ByteString -> ByteString
profLinux ByteString
ver) ByteString
binpathLinux (ByteString -> ByteString
dllLinux ByteString
ver)

-- | @since 0.1.2.0
macProfile :: JVersion
           -> Profile
macProfile :: JVersion -> Profile
macProfile JVersion
v =
    let binPathMac :: ByteString
binPathMac = ByteString
"/Applications/j64-" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> JVersion -> ByteString
squashVersionBS JVersion
v ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"/bin"
        in ByteString -> ByteString -> ByteString -> Profile
Profile (ByteString
binPathMac ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"/profile.ijs") ByteString
binPathMac (ByteString
binPathMac ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"/libj.dylib")

-- | @since 0.1.2.0
windowsProfile :: JVersion
               -> Profile
windowsProfile :: JVersion -> Profile
windowsProfile JVersion
v =
    let binPathWindows :: ByteString
binPathWindows = ByteString
"C:\\Program Files\\J" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> JVersion -> ByteString
squashVersionBS JVersion
v ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\\bin"
        in ByteString -> ByteString -> ByteString -> Profile
Profile (ByteString
binPathWindows ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\\profile.ijs") ByteString
binPathWindows (ByteString
binPathWindows ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"j.dll")

-- | @since 0.1.2.0
data Profile = Profile { Profile -> ByteString
profPath :: BS.ByteString -- ^ @profile.ijs@
                       , Profile -> ByteString
binPath  :: BS.ByteString
                       , Profile -> ByteString
dllName  :: BS.ByteString
                       }

-- | Load user profile.
--
-- @since 0.1.2.0@
jLoad :: JEnv
      -> Profile
      -> IO ()
jLoad :: JEnv -> Profile -> IO ()
jLoad JEnv
jenv (Profile ByteString
fp ByteString
bin ByteString
dll) = JEnv -> ByteString -> IO ()
bsDispatch JEnv
jenv (ByteString
"(3 : '0!:0 y')<'"ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
fp ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"'[BINPATH_z_=:'" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
bin ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"'[LIBFILE_z_=:'" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
dll ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"'[ARGV_z_=:''")

-- | Get a J environment
--
-- Passing the resultant 'JEnv' between threads can cause unexpected bugs.
#ifndef mingw32_HOST_OS
jinit :: RawFilePath -- ^ Path to J library
      -> IO JEnv
jinit :: ByteString -> IO JEnv
jinit ByteString
libFp = do
    DL
libj <- ByteString -> [RTLDFlags] -> IO DL
dlopen ByteString
libFp [RTLDFlags
RTLD_LAZY]
    Ptr J
jt <- FunPtr (IO (Ptr J)) -> IO (Ptr J)
mkJInit (FunPtr (IO (Ptr J)) -> IO (Ptr J))
-> IO (FunPtr (IO (Ptr J))) -> IO (Ptr J)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DL -> String -> IO (FunPtr (IO (Ptr J)))
forall a. DL -> String -> IO (FunPtr a)
dlsym DL
libj String
"JInit"
    let jeval :: IO JDoType
jeval = FunPtr JDoType -> JDoType
mkJDo (FunPtr JDoType -> JDoType) -> IO (FunPtr JDoType) -> IO JDoType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DL -> String -> IO (FunPtr JDoType)
forall a. DL -> String -> IO (FunPtr a)
dlsym DL
libj String
"JDo"
    let jread :: IO JGetMType
jread = FunPtr JGetMType -> JGetMType
mkJGetM (FunPtr JGetMType -> JGetMType)
-> IO (FunPtr JGetMType) -> IO JGetMType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DL -> String -> IO (FunPtr JGetMType)
forall a. DL -> String -> IO (FunPtr a)
dlsym DL
libj String
"JGetM"
    let jOut :: IO JGetRType
jOut = FunPtr JGetRType -> JGetRType
mkJGetR (FunPtr JGetRType -> JGetRType)
-> IO (FunPtr JGetRType) -> IO JGetRType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DL -> String -> IO (FunPtr JGetRType)
forall a. DL -> String -> IO (FunPtr a)
dlsym DL
libj String
"JGetR"
    let jSet :: IO JSetAType
jSet = FunPtr JSetAType -> JSetAType
mkJSetA (FunPtr JSetAType -> JSetAType)
-> IO (FunPtr JSetAType) -> IO JSetAType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DL -> String -> IO (FunPtr JSetAType)
forall a. DL -> String -> IO (FunPtr a)
dlsym DL
libj String
"JSetA"
    Ptr J -> JDoType -> JGetMType -> JGetRType -> JSetAType -> JEnv
JEnv Ptr J
jt (JDoType -> JGetMType -> JGetRType -> JSetAType -> JEnv)
-> IO JDoType -> IO (JGetMType -> JGetRType -> JSetAType -> JEnv)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO JDoType
jeval IO (JGetMType -> JGetRType -> JSetAType -> JEnv)
-> IO JGetMType -> IO (JGetRType -> JSetAType -> JEnv)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO JGetMType
jread IO (JGetRType -> JSetAType -> JEnv)
-> IO JGetRType -> IO (JSetAType -> JEnv)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO JGetRType
jOut IO (JSetAType -> JEnv) -> IO JSetAType -> IO JEnv
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO JSetAType
jSet
#else
jinit :: FilePath
      -> IO JEnv
jinit libFp = do
    libj <- loadLibrary libFp
    jt <- mkJInit . castPtrToFunPtr =<< getProcAddress libj "JInit"
    let jeval = mkJDo . castPtrToFunPtr <$> getProcAddress libj "JDo"
    let jread = mkJGetM . castPtrToFunPtr <$> getProcAddress libj "JGetM"
    let jOut = mkJGetR . castPtrToFunPtr <$> getProcAddress libj "JGetR"
    let jSet = mkJSetA . castPtrToFunPtr <$> getProcAddress libj "JSetA"
    JEnv jt <$> jeval <*> jread <*> jOut <*> jSet
#endif


-- | Send some J code to the environment.
bsDispatch :: JEnv -> BS.ByteString -> IO ()
bsDispatch :: JEnv -> ByteString -> IO ()
bsDispatch (JEnv Ptr J
ctx JDoType
jdo JGetMType
_ JGetRType
_ JSetAType
_) ByteString
bs =
    IO CInt -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> (CString -> IO CInt) -> IO CInt
forall a. ByteString -> (CString -> IO a) -> IO a
BS.useAsCString ByteString
bs ((CString -> IO CInt) -> IO CInt)
-> (CString -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ JDoType
jdo Ptr J
ctx

-- | Read last output
--
-- For debugging
bsOut :: JEnv -> IO BS.ByteString
bsOut :: JEnv -> IO ByteString
bsOut (JEnv Ptr J
ctx JDoType
_ JGetMType
_ JGetRType
jout JSetAType
_) = CString -> IO ByteString
BS.packCString (CString -> IO ByteString) -> IO CString -> IO ByteString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< JGetRType
jout Ptr J
ctx

-- | \( O(n) \) in the array size
getJData :: R.Shape sh
         => JEnv -> BS.ByteString -- ^ Name of the value in question
         -> IO (JData sh)
getJData :: JEnv -> ByteString -> IO (JData sh)
getJData JEnv
jenv ByteString
bs = JAtom -> JData sh
forall sh. Shape sh => JAtom -> JData sh
jData (JAtom -> JData sh) -> IO JAtom -> IO (JData sh)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JEnv -> ByteString -> IO JAtom
getAtomInternal JEnv
jenv ByteString
bs

getAtomInternal :: JEnv -> BS.ByteString -- ^ Name of the value in question
                -> IO JAtom
getAtomInternal :: JEnv -> ByteString -> IO JAtom
getAtomInternal (JEnv Ptr J
ctx JDoType
_ JGetMType
jget JGetRType
_ JSetAType
_) ByteString
bs = do
    ByteString -> (CString -> IO JAtom) -> IO JAtom
forall a. ByteString -> (CString -> IO a) -> IO a
BS.useAsCString ByteString
bs ((CString -> IO JAtom) -> IO JAtom)
-> (CString -> IO JAtom) -> IO JAtom
forall a b. (a -> b) -> a -> b
$ \CString
name ->
        (Ptr CLLong -> IO JAtom) -> IO JAtom
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CLLong -> IO JAtom) -> IO JAtom)
-> (Ptr CLLong -> IO JAtom) -> IO JAtom
forall a b. (a -> b) -> a -> b
$ \Ptr CLLong
t ->
        (Ptr (Ptr CLLong) -> IO JAtom) -> IO JAtom
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr (Ptr CLLong) -> IO JAtom) -> IO JAtom)
-> (Ptr (Ptr CLLong) -> IO JAtom) -> IO JAtom
forall a b. (a -> b) -> a -> b
$ \Ptr (Ptr CLLong)
s ->
        (Ptr CLLong -> IO JAtom) -> IO JAtom
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CLLong -> IO JAtom) -> IO JAtom)
-> (Ptr CLLong -> IO JAtom) -> IO JAtom
forall a b. (a -> b) -> a -> b
$ \Ptr CLLong
r ->
        (Ptr CString -> IO JAtom) -> IO JAtom
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CString -> IO JAtom) -> IO JAtom)
-> (Ptr CString -> IO JAtom) -> IO JAtom
forall a b. (a -> b) -> a -> b
$ \Ptr CString
d -> do
            JGetMType
jget Ptr J
ctx CString
name Ptr CLLong
t Ptr CLLong
r Ptr (Ptr CLLong)
s Ptr CString
d
            JType
ty' <- CLLong -> JType
intToJType (CLLong -> JType) -> IO CLLong -> IO JType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr CLLong -> IO CLLong
forall a. Storable a => Ptr a -> IO a
peek Ptr CLLong
t
            CLLong
rank' <- Ptr CLLong -> IO CLLong
forall a. Storable a => Ptr a -> IO a
peek Ptr CLLong
r
            let intRank :: Int
intRank = CLLong -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CLLong
rank'
            [CLLong]
shape' <- Int -> Ptr CLLong -> IO [CLLong]
forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray Int
intRank (Ptr CLLong -> IO [CLLong]) -> IO (Ptr CLLong) -> IO [CLLong]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr (Ptr CLLong) -> IO (Ptr CLLong)
forall a. Storable a => Ptr a -> IO a
peek Ptr (Ptr CLLong)
s
            let mult :: Int
mult = JType -> Int
jTypeWidth JType
ty'
            let resBytes :: Int
resBytes = Int
mult Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
intRank
            ForeignPtr CChar
res <- Int -> IO (ForeignPtr CChar)
forall a. Int -> IO (ForeignPtr a)
mallocForeignPtrBytes Int
resBytes
            let arrSz :: Int
arrSz = Int
mult Int -> Int -> Int
forall a. Num a => a -> a -> a
* CLLong -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([CLLong] -> CLLong
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [CLLong]
shape')
            ForeignPtr CChar -> (CString -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CChar
res ((CString -> IO ()) -> IO ()) -> (CString -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \CString
r' -> do
                CString
d' <- Ptr CString -> IO CString
forall a. Storable a => Ptr a -> IO a
peek Ptr CString
d
                CString -> CString -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray CString
r' CString
d' Int
arrSz
            JAtom -> IO JAtom
forall (f :: * -> *) a. Applicative f => a -> f a
pure (JAtom -> IO JAtom) -> JAtom -> IO JAtom
forall a b. (a -> b) -> a -> b
$ JType -> [CLLong] -> ForeignPtr CChar -> JAtom
JAtom JType
ty' [CLLong]
shape' ForeignPtr CChar
res

data JAtom = JAtom !JType ![CLLong] !(ForeignPtr CChar)

-- | J data backed by repa array
data JData sh = JIntArr !(R.Array RF.F sh CLLong)
              | JDoubleArr !(R.Array RF.F sh CDouble)
              | JComplexArr !(R.Array RF.F sh (Complex CDouble))
              | JBoolArr !(R.Array RF.F sh CChar)
              | JString !BS.ByteString

-- | \( O(n) \) in the array size
setJData :: (R.Shape sh) => JEnv -> BS.ByteString -- ^ Name
                         -> JData sh -> IO CInt
setJData :: JEnv -> ByteString -> JData sh -> IO CInt
setJData (JEnv Ptr J
ctx JDoType
_ JGetMType
_ JGetRType
_ JSetAType
jset) ByteString
name (JIntArr Array F sh CLLong
iarr) = ByteString -> (CStringLen -> IO CInt) -> IO CInt
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BS.useAsCStringLen ByteString
name ((CStringLen -> IO CInt) -> IO CInt)
-> (CStringLen -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \(CString
n, Int
sz) -> do
    (CLLong
ds, Ptr ()
d) <- JType -> Array F sh CLLong -> IO (CLLong, Ptr ())
forall sh e.
(Shape sh, Storable e) =>
JType -> Array F sh e -> IO (CLLong, Ptr ())
repaArr JType
JInteger Array F sh CLLong
iarr
    JSetAType
jset Ptr J
ctx (Int -> CLLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
sz) CString
n CLLong
ds Ptr ()
d
setJData (JEnv Ptr J
ctx JDoType
_ JGetMType
_ JGetRType
_ JSetAType
jset) ByteString
name (JDoubleArr Array F sh CDouble
iarr) = ByteString -> (CStringLen -> IO CInt) -> IO CInt
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BS.useAsCStringLen ByteString
name ((CStringLen -> IO CInt) -> IO CInt)
-> (CStringLen -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \(CString
n, Int
sz) -> do
    (CLLong
ds, Ptr ()
d) <- JType -> Array F sh CDouble -> IO (CLLong, Ptr ())
forall sh e.
(Shape sh, Storable e) =>
JType -> Array F sh e -> IO (CLLong, Ptr ())
repaArr JType
JDouble Array F sh CDouble
iarr
    JSetAType
jset Ptr J
ctx (Int -> CLLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
sz) CString
n CLLong
ds Ptr ()
d
setJData (JEnv Ptr J
ctx JDoType
_ JGetMType
_ JGetRType
_ JSetAType
jset) ByteString
name (JComplexArr Array F sh (Complex CDouble)
iarr) = ByteString -> (CStringLen -> IO CInt) -> IO CInt
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BS.useAsCStringLen ByteString
name ((CStringLen -> IO CInt) -> IO CInt)
-> (CStringLen -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \(CString
n, Int
sz) -> do
    (CLLong
ds, Ptr ()
d) <- JType -> Array F sh (Complex CDouble) -> IO (CLLong, Ptr ())
forall sh e.
(Shape sh, Storable e) =>
JType -> Array F sh e -> IO (CLLong, Ptr ())
repaArr JType
JComplex Array F sh (Complex CDouble)
iarr
    JSetAType
jset Ptr J
ctx (Int -> CLLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
sz) CString
n CLLong
ds Ptr ()
d
setJData (JEnv Ptr J
ctx JDoType
_ JGetMType
_ JGetRType
_ JSetAType
jset) ByteString
name (JBoolArr Array F sh CChar
iarr) = ByteString -> (CStringLen -> IO CInt) -> IO CInt
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BS.useAsCStringLen ByteString
name ((CStringLen -> IO CInt) -> IO CInt)
-> (CStringLen -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \(CString
n, Int
sz) -> do
    (CLLong
ds, Ptr ()
d) <- JType -> Array F sh CChar -> IO (CLLong, Ptr ())
forall sh e.
(Shape sh, Storable e) =>
JType -> Array F sh e -> IO (CLLong, Ptr ())
repaArr JType
JBool Array F sh CChar
iarr
    JSetAType
jset Ptr J
ctx (Int -> CLLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
sz) CString
n CLLong
ds Ptr ()
d
setJData (JEnv Ptr J
ctx JDoType
_ JGetMType
_ JGetRType
_ JSetAType
jset) ByteString
name (JString ByteString
bs) = ByteString -> (CStringLen -> IO CInt) -> IO CInt
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BS.useAsCStringLen ByteString
name ((CStringLen -> IO CInt) -> IO CInt)
-> (CStringLen -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \(CString
n, Int
sz) -> do
    (CLLong
ds, Ptr ()
d) <- ByteString -> IO (CLLong, Ptr ())
strArr ByteString
bs
    JSetAType
jset Ptr J
ctx (Int -> CLLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
sz) CString
n CLLong
ds Ptr ()
d

-- | Return a @'Ptr' ()@ suitable to be passed to @JSetA@
--
-- To be used on integer, double, and complex arrays
repaArr :: (R.Shape sh, Storable e) => JType -> R.Array RF.F sh e -> IO (CLLong, Ptr ())
repaArr :: JType -> Array F sh e -> IO (CLLong, Ptr ())
repaArr JType
jty Array F sh e
arr = do
    let (CLLong
rank', [CLLong]
sh) = Array F sh e -> (CLLong, [CLLong])
forall r e sh.
(Source r e, Shape sh) =>
Array r sh e -> (CLLong, [CLLong])
repaSize Array F sh e
arr
        sz :: CLLong
sz = [CLLong] -> CLLong
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [CLLong]
sh
    let wid :: CLLong
wid = CLLong
32 CLLong -> CLLong -> CLLong
forall a. Num a => a -> a -> a
+ (Int -> CLLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CLLong) -> Int -> CLLong
forall a b. (a -> b) -> a -> b
$ JType -> Int
jTypeWidth JType
jty) CLLong -> CLLong -> CLLong
forall a. Num a => a -> a -> a
* (CLLong
rank' CLLong -> CLLong -> CLLong
forall a. Num a => a -> a -> a
+ CLLong
sz)
    Ptr ()
ptr <- Int -> IO (Ptr ())
forall a. Int -> IO (Ptr a)
mallocBytes (CLLong -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CLLong
wid)
    Ptr () -> Int -> CLLong -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr ()
ptr Int
0 (CLLong
227 :: CLLong) -- I think this is because it's non-boxed
    Ptr () -> Int -> CLLong -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr ()
ptr (CLLong -> Int
forall a. Storable a => a -> Int
sizeOf (CLLong
forall a. HasCallStack => a
undefined :: CLLong)) (JType -> CLLong
jTypeToInt JType
jty)
    Ptr () -> Int -> CLLong -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr ()
ptr (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* CLLong -> Int
forall a. Storable a => a -> Int
sizeOf (CLLong
forall a. HasCallStack => a
undefined :: CLLong)) CLLong
sz
    Ptr () -> Int -> CLLong -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr ()
ptr (Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
* CLLong -> Int
forall a. Storable a => a -> Int
sizeOf (CLLong
forall a. HasCallStack => a
undefined :: CLLong)) CLLong
rank'
    let dimOff :: Int
dimOff = Int
4 Int -> Int -> Int
forall a. Num a => a -> a -> a
* CLLong -> Int
forall a. Storable a => a -> Int
sizeOf (CLLong
forall a. HasCallStack => a
undefined :: CLLong)
    Ptr CLLong -> [CLLong] -> IO ()
forall a. Storable a => Ptr a -> [a] -> IO ()
pokeArray (Ptr ()
ptr Ptr () -> Int -> Ptr CLLong
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
dimOff) [CLLong]
sh
    let dataOff :: Int
dataOff = Int
dimOff Int -> Int -> Int
forall a. Num a => a -> a -> a
+ CLLong -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CLLong
rank' Int -> Int -> Int
forall a. Num a => a -> a -> a
* CLLong -> Int
forall a. Storable a => a -> Int
sizeOf (CLLong
forall a. HasCallStack => a
undefined :: CLLong)
    ForeignPtr e -> (Ptr e -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr (Array F sh e -> ForeignPtr e
forall sh e. Array F sh e -> ForeignPtr e
RF.toForeignPtr Array F sh e
arr) ((Ptr e -> IO ()) -> IO ()) -> (Ptr e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr e
src ->
        Ptr e -> Ptr e -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray (Ptr ()
ptr Ptr () -> Int -> Ptr e
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
dataOff) Ptr e
src (CLLong -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CLLong
sz)
    (CLLong, Ptr ()) -> IO (CLLong, Ptr ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CLLong
wid, Ptr ()
ptr)

strArr :: BS.ByteString -> IO (CLLong, Ptr ())
strArr :: ByteString -> IO (CLLong, Ptr ())
strArr ByteString
bs = do
    let len :: Int
len = ByteString -> Int
BS.length ByteString
bs
        wid :: Int
wid = Int
40 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8)
        len' :: CLLong
len' = Int -> CLLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len :: CLLong
    Ptr ()
ptr <- Int -> IO (Ptr ())
forall a. Int -> IO (Ptr a)
mallocBytes Int
wid
    Ptr () -> Int -> CLLong -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr ()
ptr Int
0 (CLLong
227 :: CLLong)
    Ptr () -> Int -> CLLong -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr ()
ptr (CLLong -> Int
forall a. Storable a => a -> Int
sizeOf (CLLong
forall a. HasCallStack => a
undefined :: CLLong)) (JType -> CLLong
jTypeToInt JType
JChar)
    Ptr () -> Int -> CLLong -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr ()
ptr (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* CLLong -> Int
forall a. Storable a => a -> Int
sizeOf (CLLong
forall a. HasCallStack => a
undefined :: CLLong)) CLLong
len'
    Ptr () -> Int -> CLLong -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr ()
ptr (Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
* CLLong -> Int
forall a. Storable a => a -> Int
sizeOf (CLLong
forall a. HasCallStack => a
undefined :: CLLong)) (CLLong
1 :: CLLong)
    Ptr () -> Int -> CLLong -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr ()
ptr (Int
4 Int -> Int -> Int
forall a. Num a => a -> a -> a
* CLLong -> Int
forall a. Storable a => a -> Int
sizeOf (CLLong
forall a. HasCallStack => a
undefined :: CLLong)) CLLong
len'
    let dataOff :: Int
dataOff = Int
5 Int -> Int -> Int
forall a. Num a => a -> a -> a
* CLLong -> Int
forall a. Storable a => a -> Int
sizeOf (CLLong
forall a. HasCallStack => a
undefined :: CLLong)
    ByteString -> (CString -> IO ()) -> IO ()
forall a. ByteString -> (CString -> IO a) -> IO a
BS.useAsCString ByteString
bs ((CString -> IO ()) -> IO ()) -> (CString -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \CString
pSrc ->
        CString -> CString -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray (Ptr ()
ptr Ptr () -> Int -> CString
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
dataOff) CString
pSrc Int
len
    (CLLong, Ptr ()) -> IO (CLLong, Ptr ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> CLLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
wid, Ptr ()
ptr)

repaSize :: (R.Source r e, R.Shape sh) => R.Array r sh e -> (CLLong, [CLLong])
repaSize :: Array r sh e -> (CLLong, [CLLong])
repaSize Array r sh e
arr = let sh :: sh
sh = Array r sh e -> sh
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
R.extent Array r sh e
arr in (Int -> CLLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CLLong) -> Int -> CLLong
forall a b. (a -> b) -> a -> b
$ sh -> Int
forall sh. Shape sh => sh -> Int
R.rank sh
sh, Int -> CLLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CLLong) -> JVersion -> [CLLong]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> sh -> JVersion
forall sh. Shape sh => sh -> JVersion
R.listOfShape sh
sh)

-- | J types
data JType = JBool
           | JChar
           | JInteger
           | JDouble
           | JComplex

intToJType :: CLLong -> JType
intToJType :: CLLong -> JType
intToJType CLLong
1  = JType
JBool
intToJType CLLong
2  = JType
JChar
intToJType CLLong
4  = JType
JInteger
intToJType CLLong
8  = JType
JDouble
intToJType CLLong
16 = JType
JComplex
intToJType CLLong
_  = String -> JType
forall a. HasCallStack => String -> a
error String
"Unsupported type!"

jTypeToInt :: JType -> CLLong
jTypeToInt :: JType -> CLLong
jTypeToInt JType
JBool    = CLLong
1
jTypeToInt JType
JChar    = CLLong
2
jTypeToInt JType
JInteger = CLLong
4
jTypeToInt JType
JDouble  = CLLong
8
jTypeToInt JType
JComplex = CLLong
16

jTypeWidth :: JType -> Int
jTypeWidth :: JType -> Int
jTypeWidth JType
JBool    = CChar -> Int
forall a. Storable a => a -> Int
sizeOf (CChar
forall a. HasCallStack => a
undefined :: CChar)
jTypeWidth JType
JChar    = CChar -> Int
forall a. Storable a => a -> Int
sizeOf (CChar
forall a. HasCallStack => a
undefined :: CChar)
jTypeWidth JType
JInteger = CLLong -> Int
forall a. Storable a => a -> Int
sizeOf (CLLong
forall a. HasCallStack => a
undefined :: CLLong)
jTypeWidth JType
JDouble  = CDouble -> Int
forall a. Storable a => a -> Int
sizeOf (CDouble
forall a. HasCallStack => a
undefined :: CDouble)
jTypeWidth JType
JComplex = Complex CDouble -> Int
forall a. Storable a => a -> Int
sizeOf (Complex CDouble
forall a. HasCallStack => a
undefined :: Complex CDouble)

jData :: R.Shape sh => JAtom -> JData sh
jData :: JAtom -> JData sh
jData (JAtom JType
JInteger [CLLong]
sh ForeignPtr CChar
fp) = Array F sh CLLong -> JData sh
forall sh. Array F sh CLLong -> JData sh
JIntArr (Array F sh CLLong -> JData sh) -> Array F sh CLLong -> JData sh
forall a b. (a -> b) -> a -> b
$ sh -> ForeignPtr CLLong -> Array F sh CLLong
forall sh e. Shape sh => sh -> ForeignPtr e -> Array F sh e
RF.fromForeignPtr (JVersion -> sh
forall sh. Shape sh => JVersion -> sh
R.shapeOfList (JVersion -> sh) -> JVersion -> sh
forall a b. (a -> b) -> a -> b
$ (CLLong -> Int) -> [CLLong] -> JVersion
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CLLong -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral [CLLong]
sh) (ForeignPtr CChar -> ForeignPtr CLLong
forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr CChar
fp)
jData (JAtom JType
JDouble [CLLong]
sh ForeignPtr CChar
fp)  = Array F sh CDouble -> JData sh
forall sh. Array F sh CDouble -> JData sh
JDoubleArr (Array F sh CDouble -> JData sh) -> Array F sh CDouble -> JData sh
forall a b. (a -> b) -> a -> b
$ sh -> ForeignPtr CDouble -> Array F sh CDouble
forall sh e. Shape sh => sh -> ForeignPtr e -> Array F sh e
RF.fromForeignPtr (JVersion -> sh
forall sh. Shape sh => JVersion -> sh
R.shapeOfList (JVersion -> sh) -> JVersion -> sh
forall a b. (a -> b) -> a -> b
$ (CLLong -> Int) -> [CLLong] -> JVersion
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CLLong -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral [CLLong]
sh) (ForeignPtr CChar -> ForeignPtr CDouble
forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr CChar
fp)
jData (JAtom JType
JComplex [CLLong]
sh ForeignPtr CChar
fp) = Array F sh (Complex CDouble) -> JData sh
forall sh. Array F sh (Complex CDouble) -> JData sh
JComplexArr (Array F sh (Complex CDouble) -> JData sh)
-> Array F sh (Complex CDouble) -> JData sh
forall a b. (a -> b) -> a -> b
$ sh -> ForeignPtr (Complex CDouble) -> Array F sh (Complex CDouble)
forall sh e. Shape sh => sh -> ForeignPtr e -> Array F sh e
RF.fromForeignPtr (JVersion -> sh
forall sh. Shape sh => JVersion -> sh
R.shapeOfList (JVersion -> sh) -> JVersion -> sh
forall a b. (a -> b) -> a -> b
$ (CLLong -> Int) -> [CLLong] -> JVersion
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CLLong -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral [CLLong]
sh) (ForeignPtr CChar -> ForeignPtr (Complex CDouble)
forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr CChar
fp)
jData (JAtom JType
JBool [CLLong]
sh ForeignPtr CChar
fp)    = Array F sh CChar -> JData sh
forall sh. Array F sh CChar -> JData sh
JBoolArr (Array F sh CChar -> JData sh) -> Array F sh CChar -> JData sh
forall a b. (a -> b) -> a -> b
$ sh -> ForeignPtr CChar -> Array F sh CChar
forall sh e. Shape sh => sh -> ForeignPtr e -> Array F sh e
RF.fromForeignPtr (JVersion -> sh
forall sh. Shape sh => JVersion -> sh
R.shapeOfList (JVersion -> sh) -> JVersion -> sh
forall a b. (a -> b) -> a -> b
$ (CLLong -> Int) -> [CLLong] -> JVersion
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CLLong -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral [CLLong]
sh) (ForeignPtr CChar -> ForeignPtr CChar
forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr CChar
fp)
jData (JAtom JType
JChar [CLLong
l] ForeignPtr CChar
fp)   = ByteString -> JData sh
forall sh. ByteString -> JData sh
JString (ByteString -> JData sh) -> ByteString -> JData sh
forall a b. (a -> b) -> a -> b
$ ForeignPtr Word8 -> Int -> Int -> ByteString
BS.fromForeignPtr (ForeignPtr CChar -> ForeignPtr Word8
forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr CChar
fp) Int
0 (CLLong -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CLLong
l)
jData (JAtom JType
JChar [CLLong]
_ ForeignPtr CChar
_)      = String -> JData sh
forall a. HasCallStack => String -> a
error String
"Not supported."