{-# LANGUAGE Rank2Types #-}

{-# LANGUAGE FlexibleInstances #-}



module Control.Search.Stat

  ( appStat

  , constStat

  , depthStat

  , nodesStat

  , discrepancyStat

  , solutionsStat

  , failsStat

  , timeStat

  , notStat

  , Stat(..), IValue, varStat

  , (#>), (#<), (#>=), (#<=), (#=), (#/)

  , readStat, evalStat

  ) where



import Text.PrettyPrint hiding (space)

import Prelude hiding ((<>))



import Control.Search.Language

import Control.Search.GeneratorInfo

import Control.Search.Memo

import Control.Search.Generator



-- ========================================================================== --

-- IVALUE

-- ========================================================================== --



type IValue = Info -> Value



instance Show (Info -> Value) where

  show x  = "<IValue>"

instance Eq (Info -> Value) where

  x == y  = False



instance Num (Info -> Value) where

  x - y          = \i -> x i - y i

  fromInteger x  = \i -> IVal (fromInteger x)

  x + y          = \i -> x i + y i

  x * y          = \i -> x i * y i

  abs x          = \i -> abs (x i)

  signum x       = \i -> signum (x i)



-- ========================================================================== --

-- STATS

-- ========================================================================== --



data Stat =   Stat (forall m. Evalable m => Eval m -> Eval m) (forall m. Evalable m => m IValue)



instance Show Stat where

  show _  = "<Stat>"



instance Eq Stat where

  _ == _ = False



readStat :: Evalable m => Stat -> m IValue

readStat (Stat _ r) = r



evalStat :: Evalable m => Stat -> Eval m -> Eval m

evalStat (Stat e _) = e



-- -------------------------------------------------------------------------- --



instance Num Stat where

  x - y          = liftStat (-) x y

  fromInteger    = constStat . fromInteger

  x + y          = liftStat (+) x y

  x * y          = liftStat (*) x y

  abs            = appStat abs

  signum         = appStat signum



instance Bounded Stat where

  maxBound       = constStat $ const MaxVal

  minBound       = constStat $ const MinVal



appStat :: (Value -> Value) -> Stat -> Stat

appStat f (Stat e r) = Stat e (r >>= \x -> return (\i -> f (x i)))



liftStat :: (Value -> Value -> Value) -> Stat -> Stat -> Stat

liftStat op (Stat e1 x) (Stat e2 y) = Stat (e1 . e2) (x >>= \xv -> y >>= \yv -> return (\i -> xv i `op` yv i))



constStat :: IValue -> Stat

constStat x = Stat id (return x)



(#>) :: Stat -> Stat -> Stat

(#>) = liftStat (@>)



(#=) :: Stat -> Stat -> Stat

(#=) = liftStat (@==)



(#<) :: Stat -> Stat -> Stat

(#<) = liftStat (@<)



(#>=) :: Stat -> Stat -> Stat

(#>=) = liftStat (@>=)



(#<=) :: Stat -> Stat -> Stat

(#<=)  = liftStat (@<=)



(#/) :: Stat -> Stat -> Stat

(#/)   = liftStat (divValue)



notStat :: Stat -> Stat

notStat = appStat Not

-- -------------------------------------------------------------------------- --



depthStat :: Stat

depthStat = 

  Stat (\super -> 

               let push dir = \i -> dir super (i `onCommit` mkUpdate i "depth" (\x -> x + 1))

	       in commentEval $ super

                     { treeState_ = entry ("depth",Int,assign $ 0) : treeState_ super

		     , pushLeftH   = push pushLeft

                     , pushRightH  = push pushRight

                     , toString   = "stat_depth:" ++ toString super

                     })

       (return (\info -> tstate info @-> "depth"))



discrepancyStat :: Stat

discrepancyStat = 

  Stat 

    (\super -> commentEval $

       super

         { treeState_ = entry ("discrepancy",Int,assign 0) : treeState_ super

         , pushLeftH   = \i -> pushLeft  super (i `onCommit` mkCopy i "discrepancy")

         , pushRightH  = \i -> pushRight super (i `onCommit` mkUpdate i "discrepancy" (\x -> x + 1))

         , toString = "stat_discr:" ++ toString super

         })

    (return (\info -> tstate info @-> "discrepancy"))



nodesStat :: Stat

nodesStat = 

  eStat ("nodes", Int, const 0) $

          \super -> super { bodyH = \i -> return (inc (estate i @=> "nodes")) @>>>@ bodyE super i }



solutionsStat :: Stat

solutionsStat = 

  eStat ("solutions", Int, const 0) $

           \super -> super {returnH  = \i -> returnE super (i `onCommit` inc (solutions i))}

  where solutions i = estate i @=> "solutions"



varStat :: VarId -> Stat

varStat v@(VarId i) = Stat id (do inf <- lookupVarInfo v

                                  return (const $ estate inf @=> ("var" ++ show i))

                              )



failsStat :: Stat

failsStat = 

  eStat ("fails", Int, const 0) $

          \super -> super { failH = \i -> returnH super i @>>>@ return (inc (fails i)) }

  where fails i = estate i @=> "fails"



eStat :: (String, Type, Info -> Value) -> (forall m. Evalable m => Eval m -> Eval m) -> Stat

eStat (name,typ,val) f =

   Stat (\super -> commentEval $ f $ super { evalState_ = (name,typ,\i -> return (val i)) : evalState_ super, toString = "stat_" ++ name ++ ":" ++ toString super })

        (return (\i -> estate i @=> name))



-- TIMER STATISTIC

--

-- Based on Gecode::Support::Timer.

--

--

--

timeStat :: Stat

timeStat =

   Stat (\super -> commentEval $

		super { evalState_ = ("total",Int, const $ return 0) : 

                                     ("timer",THook "Gecode::Support::Timer",const $ return Null) :

                                     ("running",Bool,const $ return false) :

                                     evalState_ super 

		      , nextDiffH   = \i ->

			return (ifthen (running i) 

			               ((running i <== false) >>> 

	                                (total i <== (total i + (VHook (render $ text "static_cast<int>" <> parens (pretty (timer i) <> text ".stop()"))))))) 

	              , bodyH      = \i -> 

			return (ifthen (Not $ running i) 

			               ((running i <== true) >>> SHook ((render $ pretty $ timer i) ++ ".start();"))) 

		        @>>>@ bodyE super i

                      , toString = "stat_time:" ++ toString super

                      })

       (return (\i -> total i + Cond (running i) (VHook (render $ text "static_cast<int>" <> parens (pretty (timer i) <> text ".stop()"))) 0))

  where running i = estate i @=> "running"

        timer   i = estate i @=> "timer"

        total   i = estate i @=> "total"