module Control.ERNet.Deployment.Local.Channel
(
ChannelLocal,
ChannelLocalAnyProt
)
where
import Control.ERNet.Deployment.Local.Logger
import Control.ERNet.Foundations.Protocol
import Control.ERNet.Foundations.Event
import qualified Control.ERNet.Foundations.Event.Logger as LG
import qualified Control.ERNet.Foundations.Channel as CH
import Control.Concurrent as Concurrent
import Control.Concurrent.STM
import Data.Number.ER.Misc.STM
import Data.Time.Clock
import qualified Data.Map as Map
import Data.Maybe
import Data.Typeable
instance CH.Channel ChannelLocal ChannelLocal ChannelLocalAnyProt ChannelLocalAnyProt
where
castIn = castChannel
castOut = castChannel
castInIO = castChannelIO
castOutIO = castChannelIO
makeQuery = makeQuery
makeQueryAnyProt = makeQueryAnyProt
waitForQuery = waitForQuery
waitForQueryMulti = waitForQueryMulti
answerQuery = answerQuery
answerQueryAnyProt = answerQueryAnyProt
waitForAnswer = waitForAnswer
waitForAnswerMulti = waitForAnswerMulti
instance CH.ChannelForScheduler LoggerLocal ChannelLocal ChannelLocal ChannelLocalAnyProt ChannelLocalAnyProt
where
new = newChannel
data ChannelLocalAnyProt =
forall q a. (QAProtocol q a, Show q, Show a) =>
ChannelLocalAnyProt
{
chanyCH :: (ChannelLocal q a)
,
chanyChT :: ChannelType
}
data ChannelLocal q a =
ChannelLocal
{
chTV :: TVar (ChannelState q a),
chLogger :: LoggerLocal,
chDestination :: String,
chID :: Int
}
deriving (Typeable)
instance (Eq (ChannelLocal q a)) where
ch1 == ch2 = (chID ch1) == (chID ch2)
instance (Ord (ChannelLocal q a)) where
compare ch1 ch2 = compare (chID ch1) (chID ch2)
instance (Show (ChannelLocal q a)) where
show ch = "CH" ++ show (chID ch)
type ChannelName = String
makeChName ::
ChannelLocal q a ->
ChannelName
makeChName channel
| cId == 0 = procName
| otherwise = procName ++ show (cId)
where
procName = chDestination channel
cId = chID channel
data ChannelState q a =
ChannelState
{
chStNextId :: QueryId,
chStQueriesNew :: [(QueryId, q)],
chStQueriesCache :: Map.Map QueryId (QAState q),
chStQueriesIndex :: Map.Map q QueryId,
chStAnswers :: Map.Map QueryId (a, Bool)
}
initialChState =
ChannelState 1 [] Map.empty Map.empty Map.empty
data QAState q =
QAState
{
qaStQry :: q,
qaStWaiting :: Int
}
initialQAState qry = QAState qry 1
newChannel ::
LoggerLocal ->
String ->
Int ->
ChannelType ->
IO (ChannelLocalAnyProt, ChannelLocalAnyProt)
newChannel logger procName cID chType =
do
cha <- auxNewCHA chType
return (cha, cha)
where
auxNewCHA chType@(ChannelType (q :: q) (a :: a)) =
do
cTV <- newTVarIO initialChState
return $
ChannelLocalAnyProt
((chan cTV) :: ChannelLocal q a)
chType
chan cTV =
ChannelLocal
{
chID = cID,
chTV = cTV,
chLogger = logger,
chDestination = procName
}
castChannel ::
(QAProtocol q a) =>
String ->
ChannelLocalAnyProt ->
(ChannelLocal q a)
castChannel locationDescr chA =
case chA of
ChannelLocalAnyProt ch chtp ->
case cast ch of
Just ch -> ch
Nothing ->
channelCastError locationDescr (makeChName ch) chtp
castChannelIO ::
(QAProtocol q a) =>
String ->
ChannelLocalAnyProt ->
IO (ChannelLocal q a)
castChannelIO locationDescr chA =
do
(chA_, _) <- newChannel undefined "" 0 chtp
case (chA, chA_) of
(ChannelLocalAnyProt ch _, ChannelLocalAnyProt ch_ _) ->
case [cast ch, cast ch_] of
[Just ch, Just _] ->
return ch
_ ->
channelCastError locationDescr (makeChName ch) chtp
where
chtp = chanyChT chA
channelCastError ::
String ->
ChannelName ->
ChannelType ->
a
channelCastError locationDescr chnm chtp =
error $
locationDescr
++ " failed casting channel " ++ chnm
++ " to " ++ show chtp
makeQuery callingCh callingQryId channel qry =
do
qryId <- atomically updateChSt
timeNow <- getCurrentTime
LG.addEvent (chLogger channel)
ERNetEvQryMade
{
ernetevTime = timeNow,
ernetevQryId = qryId,
ernetevFromId = makeChName callingCh,
ernetevFromQryId = callingQryId,
ernetevToId = makeChName channel,
ernetevQry = qry
}
return qryId
where
updateChSt =
do
chSt <- readTVar cTV
case (Map.lookup qry (chStQueriesIndex chSt)) of
(Nothing) ->
do
writeTVar cTV (updateNew chSt)
return $ chStNextId chSt
(Just qryId) ->
do
writeTVar cTV (updateOld chSt qryId)
return qryId
where
cTV = chTV channel
updateNew chSt =
chSt
{
chStNextId = qryId + 1,
chStQueriesNew =
(chStQueriesNew chSt) ++ [(qryId, qry)],
chStQueriesCache =
Map.insert qryId (initialQAState qry) (chStQueriesCache chSt),
chStQueriesIndex =
Map.insert qry qryId (chStQueriesIndex chSt)
}
where
qryId = chStNextId chSt
updateOld chSt qryId =
chSt
{
chStQueriesCache =
Map.adjust cacheIncCount qryId (chStQueriesCache chSt)
}
where
cacheIncCount qaSt@(QAState _ count) =
qaSt { qaStWaiting = count + 1 }
makeQueryAnyProt locationDescr callingCHA callingQryId chA qry =
case (callingCHA, qry) of
(ChannelLocalAnyProt callingCH _, QueryAnyProt q) ->
do
ch <- castChannelIO locationDescr chA
makeQuery callingCH callingQryId ch q
waitForQuery channel =
do
(qryId, qry) <- atomically waitUpdateChSt
return (qryId, qry)
where
waitUpdateChSt =
do
chSt <- readTVar cTV
exploreState chSt
where
cTV = chTV channel
exploreState chSt =
case chStQueriesNew chSt of
[] ->
retry
(qryData : otherQueries) ->
do
writeTVar cTV chSt'
return qryData
where
chSt' =
chSt
{
chStQueriesNew = otherQueries
}
waitForQueryMulti channels =
do
(chN, qryData) <- atomically waitUpdateChSt
return (chN, qryData)
where
waitUpdateChSt =
do
exploreChannels $ zip [0..] channels
where
exploreChannels [] = retry
exploreChannels ((chN, ChannelLocalAnyProt ch _) : otherChannels) =
do
res <- exploreState $ chTV ch
case res of
Nothing -> exploreChannels otherChannels
Just qryData -> return (chN, qryData)
exploreState cTV =
do
chSt <- readTVar cTV
case chStQueriesNew chSt of
[] ->
return Nothing
((qryId, qry) : otherQueries) ->
do
writeTVar cTV chSt'
return $ Just (qryId, QueryAnyProt qry)
where
chSt' =
chSt
{
chStQueriesNew = otherQueries
}
answerQuery useCache channel (qryId, ans) =
do
atomically updateChSt
return ()
where
updateChSt =
do
modifyTVar cTV update
where
cTV = chTV channel
update chSt =
chSt
{
chStAnswers =
Map.insert qryId (ans, useCache) (chStAnswers chSt)
}
answerQueryAnyProt locationDescr useCache chA (qryId, ans) =
case ans of
AnswerAnyProt a ->
answerQuery useCache (castChannel locationDescr chA) (qryId, a)
waitForAnswer waitingCh waitingQryId channel qryId =
do
(qry, ans) <- atomically waitUpdateChSt
timeNow <- getCurrentTime
LG.addEvent (chLogger channel)
ERNetEvAnsReceived
{
ernetevTime = timeNow,
ernetevQryId = qryId,
ernetevFromId = makeChName waitingCh,
ernetevFromQryId = waitingQryId,
ernetevToId = makeChName channel,
ernetevAns = ans,
ernetevQry = qry
}
return ans
where
waitUpdateChSt =
do
chSt <- readTVar cTV
case Map.lookup qryId (chStAnswers chSt) of
Nothing -> retry
Just (ans, isCached) ->
do
writeTVar cTV chSt'
return (qry, ans)
where
(chSt', qry) =
waitForAnswerAUX qryId chSt isCached
cTV = chTV channel
waitForAnswerAUX qryId chSt isCached =
(chSt', qry)
where
(Just (QAState qry count)) =
qryId `Map.lookup` (chStQueriesCache chSt)
chSt'
| count > 1 || isCached =
chSt
{
chStQueriesCache =
Map.insert qryId (QAState qry (count 1))
(chStQueriesCache chSt)
}
| otherwise =
chSt
{
chStQueriesCache =
Map.delete qryId (chStQueriesCache chSt),
chStQueriesIndex =
Map.delete qry (chStQueriesIndex chSt),
chStAnswers =
Map.delete qryId (chStAnswers chSt)
}
waitForAnswerMulti waitingCHA waitingQryId channelIds =
do
(chN, chn, qryId, qry, ans) <- atomically $ waitUpdateChSt $ zip [0..] channelIds
timeNow <- getCurrentTime
case (waitingCHA, chn, qry, ans) of
(ChannelLocalAnyProt waitingCH _,
ChannelLocalAnyProt ch _,
QueryAnyProt q, AnswerAnyProt a) ->
LG.addEvent (chLogger ch)
ERNetEvAnsReceived
{
ernetevTime = timeNow,
ernetevQryId = qryId,
ernetevFromId = makeChName waitingCH,
ernetevFromQryId = waitingQryId,
ernetevToId = makeChName ch,
ernetevAns = a,
ernetevQry = fromJust $ cast q
}
return (chN, ans)
where
waitUpdateChSt [] = retry
waitUpdateChSt ((chN, (channel, qryId)) : otherQueryInfos) =
case channel of
(ChannelLocalAnyProt ch chtp) ->
do
chSt <- readTVar cTV
case Map.lookup qryId (chStAnswers chSt) of
Nothing -> waitUpdateChSt otherQueryInfos
Just (ans, isCached) ->
do
writeTVar cTV chSt'
return (chN, ChannelLocalAnyProt ch chtp, qryId, QueryAnyProt qry, AnswerAnyProt ans)
where
(chSt', qry) =
waitForAnswerAUX qryId chSt isCached
where
cTV = chTV ch