{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}
#if MIN_VERSION_base(4,14,0)
{-# LANGUAGE StandaloneKindSignatures #-}
#endif
module Math.Tensor.Safe.TH where
import Data.Singletons.Prelude
import Data.Singletons.Prelude.Enum
import Data.Singletons.Prelude.List.NonEmpty hiding (sLength)
import Data.Singletons.Prelude.Ord
import Data.Singletons.TH
import Data.Singletons.TypeLits
import Data.List.NonEmpty (NonEmpty((:|)),sort,sortBy,(<|))
$(singletons [d|
data N where
Z :: N
S :: N -> N
fromNat :: Nat -> N
fromNat n = case n == 0 of
True -> Z
False -> S $ fromNat (pred n)
deriving instance Show N
deriving instance Eq N
instance Ord N where
Z <= _ = True
(S _) <= Z = False
(S n) <= (S m) = n <= m
instance Num N where
Z + n = n
(S n) + m = S $ n + m
n - Z = n
Z - S _ = error "cannot subtract (S n) from Z!"
(S n) - (S m) = n - m
negate Z = Z
negate _ = error "cannot negate (S n)!"
Z * _ = Z
(S n) * m = m + n * m
abs n = n
signum n = n
fromInteger n
| n == 0 = Z
| otherwise = S $ fromInteger (n-1)
data VSpace a b = VSpace { vId :: a,
vDim :: b }
deriving (Show, Ord, Eq)
data Ix a = ICon a | ICov a deriving (Show, Ord, Eq)
ixCompare :: Ord a => Ix a -> Ix a -> Ordering
ixCompare (ICon a) (ICon b) = compare a b
ixCompare (ICon a) (ICov b) = case compare a b of
LT -> LT
EQ -> LT
GT -> GT
ixCompare (ICov a) (ICon b) = case compare a b of
LT -> LT
EQ -> GT
GT -> GT
ixCompare (ICov a) (ICov b) = compare a b
data IList a = ConCov (NonEmpty a) (NonEmpty a) |
Cov (NonEmpty a) |
Con (NonEmpty a)
deriving (Show, Ord, Eq)
type GRank s n = [(VSpace s n, IList s)]
type Rank = GRank Symbol Nat
isAscending :: Ord a => [a] -> Bool
isAscending [] = True
isAscending [_] = True
isAscending (x:y:xs) = x < y && isAscending (y:xs)
isAscendingNE :: Ord a => NonEmpty a -> Bool
isAscendingNE (x :| xs) = isAscending (x:xs)
isAscendingI :: Ord a => IList a -> Bool
isAscendingI (ConCov x y) = isAscendingNE x && isAscendingNE y
isAscendingI (Con x) = isAscendingNE x
isAscendingI (Cov y) = isAscendingNE y
isLengthNE :: NonEmpty a -> Nat -> Bool
isLengthNE (_ :| []) l = l == 1
isLengthNE (_ :| (x:xs)) l = isLengthNE (x :| xs) (pred l)
lengthNE :: NonEmpty a -> N
lengthNE (_ :| []) = S Z
lengthNE (_ :| (x:xs)) = S Z + lengthNE (x :| xs)
lengthIL :: IList a -> N
lengthIL (ConCov xs ys) = lengthNE xs + lengthNE ys
lengthIL (Con xs) = lengthNE xs
lengthIL (Cov ys) = lengthNE ys
lengthR :: GRank s n -> N
lengthR [] = Z
lengthR ((_,x):xs) = lengthIL x + lengthR xs
sane :: (Ord a, Ord b) => [(VSpace a b, IList a)] -> Bool
sane [] = True
sane [(_, is)] = isAscendingI is
sane ((v, is):(v', is'):xs) =
v < v' && isAscendingI is && sane ((v',is'):xs)
headR :: Ord s => GRank s n -> (VSpace s n, Ix s)
headR ((v, l):_) = (v, case l of
ConCov (a :| _) (b :| _) ->
case compare a b of
LT -> ICon a
EQ -> ICon a
GT -> ICov b
Con (a :| _) -> ICon a
Cov (a :| _) -> ICov a)
headR [] = error "headR of empty list"
tailR :: Ord s => GRank s n -> GRank s n
tailR ((v, l):ls) =
let l' = case l of
ConCov (a :| []) (b :| []) ->
case compare a b of
LT -> Just $ Cov (b :| [])
EQ -> Just $ Cov (b :| [])
GT -> Just $ Con (a :| [])
ConCov (a :| (a':as)) (b :| []) ->
case compare a b of
LT -> Just $ ConCov (a' :| as) (b :| [])
EQ -> Just $ ConCov (a' :| as) (b :| [])
GT -> Just $ Con (a :| (a':as))
ConCov (a :| []) (b :| (b':bs)) ->
case compare a b of
LT -> Just $ Cov (b :| (b':bs))
EQ -> Just $ Cov (b :| (b':bs))
GT -> Just $ ConCov (a :| []) (b' :| bs)
ConCov (a :| (a':as)) (b :| (b':bs)) ->
case compare a b of
LT -> Just $ ConCov (a' :| as) (b :| (b':bs))
EQ -> Just $ ConCov (a' :| as) (b :| (b':bs))
GT -> Just $ ConCov (a :| (a':as)) (b' :| bs)
Con (_ :| []) -> Nothing
Con (_ :| (a':as)) -> Just $ Con (a' :| as)
Cov (_ :| []) -> Nothing
Cov (_ :| (a':as)) -> Just $ Cov (a' :| as)
in case l' of
Just l'' -> (v, l''):ls
Nothing -> ls
tailR [] = error "tailR of empty list"
mergeR :: (Ord s, Ord n) => GRank s n -> GRank s n -> Maybe (GRank s n)
mergeR [] ys = Just ys
mergeR xs [] = Just xs
mergeR ((xv,xl):xs) ((yv,yl):ys) =
case compare xv yv of
LT -> ((xv,xl) :) <$> mergeR xs ((yv,yl):ys)
EQ -> do
xl' <- mergeIL xl yl
xs' <- mergeR xs ys
Just $ (xv, xl') : xs'
GT -> ((yv,yl) :) <$> mergeR ((xv,xl):xs) ys
mergeIL :: Ord a => IList a -> IList a -> Maybe (IList a)
mergeIL (ConCov xs ys) (ConCov xs' ys') = do
xs'' <- mergeNE xs xs'
ys'' <- mergeNE ys ys'
Just $ ConCov xs'' ys''
mergeIL (ConCov xs ys) (Con xs') = do
xs'' <- mergeNE xs xs'
Just $ ConCov xs'' ys
mergeIL (ConCov xs ys) (Cov ys') = do
ys'' <- mergeNE ys ys'
Just $ ConCov xs ys''
mergeIL (Con xs) (ConCov xs' ys) = do
xs'' <- mergeNE xs xs'
Just $ ConCov xs'' ys
mergeIL (Con xs) (Con xs') = Con <$> mergeNE xs xs'
mergeIL (Con xs) (Cov ys) = Just $ ConCov xs ys
mergeIL (Cov ys) (ConCov xs ys') = do
ys'' <- mergeNE ys ys'
Just $ ConCov xs ys''
mergeIL (Cov ys) (Con xs) = Just $ ConCov xs ys
mergeIL (Cov ys) (Cov ys') = Cov <$> mergeNE ys ys'
merge :: Ord a => [a] -> [a] -> Maybe [a]
merge [] ys = Just ys
merge xs [] = Just xs
merge (x:xs) (y:ys) =
case compare x y of
LT -> (x :) <$> merge xs (y:ys)
EQ -> Nothing
GT -> (y :) <$> merge (x:xs) ys
mergeNE :: Ord a => NonEmpty a -> NonEmpty a -> Maybe (NonEmpty a)
mergeNE (x :| xs) (y :| ys) =
case compare x y of
LT -> (x :|) <$> merge xs (y:ys)
EQ -> Nothing
GT -> (y :|) <$> merge (x:xs) ys
contractR :: Ord s => GRank s n -> GRank s n
contractR [] = []
contractR ((v, is):xs) = case contractI is of
Nothing -> contractR xs
Just is' -> (v, is') : contractR xs
prepICon :: a -> IList a -> IList a
prepICon a (ConCov (x:|xs) y) = ConCov (a:|(x:xs)) y
prepICon a (Con (x:|xs)) = Con (a:|(x:xs))
prepICon a (Cov y) = ConCov (a:|[]) y
prepICov :: a -> IList a -> IList a
prepICov a (ConCov x (y:|ys)) = ConCov x (a:|(y:ys))
prepICov a (Con x) = ConCov x (a:|[])
prepICov a (Cov (y:|ys)) = Cov (a:|(y:ys))
contractI :: Ord a => IList a -> Maybe (IList a)
contractI (ConCov (x:|xs) (y:|ys)) =
case compare x y of
EQ -> case xs of
[] -> case ys of
[] -> Nothing
(y':ys') -> Just $ Cov (y' :| ys')
(x':xs') -> case ys of
[] -> Just $ Con (x' :| xs')
(y':ys') -> contractI $ ConCov (x':|xs') (y':|ys')
LT -> case xs of
[] -> Just $ ConCov (x:|xs) (y:|ys)
(x':xs') -> case contractI $ ConCov (x':|xs') (y:|ys) of
Nothing -> Just $ Con (x:|[])
Just i -> Just $ prepICon x i
GT -> case ys of
[] -> Just $ ConCov (x:|xs) (y:|ys)
(y':ys') -> case contractI $ ConCov (x:|xs) (y':|ys') of
Nothing -> Just $ Cov (y:|[])
Just i -> Just $ prepICov y i
contractI (Con x) = Just $ Con x
contractI (Cov x) = Just $ Cov x
subsetNE :: Ord a => NonEmpty a -> NonEmpty a -> Bool
subsetNE (x :| []) ys = x `elemNE` ys
subsetNE (x :| (x':xs)) ys = (x `elemNE` ys) && ((x' :| xs) `subsetNE` ys)
elemNE :: Ord a => a -> NonEmpty a -> Bool
elemNE a (x :| []) = a == x
elemNE a (x :| (x':xs)) = case compare a x of
LT -> False
EQ -> True
GT -> elemNE a (x' :| xs)
canTransposeCon :: (Ord s, Ord n) => VSpace s n -> s -> s -> GRank s n -> Bool
canTransposeCon _ _ _ [] = False
canTransposeCon v a b ((v',il):r) =
case compare v v' of
LT -> False
GT -> canTransposeCon v a b r
EQ -> case il of
Cov _ -> canTransposeCon v a b r
Con cs -> case elemNE a cs of
True -> case elemNE b cs of
True -> True
False -> False
False -> case elemNE b cs of
True -> False
False -> canTransposeCon v a b r
ConCov cs _ -> case elemNE a cs of
True -> case elemNE b cs of
True -> True
False -> False
False -> case elemNE b cs of
True -> False
False -> canTransposeCon v a b r
canTransposeCov :: (Ord s, Ord n) => VSpace s n -> s -> s -> GRank s n -> Bool
canTransposeCov _ _ _ [] = False
canTransposeCov v a b ((v',il):r) =
case compare v v' of
LT -> False
GT -> canTransposeCov v a b r
EQ -> case il of
Con _ -> canTransposeCov v a b r
Cov cs -> case elemNE a cs of
True -> case elemNE b cs of
True -> True
False -> False
False -> case elemNE b cs of
True -> False
False -> canTransposeCov v a b r
ConCov _ cs -> case elemNE a cs of
True -> case elemNE b cs of
True -> True
False -> False
False -> case elemNE b cs of
True -> False
False -> canTransposeCov v a b r
canTranspose :: (Ord s, Ord n) => VSpace s n -> Ix s -> Ix s -> GRank s n -> Bool
canTranspose v (ICon a) (ICon b) r = case compare a b of
LT -> canTransposeCon v a b r
EQ -> True
GT -> canTransposeCov v b a r
canTranspose v (ICov a) (ICov b) r = case compare a b of
LT -> canTransposeCov v a b r
EQ -> True
GT -> canTransposeCov v b a r
canTranspose _ (ICov _) (ICon _) _ = False
canTranspose _ (ICon _) (ICov _) _ = False
removeUntil :: Ord s => Ix s -> GRank s n -> GRank s n
removeUntil i r = go i r
where
go i' r'
| snd (headR r') == i' = tailR r'
| otherwise = go i $ tailR r'
data TransRule a = TransCon (NonEmpty a) (NonEmpty a) |
TransCov (NonEmpty a) (NonEmpty a)
deriving (Show, Eq)
saneTransRule :: Ord a => TransRule a -> Bool
saneTransRule tl =
case tl of
TransCon sources targets -> isAscendingNE sources &&
sort targets == sources
TransCov sources targets -> isAscendingNE sources &&
sort targets == sources
canTransposeMult :: (Ord s, Ord n) => VSpace s n -> TransRule s -> GRank s n -> Bool
canTransposeMult vs tl r = case transpositions vs tl r of
Just _ -> True
Nothing -> False
transpositions :: (Ord s, Ord n) => VSpace s n -> TransRule s -> GRank s n -> Maybe [(N,N)]
transpositions _ _ [] = Nothing
transpositions vs tl ((vs',il):r) =
case saneTransRule tl of
False -> Nothing
True ->
case compare vs vs' of
LT -> Nothing
GT -> transpositions vs tl r
EQ ->
case il of
Con xs ->
case tl of
TransCon sources targets -> transpositions' sources targets (fmap Just xs)
TransCov _ _ -> Nothing
Cov xs ->
case tl of
TransCov sources targets -> transpositions' sources targets (fmap Just xs)
TransCon _ _ -> Nothing
ConCov xsCon xsCov ->
case tl of
TransCon sources targets -> transpositions' sources targets (xsCon `zipCon` xsCov)
TransCov sources targets -> transpositions' sources targets (xsCon `zipCov` xsCov)
zipCon :: Ord a => NonEmpty a -> NonEmpty a -> NonEmpty (Maybe a)
zipCon (x :| xs) (y :| ys) =
case ICon x `ixCompare` ICov y of
LT -> case xs of
[] -> Just x :| fmap (const Nothing) (y:ys)
(x':xs') -> Just x <| zipCon (x' :| xs') (y :| ys)
GT -> case ys of
[] -> Nothing :| fmap Just (x : xs)
(y':ys') -> Nothing <| zipCon (x :| xs) (y' :| ys')
zipCov :: Ord a => NonEmpty a -> NonEmpty a -> NonEmpty (Maybe a)
zipCov (x :| xs) (y :| ys) =
case ICon x `ixCompare` ICov y of
LT -> case xs of
[] -> Nothing :| fmap Just (y:ys)
(x':xs') -> Nothing <| zipCov (x' :| xs') (y :| ys)
GT -> case ys of
[] -> Just y :| fmap (const Nothing) (x : xs)
(y':ys') -> Just y <| zipCov (x :| xs) (y' :| ys')
transpositions' :: Eq a => NonEmpty a -> NonEmpty a -> NonEmpty (Maybe a) -> Maybe [(N,N)]
transpositions' sources targets xs =
do
ss <- mapM (`find` xs') sources
ts <- mapM (`find` xs') targets
zip' ss ts
where
xs' = go' Z xs
go' :: N -> NonEmpty a -> NonEmpty (N,a)
go' n (y :| []) = (n,y) :| []
go' n (y :| (y':ys')) = (n,y) <| go' (S n) (y' :| ys')
find :: Eq a => a -> NonEmpty (N, Maybe a) -> Maybe N
find _ ((_,Nothing) :| []) = Nothing
find a ((_,Nothing) :| (y':ys')) = find a (y' :| ys')
find a ((n,Just y) :| ys) =
case a == y of
True -> Just n
False ->
case ys of
[] -> Nothing
y':ys' -> find a (y' :| ys')
zip' :: NonEmpty a -> NonEmpty b -> Maybe [(a,b)]
zip' (a :| []) (b :| []) = Just [(a,b)]
zip' (_ :| (_:_)) (_ :| []) = Nothing
zip' (_ :| []) (_ :| (_:_)) = Nothing
zip' (y:|(y':ys')) (z:|(z':zs')) = ((y,z):) <$> zip' (y':|ys') (z':|zs')
type RelabelRule s = NonEmpty (s,s)
saneRelabelRule :: Ord a => NonEmpty (a,a) -> Bool
saneRelabelRule xs = isAscendingNE xs &&
isAscendingNE xs'
where
xs' = sort $ fmap (\(a,b) -> (b,a)) xs
relabelNE :: Ord a => NonEmpty (a,a) -> NonEmpty a -> Maybe (NonEmpty (a,a))
relabelNE = go
where
go :: Ord a => NonEmpty (a,a) -> NonEmpty a -> Maybe (NonEmpty (a,a))
go ((source,target) :| ms) (x :| xs) =
case source `compare` x of
LT ->
case ms of
[] -> Just $ (\a -> (a,a)) <$> (x :| xs)
m':ms' -> go (m' :| ms') (x :| xs)
EQ ->
case ms of
[] -> Just $ ((target,source) :|) $ fmap (\a -> (a,a)) xs
m':ms' ->
case xs of
[] -> Just $ (target,source) :| []
x':xs' -> ((target,source) <|) <$> go (m' :| ms') (x' :| xs')
GT ->
case xs of
[] -> Just $ (x,x) :| []
x':xs' -> ((x,x) <|) <$> go ((source,target) :| ms) (x' :| xs')
relabelR :: (Ord s, Ord n) => VSpace s n -> RelabelRule s -> GRank s n -> Maybe (GRank s n)
relabelR _ _ [] = Nothing
relabelR vs rls ((vs',il):r) =
case vs `compare` vs' of
LT -> Nothing
EQ -> (\il' -> (vs',il'):r) <$> relabelIL rls il
GT -> ((vs',il) :) <$> relabelR vs rls r
relabelIL :: Ord a => NonEmpty (a,a) -> IList a -> Maybe (IList a)
relabelIL rl is = case relabelIL' rl is of
Nothing -> Nothing
Just (Con is') -> Just $ Con $ fst <$> is'
Just (Cov is') -> Just $ Cov $ fst <$> is'
Just (ConCov is' js') -> Just $ ConCov (fst <$> is') (fst <$> js')
relabelIL' :: Ord a => NonEmpty (a,a) -> IList a -> Maybe (IList (a,a))
relabelIL' rl (Con is) =
do
is' <- Con . sort <$> relabelNE rl is
case isAscendingI is' of
True -> return is'
False -> Nothing
relabelIL' rl (Cov is) =
do
is' <- Cov . sort <$> relabelNE rl is
case isAscendingI is' of
True -> return is'
False -> Nothing
relabelIL' rl (ConCov is js) =
do
is' <- sort <$> relabelNE rl is
js' <- sort <$> relabelNE rl js
let l' = ConCov is' js'
case isAscendingI l' of
True -> return l'
False -> Nothing
relabelTranspositions :: Ord a => NonEmpty (a,a) -> IList a -> Maybe [(N,N)]
relabelTranspositions rl is =
case relabelIL' rl is of
Nothing -> Nothing
Just (Con is') -> Just $ relabelTranspositions' is'
Just (Cov is') -> Just $ relabelTranspositions' is'
Just (ConCov is' js') -> Just $ relabelTranspositions' $ is' `zipConCov` js'
zipConCov :: Ord a => NonEmpty a -> NonEmpty a -> NonEmpty a
zipConCov = go
where
go :: Ord a => NonEmpty a -> NonEmpty a -> NonEmpty a
go (i :| is) (j :| js) =
case i `compare` j of
LT -> case is of
[] -> i <| (j:|js)
i':is' -> i <| go (i' :| is') (j :| js)
EQ -> case is of
[] -> i <| (j:|js)
i':is' -> i <| go (i' :| is') (j :| js)
GT -> case js of
[] -> j <| (i:|is)
j':js' -> j <| go (i :| is) (j' :| js')
relabelTranspositions' :: Ord a => NonEmpty (a,a) -> [(N,N)]
relabelTranspositions' is = go'' is'''
where
is' = go Z is
is'' = sortBy (\a b -> snd a `compare` snd b) is'
is''' = go' Z is''
go :: N -> NonEmpty (a,b) -> NonEmpty (N,b)
go n ((_,y) :| []) = (n,y) :| []
go n ((_,y) :| (j:js)) = (n,y) <| go (S n) (j :| js)
go' :: N -> NonEmpty (a,b) -> NonEmpty (a,N)
go' n ((x,_) :| []) = (x,n) :| []
go' n ((x,_) :| (j:js)) = (x,n) <| go' (S n) (j :| js)
go'' :: NonEmpty (a,a) -> [(a,a)]
go'' ((x1,x2) :| []) = [(x2,x1)]
go'' ((x1,x2) :| (y:ys)) = (x2,x1) : go'' (y :| ys)
|])
toInt :: N -> Int
toInt Z = 0
toInt (S n) = 1 + toInt n