{-# LANGUAGE CPP #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE DefaultSignatures #-}
#define USE_GHC_GENERICS
#endif
module Linear.Vector
( Additive(..)
, E(..)
, negated
, (^*)
, (*^)
, (^/)
, sumV
, basis
, basisFor
, scaled
, outer
, unit
) where
import Control.Applicative
import Control.Lens
import Data.Complex
#if __GLASGOW_HASKELL__ < 710
import Data.Foldable as Foldable (Foldable, forM_, foldl')
#else
import Data.Foldable as Foldable (forM_, foldl')
#endif
import Data.HashMap.Lazy as HashMap
import Data.Hashable
import Data.IntMap as IntMap
import Data.Map as Map
#if __GLASGOW_HASKELL__ < 710
import Data.Monoid (mempty)
#endif
import Data.Vector as Vector
import Data.Vector.Mutable as Mutable
#ifdef USE_GHC_GENERICS
import GHC.Generics
#endif
import Linear.Instances ()
{-# ANN module "HLint: ignore Redundant lambda" #-}
newtype E t = E { el :: forall x. Lens' (t x) x }
infixl 6 ^+^, ^-^
infixl 7 ^*, *^, ^/
#ifdef USE_GHC_GENERICS
class GAdditive f where
gzero :: Num a => f a
gliftU2 :: (a -> a -> a) -> f a -> f a -> f a
gliftI2 :: (a -> b -> c) -> f a -> f b -> f c
instance GAdditive U1 where
gzero = U1
{-# INLINE gzero #-}
gliftU2 _ U1 U1 = U1
{-# INLINE gliftU2 #-}
gliftI2 _ U1 U1 = U1
{-# INLINE gliftI2 #-}
instance (GAdditive f, GAdditive g) => GAdditive (f :*: g) where
gzero = gzero :*: gzero
{-# INLINE gzero #-}
gliftU2 f (a :*: b) (c :*: d) = gliftU2 f a c :*: gliftU2 f b d
{-# INLINE gliftU2 #-}
gliftI2 f (a :*: b) (c :*: d) = gliftI2 f a c :*: gliftI2 f b d
{-# INLINE gliftI2 #-}
instance Additive f => GAdditive (Rec1 f) where
gzero = Rec1 zero
{-# INLINE gzero #-}
gliftU2 f (Rec1 g) (Rec1 h) = Rec1 (liftU2 f g h)
{-# INLINE gliftU2 #-}
gliftI2 f (Rec1 g) (Rec1 h) = Rec1 (liftI2 f g h)
{-# INLINE gliftI2 #-}
instance GAdditive f => GAdditive (M1 i c f) where
gzero = M1 gzero
{-# INLINE gzero #-}
gliftU2 f (M1 g) (M1 h) = M1 (gliftU2 f g h)
{-# INLINE gliftU2 #-}
gliftI2 f (M1 g) (M1 h) = M1 (gliftI2 f g h)
{-# INLINE gliftI2 #-}
instance GAdditive Par1 where
gzero = Par1 0
gliftU2 f (Par1 a) (Par1 b) = Par1 (f a b)
{-# INLINE gliftU2 #-}
gliftI2 f (Par1 a) (Par1 b) = Par1 (f a b)
{-# INLINE gliftI2 #-}
#endif
class Functor f => Additive f where
zero :: Num a => f a
#ifdef USE_GHC_GENERICS
#ifndef HLINT
default zero :: (GAdditive (Rep1 f), Generic1 f, Num a) => f a
zero = to1 gzero
#endif
#endif
(^+^) :: Num a => f a -> f a -> f a
(^+^) = liftU2 (+)
{-# INLINE (^+^) #-}
(^-^) :: Num a => f a -> f a -> f a
x ^-^ y = x ^+^ negated y
lerp :: Num a => a -> f a -> f a -> f a
lerp alpha u v = alpha *^ u ^+^ (1 - alpha) *^ v
{-# INLINE lerp #-}
liftU2 :: (a -> a -> a) -> f a -> f a -> f a
#ifdef USE_GHC_GENERICS
#ifndef HLINT
default liftU2 :: Applicative f => (a -> a -> a) -> f a -> f a -> f a
liftU2 = liftA2
{-# INLINE liftU2 #-}
#endif
#endif
liftI2 :: (a -> b -> c) -> f a -> f b -> f c
#ifdef USE_GHC_GENERICS
#ifndef HLINT
default liftI2 :: Applicative f => (a -> b -> c) -> f a -> f b -> f c
liftI2 = liftA2
{-# INLINE liftI2 #-}
#endif
#endif
instance Additive ZipList where
zero = ZipList []
{-# INLINE zero #-}
liftU2 f (ZipList xs) (ZipList ys) = ZipList (liftU2 f xs ys)
{-# INLINE liftU2 #-}
liftI2 = liftA2
{-# INLINE liftI2 #-}
instance Additive Vector where
zero = mempty
{-# INLINE zero #-}
liftU2 f u v = case compare lu lv of
LT | lu == 0 -> v
| otherwise -> Vector.modify (\ w -> Foldable.forM_ [0..lu-1] $ \i -> unsafeWrite w i $ f (unsafeIndex u i) (unsafeIndex v i)) v
EQ -> Vector.zipWith f u v
GT | lv == 0 -> u
| otherwise -> Vector.modify (\ w -> Foldable.forM_ [0..lv-1] $ \i -> unsafeWrite w i $ f (unsafeIndex u i) (unsafeIndex v i)) u
where
lu = Vector.length u
lv = Vector.length v
{-# INLINE liftU2 #-}
liftI2 = Vector.zipWith
{-# INLINE liftI2 #-}
instance Additive Maybe where
zero = Nothing
{-# INLINE zero #-}
liftU2 f (Just a) (Just b) = Just (f a b)
liftU2 _ Nothing ys = ys
liftU2 _ xs Nothing = xs
{-# INLINE liftU2 #-}
liftI2 = liftA2
{-# INLINE liftI2 #-}
instance Additive [] where
zero = []
{-# INLINE zero #-}
liftU2 f = go where
go (x:xs) (y:ys) = f x y : go xs ys
go [] ys = ys
go xs [] = xs
{-# INLINE liftU2 #-}
liftI2 = Prelude.zipWith
{-# INLINE liftI2 #-}
instance Additive IntMap where
zero = IntMap.empty
{-# INLINE zero #-}
liftU2 = IntMap.unionWith
{-# INLINE liftU2 #-}
liftI2 = IntMap.intersectionWith
{-# INLINE liftI2 #-}
instance Ord k => Additive (Map k) where
zero = Map.empty
{-# INLINE zero #-}
liftU2 = Map.unionWith
{-# INLINE liftU2 #-}
liftI2 = Map.intersectionWith
{-# INLINE liftI2 #-}
instance (Eq k, Hashable k) => Additive (HashMap k) where
zero = HashMap.empty
{-# INLINE zero #-}
liftU2 = HashMap.unionWith
{-# INLINE liftU2 #-}
liftI2 = HashMap.intersectionWith
{-# INLINE liftI2 #-}
instance Additive ((->) b) where
zero = const 0
{-# INLINE zero #-}
liftU2 = liftA2
{-# INLINE liftU2 #-}
liftI2 = liftA2
{-# INLINE liftI2 #-}
instance Additive Complex where
zero = 0 :+ 0
{-# INLINE zero #-}
liftU2 f (a :+ b) (c :+ d) = f a c :+ f b d
{-# INLINE liftU2 #-}
liftI2 f (a :+ b) (c :+ d) = f a c :+ f b d
{-# INLINE liftI2 #-}
instance Additive Identity where
zero = Identity 0
{-# INLINE zero #-}
liftU2 = liftA2
{-# INLINE liftU2 #-}
liftI2 = liftA2
{-# INLINE liftI2 #-}
negated :: (Functor f, Num a) => f a -> f a
negated = fmap negate
{-# INLINE negated #-}
sumV :: (Foldable f, Additive v, Num a) => f (v a) -> v a
sumV = Foldable.foldl' (^+^) zero
{-# INLINE sumV #-}
(*^) :: (Functor f, Num a) => a -> f a -> f a
(*^) a = fmap (a*)
{-# INLINE (*^) #-}
(^*) :: (Functor f, Num a) => f a -> a -> f a
f ^* a = fmap (*a) f
{-# INLINE (^*) #-}
(^/) :: (Functor f, Fractional a) => f a -> a -> f a
f ^/ a = fmap (/a) f
{-# INLINE (^/) #-}
basis :: (Additive t, Traversable t, Num a) => [t a]
basis = basisFor (zero :: Additive v => v Int)
basisFor :: (Traversable t, Num a) => t b -> [t a]
basisFor = \t ->
ifoldMapOf traversed ?? t $ \i _ ->
return $
iover traversed ?? t $ \j _ ->
if i == j then 1 else 0
{-# INLINABLE basisFor #-}
scaled :: (Traversable t, Num a) => t a -> t (t a)
scaled = \t -> iter t (\i x -> iter t (\j _ -> if i == j then x else 0))
where
iter :: Traversable t => t a -> (Int -> a -> b) -> t b
iter x f = iover traversed f x
{-# INLINE scaled #-}
unit :: (Additive t, Num a) => ASetter' (t a) a -> t a
unit l = set' l 1 zero
outer :: (Functor f, Functor g, Num a) => f a -> g a -> f (g a)
outer a b = fmap (\x->fmap (*x) b) a