module Spark.Core.Internal.ContextInternal(
FinalResult,
prepareExecution1,
buildComputationGraph,
performGraphTransforms,
getTargetNodes,
storeResults,
) where
import Control.Monad.State(get, put)
import Control.Monad(forM)
import Data.Text(pack)
import Debug.Trace(trace)
import Data.Foldable(toList)
import Control.Arrow((&&&))
import Formatting
import qualified Data.Map.Strict as M
import qualified Data.Vector as V
import Spark.Core.Dataset
import Spark.Core.Try
import Spark.Core.Row
import Spark.Core.Types
import Spark.Core.Internal.Caching
import Spark.Core.Internal.CachingUntyped
import Spark.Core.Internal.ContextStructures
import Spark.Core.Internal.Client
import Spark.Core.Internal.ComputeDag
import Spark.Core.Internal.PathsUntyped
import Spark.Core.Internal.Paths()
import Spark.Core.Internal.TypesStructures
import Spark.Core.Internal.TypesFunctions(arrayType)
import Spark.Core.Internal.DAGFunctions(buildVertexList)
import Spark.Core.Internal.DAGStructures
import Spark.Core.Internal.DatasetFunctions
import Spark.Core.Internal.DatasetStructures
import Spark.Core.Internal.Utilities
type FinalResult = Either NodeComputationFailure NodeComputationSuccess
prepareExecution1 :: LocalData a -> SparkStatePure (Try Computation)
prepareExecution1 ld = get >>= \session ->
let cg = buildComputationGraph ld
cg' = performGraphTransforms =<< cg
comp = _buildComputation session =<< cg'
in case comp of
Left _ -> return comp
Right _ -> do
_increaseCompCounter
return comp
buildComputationGraph :: ComputeNode loc a -> Try ComputeGraph
buildComputationGraph ld = do
cg <- tryEither $ buildCGraph (untyped ld)
assignPathsUntyped cg
performGraphTransforms :: ComputeGraph -> Try ComputeGraph
performGraphTransforms cg = do
let g = traceHint "_performGraphTransforms g=" $ computeGraphToGraph cg
let acg = traceHint "_performGraphTransforms: After autocaching:" $ fillAutoCache cachingType autocacheGen g
g' <- tryEither acg
failures <- tryEither $ checkCaching g' cachingType
case failures of
[] -> return (graphToComputeGraph g')
_ -> tryError $ sformat ("Found some caching errors: "%sh) failures
_buildComputation :: SparkSession -> ComputeGraph -> Try Computation
_buildComputation session cg =
let sid = ssId session
cid = (ComputationID . pack . show . ssCommandCounter) session
tiedCg = tieNodes cg
allNodes = vertexData <$> toList (cdVertices tiedCg)
terminalNodeNames = nodeName . vertexData <$> toList (cdOutputs tiedCg)
in case terminalNodeNames of
[name] ->
return $ Computation sid cid allNodes [name] name
_ -> tryError $ sformat ("Programming error in _build1: cg="%sh) cg
_increaseCompCounter :: SparkStatePure ()
_increaseCompCounter = get >>= \session ->
let
curr = ssCommandCounter session
session2 = session { ssCommandCounter = curr + 1 }
in put session2
_gatherNodes :: LocalData a -> Try [UntypedNode]
_gatherNodes = tryEither . buildVertexList . untyped
_extractionType :: SQLType a -> SQLType [a]
_extractionType = arrayType . SQLType . unSQLType
_postprocessBasic :: (HasCallStack) => Cell -> Cell
_postprocessBasic (RowArray rows) =
RowArray (process <$> rows) where
process (RowArray arr) = case V.toList arr of
[IntElement x] -> IntElement x
[StringElement x] -> StringElement x
_ -> RowArray arr
process x = x
_postprocessBasic x = x
_extract1 :: FinalResult -> SQLType Cell -> Try Cell
_extract1 (Left nf) _ = tryError $ sformat ("got an error "%shown) nf
_extract1 (Right ncs) sqlt = res0 where
wrappingType = _extractionType sqlt
trow = tryEither $ jsonToCell (unSQLType wrappingType) (ncsData ncs)
res = trow >>= \l -> case l of
RowArray arr | V.length arr == 1 -> Right $ _postprocessBasic (V.head arr)
x -> tryError $ sformat ("ContextInternal:_extract1: Expected on element, got "%shown) x
res0 = trace ("_extract1: wrappingType = " ++ show wrappingType ++ " ncs = " ++ show ncs ++ " res = " ++ show res) res
getTargetNodes :: (HasCallStack) => Computation -> [UntypedLocalData]
getTargetNodes comp =
let
fun2 :: (HasCallStack) => UntypedNode -> UntypedLocalData
fun2 n = case asLocalObservable <$> castLocality n of
Right (Right x) -> x
err -> failure $ sformat ("_getNodes:fun2: err="%shown%" n="%shown) err n
finalNodeNames = traceHint "_getTargetNodes: finalNodeNames=" $cTerminalNodes comp
dct = traceHint "_getTargetNodes: dct=" $ M.fromList $ (nodeName &&& id) <$> cNodes comp
untyped' = finalNodeNames <&> \n ->
let err = failure $ sformat ("Could not find "%sh%" in "%sh) n dct
in M.findWithDefault err n dct
in fun2 <$> untyped'
storeResults :: Computation -> [(LocalData Cell, FinalResult)] -> SparkStatePure (Try Cell)
storeResults comp [] = return e where
e = tryError $ sformat ("No result returned for computation "%shown) comp
storeResults _ res =
let
fun4 :: (LocalData Cell, FinalResult) -> Try Cell
fun4 (node, fresult) =
trace ("_storeResults node=" ++ show node ++ "final = " ++ show fresult) $
_extract1 fresult (nodeType node)
allResults = sequence $ forM res fun4
expResult = head allResults
in
return expResult