{- 
 - 	Monadic Constraint Programming
 - 	http://www.cs.kuleuven.be/~toms/MCP/
 - 	Pieter Wuille
 -}



{-# LANGUAGE FlexibleInstances #-}

{-# LANGUAGE StandaloneDeriving #-}



module Data.Expr.Data (

  Expr(..),

  ColExpr(..),

  BoolExpr(..),

  ExprRel(..),

  (<<>>)

) where 



--------------------

-- | Data types | --

--------------------



-- some simple kinds of expressions

data Expr t c b =

    Term t

  | ExprHole Int

  | Const Integer

  | Plus (Expr t c b) (Expr t c b)

  | Minus (Expr t c b) (Expr t c b)

  | Mult (Expr t c b) (Expr t c b)

  | Div (Expr t c b) (Expr t c b)

  | Mod (Expr t c b) (Expr t c b)

  | Abs (Expr t c b)

  | At (ColExpr t c b) (Expr t c b)

  | Fold (Expr t c b -> Expr t c b -> Expr t c b) (Expr t c b) (ColExpr t c b)

  | Cond (BoolExpr t c b) (Expr t c b) (Expr t c b)

  | ColSize (ColExpr t c b)

  | Channel (BoolExpr t c b)



data ColExpr t c b = 

    ColTerm c

  | ColList [Expr t c b]

  | ColRange (Expr t c b) (Expr t c b)

  | ColMap (Expr t c b -> Expr t c b) (ColExpr t c b)

  | ColSlice (Expr t c b -> Expr t c b) (Expr t c b) (ColExpr t c b)   -- ColSlice f n c -> c[f(0)..f(n-1)]

  | ColCat (ColExpr t c b) (ColExpr t c b)



data ExprRel =

    EREqual

  | ERDiff

  | ERLess

  deriving (Show,Eq,Ord)



data BoolExpr t c b =

    BoolTerm b

  | BoolConst Bool

  | BoolAnd (BoolExpr t c b) (BoolExpr t c b)

  | BoolOr (BoolExpr t c b) (BoolExpr t c b)

  | BoolNot (BoolExpr t c b)

  | BoolCond (BoolExpr t c b) (BoolExpr t c b) (BoolExpr t c b)

  | Rel (Expr t c b) ExprRel (Expr t c b)

  | BoolAll (Expr t c b -> BoolExpr t c b) (ColExpr t c b)

  | BoolAny (Expr t c b -> BoolExpr t c b) (ColExpr t c b)

  | ColEqual (ColExpr t c b) (ColExpr t c b)

  | BoolEqual (BoolExpr t c b) (BoolExpr t c b)

  | AllDiff Bool (ColExpr t c b)

  | Sorted Bool (ColExpr t c b)

  | Dom (Expr t c b) (ColExpr t c b)



-----------------------

-- | Show instance | --

-----------------------



class ShowFn t where

  showFn :: Int -> t -> String

instance (Show t, Show c, Show b) => ShowFn (Expr t c b) where

  showFn _ (Term a) = "Term ("++(show a)++")"

  showFn _ (ExprHole a) = "par"++(show a)

  showFn _ (Const a) = "Const "++(show a)

  showFn l (Plus a b) = "Plus ("++(showFn l a)++") ("++(showFn l b)++")"

  showFn l (Minus a b) = "Minus ("++(showFn l a)++") ("++(showFn l b)++")"

  showFn l (Mult a b) = "Mult ("++(showFn l a)++") ("++(showFn l b)++")"

  showFn l (Div a b) = "Div ("++(showFn l a)++") ("++(showFn l b)++")"

  showFn l (Mod a b) = "Mod ("++(showFn l a)++") ("++(showFn l b)++")"

  showFn l (Abs a) = "Abs ("++(showFn l a)++")"

  showFn l (At a b) = "At ("++(showFn l a)++") ("++(showFn l b)++")"

  showFn l (Fold a b c) = "Fold ("++(showFn l a)++") ("++(showFn l b)++") ("++(showFn l c)++")"

  showFn l (ColSize a) = "ColSize ("++(showFn l a)++")"

  showFn l (Channel b) = "Channel ("++(showFn l b)++")"

  showFn l (Cond c t f) = "Cond ("++(showFn l c)++") ("++(showFn l t)++") ("++(showFn l f)++")"

instance (ShowFn l) => ShowFn [l] where

  showFn d l = "[" ++ (foldr1 (\a b -> a++","++b) $ map (showFn d) l) ++ "]"

instance (Show t, Show c, Show b) => ShowFn (ColExpr t c b) where

  showFn d (ColTerm a) = "ColTerm ("++(show a)++")"

  showFn d (ColList l) = "ColList ("++(showFn d l)++")"

  showFn d (ColMap f l) = "ColMap ("++(showFn d f)++") ("++(showFn d l)++")"

  showFn d (ColSlice f l c) = "ColSlice ("++(showFn d f)++") ("++(showFn d l)++") ("++(showFn d c)++")"

  showFn d (ColCat a b) = "ColCat ("++(showFn d a)++") ("++(showFn d b)++")"

  showFn d (ColRange a b) = "ColRange ("++(showFn d a)++") ("++(showFn d b)++")"

instance (Show t, Show c, Show b) => ShowFn (BoolExpr t c b) where

  showFn d (BoolTerm b) = "BoolTerm ("++(show b)++")"

  showFn d (BoolConst b) = "BoolConst "++(show b)

  showFn d (BoolAnd a b) = "BoolAnd ("++(showFn d a)++") ("++(showFn d b)++")"

  showFn d (BoolOr a b) = "BoolOr ("++(showFn d a)++") ("++(showFn d b)++")"

  showFn d (BoolNot a) = "BoolNot ("++(showFn d a)++")"

  showFn d (BoolEqual a b) = "BoolEqual ("++(showFn d a)++") ("++(showFn d b)++")"

  showFn d (Rel a r b) = "Rel ("++(showFn d a)++") "++(show r)++" ("++(showFn d b)++")"

  showFn d (BoolAll f c) = "BoolAll ("++(showFn d f)++") ("++(showFn d c)++")"

  showFn d (BoolAny f c) = "BoolAny ("++(showFn d f)++") ("++(showFn d c)++")"

  showFn d (ColEqual a b) = "ColEqual ("++(showFn d a)++") ("++(showFn d b)++")"

  showFn d (AllDiff _ c) = "AllDiff ("++(showFn d c)++")"

  showFn d (Sorted b c) = "Sorted "++(show b)++"("++(showFn d c)++")"

  showFn l (BoolCond c t f) = "BoolCond ("++(showFn l c)++") ("++(showFn l t)++") ("++(showFn l f)++")"

  showFn d (Dom i c) = "Dom ("++(showFn d i)++") ("++(showFn d c)++")"

instance (Show t, Show c, Show b, ShowFn e) => ShowFn (Expr t c b -> e) where

  showFn l f = "\\par"++(show l)++" -> "++(showFn (l+1) (f (ExprHole l)))

instance (Show t, Show c, Show b) => Show (Expr t c b) where

  show = showFn 0

instance (Show t, Show c, Show b) => Show (ColExpr t c b) where

  show = showFn 0

instance (Show t, Show c, Show b) => Show (BoolExpr t c b) where

  show = showFn 0



---------------------

-- | Eq instance | --

---------------------



equalExpr :: (Eq t, Eq c, Eq b) => Int -> Expr t c b -> Expr t c b -> Bool

equalExpr _ (Term a) (Term b) = a==b

equalExpr _ (ExprHole a) (ExprHole b) = a==b

equalExpr _ (Const a) (Const b) = a==b

equalExpr l (Plus a c) (Plus b d) = equalExpr l a b && equalExpr l d c

equalExpr l (Minus a c) (Minus b d) = equalExpr l a b && equalExpr l d c

equalExpr l (Mult a c) (Mult b d) = equalExpr l a b && equalExpr l d c

equalExpr l (Div a c) (Plus b d) = equalExpr l a b && equalExpr l d c

equalExpr l (Mod a c) (Plus b d) = equalExpr l a b && equalExpr l d c

equalExpr l (Abs a) (Abs b) = equalExpr l a b

equalExpr l (At a c) (At b d) = equalExpr l c d && equalColExpr l a b

equalExpr l (ColSize a) (ColSize b) = equalColExpr l a b

equalExpr l (Fold f a c) (Fold g b d) = equalExpr l a b && equalColExpr l c d && equalExpr (l+2) (f (ExprHole l) (ExprHole $ l+1)) (g (ExprHole l) (ExprHole $ l+1))

equalExpr l (Channel a) (Channel b) = equalBoolExpr l a b

equalExpr l (Cond c t f) (Cond d u g) = equalBoolExpr l c d && equalExpr l t u && equalExpr l f g

equalExpr _ _ _ = False



equalColExpr :: (Eq t, Eq c, Eq b) => Int -> ColExpr t c b -> ColExpr t c b -> Bool

equalColExpr _ (ColTerm a) (ColTerm b) = a==b

equalColExpr _ (ColList []) (ColList []) = True

equalColExpr l (ColList (a:ar)) (ColList (b:br)) = equalExpr l a b && equalColExpr l (ColList ar) (ColList br)

equalColExpr l (ColMap f a) (ColMap g b) = equalColExpr l a b && equalExpr (l+1) (f (ExprHole l)) (g (ExprHole l))

equalColExpr l (ColSlice a c e) (ColSlice b d f) = equalExpr (l+1) (a (ExprHole l)) (b  (ExprHole l)) && equalExpr l c d && equalColExpr l e f

equalColExpr l (ColCat a c) (ColCat b d) = equalColExpr l a b && equalColExpr l c d

equalColExpr l (ColRange a c) (ColRange b d) = equalExpr l a b && equalExpr l c d

equalColExpr _ _ _ = False



equalBoolExpr :: (Eq t, Eq c, Eq b) => Int -> BoolExpr t c b -> BoolExpr t c b -> Bool

equalBoolExpr _ (BoolTerm a) (BoolTerm b) = a==b

equalBoolExpr _ (BoolConst a) (BoolConst b) = a==b

equalBoolExpr l (BoolAnd a c) (BoolAnd b d) = equalBoolExpr l a b && equalBoolExpr l c d

equalBoolExpr l (BoolOr a c) (BoolOr b d) = equalBoolExpr l a b && equalBoolExpr l c d

equalBoolExpr l (BoolEqual a c) (BoolEqual b d) = equalBoolExpr l a b && equalBoolExpr l c d

equalBoolExpr l (BoolNot a) (BoolNot b) = equalBoolExpr l a b

equalBoolExpr l (Rel a r c) (Rel b s d) = r==s && equalExpr l a b && equalExpr l c d

equalBoolExpr l (BoolAll f c) (BoolAll g d) = equalColExpr l c d && equalBoolExpr (l+1) (f $ ExprHole l) (g $ ExprHole l)

equalBoolExpr l (BoolAny f c) (BoolAny g d) = equalColExpr l c d && equalBoolExpr (l+1) (f $ ExprHole l) (g $ ExprHole l)

equalBoolExpr l (ColEqual a c) (ColEqual b d) = equalColExpr l a b && equalColExpr l c d

equalBoolExpr l (AllDiff _ c) (AllDiff _ d) = equalColExpr l c d

equalBoolExpr l (Sorted a c) (Sorted b d) = a==b && equalColExpr l c d

equalBoolExpr l (BoolCond c t f) (BoolCond d u g) = equalBoolExpr l c d && equalBoolExpr l t u && equalBoolExpr l f g

equalBoolExpr l (Dom a c) (Dom b d) = equalExpr l a b && equalColExpr l c d

equalBoolExpr _ _ _ = False



instance (Eq t, Eq c, Eq b) => Eq (Expr t c b) where

  a == b = equalExpr 0 a b

instance (Eq t, Eq c, Eq b) => Eq (ColExpr t c b) where

  a == b = equalColExpr 0 a b

instance (Eq t, Eq c, Eq b) => Eq (BoolExpr t c b) where

  a == b = equalBoolExpr 0 a b



-----------------------------------------------------

-- | ExprKey: Provides ordering over expressions | --

-----------------------------------------------------



infixr 4 <<>>

a <<>> b = case a of

  EQ -> b

  _ -> a



compareColExpr :: (Ord s, Ord c, Ord b) => Int -> ColExpr s c b -> ColExpr s c b -> Ordering

compareColExpr _ (ColList []) (ColList []) = EQ

compareColExpr l (ColList (a:ar)) (ColList (b:br)) = compareExpr l a b <<>> compareColExpr l (ColList ar) (ColList br)

compareColExpr _ (ColList _) _ = LT

compareColExpr _ _ (ColList _) = GT

compareColExpr l (ColMap f1 c1) (ColMap f2 c2) = compareColExpr l c1 c2 <<>> compareExpr (l+1) (f1 $ ExprHole l) (f2 $ ExprHole l)

compareColExpr _ (ColMap _ _) _ = LT

compareColExpr _ _ (ColMap _ _) = GT

compareColExpr l (ColSlice p1 l1 c1) (ColSlice p2 l2 c2) = compareExpr (l+1) (p1 $ ExprHole l) (p2 $ ExprHole l) <<>> compareExpr l l1 l2 <<>> compareColExpr l c1 c2

compareColExpr _ (ColSlice _ _ _) _ = LT

compareColExpr _ _ (ColSlice _ _ _) = GT

compareColExpr l (ColCat a1 b1) (ColCat a2 b2) = compareColExpr l a1 a2 <<>> compareColExpr l b1 b2

compareColExpr _ (ColCat _ _) _ = LT

compareColExpr _ _ (ColCat _ _) = GT

compareColExpr l (ColRange l1 h1) (ColRange l2 h2) = compareExpr l l1 l2 <<>> compareExpr l l2 h2

compareColExpr _ (ColRange _ _) _ = LT

compareColExpr _ _ (ColRange _ _) = GT

compareColExpr _ (ColTerm t1) (ColTerm t2) = compare t1 t2



compareBoolExpr :: (Ord s, Ord c, Ord b) => Int -> BoolExpr s c b -> BoolExpr s c b -> Ordering

compareBoolExpr _ (BoolConst a) (BoolConst b) = compare a b

compareBoolExpr _ (BoolConst _) _ = LT

compareBoolExpr _ _ (BoolConst _) = GT

compareBoolExpr l (BoolAnd a1 b1) (BoolAnd a2 b2) = compareBoolExpr l a1 a2 <<>> compareBoolExpr l b1 b2

compareBoolExpr _ (BoolAnd _ _) _ = LT

compareBoolExpr _ _ (BoolAnd _ _) = GT

compareBoolExpr l (BoolOr a1 b1) (BoolOr a2 b2) = compareBoolExpr l a1 a2 <<>> compareBoolExpr l b1 b2

compareBoolExpr _ (BoolOr _ _) _ = LT

compareBoolExpr _ _ (BoolOr _ _) = GT

compareBoolExpr l (BoolEqual a1 b1) (BoolEqual a2 b2) = compareBoolExpr l a1 a2 <<>> compareBoolExpr l b1 b2

compareBoolExpr _ (BoolEqual _ _) _ = LT

compareBoolExpr _ _ (BoolEqual _ _) = GT

compareBoolExpr l (BoolNot a1) (BoolNot a2) = compareBoolExpr l a1 a2

compareBoolExpr _ (BoolNot _) _ = LT

compareBoolExpr _ _ (BoolNot _) = GT

compareBoolExpr l (Rel a1 r1 b1) (Rel a2 r2 b2) = compare r1 r2 <<>> compareExpr l a1 a2 <<>> compareExpr l b1 b2

compareBoolExpr _ (Rel _ _ _) _ = LT

compareBoolExpr _ _ (Rel _ _ _) = GT

compareBoolExpr l (BoolAll f1 c1) (BoolAll f2 c2) = compareColExpr l c1 c2 <<>> compareBoolExpr (l+1) (f1 $ ExprHole l) (f2 $ ExprHole l)

compareBoolExpr _ (BoolAll _ _) _ = LT

compareBoolExpr _ _ (BoolAll _ _) = GT

compareBoolExpr l (BoolAny f1 c1) (BoolAny f2 c2) = compareColExpr l c1 c2 <<>> compareBoolExpr (l+1) (f1 $ ExprHole l) (f2 $ ExprHole l)

compareBoolExpr _ (BoolAny _ _) _ = LT

compareBoolExpr _ _ (BoolAny _ _) = GT

compareBoolExpr l (ColEqual a1 b1) (ColEqual a2 b2) = compareColExpr l a1 a2 <<>> compareColExpr l b1 b2

compareBoolExpr _ (ColEqual _ _) _ = LT

compareBoolExpr _ _ (ColEqual _ _) = GT

compareBoolExpr l (Sorted a1 b1) (Sorted a2 b2) = compare a1 a2 <<>> compareColExpr l b1 b2

compareBoolExpr _ (Sorted _ _) _ = LT

compareBoolExpr _ _ (Sorted _ _) = GT

compareBoolExpr l (AllDiff _ b1) (AllDiff _ b2) = compareColExpr l b1 b2

compareBoolExpr _ (AllDiff _ _) _ = LT

compareBoolExpr _ _ (AllDiff _ _) = GT

compareBoolExpr l (BoolCond c1 t1 f1) (BoolCond c2 t2 f2) = compareBoolExpr l c1 c2 <<>> compareBoolExpr l t1 t2 <<>> compareBoolExpr l f1 f2

compareBoolExpr _ (BoolCond _ _ _) _ = LT

compareBoolExpr _ _ (BoolCond _ _ _) = GT

compareBoolExpr l (Dom i1 c1) (Dom i2 c2) = compareExpr l i1 i2 <<>> compareColExpr l c1 c2

compareBoolExpr _ (Dom _ _) _ = LT

compareBoolExpr _ _ (Dom _ _) = GT

compareBoolExpr _ (BoolTerm a) (BoolTerm b) = compare a b



compareExpr :: (Ord s, Ord c, Ord b) => Int -> Expr s c b -> Expr s c b -> Ordering

compareExpr _ (Const i1) (Const i2) = compare i1 i2

compareExpr _ (Const _) _ = LT

compareExpr _ _ (Const _) = GT

compareExpr _ (ExprHole i1) (ExprHole i2) = compare i1 i2

compareExpr _ (ExprHole _) _ = LT

compareExpr _ _ (ExprHole _) = GT

compareExpr l (Plus a1 b1) (Plus a2 b2) = compareExpr l a1 a2 <<>> compareExpr l b1 b2

compareExpr _ (Plus _ _) _ = LT

compareExpr _ _ (Plus _ _) = GT

compareExpr l (Minus a1 b1) (Minus a2 b2) = compareExpr l a1 a2 <<>> compareExpr l b1 b2

compareExpr _ (Minus _ _) _ = LT

compareExpr _ _ (Minus _ _) = GT

compareExpr l (Mult a1 b1) (Mult a2 b2) = compareExpr l a1 a2 <<>> compareExpr l b1 b2

compareExpr _ (Mult _ _) _ = LT

compareExpr _ _ (Mult _ _) = GT

compareExpr l (Div a1 b1) (Div a2 b2) = compareExpr l a1 a2 <<>> compareExpr l b1 b2

compareExpr _ (Div _ _) _ = LT

compareExpr _ _ (Div _ _) = GT

compareExpr l (Mod a1 b1) (Mod a2 b2) = compareExpr l a1 a2 <<>> compareExpr l b1 b2

compareExpr _ (Mod _ _) _ = LT

compareExpr _ _ (Mod _ _) = GT

compareExpr l (Abs a1) (Abs a2) = compareExpr l a1 a2

compareExpr _ (Abs _) _ = LT

compareExpr _ _ (Abs _) = GT

compareExpr l (At c1 a1) (At c2 a2) = compareExpr l a1 a2 <<>> compareColExpr l c1 c2

compareExpr _ (At _ _) _ = LT

compareExpr _ _ (At _ _) = GT

compareExpr l (ColSize c1) (ColSize c2) = compareColExpr l c1 c2

compareExpr _ (ColSize _) _ = LT

compareExpr _ _ (ColSize _) = GT

compareExpr l (Fold f1 i1 c1) (Fold f2 i2 c2) = compareExpr l i1 i2 <<>> compareColExpr l c1 c2 <<>> compareExpr (l+2) (f1 (ExprHole l) (ExprHole $ l+1)) (f2 (ExprHole l) (ExprHole $ l+1))

compareExpr _ (Fold _ _ _) _ = LT

compareExpr _ _ (Fold _ _ _) = GT

compareExpr l (Channel b1) (Channel b2) = compareBoolExpr l b1 b2

compareExpr _ (Channel _) _ = LT

compareExpr _ _ (Channel _) = GT

compareExpr l (Cond c1 t1 f1) (Cond c2 t2 f2) = compareBoolExpr l c1 c2 <<>> compareExpr l t1 t2 <<>> compareExpr l f1 f2

compareExpr _ (Cond _ _ _) _ = LT

compareExpr _ _ (Cond _ _ _) = GT

compareExpr _ (Term t1) (Term t2) = compare t1 t2



instance (Ord s, Ord c, Ord b) => Ord (Expr s c b) where

  compare = compareExpr 0



instance (Ord s, Ord c, Ord b) => Ord (ColExpr s c b) where

  compare = compareColExpr 0



instance (Ord s, Ord c, Ord b) => Ord (BoolExpr s c b) where

  compare = compareBoolExpr 0