module Language.Haskell.LSP.Test.Replay
( replaySession
)
where
import Prelude hiding (id)
import Control.Concurrent
import Control.Monad.IO.Class
import qualified Data.ByteString.Lazy.Char8 as B
import qualified Data.Text as T
import Language.Haskell.LSP.Capture
import Language.Haskell.LSP.Messages
import Language.Haskell.LSP.Types as LSP hiding (error)
import Data.Aeson
import Data.Default
import Data.List
import Data.Maybe
import Control.Lens hiding (List)
import Control.Monad
import System.FilePath
import Language.Haskell.LSP.Test
import Language.Haskell.LSP.Test.Files
import Language.Haskell.LSP.Test.Decoding
import Language.Haskell.LSP.Test.Messages
import Language.Haskell.LSP.Test.Server
import Language.Haskell.LSP.Test.Session
replaySession :: String
-> FilePath
-> IO ()
replaySession serverExe sessionDir = do
entries <- B.lines <$> B.readFile (sessionDir </> "session.log")
let unswappedEvents = map (fromJust . decode) entries
withServer serverExe False $ \serverIn serverOut pid -> do
events <- swapCommands pid <$> swapFiles sessionDir unswappedEvents
let clientEvents = filter isClientMsg events
serverEvents = filter isServerMsg events
clientMsgs = map (\(FromClient _ msg) -> msg) clientEvents
serverMsgs = filter (not . shouldSkip) $ map (\(FromServer _ msg) -> msg) serverEvents
requestMap = getRequestMap clientMsgs
reqSema <- newEmptyMVar
rspSema <- newEmptyMVar
passSema <- newEmptyMVar
mainThread <- myThreadId
sessionThread <- liftIO $ forkIO $
runSessionWithHandles serverIn
serverOut
(listenServer serverMsgs requestMap reqSema rspSema passSema mainThread)
def
fullCaps
sessionDir
(sendMessages clientMsgs reqSema rspSema)
takeMVar passSema
killThread sessionThread
where
isClientMsg (FromClient _ _) = True
isClientMsg _ = False
isServerMsg (FromServer _ _) = True
isServerMsg _ = False
sendMessages :: [FromClientMessage] -> MVar LspId -> MVar LspIdRsp -> Session ()
sendMessages [] _ _ = return ()
sendMessages (nextMsg:remainingMsgs) reqSema rspSema =
handleClientMessage request response notification nextMsg
where
notification msg@(NotificationMessage _ Exit _) = do
liftIO $ putStrLn "Will send exit notification soon"
liftIO $ threadDelay 10000000
sendMessage msg
liftIO $ error "Done"
notification msg@(NotificationMessage _ m _) = do
sendMessage msg
liftIO $ putStrLn $ "Sent a notification " ++ show m
sendMessages remainingMsgs reqSema rspSema
request msg@(RequestMessage _ id m _) = do
sendRequestMessage msg
liftIO $ putStrLn $ "Sent a request id " ++ show id ++ ": " ++ show m ++ "\nWaiting for a response"
rsp <- liftIO $ takeMVar rspSema
when (responseId id /= rsp) $
error $ "Expected id " ++ show id ++ ", got " ++ show rsp
sendMessages remainingMsgs reqSema rspSema
response msg@(ResponseMessage _ id _ _) = do
liftIO $ putStrLn $ "Waiting for request id " ++ show id ++ " from the server"
reqId <- liftIO $ takeMVar reqSema
if responseId reqId /= id
then error $ "Expected id " ++ show reqId ++ ", got " ++ show reqId
else do
sendResponse msg
liftIO $ putStrLn $ "Sent response to request id " ++ show id
sendMessages remainingMsgs reqSema rspSema
sendRequestMessage :: (ToJSON a, ToJSON b) => RequestMessage ClientMethod a b -> Session ()
sendRequestMessage req = do
reqMap <- requestMap <$> ask
liftIO $ modifyMVar_ reqMap $
\r -> return $ updateRequestMap r (req ^. LSP.id) (req ^. method)
sendMessage req
isNotification :: FromServerMessage -> Bool
isNotification (NotPublishDiagnostics _) = True
isNotification (NotLogMessage _) = True
isNotification (NotShowMessage _) = True
isNotification (NotCancelRequestFromServer _) = True
isNotification _ = False
listenServer [] _ _ _ passSema _ _ _ = putMVar passSema ()
listenServer expectedMsgs reqMap reqSema rspSema passSema mainThreadId serverOut ctx = do
msgBytes <- getNextMessage serverOut
let msg = decodeFromServerMsg reqMap msgBytes
handleServerMessage request response notification msg
if shouldSkip msg
then listenServer expectedMsgs reqMap reqSema rspSema passSema mainThreadId serverOut ctx
else if inRightOrder msg expectedMsgs
then listenServer (delete msg expectedMsgs) reqMap reqSema rspSema passSema mainThreadId serverOut ctx
else let remainingMsgs = takeWhile (not . isNotification) expectedMsgs
++ [head $ dropWhile isNotification expectedMsgs]
exc = ReplayOutOfOrder msg remainingMsgs
in liftIO $ throwTo mainThreadId exc
where
response :: ResponseMessage a -> IO ()
response res = do
putStrLn $ "Got response for id " ++ show (res ^. id)
putMVar rspSema (res ^. id)
request :: RequestMessage ServerMethod a b -> IO ()
request req = do
putStrLn
$ "Got request for id "
++ show (req ^. id)
++ " "
++ show (req ^. method)
putMVar reqSema (req ^. id)
notification :: NotificationMessage ServerMethod a -> IO ()
notification n = putStrLn $ "Got notification " ++ show (n ^. method)
inRightOrder :: FromServerMessage -> [FromServerMessage] -> Bool
inRightOrder _ [] = error "Why is this empty"
inRightOrder received (expected : msgs)
| received == expected = True
| isNotification expected = inRightOrder received msgs
| otherwise = False
shouldSkip :: FromServerMessage -> Bool
shouldSkip (NotLogMessage _) = True
shouldSkip (NotShowMessage _) = True
shouldSkip (ReqShowMessage _) = True
shouldSkip _ = False
swapCommands :: Int -> [Event] -> [Event]
swapCommands _ [] = []
swapCommands pid (FromClient t (ReqExecuteCommand req):xs) = FromClient t (ReqExecuteCommand swapped):swapCommands pid xs
where swapped = params . command .~ newCmd $ req
newCmd = swapPid pid (req ^. params . command)
swapCommands pid (FromServer t (RspInitialize rsp):xs) = FromServer t (RspInitialize swapped):swapCommands pid xs
where swapped = case newCommands of
Just cmds -> result . _Just . LSP.capabilities . executeCommandProvider . _Just . commands .~ cmds $ rsp
Nothing -> rsp
oldCommands = rsp ^? result . _Just . LSP.capabilities . executeCommandProvider . _Just . commands
newCommands = fmap (fmap (swapPid pid)) oldCommands
swapCommands pid (x:xs) = x:swapCommands pid xs
hasPid :: T.Text -> Bool
hasPid = (>= 2) . T.length . T.filter (':' ==)
swapPid :: Int -> T.Text -> T.Text
swapPid pid t
| hasPid t = T.append (T.pack $ show pid) $ T.dropWhile (/= ':') t
| otherwise = t