{-# LANGUAGE QuantifiedConstraints, RankNTypes, UnicodeSyntax, UndecidableInstances #-}
{-|
   Definition of 'Pattern' for use in equality matching
   ('Data.Equality.Matching'), where patterns are matched against the e-graph
 -}
module Data.Equality.Matching.Pattern where

import Data.String

import Data.Equality.Utils
import Data.Equality.Matching.Database

-- | A pattern can be either a variable or an non-variable expression of
-- patterns.
--
-- A 'NonVariablePattern' will only match an expression if the @lang@ constructor matches an expression and all child patterns match the expression children.
-- A 'VariablePattern' matches any expression.
--
-- === Example
--
-- The expression
--
-- @
-- expr :: Fix Sym
-- expr = BinOp Add (Sym "x") (Const 2.0) -- i.e. x + 2
-- @
--
-- Would be matched against the following patterns
--
-- @
-- pat1 :: Pattern Sym
-- pat1 = VariablePattern 1
--
-- pat2 :: Pattern Sym
-- pat2 = NonVariablePattern (BinOp Add (VariablePattern 1) (VariablePattern 2))
--
-- pat3 :: Pattern Sym
-- pat3 = NonVariablePattern (BinOp Add (VariablePattern 1) (NonVariablePattern (Const 2)))
-- @
--
-- But would not be matched against the following patterns
-- 
-- @
-- pat4 :: Pattern Sym
-- pat4 = NonVariablePattern (Const 5)
--
-- pat5 :: Pattern Sym
-- pat5 = NonVariablePattern (BinOp Add (NonVariablePattern (Sym "y")) (NonVariablePattern (Const 2)))
--
-- pat6 :: Pattern Sym
-- pat6 = NonVariablePattern (BinOp Add (NonVariablePattern (Sym "x")) (NonVariablePattern (Const 3)))
-- @
--
-- === IsString
-- 'Pattern' instances 'IsString', which means one can write a variable pattern simply as a string.
--
-- It works by using 'Data.Equality.Utils.hashString' to create a unique integer for a 'VariablePattern'
--
-- For example, we could write the following pattern that would match @a+a@ and @b+b@ but not @a+b@
--
-- @
-- pat7 :: Pattern Sym
-- pat7 = 'pat' (BinOp Add "x" "x")
-- @
data Pattern lang
    = NonVariablePattern (lang (Pattern lang))
    | VariablePattern Var -- ^ Should be a >0 positive number

-- | Synonym for 'NonVariablePattern'.
--
-- Example
--
-- @
-- pat8 :: Pattern Sym
-- pat8 = pat (BinOp Mul "y" (pat (Const 2))) -- matches any product of an expression by 2
-- @
pat :: lang (Pattern lang) -> Pattern lang
pat :: forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
pat = lang (Pattern lang) -> Pattern lang
forall (lang :: * -> *). lang (Pattern lang) -> Pattern lang
NonVariablePattern

instance ( a. Eq a => Eq (l a)) => (Eq (Pattern l)) where
    == :: Pattern l -> Pattern l -> Bool
(==) (NonVariablePattern l (Pattern l)
a) (NonVariablePattern l (Pattern l)
b) = l (Pattern l) -> l (Pattern l) -> Bool
forall a. Eq a => a -> a -> Bool
(==) l (Pattern l)
a l (Pattern l)
b
    (==) (VariablePattern Var
a) (VariablePattern Var
b) = Var
a Var -> Var -> Bool
forall a. Eq a => a -> a -> Bool
== Var
b 
    (==) Pattern l
_ Pattern l
_ = Bool
False

instance ( a. Eq a => Eq (l a),  a. (Ord a) => Ord (l a)) => (Ord (Pattern l)) where
    compare :: Pattern l -> Pattern l -> Ordering
compare (VariablePattern Var
_) (NonVariablePattern l (Pattern l)
_) = Ordering
LT
    compare (NonVariablePattern l (Pattern l)
_) (VariablePattern Var
_) = Ordering
GT
    compare (VariablePattern Var
a) (VariablePattern Var
b) = Var -> Var -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Var
a Var
b
    compare (NonVariablePattern l (Pattern l)
a) (NonVariablePattern l (Pattern l)
b) = l (Pattern l) -> l (Pattern l) -> Ordering
forall a. Ord a => a -> a -> Ordering
compare l (Pattern l)
a l (Pattern l)
b

instance ( a. Show a => Show (lang a)) => Show (Pattern lang) where
    showsPrec :: Var -> Pattern lang -> ShowS
showsPrec Var
_ (VariablePattern Var
s) = String -> ShowS
showString (Var -> String
forall a. Show a => a -> String
show Var
s) -- ROMES:TODO don't ignore prec?
    showsPrec Var
d (NonVariablePattern lang (Pattern lang)
x) = Var -> lang (Pattern lang) -> ShowS
forall a. Show a => Var -> a -> ShowS
showsPrec Var
d lang (Pattern lang)
x

instance IsString (Pattern lang) where
    fromString :: String -> Pattern lang
fromString = Var -> Pattern lang
forall (lang :: * -> *). Var -> Pattern lang
VariablePattern (Var -> Pattern lang) -> (String -> Var) -> String -> Pattern lang
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Var
hashString