{-# LANGUAGE CPP #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE Trustworthy #-}
---------------------------------------------------------------------------
-- |
-- Copyright   :  (C) 2012-2015 Edward Kmett
-- License     :  BSD-style (see the file LICENSE)
--
-- Maintainer  :  Edward Kmett <ekmett@gmail.com>
-- Stability   :  experimental
-- Portability :  non-portable
--
-- Simple matrix operation for low-dimensional primitives.
---------------------------------------------------------------------------
module Linear.Trace
  ( Trace(..)
  , frobenius
  ) where

import Control.Monad as Monad
import Linear.V0
import Linear.V1
import Linear.V2
import Linear.V3
import Linear.V4
import Linear.Plucker
import Linear.Quaternion
import Linear.V
import Linear.Vector
import Data.Complex
import Data.Distributive
import Data.Foldable as Foldable
import Data.Functor.Bind as Bind
import Data.Functor.Compose
import Data.Functor.Product
import Data.Hashable
import Data.HashMap.Lazy
import Data.IntMap (IntMap)
import Data.Map (Map)

-- $setup
-- >>> import Data.Complex
-- >>> import Debug.SimpleReflect.Vars
-- >>> import Linear.V2

class Functor m => Trace m where
  -- | Compute the trace of a matrix
  --
  -- >>> trace (V2 (V2 a b) (V2 c d))
  -- a + d
  trace :: Num a => m (m a) -> a
#ifndef HLINT
  default trace :: (Foldable m, Num a) => m (m a) -> a
  trace = m a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
Foldable.sum (m a -> a) -> (m (m a) -> m a) -> m (m a) -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (m a) -> m a
forall (m :: * -> *) a. Trace m => m (m a) -> m a
diagonal
  {-# INLINE trace #-}
#endif

  -- | Compute the diagonal of a matrix
  --
  -- >>> diagonal (V2 (V2 a b) (V2 c d))
  -- V2 a d
  diagonal :: m (m a) -> m a
#ifndef HLINT
  default diagonal :: Monad m => m (m a) -> m a
  diagonal = m (m a) -> m a
forall (m :: * -> *) a. Monad m => m (m a) -> m a
Monad.join
  {-# INLINE diagonal #-}
#endif

instance Trace IntMap where
  diagonal :: IntMap (IntMap a) -> IntMap a
diagonal = IntMap (IntMap a) -> IntMap a
forall (m :: * -> *) a. Bind m => m (m a) -> m a
Bind.join
  {-# INLINE diagonal #-}

instance Ord k => Trace (Map k) where
  diagonal :: Map k (Map k a) -> Map k a
diagonal = Map k (Map k a) -> Map k a
forall (m :: * -> *) a. Bind m => m (m a) -> m a
Bind.join
  {-# INLINE diagonal #-}

instance (Eq k, Hashable k) => Trace (HashMap k) where
  diagonal :: HashMap k (HashMap k a) -> HashMap k a
diagonal = HashMap k (HashMap k a) -> HashMap k a
forall (m :: * -> *) a. Bind m => m (m a) -> m a
Bind.join
  {-# INLINE diagonal #-}

instance Dim n => Trace (V n)
instance Trace V0
instance Trace V1
instance Trace V2
instance Trace V3
instance Trace V4
instance Trace Plucker
instance Trace Quaternion

instance Trace Complex where
  trace :: Complex (Complex a) -> a
trace ((a
a :+ a
_) :+ (a
_ :+ a
b)) = a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
b
  {-# INLINE trace #-}
  diagonal :: Complex (Complex a) -> Complex a
diagonal ((a
a :+ a
_) :+ (a
_ :+ a
b)) = a
a a -> a -> Complex a
forall a. a -> a -> Complex a
:+ a
b
  {-# INLINE diagonal #-}

instance (Trace f, Trace g) => Trace (Product f g) where
  trace :: Product f g (Product f g a) -> a
trace (Pair f (Product f g a)
xx g (Product f g a)
yy) = f (f a) -> a
forall (m :: * -> *) a. (Trace m, Num a) => m (m a) -> a
trace (Product f g a -> f a
forall k (f :: k -> *) (g :: k -> *) (a :: k). Product f g a -> f a
pfst (Product f g a -> f a) -> f (Product f g a) -> f (f a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (Product f g a)
xx) a -> a -> a
forall a. Num a => a -> a -> a
+ g (g a) -> a
forall (m :: * -> *) a. (Trace m, Num a) => m (m a) -> a
trace (Product f g a -> g a
forall k (f :: k -> *) (g :: k -> *) (a :: k). Product f g a -> g a
psnd (Product f g a -> g a) -> g (Product f g a) -> g (g a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> g (Product f g a)
yy) where
    pfst :: Product f g a -> f a
pfst (Pair f a
x g a
_) = f a
x
    psnd :: Product f g a -> g a
psnd (Pair f a
_ g a
y) = g a
y
  {-# INLINE trace #-}
  diagonal :: Product f g (Product f g a) -> Product f g a
diagonal (Pair f (Product f g a)
xx g (Product f g a)
yy) = f (f a) -> f a
forall (m :: * -> *) a. Trace m => m (m a) -> m a
diagonal (Product f g a -> f a
forall k (f :: k -> *) (g :: k -> *) (a :: k). Product f g a -> f a
pfst (Product f g a -> f a) -> f (Product f g a) -> f (f a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (Product f g a)
xx) f a -> g a -> Product f g a
forall k (f :: k -> *) (g :: k -> *) (a :: k).
f a -> g a -> Product f g a
`Pair` g (g a) -> g a
forall (m :: * -> *) a. Trace m => m (m a) -> m a
diagonal (Product f g a -> g a
forall k (f :: k -> *) (g :: k -> *) (a :: k). Product f g a -> g a
psnd (Product f g a -> g a) -> g (Product f g a) -> g (g a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> g (Product f g a)
yy) where
    pfst :: Product f g a -> f a
pfst (Pair f a
x g a
_) = f a
x
    psnd :: Product f g a -> g a
psnd (Pair f a
_ g a
y) = g a
y
  {-# INLINE diagonal #-}

instance (Distributive g, Trace g, Trace f) => Trace (Compose g f) where
  trace :: Compose g f (Compose g f a) -> a
trace = g (g a) -> a
forall (m :: * -> *) a. (Trace m, Num a) => m (m a) -> a
trace (g (g a) -> a)
-> (Compose g f (Compose g f a) -> g (g a))
-> Compose g f (Compose g f a)
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (f (g (f a)) -> g a) -> g (f (g (f a))) -> g (g a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((f (f a) -> a) -> g (f (f a)) -> g a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap f (f a) -> a
forall (m :: * -> *) a. (Trace m, Num a) => m (m a) -> a
trace (g (f (f a)) -> g a)
-> (f (g (f a)) -> g (f (f a))) -> f (g (f a)) -> g a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (g (f a)) -> g (f (f a))
forall (g :: * -> *) (f :: * -> *) a.
(Distributive g, Functor f) =>
f (g a) -> g (f a)
distribute) (g (f (g (f a))) -> g (g a))
-> (Compose g f (Compose g f a) -> g (f (g (f a))))
-> Compose g f (Compose g f a)
-> g (g a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Compose g f (g (f a)) -> g (f (g (f a)))
forall k1 (f :: k1 -> *) k2 (g :: k2 -> k1) (a :: k2).
Compose f g a -> f (g a)
getCompose (Compose g f (g (f a)) -> g (f (g (f a))))
-> (Compose g f (Compose g f a) -> Compose g f (g (f a)))
-> Compose g f (Compose g f a)
-> g (f (g (f a)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Compose g f a -> g (f a))
-> Compose g f (Compose g f a) -> Compose g f (g (f a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Compose g f a -> g (f a)
forall k1 (f :: k1 -> *) k2 (g :: k2 -> k1) (a :: k2).
Compose f g a -> f (g a)
getCompose
  {-# INLINE trace #-}
  diagonal :: Compose g f (Compose g f a) -> Compose g f a
diagonal = g (f a) -> Compose g f a
forall k k1 (f :: k -> *) (g :: k1 -> k) (a :: k1).
f (g a) -> Compose f g a
Compose (g (f a) -> Compose g f a)
-> (Compose g f (Compose g f a) -> g (f a))
-> Compose g f (Compose g f a)
-> Compose g f a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (f (f a) -> f a) -> g (f (f a)) -> g (f a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap f (f a) -> f a
forall (m :: * -> *) a. Trace m => m (m a) -> m a
diagonal (g (f (f a)) -> g (f a))
-> (Compose g f (Compose g f a) -> g (f (f a)))
-> Compose g f (Compose g f a)
-> g (f a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. g (g (f (f a))) -> g (f (f a))
forall (m :: * -> *) a. Trace m => m (m a) -> m a
diagonal (g (g (f (f a))) -> g (f (f a)))
-> (Compose g f (Compose g f a) -> g (g (f (f a))))
-> Compose g f (Compose g f a)
-> g (f (f a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (f (g (f a)) -> g (f (f a))) -> g (f (g (f a))) -> g (g (f (f a)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap f (g (f a)) -> g (f (f a))
forall (g :: * -> *) (f :: * -> *) a.
(Distributive g, Functor f) =>
f (g a) -> g (f a)
distribute (g (f (g (f a))) -> g (g (f (f a))))
-> (Compose g f (Compose g f a) -> g (f (g (f a))))
-> Compose g f (Compose g f a)
-> g (g (f (f a)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Compose g f (g (f a)) -> g (f (g (f a)))
forall k1 (f :: k1 -> *) k2 (g :: k2 -> k1) (a :: k2).
Compose f g a -> f (g a)
getCompose (Compose g f (g (f a)) -> g (f (g (f a))))
-> (Compose g f (Compose g f a) -> Compose g f (g (f a)))
-> Compose g f (Compose g f a)
-> g (f (g (f a)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Compose g f a -> g (f a))
-> Compose g f (Compose g f a) -> Compose g f (g (f a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Compose g f a -> g (f a)
forall k1 (f :: k1 -> *) k2 (g :: k2 -> k1) (a :: k2).
Compose f g a -> f (g a)
getCompose
  {-# INLINE diagonal #-}

-- | Compute the <http://mathworld.wolfram.com/FrobeniusNorm.html Frobenius norm> of a matrix.
frobenius :: (Num a, Foldable f, Additive f, Additive g, Distributive g, Trace g) => f (g a) -> a
frobenius :: f (g a) -> a
frobenius f (g a)
m = g (g a) -> a
forall (m :: * -> *) a. (Trace m, Num a) => m (m a) -> a
trace (g (g a) -> a) -> g (g a) -> a
forall a b. (a -> b) -> a -> b
$ (f a -> g a) -> g (f a) -> g (g a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ f a
f' -> (g a -> g a -> g a) -> g a -> f (g a) -> g a
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl' g a -> g a -> g a
forall (f :: * -> *) a. (Additive f, Num a) => f a -> f a -> f a
(^+^) g a
forall (f :: * -> *) a. (Additive f, Num a) => f a
zero (f (g a) -> g a) -> f (g a) -> g a
forall a b. (a -> b) -> a -> b
$ (a -> g a -> g a) -> f a -> f (g a) -> f (g a)
forall (f :: * -> *) a b c.
Additive f =>
(a -> b -> c) -> f a -> f b -> f c
liftI2 a -> g a -> g a
forall (f :: * -> *) a. (Functor f, Num a) => a -> f a -> f a
(*^) f a
f' f (g a)
m) (f (g a) -> g (f a)
forall (g :: * -> *) (f :: * -> *) a.
(Distributive g, Functor f) =>
f (g a) -> g (f a)
distribute f (g a)
m)