#if !MIN_VERSION_base(4,8,0)
#endif
module Foreign.Lua.FunctionCalling
( FromLuaStack (..)
, LuaCallFunc (..)
, ToHaskellFunction (..)
, HaskellFunction
, ToLuaStack (..)
, PreCFunction
, toHaskellFunction
, callFunc
, freeCFunction
, newCFunction
, pushHaskellFunction
, registerHaskellFunction
) where
import Control.Monad (when)
import Foreign.C (CInt (..))
import Foreign.Lua.Api
import Foreign.Lua.Types
import Foreign.Lua.Util (getglobal')
import Foreign.Ptr (castPtr, freeHaskellFunPtr)
import Foreign.StablePtr (deRefStablePtr, freeStablePtr, newStablePtr)
import qualified Foreign.Storable as F
type PreCFunction = LuaState -> IO NumResults
type HaskellFunction = Lua NumResults
class ToHaskellFunction a where
toHsFun :: StackIndex -> a -> Lua NumResults
#if MIN_VERSION_base(4,8,0)
instance ToHaskellFunction HaskellFunction where
#else
instance ToHaskellFunction HaskellFunction where
#endif
toHsFun _ = id
instance ToLuaStack a => ToHaskellFunction (Lua a) where
toHsFun _narg x = 1 <$ (x >>= push)
instance (FromLuaStack a, ToHaskellFunction b) =>
ToHaskellFunction (a -> b) where
toHsFun narg f = getArg >>= toHsFun (narg + 1) . f
where
getArg = peek narg `catchLuaError` \err ->
throwLuaError ("could not read argument "
++ show (fromStackIndex narg) ++ ": " ++ show err)
toHaskellFunction :: ToHaskellFunction a => a -> HaskellFunction
toHaskellFunction a = toHsFun 1 a `catchLuaError` \err -> do
push ("Error during function call: " ++ show err)
fromIntegral <$> lerror
newCFunction :: ToHaskellFunction a => a -> Lua CFunction
newCFunction = liftIO . mkWrapper . flip runLuaWith . toHaskellFunction
foreign import ccall "wrapper"
mkWrapper :: PreCFunction -> IO CFunction
freeCFunction :: CFunction -> Lua ()
freeCFunction = liftIO . freeHaskellFunPtr
class LuaCallFunc a where
callFunc' :: String -> Lua () -> NumArgs -> a
instance (FromLuaStack a) => LuaCallFunc (Lua a) where
callFunc' fnName x nargs = do
getglobal' fnName
x
z <- pcall nargs 1 Nothing
if z == OK
then peek (1) <* pop 1
else throwTopMessageAsError
instance (ToLuaStack a, LuaCallFunc b) => LuaCallFunc (a -> b) where
callFunc' fnName pushArgs nargs x =
callFunc' fnName (pushArgs *> push x) (nargs + 1)
callFunc :: (LuaCallFunc a) => String -> a
callFunc f = callFunc' f (return ()) 0
pushHaskellFunction :: ToHaskellFunction a => a -> Lua ()
pushHaskellFunction = pushPreCFunction . flip runLuaWith . toHaskellFunction
pushPreCFunction :: PreCFunction -> Lua ()
pushPreCFunction f = do
stableptr <- liftIO $ newStablePtr f
p <- newuserdata (F.sizeOf stableptr)
liftIO $ F.poke (castPtr p) stableptr
v <- newmetatable "HaskellImportedFunction"
when v $ do
pushcfunction hsmethod__gc_addr
setfield (2) "__gc"
pushcfunction hsmethod__call_addr
setfield (2) "__call"
setmetatable (2)
return ()
registerHaskellFunction :: ToHaskellFunction a => String -> a -> Lua ()
registerHaskellFunction n f = do
pushHaskellFunction f
setglobal n
foreign export ccall hsMethodGc :: PreCFunction
foreign import ccall "&hsMethodGc" hsmethod__gc_addr :: CFunction
foreign export ccall hsMethodCall :: PreCFunction
foreign import ccall "&hsMethodCall" hsmethod__call_addr :: CFunction
hsMethodGc :: LuaState -> IO NumResults
hsMethodGc l = do
ptr <- runLuaWith l $ peek (1)
stableptr <- F.peek (castPtr ptr)
freeStablePtr stableptr
return 0
hsMethodCall :: LuaState -> IO NumResults
hsMethodCall l = do
ptr <- runLuaWith l $ peek 1 <* remove 1
stableptr <- F.peek (castPtr ptr)
f <- deRefStablePtr stableptr
f l