module RL_Glue.Agent (
  Agent(Agent), loadAgentDebug, loadAgent
  ) where

import Control.Monad (unless)
import Control.Monad.Trans.Class
import Control.Monad.Trans.Maybe
import Control.Monad.Trans.State.Lazy
import Data.Binary.Get
import Data.Binary.Put
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import Data.Version (showVersion)
import Data.Word
import Network.Simple.TCP
import System.Exit

import Paths_rlglue (version)
import RL_Glue.Network

data Agent a = Agent
  { onAgentInit :: BS.ByteString -> StateT a IO ()
  , onAgentStart :: Observation -> StateT a IO Action
  , onAgentStep :: (Reward, Observation) -> StateT a IO Action
  , onAgentEnd :: Reward -> StateT a IO ()
  , onAgentCleanup :: StateT a IO ()
  , onAgentMessage :: BS.ByteString -> StateT a IO BS.ByteString
  }

loadAgent :: Agent a -> a -> IO ()
loadAgent = loadAgentDebug 0

loadAgentDebug :: Int -> Agent a -> a -> IO ()
loadAgentDebug debugLvl agent initState =
  let 
    func (sock, addr) =
      do
        -- Initial setup
        putStrLn $ "RL-Glue Haskell Agent Codec (Version " ++ showVersion version ++ ")"
        let bs = runPut (putWord32be kAgentConnection >> putWord32be (0 :: Word32))
        sendLazy sock bs

        -- Run event loop
        evalStateT (eventLoop agent sock debugLvl) initState
  in
    glueConnect func

eventLoop :: Agent a -> Socket -> Int -> StateT a IO ()
eventLoop agent sock debugLvl = do
  x <- lift $ runMaybeT (getAgentState sock)
  case x of
    Nothing -> do
      lift $ putStrLn "Error: Failed to receive state."
      lift $ exitWith (ExitFailure 1)
    Just (state, size) ->
      unless (state == kRLTerm) $ do
        handleState sock agent state debugLvl
        eventLoop agent sock debugLvl

handleState :: Socket -> Agent a -> Word32 -> Int -> StateT a IO ()
handleState sock agent state debugLvl
  | state == kAgentInit = do
    lift $ unless (debugLvl < 1) $ putStrLn "kAgentInit received"
    taskSpec <- lift $ getStringOrDie "Error: Could not get task spec" sock
    onAgentInit agent taskSpec
    let packedMsg = runPut (
          putWord32be kAgentInit >>
          putWord32be 0)
    sendLazy sock packedMsg
  | state == kAgentStart = do
    lift $ unless (debugLvl < 1) $ putStrLn "kAgentStart received"
    obs <- lift $ getObservationOrDie sock
    action <- onAgentStart agent obs
    let size = sizeOfAction action
    let packedMsg = runPut (
          putWord32be kAgentStart >>
          putWord32be (fromIntegral size) >>
          putAction action)
    sendLazy sock packedMsg
  | state == kAgentStep = do
    lift $ unless (debugLvl < 1) $ putStrLn "kAgentStep received"
    rewardObs <- lift $ getRewardObservationOrDie sock
    action <- onAgentStep agent rewardObs
    let size = sizeOfAction action
    let packedMsg = runPut (
          putWord32be kAgentStep >>
          putWord32be (fromIntegral size) >>
          putAction action)
    sendLazy sock packedMsg
  | state == kAgentEnd = do
    lift $ unless (debugLvl < 1) $ putStrLn "kAgentEnd received"
    reward <- lift $ getRewardOrDie sock
    onAgentEnd agent reward
    let packedMsg = runPut (
          putWord32be kAgentEnd >>
          putWord32be 0)
    sendLazy sock packedMsg
  | state == kAgentCleanup = do
    lift $ unless (debugLvl < 1) $ putStrLn "kAgentCleanup received"
    onAgentCleanup agent
    let packedMsg = runPut (
          putWord32be kAgentCleanup >>
          putWord32be 0)
    sendLazy sock packedMsg
  | state == kAgentMessage = do
    lift $ unless (debugLvl < 1) $ putStrLn "kAgentMessage received"
    msg <- lift $ getStringOrDie "Error: Could not read message" sock
    resp <- onAgentMessage agent msg
    let packedMsg = runPut (
          putWord32be kAgentMessage >>
          if BS.null resp
            then putWord32be 4 >> putWord32be 0
            else 
              putWord32be (fromIntegral $ 4 + BS.length resp) >>
              putString resp)
    sendLazy sock packedMsg
  | state == kRLTerm = 
    lift $ unless (debugLvl < 1) $ putStrLn "kRLTerm received"
  | otherwise  = do
    lift $ putStrLn $ "Error: Unknown state: " ++ show state
    lift $ exitWith (ExitFailure 1)

getAgentState :: Socket -> MaybeT IO (Word32, Word32)
getAgentState sock = do
  bs <- MaybeT $ recv sock (4*2)
  return $ runGet parseBytes (LBS.fromStrict bs)
  where
    parseBytes = do
      envState <- getWord32be
      dataSize <- getWord32be
      return (envState, dataSize)