module Data.Singletons.Deriving.Ord ( mkOrdInstance ) where
import Language.Haskell.TH.Desugar
import Data.Singletons.Names
import Data.Singletons.Util
import Language.Haskell.TH.Syntax
import Data.Singletons.Deriving.Infer
import Data.Singletons.Deriving.Util
import Data.Singletons.Syntax
mkOrdInstance :: DsMonad q => DerivDesc q
mkOrdInstance :: DerivDesc q
mkOrdInstance Maybe DCxt
mb_ctxt DType
ty (DataDecl Name
_ [DTyVarBndr]
_ [DCon]
cons) = do
DCxt
constraints <- Maybe DCxt -> DType -> DType -> [DCon] -> q DCxt
forall (q :: * -> *).
DsMonad q =>
Maybe DCxt -> DType -> DType -> [DCon] -> q DCxt
inferConstraintsDef Maybe DCxt
mb_ctxt (Name -> DType
DConT Name
ordName) DType
ty [DCon]
cons
[DClause]
compare_eq_clauses <- (DCon -> q DClause) -> [DCon] -> q [DClause]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DCon -> q DClause
forall (q :: * -> *). Quasi q => DCon -> q DClause
mk_equal_clause [DCon]
cons
let compare_noneq_clauses :: [DClause]
compare_noneq_clauses = (((DCon, Int), (DCon, Int)) -> DClause)
-> [((DCon, Int), (DCon, Int))] -> [DClause]
forall a b. (a -> b) -> [a] -> [b]
map (((DCon, Int) -> (DCon, Int) -> DClause)
-> ((DCon, Int), (DCon, Int)) -> DClause
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (DCon, Int) -> (DCon, Int) -> DClause
mk_nonequal_clause)
[ ((DCon, Int)
con1, (DCon, Int)
con2)
| (DCon, Int)
con1 <- [DCon] -> [Int] -> [(DCon, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [DCon]
cons [Int
1..]
, (DCon, Int)
con2 <- [DCon] -> [Int] -> [(DCon, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [DCon]
cons [Int
1..]
, DCon -> Name
extractName ((DCon, Int) -> DCon
forall a b. (a, b) -> a
fst (DCon, Int)
con1) Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
/=
DCon -> Name
extractName ((DCon, Int) -> DCon
forall a b. (a, b) -> a
fst (DCon, Int)
con2) ]
clauses :: [DClause]
clauses | [DCon] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DCon]
cons = [DClause
mk_empty_clause]
| Bool
otherwise = [DClause]
compare_eq_clauses [DClause] -> [DClause] -> [DClause]
forall a. [a] -> [a] -> [a]
++ [DClause]
compare_noneq_clauses
UInstDecl -> q UInstDecl
forall (m :: * -> *) a. Monad m => a -> m a
return (InstDecl :: forall (ann :: AnnotationFlag).
DCxt
-> Name
-> DCxt
-> OMap Name DType
-> [(Name, LetDecRHS ann)]
-> InstDecl ann
InstDecl { id_cxt :: DCxt
id_cxt = DCxt
constraints
, id_name :: Name
id_name = Name
ordName
, id_arg_tys :: DCxt
id_arg_tys = [DType
ty]
, id_sigs :: OMap Name DType
id_sigs = OMap Name DType
forall a. Monoid a => a
mempty
, id_meths :: [(Name, LetDecRHS Unannotated)]
id_meths = [(Name
compareName, [DClause] -> LetDecRHS Unannotated
UFunction [DClause]
clauses)] })
mk_equal_clause :: Quasi q => DCon -> q DClause
mk_equal_clause :: DCon -> q DClause
mk_equal_clause (DCon [DTyVarBndr]
_tvbs DCxt
_cxt Name
name DConFields
fields DType
_rty) = do
let tys :: DCxt
tys = DConFields -> DCxt
tysOfConFields DConFields
fields
[Name]
a_names <- (DType -> q Name) -> DCxt -> q [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (q Name -> DType -> q Name
forall a b. a -> b -> a
const (q Name -> DType -> q Name) -> q Name -> DType -> q Name
forall a b. (a -> b) -> a -> b
$ String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"a") DCxt
tys
[Name]
b_names <- (DType -> q Name) -> DCxt -> q [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (q Name -> DType -> q Name
forall a b. a -> b -> a
const (q Name -> DType -> q Name) -> q Name -> DType -> q Name
forall a b. (a -> b) -> a -> b
$ String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"b") DCxt
tys
let pat1 :: DPat
pat1 = Name -> [DPat] -> DPat
DConP Name
name ((Name -> DPat) -> [Name] -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> DPat
DVarP [Name]
a_names)
pat2 :: DPat
pat2 = Name -> [DPat] -> DPat
DConP Name
name ((Name -> DPat) -> [Name] -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> DPat
DVarP [Name]
b_names)
DClause -> q DClause
forall (m :: * -> *) a. Monad m => a -> m a
return (DClause -> q DClause) -> DClause -> q DClause
forall a b. (a -> b) -> a -> b
$ [DPat] -> DExp -> DClause
DClause [DPat
pat1, DPat
pat2] (Name -> DExp
DVarE Name
foldlName DExp -> DExp -> DExp
`DAppE`
Name -> DExp
DVarE Name
thenCmpName DExp -> DExp -> DExp
`DAppE`
Name -> DExp
DConE Name
cmpEQName DExp -> DExp -> DExp
`DAppE`
[DExp] -> DExp
mkListE ((Name -> Name -> DExp) -> [Name] -> [Name] -> [DExp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(\Name
a Name
b -> Name -> DExp
DVarE Name
compareName DExp -> DExp -> DExp
`DAppE` Name -> DExp
DVarE Name
a
DExp -> DExp -> DExp
`DAppE` Name -> DExp
DVarE Name
b)
[Name]
a_names [Name]
b_names))
mk_nonequal_clause :: (DCon, Int) -> (DCon, Int) -> DClause
mk_nonequal_clause :: (DCon, Int) -> (DCon, Int) -> DClause
mk_nonequal_clause (DCon [DTyVarBndr]
_tvbs1 DCxt
_cxt1 Name
name1 DConFields
fields1 DType
_rty1, Int
n1)
(DCon [DTyVarBndr]
_tvbs2 DCxt
_cxt2 Name
name2 DConFields
fields2 DType
_rty2, Int
n2) =
[DPat] -> DExp -> DClause
DClause [DPat
pat1, DPat
pat2] (case Int
n1 Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Int
n2 of
Ordering
LT -> Name -> DExp
DConE Name
cmpLTName
Ordering
EQ -> Name -> DExp
DConE Name
cmpEQName
Ordering
GT -> Name -> DExp
DConE Name
cmpGTName)
where
pat1 :: DPat
pat1 = Name -> [DPat] -> DPat
DConP Name
name1 ((DType -> DPat) -> DCxt -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map (DPat -> DType -> DPat
forall a b. a -> b -> a
const DPat
DWildP) (DConFields -> DCxt
tysOfConFields DConFields
fields1))
pat2 :: DPat
pat2 = Name -> [DPat] -> DPat
DConP Name
name2 ((DType -> DPat) -> DCxt -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map (DPat -> DType -> DPat
forall a b. a -> b -> a
const DPat
DWildP) (DConFields -> DCxt
tysOfConFields DConFields
fields2))
mk_empty_clause :: DClause
mk_empty_clause :: DClause
mk_empty_clause = [DPat] -> DExp -> DClause
DClause [DPat
DWildP, DPat
DWildP] (Name -> DExp
DConE Name
cmpEQName)