module Spark.Core.Internal.Caching(
) where
import Control.Monad.Identity
import qualified Data.Set as S
import qualified Data.Map.Strict as M
import qualified Data.Vector as V
import Control.Arrow((&&&))
import Data.Foldable
import Data.Set(Set)
import Data.Maybe(mapMaybe)
import Debug.Trace(trace)
import Data.Text(Text)
import Formatting
import Spark.Core.Internal.DAGFunctions
import Spark.Core.Internal.DAGStructures
import Spark.Core.Internal.DatasetStructures
import Spark.Core.Internal.Utilities
data NodeCachingType =
AutocacheOp VertexId
| CacheOp VertexId
| UncacheOp VertexId VertexId
| Through
| Stop deriving (Show, Eq)
data CachingFailure = CachingFailure {
cachingNode :: !VertexId,
uncachingNode :: !VertexId,
escapingNode :: !VertexId
} deriving (Show, Eq)
type CacheTry t = Either Text t
type CacheGraph v = Graph (v, NodeCachingType) StructureEdge
data AutocacheGen v = AutocacheGen {
deriveUncache :: Vertex v -> Vertex v,
deriveIdentity :: Vertex v -> Vertex v
checkCaching :: (Show v) =>
Graph v StructureEdge ->
(v -> CacheTry NodeCachingType) ->
CacheTry [CachingFailure]
checkCaching g fun = _cacheGraph g fun >>= _checkCaching
fillAutoCache :: (Show v) =>
(v -> CacheTry NodeCachingType) ->
AutocacheGen v ->
Graph v StructureEdge ->
DagTry (Graph v StructureEdge)
fillAutoCache cacheFun acGen g = do
cg <- graphMapVertices g $ \vx _ -> (const vx &&& id) <$> cacheFun vx
acg <- _fillAutoCache acGen cg
let acg' = graphMapVertices' fst acg
return acg'
newtype AutocacheVertex v = AutocacheVertex (Vertex v)
newtype StopVertex v = StopVertex { unStopVertex :: Vertex v }
newtype IdentityVertex v = IdentityVertex (Vertex v)
data UncacheVertex v = UncacheVertex (Vertex v) VertexId deriving (Show)
newtype AnyCacheOp = AnyCacheOp { unAnyCacheOp :: VertexId } deriving (Show, Ord, Eq)
type CreateUncache v =
Either (UncacheVertex v) (UncacheVertex v, Edge StructureEdge)
_fillAutoCache :: forall v. (Show v) =>
AutocacheGen v ->
Graph (v, NodeCachingType) StructureEdge ->
DagTry (Graph (v, NodeCachingType) StructureEdge)
_fillAutoCache acGen cg =
vxMap = M.fromList ((vertexId &&& id) <$> toList (gVertices cg))
findOrCreateIdentity :: StopVertex v -> IdentityVertex v
findOrCreateIdentity (StopVertex vx) =
let uvx = deriveIdentity acGen vx
in case M.lookup (vertexId uvx) vxMap of
Just vx' -> IdentityVertex $ fst <$> vx'
Nothing -> IdentityVertex uvx
acNodesAndScopes = _autoCachingCandidates cg
findOrCreateUncache' (acv, l) = case _findOrCreateUncache vxMap acGen acv of
Left x -> trace ("findOrCreateUncache: dropping autocache node " ++ show x) Nothing
Right (ucv, ed) -> Just (ed, (acv, ucv, l))
acWithUncache' = mapMaybe findOrCreateUncache' acNodesAndScopes
acWithUncache = snd <$> acWithUncache'
acEdges = fst <$> acWithUncache'
tups = myGroupBy [(vertexId (unStopVertex svx), (cvx, uvx, svx)) | (cvx, uvx, l) <- acWithUncache,
svx <- l]
group ((_, uvx, svx) : t) = (svx, findOrCreateIdentity svx, uvx : [uvx' | ( _, uvx', _) <- t])
group [] = failure "_fillAutoCache:group: empty: should not happen"
stopsWithCachingSteps :: [(StopVertex v, IdentityVertex v, [UncacheVertex v])]
stopsWithCachingSteps = (group . snd) <$> M.toList tups
tups2 = [(svx, ivx, uvx) | (svx, ivx, l) <- stopsWithCachingSteps, uvx <- l]
folder eds (svx, ivx, uvx) = _performEdgeTransform svx ivx uvx eds
startEdges = veEdge <$> [ve | (_, v) <- M.toList (gEdges cg), ve <- V.toList v]
edges = acEdges ++ foldl' folder startEdges tups2
startVertices = V.toList (gVertices cg)
ucVertices = acWithUncache <&> \(_, UncacheVertex vx cacheVid, _) ->
let op = UncacheOp (vertexId vx) cacheVid
in (id &&& const op) <$> vx
idVertices = tups2 <&> \(_, IdentityVertex vx, _) ->
(id &&& const Stop) <$> vx
allVertices = startVertices ++ ucVertices ++ idVertices
in buildGraphFromList allVertices edges
_findOrCreateUncache :: (HasCallStack, Show v) =>
M.Map VertexId (Vertex (v, NodeCachingType)) ->
AutocacheGen v ->
AutocacheVertex v -> CreateUncache v
_findOrCreateUncache vxMap acGen (AutocacheVertex acv) =
let uvx = deriveUncache acGen acv
acVid = vertexId acv
uVid = vertexId uvx
look = vertexData <$> M.lookup uVid vxMap
in case look of
Just (x, UncacheOp _ _) ->
Left $ UncacheVertex (Vertex uVid x) uVid
Just _ ->
failure $ sformat ("_findOrCreateUncache:"%sh%"->"%sh) acv look
Nothing ->
let ed' = Edge uVid acVid ParentEdge
in Right (UncacheVertex uvx uVid, ed')
_performEdgeTransform ::
StopVertex v -> IdentityVertex v -> UncacheVertex v -> [Edge StructureEdge] -> [Edge StructureEdge]
_performEdgeTransform (StopVertex svx) (IdentityVertex ivx) (UncacheVertex uvx _) eds =
let stopVid = vertexId svx
idenVid = vertexId ivx
ucVid = vertexId uvx
f ed | edgeTo ed == stopVid = ed { edgeTo = idenVid }
f ed = ed
joinEd = Edge { edgeFrom = idenVid, edgeTo = stopVid, edgeData = ParentEdge }
id1Ed = Edge { edgeFrom = idenVid, edgeTo = ucVid, edgeData = LogicalEdge }
id2Ed = Edge { edgeFrom = ucVid, edgeTo = stopVid, edgeData = LogicalEdge }
in id1Ed : id2Ed : joinEd : (f <$> eds)
_autoCachingCandidates :: forall v. (Show v) =>
Graph (v, NodeCachingType) StructureEdge ->
[(AutocacheVertex v, [StopVertex v])]
_autoCachingCandidates cg =
cg' = graphMapVertices' snd cg
exps = gVertices $ _expansions cg'
extractAutocache vx = case snd (vertexData vx) of
AutocacheOp _ -> [AutocacheVertex (fst <$> vx)]
_ -> []
acVxs = concatMap extractAutocache (gVertices cg)
extractFringe vx = case vertexData vx of
(Stop, set) -> (id &&& const (vertexId vx)) <$> toList set
_ -> []
acWithFringe = myGroupBy $ concatMap extractFringe (toList exps)
vmap = vertexMap cg
vmap' = vertexMap cg'
findStop :: VertexId -> Maybe (StopVertex v)
findStop vid = do
vx <- M.lookup vid vmap
_ <- M.lookup vid vmap'
return $ StopVertex (Vertex vid (fst vx))
combineWithFringe :: AutocacheVertex v -> (AutocacheVertex v, [StopVertex v])
combineWithFringe acv @ (AutocacheVertex vx) =
let vids = M.findWithDefault [] (AnyCacheOp (vertexId vx)) acWithFringe
in (acv, mapMaybe findStop vids)
acWithFringeVx = filter (not.null.snd) $ combineWithFringe <$> acVxs
in acWithFringeVx
_checkCaching :: Graph NodeCachingType StructureEdge -> CacheTry [CachingFailure]
_checkCaching cg =
expands = snd <$> vertexMap (_expansions cg)
removals = vertexMap $ _removals cg
f :: NodeCachingType -> [(VertexId, VertexId)]
f (UncacheOp uncacheNid cacheNid) = [(cacheNid, uncacheNid)]
f _ = []
removedNodes = M.fromList $ concatMap f (vertexData <$> gVertices cg)
removedNodeSet = S.fromList $ M.keys removedNodes
checkErrors :: VertexId -> [CachingFailure]
checkErrors nid =
let rems = S.intersection
(M.findWithDefault S.empty nid removals)
exps = S.intersection
(unAnyCacheOp `` M.findWithDefault S.empty nid expands)
bad = S.toList $ S.difference exps rems
badWithUncache = flip mapMaybe bad $ \ cacheNid ->
M.lookup cacheNid removedNodes <&> \uncacheNid ->
CachingFailure cacheNid uncacheNid nid
in badWithUncache
in return $ concatMap checkErrors (vertexId <$> gVertices cg)
_cacheGraph :: (Show v) => Graph v StructureEdge ->
(v -> CacheTry NodeCachingType) ->
CacheTry (Graph NodeCachingType StructureEdge)
_cacheGraph g f =
graphMapVertices g f' where
f' vx _ = f vx
_expansions :: (Show e) =>
Graph NodeCachingType e ->
Graph (NodeCachingType, Set AnyCacheOp) e
_expansions g = runIdentity (graphMapVertices g f) where
f x l = return (x, S.union seta parentSet) where
filt ((Stop, _), _) = S.empty
filt ((UncacheOp _ cacheVid, s), _) = S.delete (AnyCacheOp cacheVid) s
filt ((_, s), _) = s
parentSet :: S.Set AnyCacheOp
parentSet = S.unions (filt <$> l)
seta = case x of
CacheOp nid -> S.singleton (AnyCacheOp nid)
AutocacheOp nid -> S.singleton (AnyCacheOp nid)
_ -> S.empty
_removals :: (Show e) =>
Graph NodeCachingType e -> Graph (Set VertexId) e
_removals g = runIdentity (graphMapVertices (reverseGraph g) f) where
f x l = return $ S.union seta (S.unions (fst <$> l)) where
seta = case x of
UncacheOp _ cacheNid -> S.singleton cacheNid
_ -> S.empty