{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE StandaloneDeriving #-}
module Math.Tensor
(
T(..)
,
Label
, Dimension
, RankT
, rankT
, scalarT
, zeroT
, toListT
, fromListT
, removeZerosT
, (.*)
, (.+)
, (.-)
, (.°)
, contractT
, transposeT
, transposeMultT
, relabelT
,
conRank
, covRank
, conCovRank
) where
import Math.Tensor.Safe
import Math.Tensor.Safe.TH
import Data.Kind (Type)
import Data.Singletons
( Sing, SingI (sing), Demote
, withSomeSing, withSingI
, fromSing
)
import Data.Singletons.Prelude
import Data.Singletons.Prelude.Maybe
( sIsJust
)
import Data.Singletons.Decide
( Decision(Proved, Disproved)
, (:~:) (Refl), (%~)
)
import Data.Singletons.TypeLits
( Nat
, Symbol
)
import Data.Bifunctor (first)
import Data.List.NonEmpty (NonEmpty((:|)),sort)
import Control.Monad.Except (MonadError, throwError)
data T :: Type -> Type where
T :: forall (r :: Rank) v. SingI r => Tensor r v -> T v
deriving instance Show v => Show (T v)
instance Functor T where
fmap f (T t) = T $ fmap f t
type Label = Demote Symbol
type Dimension = Demote Nat
type RankT = Demote Rank
scalarT :: v -> T v
scalarT v = T $ Scalar v
zeroT :: MonadError String m => RankT -> m (T v)
zeroT dr =
withSomeSing dr $ \sr ->
case sSane sr %~ STrue of
Proved Refl ->
case sr of
(_ :: Sing r) -> withSingI sr $ return $ T (ZeroTensor :: Tensor r v)
Disproved _ -> throwError $ "Illegal index list for zero : " ++ show dr
vecToList :: Vec n a -> [a]
vecToList VNil = []
vecToList (x `VCons` xs) = x : vecToList xs
vecFromList :: forall (n :: N) a m.
MonadError String m => Sing n -> [a] -> m (Vec n a)
vecFromList SZ [] = return VNil
vecFromList (SS _) [] = throwError "List provided for vector reconstruction is too short."
vecFromList SZ (_:_) = throwError "List provided for vector reconstruction is too long."
vecFromList (SS sn) (x:xs) = do
xs' <- vecFromList sn xs
return $ x `VCons` xs'
removeZerosT :: (Eq v, Num v) => T v -> T v
removeZerosT o =
case o of
T t -> T $ removeZeros t
(.*) :: (Num v, MonadError String m) => T v -> T v -> m (T v)
(.*) o1 o2 =
case o1 of
T (t1 :: Tensor r1 v) ->
case o2 of
T (t2 :: Tensor r2 v) ->
let sr1 = sing :: Sing r1
sr2 = sing :: Sing r2
in case sMergeR sr1 sr2 of
SNothing -> throwError "Tensors have overlapping indices. Cannot multiply."
SJust sr' -> withSingI sr' $ return $ T (t1 &* t2)
infixl 7 .*
(.°) :: Num v => v -> T v -> T v
(.°) s = fmap (*s)
infixl 7 .°
(.+) :: (Eq v, Num v, MonadError String m) => T v -> T v -> m (T v)
(.+) o1 o2 =
case o1 of
T (t1 :: Tensor r1 v) ->
case o2 of
T (t2 :: Tensor r2 v) ->
let sr1 = sing :: Sing r1
sr2 = sing :: Sing r2
in case sr1 %~ sr2 of
Proved Refl -> case sSane sr1 %~ STrue of
Proved Refl -> return $ T (t1 &+ t2)
Disproved _ -> throwError "Rank of summands is not sane."
Disproved _ -> throwError "Generalized tensor ranks do not match. Cannot add."
infixl 6 .+
(.-) :: (Eq v, Num v, MonadError String m) => T v -> T v -> m (T v)
(.-) o1 o2 =
case o1 of
T (t1 :: Tensor r1 v) ->
case o2 of
T (t2 :: Tensor r2 v) ->
let sr1 = sing :: Sing r1
sr2 = sing :: Sing r2
in case sr1 %~ sr2 of
Proved Refl -> case sSane sr1 %~ STrue of
Proved Refl -> return $ T (t1 &- t2)
Disproved _ -> throwError "Rank of operands is not sane."
Disproved _ -> throwError "Generalized tensor ranks do not match. Cannot add."
contractT :: (Num v, Eq v) => T v -> T v
contractT o =
case o of
T (t :: Tensor r v) ->
let sr = sing :: Sing r
sr' = sContractR sr
in withSingI sr' $ T $ contract t
transposeT :: MonadError String m =>
VSpace Label Dimension -> Ix Label -> Ix Label ->
T v -> m (T v)
transposeT v ia ib o =
case o of
T (t :: Tensor r v) ->
let sr = sing :: Sing r
in withSingI sr $
withSomeSing v $ \sv ->
withSomeSing ia $ \sia ->
withSomeSing ib $ \sib ->
case sCanTranspose sv sia sib sr of
STrue -> return $ T $ transpose sv sia sib t
SFalse -> throwError $ "Cannot transpose indices " ++ show v ++ " " ++ show ia ++ " " ++ show ib ++ "!"
transposeMultT :: MonadError String m =>
VSpace Label Dimension -> [(Label,Label)] -> [(Label,Label)] -> T v -> m (T v)
transposeMultT _ [] [] _ = throwError "Empty lists for transpositions!"
transposeMultT v (con:cons) [] o =
case o of
T (t :: Tensor r v) ->
let sr = sing :: Sing r
cons' = sort $ con :| cons
tr = TransCon (fmap fst cons') (fmap snd cons')
in withSingI sr $
withSomeSing v $ \sv ->
withSomeSing tr $ \str ->
case sIsJust (sTranspositions sv str sr) %~ STrue of
Proved Refl -> return $ T $ transposeMult sv str t
Disproved _ -> throwError $ "Cannot transpose indices " ++ show v ++ " " ++ show tr ++ "!"
transposeMultT v [] (cov:covs) o =
case o of
T (t :: Tensor r v) ->
let sr = sing :: Sing r
covs' = sort $ cov :| covs
tr = TransCov (fmap fst covs') (fmap snd covs')
in withSingI sr $
withSomeSing v $ \sv ->
withSomeSing tr $ \str ->
case sIsJust (sTranspositions sv str sr) %~ STrue of
Proved Refl -> return $ T $ transposeMult sv str t
Disproved _ -> throwError $ "Cannot transpose indices " ++ show v ++ " " ++ show tr ++ "!"
transposeMultT _ _ _ _ = throwError "Simultaneous transposition of contravariant and covariant indices not yet supported!"
relabelT :: MonadError String m =>
VSpace Label Dimension -> [(Label,Label)] -> T v -> m (T v)
relabelT _ [] _ = throwError "Empty list for relabelling!"
relabelT v (r:rs) o =
case o of
T (t :: Tensor r v) ->
let sr = sing :: Sing r
rr = sort $ r :| rs
in withSingI sr $
withSomeSing v $ \sv ->
withSomeSing rr $ \srr ->
case sRelabelR sv srr sr of
SJust sr' ->
withSingI sr' $
case sSane sr' %~ STrue of
Proved Refl -> return $ T $ relabel sv srr t
Disproved _ -> throwError $ "Cannot relabel indices " ++ show v ++ " " ++ show rr ++ "!"
_ -> throwError $ "Cannot relabel indices " ++ show v ++ " " ++ show rr ++ "!"
rankT :: T v -> RankT
rankT o =
case o of
T (_ :: Tensor r v) ->
let sr = sing :: Sing r
in fromSing sr
toListT :: T v -> [([Int], v)]
toListT o =
case o of
T (t :: Tensor r v) -> let sr = sing :: Sing r
sn = sLengthR sr
in withSingI sn $
first vecToList <$> toList t
fromListT :: MonadError String m => RankT -> [([Int], v)] -> m (T v)
fromListT r xs =
withSomeSing r $ \sr ->
withSingI sr $
let sn = sLengthR sr
in case sSane sr %~ STrue of
Proved Refl -> T . fromList' sr <$>
mapM (\(vec, val) -> do
vec' <- vecFromList sn vec
return (vec', val)) xs
Disproved _ -> throwError $ "Insane tensor rank : " <> show r
saneRank :: (Ord s, Ord n, MonadError String m) => GRank s n -> m (GRank s n)
saneRank r
| sane r = pure r
| otherwise = throwError "Index lists must be strictly ascending."
conRank :: (MonadError String m, Integral a, Ord s, Ord n, Num n) =>
s -> a -> [s] -> m (GRank s n)
conRank _ _ [] = throwError "Generalized rank must have non-vanishing index list!"
conRank v d (i:is) = saneRank [(VSpace v (fromIntegral d), Con (i :| is))]
covRank :: (MonadError String m, Integral a, Ord s, Ord n, Num n) =>
s -> a -> [s] -> m (GRank s n)
covRank _ _ [] = throwError "Generalized rank must have non-vanishing index list!"
covRank v d (i:is) = saneRank [(VSpace v (fromIntegral d), Cov (i :| is))]
conCovRank :: (MonadError String m, Integral a, Ord s, Ord n, Num n) =>
s -> a -> [s] -> [s] -> m (GRank s n)
conCovRank _ _ _ [] = throwError "Generalized rank must have non-vanishing index list!"
conCovRank _ _ [] _ = throwError "Generalized rank must have non-vanishing index list!"
conCovRank v d (i:is) (j:js) = saneRank [(VSpace v (fromIntegral d), ConCov (i :| is) (j :| js))]