{-# LANGUAGE PatternGuards #-}

-- | Monomials in a countably infinite set of variables x1, x2, x3, ...
module MathObj.Monomial
    ( -- * Type
      T(..)

      -- * Creating monomials
    , mkMonomial
    , constant
    , x

      -- * Utility functions
    , degree
    , pDegree
    , scaleMon

    ) where

import qualified Algebra.Additive     as Additive
import qualified Algebra.Differential as Differential
import qualified Algebra.Field        as Field
import qualified Algebra.Ring         as Ring
import qualified Algebra.ZeroTestable as ZeroTestable

import           Data.List            (intercalate, sort)
import qualified Data.Map             as M
import           Data.Ord             (comparing)

import           NumericPrelude

-- | A monomial is a map from variable indices to integer powers,
--   paired with a (polymorphic) coefficient.  Note that negative
--   integer powers are handled just fine, so monomials form a field.
--
--   Instances are provided for Eq, Ord, ZeroTestable, Additive, Ring,
--   Differential, and Field.  Note that adding two monomials only
--   makes sense if they have matching variables and exponents.  The
--   Differential instance represents partial differentiation with
--   respect to x1.
--
--   The Ord instance for monomials orders them first by permutation
--   degree, then by largest variable index (largest first), then by
--   exponent (largest first).  This may seem a bit odd, but in fact
--   reflects the use of these monomials to implement cycle index
--   series, where this ordering corresponds nicely to generation
--   of integer partitions. To make the library more general we could
--   parameterize monomials by the desired ordering.
data T a = Cons { coeff  :: a
                , powers :: M.Map Integer Integer
                }

mkMonomial :: a -> [(Integer, Integer)] -> T a
mkMonomial a p = Cons a (M.fromList p)

negOne :: Ring.C a => a
negOne = negate one

instance (ZeroTestable.C a, Ring.C a, Eq a, Show a) => Show (T a) where
  show (Cons a pows) | isZero a    = "0"
                     | M.null pows = show a
                     | a == one    = showVars pows
                     | a == negOne = "-" ++ showVars pows
                     | otherwise   = show a ++ " " ++ showVars pows

showVars :: M.Map Integer Integer -> String
showVars m = intercalate " " $ concatMap showVar (M.toList m)
  where showVar (_,0) = []
        showVar (v,1) = ["x" ++ show v]
        showVar (v,p) = ["x" ++ show v ++ "^" ++ show p]

-- | The degree of a monomial is the sum of its exponents.
degree :: T a -> Integer
degree (Cons _ m) = M.foldr (+) 0 m

-- | The \"partition degree\" of a monomial is the sum of the products
--   of each variable index with its exponent.  For example, x1^3 x2^2
--   x4^3 has partition degree 1*3 + 2*2 + 4*3 = 19.  The terminology
--   comes from the fact that, for example, we can view x1^3 x2^2 x4^3
--   as corresponding to an integer partition of 19 (namely, 1 + 1 + 1
--   + 2 + 2 + 4 + 4 + 4).
pDegree :: T a -> Integer
pDegree (Cons _ m) = sum . map (uncurry (*)) . M.assocs $ m

-- | Create a constant monomial.
constant :: a -> T a
constant a = Cons a M.empty

-- | Create the monomial xn for a given n.
x :: (Ring.C a) => Integer -> T a
x n = Cons Ring.one (M.singleton n 1)

-- | Scale all the variable subscripts by a constant.  Useful for
--   operations like plethyistic substitution or Mobius inversion.
scaleMon :: Integer -> T a -> T a
scaleMon n (Cons a m) = Cons a (M.mapKeys (n*) m)

instance Eq (T a) where
  (Cons _ m1) == (Cons _ m2) = m1 == m2

instance Ord (T a) where
  compare m1 m2
    | d1 < d2   = LT
    | d1 > d2   = GT
    | otherwise = comparing q m1 m2
    where d1 = pDegree m1
          d2 = pDegree m2
          q  = map Rev . reverse . sort . M.assocs . powers

newtype Rev a = Rev a
  deriving Eq
instance Ord a => Ord (Rev a) where
  compare (Rev a) (Rev b) = compare b a

instance (ZeroTestable.C a) => ZeroTestable.C (T a) where
  isZero (Cons a _) = isZero a

instance (Additive.C a, ZeroTestable.C a) => Additive.C (T a) where
  zero = Cons zero M.empty
  negate (Cons a m) = Cons (negate a) m

  -- precondition: m1 == m2
  (Cons a1 m1) + (Cons a2 _m2) | isZero s  = Cons s M.empty
                               | otherwise = Cons s m1
                               where s = a1 + a2

instance (Ring.C a, ZeroTestable.C a) => Ring.C (T a) where
  fromInteger n = Cons (fromInteger n) M.empty
  (Cons a1 m1) * (Cons a2 m2) = Cons (a1*a2)
                                     (M.filterWithKey (\_ p -> not (isZero p)) $
                                        M.unionWith (+) m1 m2
                                     )

-- Partial differentiation with respect to x1.
instance (ZeroTestable.C a, Ring.C a) => Differential.C (T a) where
  differentiate (Cons a m)
    | Just p <- M.lookup 1 m = Cons (a*fromInteger p) (M.update powerPred 1 m)
    | otherwise              = Cons zero M.empty
    where
      powerPred 1 = Nothing
      powerPred p = Just (p-1)

instance (ZeroTestable.C a, Field.C a, Eq a) => Field.C (T a) where
  recip (Cons a pows) = if isZero a
    then error "Monomial.recip: division by zero"
    else Cons (recip a) (M.map negate pows)