{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -fno-warn-missing-methods #-}
module Data.Semiring
(
Semiring(..)
, (+)
, (*)
, (^)
, foldMapP
, foldMapT
, sum
, product
, sum'
, product'
, Add(..)
, Mul(..)
, Ring(..)
, (-)
, minus
) where
import Control.Applicative (Applicative(..), Const(..), liftA2)
import Data.Bool (Bool(..), (||), (&&), otherwise, not)
import Data.Complex (Complex(..))
import Data.Eq (Eq(..))
import Data.Fixed (Fixed, HasResolution)
import Data.Foldable (Foldable)
import qualified Data.Foldable as Foldable
import Data.Function ((.), const, flip, id)
import Data.Functor (Functor(..))
import Data.Functor.Identity (Identity(..))
#if defined(VERSION_unordered_containers)
import Data.Hashable (Hashable)
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HashMap
import Data.HashSet (HashSet)
import qualified Data.HashSet as HashSet
#endif
import Data.Int (Int, Int8, Int16, Int32, Int64)
import Data.Maybe (Maybe(..))
#if MIN_VERSION_base(4,12,0)
import Data.Monoid (Ap(..))
#endif
#if defined(VERSION_containers)
import Data.Map (Map)
import qualified Data.Map as Map
#endif
import Data.Monoid (Monoid(..),Dual(..), Product(..), Sum(..))
import Data.Ord (Ord(..), Ordering(..), compare)
#if MIN_VERSION_base(4,6,0)
import Data.Ord (Down(..))
#endif
import Data.Proxy (Proxy(..))
import Data.Ratio (Ratio)
import Data.Semigroup (Semigroup(..),Max(..), Min(..))
#if defined(VERSION_containers)
import Data.Set (Set)
import qualified Data.Set as Set
#endif
import Data.Traversable (Traversable)
import Data.Typeable (Typeable)
#if defined(VERSION_vector)
import Data.Vector (Vector)
import qualified Data.Vector as Vector
import qualified Data.Vector.Storable as SV
import qualified Data.Vector.Unboxed as UV
#endif
import Data.Word (Word, Word8, Word16, Word32, Word64)
import Foreign.C.Types
(CChar, CClock, CDouble, CFloat, CInt,
CIntMax, CIntPtr, CLLong, CLong,
CPtrdiff, CSChar, CSUSeconds, CShort,
CSigAtomic, CSize, CTime, CUChar, CUInt,
CUIntMax, CUIntPtr, CULLong, CULong,
CUSeconds, CUShort, CWchar)
import Foreign.Ptr (IntPtr, WordPtr)
import Foreign.Storable (Storable)
import GHC.Base (build)
import GHC.Enum (Enum, Bounded)
import GHC.Float (Float, Double)
#if MIN_VERSION_base(4,6,1)
import GHC.Generics (Generic,Generic1)
#endif
import GHC.IO (IO)
import GHC.Integer (Integer)
import qualified GHC.Num as Num
import GHC.Read (Read)
import GHC.Real (Integral, Fractional, Real, RealFrac, quot, even)
import GHC.Show (Show)
import Numeric.Natural (Natural)
import System.Posix.Types
(CCc, CDev, CGid, CIno, CMode, CNlink,
COff, CPid, CRLim, CSpeed, CSsize,
CTcflag, CUid, Fd)
infixl 7 *, `times`
infixl 6 +, `plus`, -, `minus`
infixr 8 ^
(^) :: (Semiring a, Integral b) => a -> b -> a
x0 ^ y0 | y0 < 0 = zero
| y0 == 0 = one
| otherwise = f x0 y0
where
f x y | even y = f (x * x) (y `quot` 2)
| y == 1 = x
| otherwise = g (x * x) (y `quot` 2) x
g x y z | even y = g (x * x) (y `quot` 2) z
| y == 1 = x * z
| otherwise = g (x * x) (y `quot` 2) (x * z)
{-# INLINE (^) #-}
(+) :: Semiring a => a -> a -> a
(+) = plus
{-# INLINE (+) #-}
(*) :: Semiring a => a -> a -> a
(*) = times
{-# INLINE (*) #-}
(-) :: Ring a => a -> a -> a
(-) = minus
{-# INLINE (-) #-}
foldMapP :: (Foldable t, Semiring s) => (a -> s) -> t a -> s
foldMapP f = Foldable.foldr (plus . f) zero
{-# INLINE foldMapP #-}
foldMapT :: (Foldable t, Semiring s) => (a -> s) -> t a -> s
foldMapT f = Foldable.foldr (times . f) one
{-# INLINE foldMapT #-}
sum :: (Foldable t, Semiring a) => t a -> a
sum = Foldable.foldr plus zero
{-# INLINE sum #-}
product :: (Foldable t, Semiring a) => t a -> a
product = Foldable.foldr times one
{-# INLINE product #-}
sum' :: (Foldable t, Semiring a) => t a -> a
sum' = Foldable.foldl' plus zero
{-# INLINE sum' #-}
product' :: (Foldable t, Semiring a) => t a -> a
product' = Foldable.foldl' times one
{-# INLINE product' #-}
newtype Add a = Add { getAdd :: a }
deriving
( Bounded
, Enum
, Eq
, Foldable
, Fractional
, Functor
#if MIN_VERSION_base(4,6,1)
, Generic
, Generic1
#endif
, Num.Num
, Ord
, Read
, Real
, RealFrac
, Semiring
, Show
, Storable
, Traversable
, Typeable
)
newtype Mul a = Mul { getMul :: a }
deriving
( Bounded
, Enum
, Eq
, Foldable
, Fractional
, Functor
#if MIN_VERSION_base(4,6,1)
, Generic
, Generic1
#endif
, Num.Num
, Ord
, Read
, Real
, RealFrac
, Semiring
, Show
, Storable
, Traversable
, Typeable
)
instance Semiring a => Semigroup (Add a) where
(<>) = (+)
{-# INLINE (<>) #-}
instance Semiring a => Monoid (Add a) where
mempty = Add zero
mappend = (<>)
{-# INLINE mempty #-}
{-# INLINE mappend #-}
instance Semiring a => Semigroup (Mul a) where
(<>) = (*)
{-# INLINE (<>) #-}
instance Semiring a => Monoid (Mul a) where
mempty = Mul one
mappend = (<>)
{-# INLINE mempty #-}
{-# INLINE mappend #-}
class Semiring a where
#if __GLASGOW_HASKELL__ >= 708
{-# MINIMAL plus, zero, times, one #-}
#endif
plus :: a -> a -> a
zero :: a
times :: a -> a -> a
one :: a
default zero :: Num.Num a => a
default one :: Num.Num a => a
default plus :: Num.Num a => a -> a -> a
default times :: Num.Num a => a -> a -> a
zero = 0
one = 1
plus = (Num.+)
times = (Num.*)
class Semiring a => Ring a where
#if __GLASGOW_HASKELL__ >= 708
{-# MINIMAL negate #-}
#endif
negate :: a -> a
default negate :: Num.Num a => a -> a
negate = Num.negate
minus :: Ring a => a -> a -> a
minus x y = x + negate y
{-# INLINE minus #-}
instance Semiring b => Semiring (a -> b) where
plus f g x = f x `plus` g x
zero = const zero
times f g x = f x `times` g x
one = const one
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring b => Ring (a -> b) where
negate f x = negate (f x)
{-# INLINE negate #-}
instance Semiring () where
plus _ _ = ()
zero = ()
times _ _ = ()
one = ()
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring () where
negate _ = ()
{-# INLINE negate #-}
instance Semiring (Proxy a) where
plus _ _ = Proxy
zero = Proxy
times _ _ = Proxy
one = Proxy
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Semiring Bool where
plus = (||)
zero = False
times = (&&)
one = True
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring Bool where
negate = not
{-# INLINE negate #-}
instance Semiring a => Semiring [a] where
zero = []
one = [one]
plus = listAdd
times = listTimes
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring a => Ring [a] where
negate = fmap negate
{-# INLINE negate #-}
instance Semiring a => Semiring (Maybe a) where
zero = Nothing
one = Just one
plus Nothing y = y
plus x Nothing = x
plus (Just x) (Just y) = Just (plus x y)
times Nothing _ = Nothing
times _ Nothing = Nothing
times (Just x) (Just y) = Just (times x y)
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring a => Ring (Maybe a) where
negate = fmap negate
{-# INLINE negate #-}
instance Semiring a => Semiring (IO a) where
zero = pure zero
one = pure one
plus = liftA2 plus
times = liftA2 times
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring a => Ring (IO a) where
negate = fmap negate
{-# INLINE negate #-}
instance Semiring a => Semiring (Dual a) where
zero = Dual zero
Dual x `plus` Dual y = Dual (y `plus` x)
one = Dual one
Dual x `times` Dual y = Dual (y `times` x)
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring a => Ring (Dual a) where
negate (Dual x) = Dual (negate x)
{-# INLINE negate #-}
instance Semiring a => Semiring (Const a b) where
zero = Const zero
one = Const one
plus (Const x) (Const y) = Const (x `plus` y)
times (Const x) (Const y) = Const (x `times` y)
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring a => Ring (Const a b) where
negate (Const x) = Const (negate x)
{-# INLINE negate #-}
instance Ring a => Semiring (Complex a) where
zero = zero :+ zero
one = one :+ zero
plus (x :+ y) (x' :+ y') = plus x x' :+ plus y y'
times (x :+ y) (x' :+ y')
= (x * x' - (y * y')) :+ (x * y' + y * x')
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring a => Ring (Complex a) where
negate (x :+ y) = negate x :+ negate y
{-# INLINE negate #-}
#if MIN_VERSION_base(4,12,0)
instance (Semiring a, Applicative f) => Semiring (Ap f a) where
zero = pure zero
one = pure one
plus = liftA2 plus
times = liftA2 times
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
#endif
instance Semiring Int
instance Semiring Int8
instance Semiring Int16
instance Semiring Int32
instance Semiring Int64
instance Semiring Integer
instance Semiring Word
instance Semiring Word8
instance Semiring Word16
instance Semiring Word32
instance Semiring Word64
instance Semiring Float
instance Semiring Double
instance Semiring CUIntMax
instance Semiring CIntMax
instance Semiring CUIntPtr
instance Semiring CIntPtr
instance Semiring CSUSeconds
instance Semiring CUSeconds
instance Semiring CTime
instance Semiring CClock
instance Semiring CSigAtomic
instance Semiring CWchar
instance Semiring CSize
instance Semiring CPtrdiff
instance Semiring CDouble
instance Semiring CFloat
instance Semiring CULLong
instance Semiring CLLong
instance Semiring CULong
instance Semiring CLong
instance Semiring CUInt
instance Semiring CInt
instance Semiring CUShort
instance Semiring CShort
instance Semiring CUChar
instance Semiring CSChar
instance Semiring CChar
instance Semiring IntPtr
instance Semiring WordPtr
instance Semiring Fd
instance Semiring CRLim
instance Semiring CTcflag
instance Semiring CSpeed
instance Semiring CCc
instance Semiring CUid
instance Semiring CNlink
instance Semiring CGid
instance Semiring CSsize
instance Semiring CPid
instance Semiring COff
instance Semiring CMode
instance Semiring CIno
instance Semiring CDev
instance Semiring Natural
instance Integral a => Semiring (Ratio a)
deriving instance Semiring a => Semiring (Product a)
deriving instance Semiring a => Semiring (Sum a)
deriving instance Semiring a => Semiring (Identity a)
#if MIN_VERSION_base(4,6,0)
deriving instance Semiring a => Semiring (Down a)
#endif
deriving instance Semiring a => Semiring (Max a)
deriving instance Semiring a => Semiring (Min a)
instance HasResolution a => Semiring (Fixed a)
instance Ring Int
instance Ring Int8
instance Ring Int16
instance Ring Int32
instance Ring Int64
instance Ring Integer
instance Ring Word
instance Ring Word8
instance Ring Word16
instance Ring Word32
instance Ring Word64
instance Ring Float
instance Ring Double
instance Ring CUIntMax
instance Ring CIntMax
instance Ring CUIntPtr
instance Ring CIntPtr
instance Ring CSUSeconds
instance Ring CUSeconds
instance Ring CTime
instance Ring CClock
instance Ring CSigAtomic
instance Ring CWchar
instance Ring CSize
instance Ring CPtrdiff
instance Ring CDouble
instance Ring CFloat
instance Ring CULLong
instance Ring CLLong
instance Ring CULong
instance Ring CLong
instance Ring CUInt
instance Ring CInt
instance Ring CUShort
instance Ring CShort
instance Ring CUChar
instance Ring CSChar
instance Ring CChar
instance Ring IntPtr
instance Ring WordPtr
instance Ring Fd
instance Ring CRLim
instance Ring CTcflag
instance Ring CSpeed
instance Ring CCc
instance Ring CUid
instance Ring CNlink
instance Ring CGid
instance Ring CSsize
instance Ring CPid
instance Ring COff
instance Ring CMode
instance Ring CIno
instance Ring CDev
instance Ring Natural
instance Integral a => Ring (Ratio a)
#if MIN_VERSION_base(4,6,0)
deriving instance Ring a => Ring (Down a)
#endif
deriving instance Ring a => Ring (Product a)
deriving instance Ring a => Ring (Sum a)
deriving instance Ring a => Ring (Identity a)
deriving instance Ring a => Ring (Max a)
deriving instance Ring a => Ring (Min a)
instance HasResolution a => Ring (Fixed a)
#if defined(VERSION_containers)
instance (Ord a, Monoid a) => Semiring (Set a) where
zero = Set.empty
one = Set.singleton mempty
plus = Set.union
times xs ys = Foldable.foldMap (flip Set.map ys . mappend) xs
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance (Ord k, Monoid k, Semiring v) => Semiring (Map k v) where
zero = Map.empty
one = Map.singleton mempty one
plus = Map.unionWith (+)
xs `times` ys
= Map.fromListWith (+)
[ (mappend k l, v * u)
| (k,v) <- Map.toList xs
, (l,u) <- Map.toList ys
]
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
#endif
#if defined(VERSION_unordered_containers)
instance (Eq a, Hashable a, Monoid a) => Semiring (HashSet a) where
zero = HashSet.empty
one = HashSet.singleton mempty
plus = HashSet.union
times xs ys = Foldable.foldMap (flip HashSet.map ys . mappend) xs
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance (Eq k, Hashable k, Monoid k, Semiring v) => Semiring (HashMap k v) where
zero = HashMap.empty
one = HashMap.singleton mempty one
plus = HashMap.unionWith (+)
xs `times` ys
= HashMap.fromListWith (+)
[ (mappend k l, v * u)
| (k,v) <- HashMap.toList xs
, (l,u) <- HashMap.toList ys
]
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
#endif
#if defined(VERSION_primitive)
#endif
#if defined(VERSION_vector)
instance Semiring a => Semiring (Vector a) where
zero = Vector.empty
one = Vector.singleton one
plus xs ys =
case compare (Vector.length xs) (Vector.length ys) of
EQ -> Vector.zipWith (+) xs ys
LT -> Vector.unsafeAccumulate (+) ys (Vector.indexed xs)
GT -> Vector.unsafeAccumulate (+) xs (Vector.indexed ys)
times signal kernel
| Vector.null signal = Vector.empty
| Vector.null kernel = Vector.empty
| otherwise = Vector.generate (slen + klen - 1) f
where
!slen = Vector.length signal
!klen = Vector.length kernel
f n = Foldable.foldl'
(\a k -> a +
Vector.unsafeIndex signal k *
Vector.unsafeIndex kernel (n - k)
)
zero
[kmin .. kmax]
where
!kmin = max 0 (n - (klen - 1))
!kmax = min n (slen - 1)
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance Ring a => Ring (Vector a) where
negate = Vector.map negate
{-# INLINE negate #-}
instance (UV.Unbox a, Semiring a) => Semiring (UV.Vector a) where
zero = UV.empty
one = UV.singleton one
plus xs ys =
case compare (UV.length xs) (UV.length ys) of
EQ -> UV.zipWith (+) xs ys
LT -> UV.unsafeAccumulate (+) ys (UV.indexed xs)
GT -> UV.unsafeAccumulate (+) xs (UV.indexed ys)
times signal kernel
| UV.null signal = UV.empty
| UV.null kernel = UV.empty
| otherwise = UV.generate (slen + klen - 1) f
where
!slen = UV.length signal
!klen = UV.length kernel
f n = Foldable.foldl'
(\a k -> a +
UV.unsafeIndex signal k *
UV.unsafeIndex kernel (n - k)
)
zero
[kmin .. kmax]
where
!kmin = max 0 (n - (klen - 1))
!kmax = min n (slen - 1)
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance (UV.Unbox a, Ring a) => Ring (UV.Vector a) where
negate = UV.map negate
{-# INLINE negate #-}
instance (SV.Storable a, Semiring a) => Semiring (SV.Vector a) where
zero = SV.empty
one = SV.singleton one
plus xs ys =
case compare lxs lys of
EQ -> SV.zipWith (+) xs ys
LT -> SV.unsafeAccumulate_ (+) ys (SV.enumFromN 0 lxs) xs
GT -> SV.unsafeAccumulate_ (+) xs (SV.enumFromN 0 lys) ys
where
lxs = SV.length xs
lys = SV.length ys
times signal kernel
| SV.null signal = SV.empty
| SV.null kernel = SV.empty
| otherwise = SV.generate (slen + klen - 1) f
where
!slen = SV.length signal
!klen = SV.length kernel
f n = Foldable.foldl'
(\a k -> a +
SV.unsafeIndex signal k *
SV.unsafeIndex kernel (n - k))
zero
[kmin .. kmax]
where
!kmin = max 0 (n - (klen - 1))
!kmax = min n (slen - 1)
{-# INLINE plus #-}
{-# INLINE zero #-}
{-# INLINE times #-}
{-# INLINE one #-}
instance (SV.Storable a, Ring a) => Ring (SV.Vector a) where
negate = SV.map negate
{-# INLINE negate #-}
#endif
listAdd, listTimes :: Semiring a => [a] -> [a] -> [a]
listAdd [] ys = ys
listAdd xs [] = xs
listAdd (x:xs) (y:ys) = (x + y) : listAdd xs ys
{-# NOINLINE [0] listAdd #-}
listTimes [] (_:xs) = zero : listTimes [] xs
listTimes (_:xs) [] = zero : listTimes xs []
listTimes [] [] = []
listTimes (x:xs) (y:ys) = (x * y) : listTimes xs ys
{-# NOINLINE [0] listTimes #-}
type ListBuilder a = forall b. (a -> b -> b) -> b -> b
{-# RULES
"listAddFB/left" forall (g :: ListBuilder a). listAdd (build g) = listAddFBL g
"listAddFB/right" forall xs (g :: ListBuilder a). listAdd xs (build g) = listAddFBR xs g
#-}
listAddFBL :: Semiring a => ListBuilder a -> [a] -> [a]
listAddFBL xf = xf f id where
f x xs (y:ys) = x + y : xs ys
f x xs [] = x : xs []
listAddFBR :: Semiring a => [a] -> ListBuilder a -> [a]
listAddFBR xs' yf = yf f id xs' where
f y ys (x:xs) = x + y : ys xs
f y ys [] = y : ys []