{-# LANGUAGE BlockArguments #-} {-# LANGUAGE TupleSections #-} module Util where import qualified Data.Map.Strict as Map import Data.Massiv.Array as MA hiding (forM_, forM) import Data.SRTree import Data.SRTree.Eval import Algorithm.SRTree.Opt import Algorithm.EqSat.Egraph import Algorithm.EqSat.Build import Algorithm.EqSat.Info import Algorithm.SRTree.NonlinearOpt import System.Random import Random import Algorithm.SRTree.Likelihoods --import Algorithm.SRTree.ModelSelection --import Algorithm.SRTree.Opt import qualified Data.IntMap.Strict as IM import Control.Monad.State.Strict import Control.Monad ( when, replicateM, forM, forM_ ) import Data.Maybe ( fromJust ) import Data.List ( maximumBy ) import Data.Function ( on ) import List.Shuffle ( shuffle ) import Data.List.Split ( splitOn ) import Data.Char ( toLower ) import qualified Data.IntSet as IntSet import Data.SRTree.Datasets import Algorithm.EqSat.Queries type RndEGraph a = EGraphST (StateT StdGen IO) a type DataSet = (SRMatrix, PVector, Maybe PVector) csvHeader :: String csvHeader = "id,Expression,theta,size,MSE_train,MSE_val,MSE_test,R2_train,R2_val,R2_test,nll_train,nll_val,nll_test,mdl_train,mdl_val,mdl_test" io :: IO a -> RndEGraph a io = lift . lift {-# INLINE io #-} rnd :: StateT StdGen IO a -> RndEGraph a rnd = lift {-# INLINE rnd #-} myCost :: SRTree Int -> Int myCost (Var _) = 1 myCost (Const _) = 1 myCost (Param _) = 1 myCost (Bin _ l r) = 2 + l + r myCost (Uni _ t) = 3 + t while :: Monad f => (t -> Bool) -> t -> (t -> f t) -> f () while p arg prog = do when (p arg) do arg' <- prog arg while p arg' prog fitnessFun :: Int -> Distribution -> DataSet -> DataSet -> Fix SRTree -> PVector -> (Double, PVector) fitnessFun nIter distribution (x, y, mYErr) (x_val, y_val, mYErr_val) _tree thetaOrig = if isNaN val || isNaN tr then (-(1/0), theta) -- infinity else (val, theta) -- (min tr val, theta) where tree = relabelParams _tree nParams = countParams tree + if distribution == ROXY then 3 else if distribution == Gaussian then 1 else 0 (theta, _, _) = minimizeNLL' VAR1 distribution mYErr nIter x y tree thetaOrig evalF a b c = negate $ nll distribution c a b tree $ if nParams == 0 then thetaOrig else theta tr = evalF x y mYErr val = evalF x_val y_val mYErr_val {-# INLINE fitnessFun #-} fitnessFunRep :: Int -> Int -> Distribution -> DataSet -> DataSet -> Fix SRTree -> RndEGraph (Double, PVector) fitnessFunRep nRep nIter distribution dataTrain dataVal _tree = do let tree = relabelParams _tree nParams = countParams tree + if distribution == ROXY then 3 else if distribution == Gaussian then 1 else 0 thetaOrigs <- replicateM nRep (rnd $ randomVec nParams) let fits = Prelude.map (fitnessFun nIter distribution dataTrain dataVal _tree) thetaOrigs pure (maximumBy (compare `on` fst) fits) {-# INLINE fitnessFunRep #-} --fitnessMV :: Int -> Int -> Distribution -> [DataSet] -> [DataSet] -> Fix SRTree -> RndEGraph (Double, [PVector]) --fitnessMV nRep nIter distribution dataTrains dataVals _tree = do -- response <- forM (zip dataTrains dataVals) $ \(dt, dv) -> fitnessFunRep nRep nIter distribution dt dv _tree -- pure (minimum (map fst response), map snd response) -- helper query functions -- TODO: move to egraph lib getFitness :: EClassId -> RndEGraph (Maybe Double) getFitness c = gets (_fitness . _info . (IM.! c) . _eClass) {-# INLINE getFitness #-} getTheta :: EClassId -> RndEGraph (Maybe PVector) getTheta c = gets (_theta . _info . (IM.! c) . _eClass) {-# INLINE getTheta #-} getSize :: EClassId -> RndEGraph Int getSize c = gets (_size . _info . (IM.! c) . _eClass) {-# INLINE getSize #-} isSizeOf :: (Int -> Bool) -> EClass -> Bool isSizeOf p = p . _size . _info {-# INLINE isSizeOf #-} getBestFitness :: RndEGraph (Maybe Double) getBestFitness = do bec <- (gets (snd . getGreatest . _fitRangeDB . _eDB) >>= canonical) gets (_fitness . _info . (IM.! bec) . _eClass) -- TODO: move to dataset lib chunksOf :: Int -> [e] -> [[e]] chunksOf i ls = Prelude.map (Prelude.take i) (build (splitter ls)) where splitter :: [e] -> ([e] -> a -> a) -> a -> a splitter [] _ n = n splitter l c n = l `c` splitter (Prelude.drop i l) c n build :: ((a -> [a] -> [a]) -> [a] -> [a]) -> [a] build g = g (:) [] splitData :: DataSet ->Int -> State StdGen (DataSet, DataSet) splitData (x, y, mYErr) k = do if k == 1 then pure ((x, y, mYErr), (x, y, mYErr)) else do ixs' <- (state . shuffle) [0 .. sz-1] let ixs = chunksOf k ixs' let (x_tr, x_te) = getX ixs x (y_tr, y_te) = getY ixs y mY = fmap (getY ixs) mYErr (y_err_tr, y_err_te) = (fmap fst mY, fmap snd mY) pure ((x_tr, y_tr, y_err_tr), (x_te, y_te, y_err_te)) where (MA.Sz sz) = MA.size y comp_x = MA.getComp x comp_y = MA.getComp y getX :: [[Int]] -> SRMatrix -> (SRMatrix, SRMatrix) getX ixs xs' = let xs = MA.toLists xs' :: [MA.ListItem MA.Ix2 Double] in ( MA.fromLists' comp_x [xs !! ix | ixs_i <- ixs, ix <- Prelude.tail ixs_i] , MA.fromLists' comp_x [xs !! ix | ixs_i <- ixs, let ix = Prelude.head ixs_i] ) getY :: [[Int]] -> PVector -> (PVector, PVector) getY ixs ys = ( MA.fromList comp_y [ys MA.! ix | ixs_i <- ixs, ix <- Prelude.tail ixs_i] , MA.fromList comp_y [ys MA.! ix | ixs_i <- ixs, let ix = Prelude.head ixs_i] ) getTrain :: ((a, b1, c1, d1), (c2, b2), c3, d2) -> (a, b1, c2) getTrain ((a, b, _, _), (c, _), _, _) = (a,b,c) getX :: DataSet -> SRMatrix getX (a, _, _) = a getTarget :: DataSet -> PVector getTarget (_, b, _) = b getError :: DataSet -> Maybe PVector getError (_, _, c) = c loadTrainingOnly fname b = getTrain <$> loadDataset fname b parseNonTerms :: String -> [SRTree ()] parseNonTerms = Prelude.map toNonTerm . splitOn "," where binTerms = Map.fromList [ (Prelude.map toLower (show op), op) | op <- [Add .. AQ]] uniTerms = Map.fromList [ (Prelude.map toLower (show f), f) | f <- [Abs .. Cube]] toNonTerm xs' = let xs = Prelude.map toLower xs' in case binTerms Map.!? xs of Just op -> Bin op () () Nothing -> case uniTerms Map.!? xs of Just f -> Uni f () Nothing -> error $ "invalid non-terminal " <> show xs -- RndEGraph utils -- fitFun fitnessFunRep rep iter distribution x y mYErr x_val y_val mYErr_val insertExpr :: Fix SRTree -> (Fix SRTree -> RndEGraph (Double, PVector)) -> RndEGraph EClassId insertExpr t fitFun = do ecId <- fromTree myCost t >>= canonical (f, p) <- fitFun t insertFitness ecId f p io . putStrLn $ "Best fit global: " <> show f pure ecId where powabs l r = Fix (Bin PowerAbs l r) updateIfNothing fitFun ec = do mf <- getFitness ec case mf of Nothing -> do t <- getBestExpr ec (f, p) <- fitFun t insertFitness ec f p pure True Just _ -> pure False pickRndSubTree :: RndEGraph (Maybe EClassId) pickRndSubTree = do ecIds <- gets (IntSet.toList . _unevaluated . _eDB) if not (null ecIds) then do rndId' <- rnd $ randomFrom ecIds rndId <- canonical rndId' constType <- gets (_consts . _info . (IM.! rndId) . _eClass) case constType of NotConst -> pure $ Just rndId _ -> pure Nothing else pure Nothing getParetoEcsUpTo n maxSize = concat <$> forM [1..maxSize] (\i -> getTopFitEClassWithSize i n) getBestExprWithSize n = do ec <- getTopFitEClassWithSize n 1 >>= traverse canonical if (not (null ec)) then do bestFit <- getFitness $ head ec bestP <- gets (_theta . _info . (IM.! (head ec)) . _eClass) (:[]) . (,bestP) . (,bestFit) . (,ec) <$> getBestExpr (head ec) else pure [] insertRndExpr maxSize rndTerm rndNonTerm = do grow <- rnd toss n <- rnd (randomFrom [if maxSize > 4 then 4 else 1 .. maxSize]) t <- rnd $ Random.randomTree 3 8 n rndTerm rndNonTerm grow fromTree myCost t >>= canonical printBest :: (Int -> EClassId -> RndEGraph ()) -> RndEGraph () printBest printExprFun = do bec <- gets (snd . getGreatest . _fitRangeDB . _eDB) >>= canonical printExprFun 0 bec paretoFront :: Int -> (Int -> EClassId -> RndEGraph ()) -> RndEGraph () paretoFront maxSize printExprFun = go 1 0 (-(1.0/0.0)) where go :: Int -> Int -> Double -> RndEGraph () go n ix f | n > maxSize = pure () | otherwise = do ecList <- getBestExprWithSize n if not (null ecList) then do let (((_, ec), mf), _) = head ecList improved = fromJust mf > f ec' <- traverse canonical ec when improved $ printExprFun ix (head ec') go (n+1) (ix + if improved then 1 else 0) (max f (fromJust mf)) else go (n+1) ix f evaluateUnevaluated fitFun = do ec <- gets (IntSet.toList . _unevaluated . _eDB) forM_ ec $ \c -> do t <- getBestExpr c (f, p) <- fitFun t insertFitness c f p evaluateRndUnevaluated fitFun = do ec <- gets (IntSet.toList . _unevaluated . _eDB) c <- rnd . randomFrom $ ec t <- getBestExpr c (f, p) <- fitFun t insertFitness c f p pure c