module RL_Glue.Agent (
Agent(Agent), loadAgentDebug, loadAgent
) where
import Control.Monad (unless)
import Control.Monad.Trans.Class
import Control.Monad.Trans.Maybe
import Control.Monad.Trans.State.Lazy
import Data.Binary.Get
import Data.Binary.Put
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import Data.Version (showVersion)
import Data.Word
import Network.Simple.TCP
import System.Exit
import Paths_rlglue (version)
import RL_Glue.Network
data Agent a = Agent
{ onAgentInit :: BS.ByteString -> StateT a IO ()
, onAgentStart :: Observation -> StateT a IO Action
, onAgentStep :: (Reward, Observation) -> StateT a IO Action
, onAgentEnd :: Reward -> StateT a IO ()
, onAgentCleanup :: StateT a IO ()
, onAgentMessage :: BS.ByteString -> StateT a IO BS.ByteString
}
loadAgent :: Agent a -> a -> IO ()
loadAgent = loadAgentDebug 0
loadAgentDebug :: Int -> Agent a -> a -> IO ()
loadAgentDebug debugLvl agent initState =
let
func (sock, addr) =
do
putStrLn $ "RL-Glue Haskell Agent Codec (Version " ++ showVersion version ++ ")"
let bs = runPut (putWord32be kAgentConnection >> putWord32be (0 :: Word32))
sendLazy sock bs
evalStateT (eventLoop agent sock debugLvl) initState
in
glueConnect func
eventLoop :: Agent a -> Socket -> Int -> StateT a IO ()
eventLoop agent sock debugLvl = do
x <- lift $ runMaybeT (getAgentState sock)
case x of
Nothing -> do
lift $ putStrLn "Error: Failed to receive state."
lift $ exitWith (ExitFailure 1)
Just (state, size) ->
unless (state == kRLTerm) $ do
handleState sock agent state debugLvl
eventLoop agent sock debugLvl
handleState :: Socket -> Agent a -> Word32 -> Int -> StateT a IO ()
handleState sock agent state debugLvl
| state == kAgentInit = do
lift $ unless (debugLvl < 1) $ putStrLn "kAgentInit received"
taskSpec <- lift $ getStringOrDie "Error: Could not get task spec" sock
onAgentInit agent taskSpec
let packedMsg = runPut (
putWord32be kAgentInit >>
putWord32be 0)
sendLazy sock packedMsg
| state == kAgentStart = do
lift $ unless (debugLvl < 1) $ putStrLn "kAgentStart received"
obs <- lift $ getObservationOrDie sock
action <- onAgentStart agent obs
let size = sizeOfAction action
let packedMsg = runPut (
putWord32be kAgentStart >>
putWord32be (fromIntegral size) >>
putAction action)
sendLazy sock packedMsg
| state == kAgentStep = do
lift $ unless (debugLvl < 1) $ putStrLn "kAgentStep received"
rewardObs <- lift $ getRewardObservationOrDie sock
action <- onAgentStep agent rewardObs
let size = sizeOfAction action
let packedMsg = runPut (
putWord32be kAgentStep >>
putWord32be (fromIntegral size) >>
putAction action)
sendLazy sock packedMsg
| state == kAgentEnd = do
lift $ unless (debugLvl < 1) $ putStrLn "kAgentEnd received"
reward <- lift $ getRewardOrDie sock
onAgentEnd agent reward
let packedMsg = runPut (
putWord32be kAgentEnd >>
putWord32be 0)
sendLazy sock packedMsg
| state == kAgentCleanup = do
lift $ unless (debugLvl < 1) $ putStrLn "kAgentCleanup received"
onAgentCleanup agent
let packedMsg = runPut (
putWord32be kAgentCleanup >>
putWord32be 0)
sendLazy sock packedMsg
| state == kAgentMessage = do
lift $ unless (debugLvl < 1) $ putStrLn "kAgentMessage received"
msg <- lift $ getStringOrDie "Error: Could not read message" sock
resp <- onAgentMessage agent msg
let packedMsg = runPut (
putWord32be kAgentMessage >>
if BS.null resp
then putWord32be 4 >> putWord32be 0
else
putWord32be (fromIntegral $ 4 + BS.length resp) >>
putString resp)
sendLazy sock packedMsg
| state == kRLTerm =
lift $ unless (debugLvl < 1) $ putStrLn "kRLTerm received"
| otherwise = do
lift $ putStrLn $ "Error: Unknown state: " ++ show state
lift $ exitWith (ExitFailure 1)
getAgentState :: Socket -> MaybeT IO (Word32, Word32)
getAgentState sock = do
bs <- MaybeT $ recv sock (4*2)
return $ runGet parseBytes (LBS.fromStrict bs)
where
parseBytes = do
envState <- getWord32be
dataSize <- getWord32be
return (envState, dataSize)