{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE InstanceSigs #-}
module Data.Tensor.Static (
IsTensor(..)
, Tensor(..)
, TensorConstructor
, PositiveDims
, fill
, zero
, enumFromN
, EnumFromN
, enumFromStepN
, EnumFromStepN
, generate
, Generate
, dimensions
, elemsNumber
, subtensorsElemsNumbers
, ElemsNumber
, SubtensorsElemsNumbers
, FlattenIndex
, AllIndexes
, NatsFromTo
, NormalizeDims
, withTensor
, add
, Add
, diff
, Diff
, scale
, Scale
, cons
, Cons
, ConsSubtensorDims
, DimsAfterCons
, snoc
, Snoc
, SnocSubtensorDims
, DimsAfterSnoc
, append
, Append
, DimsAfterAppend
, remove
, Remove
, DimsAfterRemove
, NestedList
, toNestedList
, ToNestedList
, tensorElem
, TensorElem
, Subtensor
, SubtensorStartIndex
, SubtensorDims
, subtensor
, SubtensorCtx
, getSubtensor
, GetSubtensor
, getSubtensorElems
, GetSubtensorElems
, setSubtensor
, SetSubtensor
, setSubtensorElems
, SetSubtensorElems
, mapSubtensorElems
, MapSubtensorElems
, SliceEndIndex
, ElemsInSlice
, slice
, Slice
, getSlice
, GetSlice
, getSliceElems
, GetSliceElems
, setSlice
, SetSlice
, setSliceElems
, SetSliceElems
, mapSliceElems
, MapSliceElems
, MonoFunctorCtx
, MonoFoldableCtx
, MonoTraversableCtx
, MonoZipCtx
, unsafeWithTensorPtr
) where
import Control.Lens (Lens', lens, Each(..), traversed)
import Data.Containers (MonoZip(..))
import Data.Function.NAry (NAry, ApplyNAry(..))
import Data.Kind (Type, Constraint)
import Data.List (intersperse)
import Data.List.Split (chunksOf)
import Data.MonoTraversable (MonoFunctor(..), MonoFoldable(..), MonoTraversable(..), Element)
import Data.Proxy (Proxy(..))
import Data.Singletons.Prelude.List (Tail, Product, Length)
import Data.Type.Equality (type (==))
import Data.Type.Bool (If, type (&&))
import Foreign.Storable (Storable(..))
import Foreign.Ptr (Ptr, castPtr)
import Foreign.Marshal.Utils (with)
import GHC.TypeLits (Nat, KnownNat, natVal, type (+), type (-), type (<=?), type (*), TypeError, ErrorMessage(..))
import Type.List (MkCtx, DemoteWith(..), KnownNats(..))
import qualified Data.List.Unrolled as U
natVal' :: forall (n :: Nat). (KnownNat n) => Int
natVal' = fromInteger $ natVal (Proxy @n)
{-# INLINE natVal' #-}
type family PositiveDims (dims :: [Nat]) :: Constraint where
PositiveDims '[] = ()
PositiveDims (d ': ds) = PositiveDims' (1 <=? d) ds
type family PositiveDims' (b :: Bool) (dims :: [Nat]) :: Constraint where
PositiveDims' 'True ds = PositiveDims ds
PositiveDims' 'False _ = TypeError ('Text "Tensor must have positive dimensions.")
type FlattenIndex (index :: [Nat]) (dims :: [Nat]) = FlattenIndex' index (SubtensorsElemsNumbers dims)
type family FlattenIndex' (index :: [Nat]) (elemNumbers :: [Nat]) :: Nat where
FlattenIndex' '[] '[] = 0
FlattenIndex' (i ': is) '[] = TypeError ('Text "FlattenIndex: Too many dimensions in the index for subtensor.")
FlattenIndex' '[] (n ': ns) = TypeError ('Text "FlattenIndex: Not enough dimensions in the index for subtensor.")
FlattenIndex' (i ': is) (n ': ns) = i * n + FlattenIndex' is ns
type family NormalizeDims (dims :: [Nat]) :: [Nat] where
NormalizeDims '[] = '[]
NormalizeDims (1 ': xs) = NormalizeDims xs
NormalizeDims (x ': xs) = x ': NormalizeDims xs
type AllIndexes (dims :: [Nat]) = Sequence (IndexesRanges dims)
type family Sequence (xss :: [[k]]) :: [[k]] where
Sequence '[] = '[ '[] ]
Sequence (x ': xs) = Sequence' x (Sequence xs)
type family Sequence' (xs :: [k]) (yss :: [[k]]) :: [[k]] where
Sequence' '[] _ = '[]
Sequence' (x ': xs) ys = Sequence'' x ys ++ Sequence' xs ys
type family Sequence'' (x :: k) (yss :: [[k]]) :: [[k]] where
Sequence'' _ '[] = '[]
Sequence'' x (y ': ys) = '[ x ': y ] ++ Sequence'' x ys
infixr 5 ++
type family (++) (xs :: [k]) (ys :: [k]) :: [k] where
'[] ++ ys = ys
(x ': xs) ++ ys = x ': (xs ++ ys)
type family IndexesRanges (dims :: [Nat]) :: [[Nat]] where
IndexesRanges '[] = '[]
IndexesRanges (d ': ds) = IndexesRanges' (d ': ds) (1 <=? d)
type family IndexesRanges' (dims :: [Nat]) (dimPositive :: Bool) :: [[Nat]] where
IndexesRanges' (d ': ds) 'True = NatsFromTo 0 (d - 1) ': IndexesRanges ds
IndexesRanges' (d ': _) 'False =
TypeError ('Text "IndexesRanges: Tensor has non-positive dimension: " ':<>: 'ShowType d)
type NatsFromTo (from :: Nat) (to :: Nat) = NatsFromTo' from to (from <=? to)
type family NatsFromTo' (from :: Nat) (to :: Nat) (fromLTEto :: Bool) :: [Nat] where
NatsFromTo' _ _ 'False = '[]
NatsFromTo' f t 'True = f ': NatsFromTo' (f + 1) t (f + 1 <=? t)
class (PositiveDims dims, KnownNats dims) => IsTensor (dims :: [Nat]) e where
{-# MINIMAL tensor, unsafeFromList, toList #-}
data Tensor dims e :: Type
tensor :: TensorConstructor dims e
unsafeFromList :: [e] -> Tensor dims e
toList :: Tensor dims e -> [] e
type TensorConstructor (dims :: [Nat]) (e :: Type) = NAry (ElemsNumber dims) e (Tensor dims e)
type family ElemsNumber (dims :: [Nat]) :: Nat where
ElemsNumber '[] = 1
ElemsNumber (d ': ds) = d * ElemsNumber ds
instance IsTensor '[] e where
newtype Tensor '[] e = Scalar e
tensor = Scalar
{-# INLINE tensor #-}
unsafeFromList (a:_) = Scalar a
unsafeFromList _ = error "Not enough elements to build a Tensor of shape []."
{-# INLINE unsafeFromList #-}
toList (Scalar a) = [a]
{-# INLINE toList #-}
instance {-# OVERLAPPABLE #-}
( Show (NestedList (Length dims) e)
, IsTensor dims e
, ToNestedListWrk dims e
, KnownNats dims
) =>
Show (Tensor dims e)
where
show t = "Tensor'" ++ dims ++ " " ++ show (toNestedList t)
where
dims = concat $ intersperse "\'" $ map show (dimensions @dims)
{-# INLINE show #-}
instance {-# OVERLAPPING #-} (Show e) => Show (Tensor '[] e) where
show (Scalar e) = "Scalar " ++ show e
{-# INLINE show #-}
withTensor :: forall dims e r.
( IsTensor dims e
, ApplyNAry (ElemsNumber dims) e r
)
=> Tensor dims e
-> (NAry (ElemsNumber dims) e r)
-> r
withTensor t f = applyNAry @(ElemsNumber dims) @e @r f (toList t)
{-# INLINE withTensor #-}
dimensions :: forall (dims :: [Nat]). (KnownNats dims) => [Int]
dimensions = natsVal @dims
{-# INLINE dimensions #-}
elemsNumber :: forall (dims :: [Nat]). (KnownNat (ElemsNumber dims)) => Int
elemsNumber = natVal' @(ElemsNumber dims)
{-# INLINE elemsNumber #-}
subtensorsElemsNumbers :: forall (dims :: [Nat]). (KnownNats (SubtensorsElemsNumbers dims)) => [Int]
subtensorsElemsNumbers = natsVal @(SubtensorsElemsNumbers dims)
{-# INLINE subtensorsElemsNumbers #-}
add :: (Add dims e) => Tensor dims e -> Tensor dims e -> Tensor dims e
add = ozipWith (+)
{-# INLINE add #-}
type Add (dims :: [Nat]) e =
( IsTensor dims e
, Num e
, U.ZipWith (ElemsNumber dims)
, U.Zip (ElemsNumber dims)
, U.Unzip (ElemsNumber dims)
, U.Map (ElemsNumber dims)
)
diff :: (Diff dims e) => Tensor dims e -> Tensor dims e -> Tensor dims e
diff = ozipWith (-)
{-# INLINE diff #-}
type Diff (dims :: [Nat]) e =
( IsTensor dims e
, Num e
, U.ZipWith (ElemsNumber dims)
, U.Zip (ElemsNumber dims)
, U.Unzip (ElemsNumber dims)
, U.Map (ElemsNumber dims)
)
scale :: (Scale dims e) => Tensor dims e -> e -> Tensor dims e
scale t k = omap (*k) t
{-# INLINE scale #-}
type Scale (dims :: [Nat]) e =
( IsTensor dims e
, Num e
, U.Map (ElemsNumber dims)
)
fill :: forall (dims :: [Nat]) e. (Fill dims e) => e -> Tensor dims e
fill = unsafeFromList . U.replicate @(ElemsNumber dims)
{-# INLINE fill #-}
type Fill (dims :: [Nat]) e = (IsTensor dims e, U.Replicate (ElemsNumber dims))
zero :: (Fill dims e, Num e) => Tensor dims e
zero = fill 0
{-# INLINE zero #-}
enumFromN :: forall (dims :: [Nat]) e.
(EnumFromN dims e)
=> e
-> Tensor dims e
enumFromN = unsafeFromList . U.enumFromN @(ElemsNumber dims)
{-# INLINE enumFromN #-}
type EnumFromN (dims :: [Nat]) e =
( IsTensor dims e
, U.EnumFromN (ElemsNumber dims)
, Num e
)
enumFromStepN :: forall (dims :: [Nat]) e.
(EnumFromStepN dims e)
=> e
-> e
-> Tensor dims e
enumFromStepN a = unsafeFromList . U.enumFromStepN @(ElemsNumber dims) a
{-# INLINE enumFromStepN #-}
type EnumFromStepN (dims :: [Nat]) e =
( IsTensor dims e
, U.EnumFromStepN (ElemsNumber dims)
, Num e
)
generate :: forall (dims :: [Nat]) (e :: Type) (kctx :: Type) (ctx :: kctx).
(Generate dims e kctx ctx)
=> (forall (index :: [Nat]).
(MkCtx [Nat] kctx ctx index)
=> Proxy index
-> e
)
-> Tensor dims e
generate f = unsafeFromList (demoteWith @[Nat] @kctx @ctx @(AllIndexes dims) f)
{-# INLINE generate #-}
type Generate (dims :: [Nat]) (e :: Type) (kctx :: Type) (ctx :: kctx) =
( IsTensor dims e
, DemoteWith [Nat] kctx ctx (AllIndexes dims)
)
type family NestedList (depth :: Nat) (e :: Type) :: Type where
NestedList 0 e = e
NestedList n e = [NestedList (n - 1) e]
toNestedList :: forall dims e. (ToNestedList dims e)
=> Tensor dims e
-> NestedList (Length dims) e
toNestedList = toNestedListWrk @dims @e . toList
{-# INLINE toNestedList #-}
type ToNestedList (dims :: [Nat]) e = (IsTensor dims e, ToNestedListWrk dims e)
class ToNestedListWrk (dims :: [Nat]) e where
toNestedListWrk :: [e] -> NestedList (Length dims) e
instance ToNestedListWrk '[] e where
toNestedListWrk = head
instance ToNestedListWrk '[x] e where
toNestedListWrk = id
{-# INLINE toNestedListWrk #-}
instance ( ToNestedListWrk (xx ': xs) e
, KnownNat (Product (xx ': xs))
, NestedList (Length (x ': xx ': xs)) e ~ [NestedList (Length (xx ': xs)) e]
) =>
ToNestedListWrk (x ': xx ': xs) e
where
toNestedListWrk xs = map (toNestedListWrk @(xx ': xs)) $ chunksOf (natVal' @(Product (xx ': xs))) xs
{-# INLINABLE toNestedListWrk #-}
type Subtensor index dims e = Tensor (NormalizeDims (SubtensorDims index dims)) e
type SubtensorsElemsNumbers (dims :: [Nat]) = Tail (SubtensorsElemsNumbers' dims)
type family SubtensorsElemsNumbers' (dims :: [Nat]) :: [Nat] where
SubtensorsElemsNumbers' '[] = '[1]
SubtensorsElemsNumbers' (d ': ds) = SubtensorsElemsNumbers'' d (SubtensorsElemsNumbers' ds)
type family SubtensorsElemsNumbers'' (dim :: Nat) (dims :: [Nat]) :: [Nat] where
SubtensorsElemsNumbers'' d (q ': qs) = d * q ': q ': qs
type family SubtensorDims (index :: [Nat]) (dims :: [Nat]) :: [Nat] where
SubtensorDims '[] ds = ds
SubtensorDims (_ ': _ ) '[] = TypeError ('Text "SubtensorDims: Too many dimensions in the index for subtensor.")
SubtensorDims (i ': is) (d ': ds) =
If (i <=? d - 1)
(1 ': SubtensorDims is ds)
(TypeError
('Text "SubtensorDims: Index "
':<>: 'ShowType i
':<>: 'Text " is outside of the range of dimension [0.."
':<>: 'ShowType (d - 1)
':<>: 'Text "]."))
type family SubtensorStartIndex (index :: [Nat]) (dims :: [Nat]) :: [Nat] where
SubtensorStartIndex '[] '[] = '[]
SubtensorStartIndex (i ': is) '[] = TypeError ('Text "SubtensorStartIndex: Too many dimensions in the index for subtensor.")
SubtensorStartIndex '[] (d ': ds) = 0 ': SubtensorStartIndex '[] ds
SubtensorStartIndex (i ': is) (d ': ds) =
If (i <=? d - 1)
(i ': SubtensorStartIndex is ds)
(TypeError
('Text "SubtensorStartIndex: Index "
':<>: 'ShowType i
':<>: 'Text " is outside of the range of dimension [0.."
':<>: 'ShowType (d - 1)
':<>: 'Text "]."))
getSubtensor :: forall (index :: [Nat]) (dims :: [Nat]) e.
(GetSubtensor index dims e)
=> Tensor dims e
-> Subtensor index dims e
getSubtensor = getSlice @(SubtensorStartIndex index dims) @(SubtensorDims index dims) @dims @e
{-# INLINE getSubtensor #-}
type GetSubtensor index dims e =
( GetSlice (SubtensorStartIndex index dims) (SubtensorDims index dims) dims e
)
getSubtensorElems :: forall (index :: [Nat]) (dims :: [Nat]) e.
(GetSubtensorElems index dims e)
=> Tensor dims e
-> [e]
getSubtensorElems = getSliceElems @(SubtensorStartIndex index dims) @(SubtensorDims index dims) @dims @e
{-# INLINE getSubtensorElems #-}
type GetSubtensorElems index dims e =
GetSliceElems (SubtensorStartIndex index dims) (SubtensorDims index dims) dims e
setSubtensor :: forall (index :: [Nat]) (dims :: [Nat]) e.
(SetSubtensor index dims e)
=> Tensor dims e
-> Subtensor index dims e
-> Tensor dims e
setSubtensor = setSlice @(SubtensorStartIndex index dims) @(SubtensorDims index dims) @dims @e
{-# INLINE setSubtensor #-}
type SetSubtensor index dims e =
SetSlice (SubtensorStartIndex index dims) (SubtensorDims index dims) dims e
setSubtensorElems :: forall (index :: [Nat]) (dims :: [Nat]) e.
(SetSubtensorElems index dims e)
=> Tensor dims e
-> [e]
-> Maybe (Tensor dims e)
setSubtensorElems = setSliceElems @(SubtensorStartIndex index dims) @(SubtensorDims index dims) @dims @e
{-# INLINE setSubtensorElems #-}
type SetSubtensorElems index dims e =
SetSliceElems (SubtensorStartIndex index dims) (SubtensorDims index dims) dims e
mapSubtensorElems :: forall (index :: [Nat]) (dims :: [Nat]) e.
(MapSubtensorElems index dims e)
=> Tensor dims e
-> (e -> e)
-> Tensor dims e
mapSubtensorElems = mapSliceElems @(SubtensorStartIndex index dims) @(SubtensorDims index dims) @dims @e
{-# INLINE mapSubtensorElems #-}
type MapSubtensorElems index dims e =
MapSliceElems (SubtensorStartIndex index dims) (SubtensorDims index dims) dims e
subtensor :: forall (index :: [Nat]) (dims :: [Nat]) e.
(SubtensorCtx index dims e)
=> Lens' (Tensor dims e) (Subtensor index dims e)
subtensor = lens (getSubtensor @index @dims @e) (setSubtensor @index @dims @e)
{-# INLINE subtensor #-}
type SubtensorCtx index dims e =
( GetSubtensor index dims e
, SetSubtensor index dims e)
tensorElem :: forall (index :: [Nat]) (dims :: [Nat]) e.
(TensorElem index dims e)
=> Lens' (Tensor dims e) e
tensorElem = subtensor @index @dims @e . (lens (\(Scalar a) -> a) (\_ b -> Scalar b))
{-# INLINE tensorElem #-}
type TensorElem index dims e =
( SubtensorCtx index dims e
, NormalizeDims (SubtensorDims index dims) ~ '[]
)
type family SliceEndIndex (startIndex :: [Nat]) (sliceDims :: [Nat]) (dims :: [Nat]) :: [Nat] where
SliceEndIndex '[] '[] '[] = '[]
SliceEndIndex '[] '[] (d ': ds) = TypeError ('Text "SliceEndIndex: Slice and its starting index have not enough dimensions.")
SliceEndIndex '[] (sd ': sds) '[] = TypeError ('Text "SliceEndIndex: Slice has too many dimensions.")
SliceEndIndex '[] (sd ': sds) (d ': ds) = TypeError ('Text "SliceEndIndex: Starting index of the slice has not enough dimensions.")
SliceEndIndex (si ': sis) '[] '[] = TypeError ('Text "SliceEndIndex: Starting index of the slice has too many dimensions.")
SliceEndIndex (si ': sis) '[] (d ': ds) = TypeError ('Text "SliceEndIndex: Slice has not enough dimensions.")
SliceEndIndex (si ': sis) (sd ': sds) '[] = TypeError ('Text "SliceEndIndex: Slice and its starting index have too many dimensions.")
SliceEndIndex (si ': sis) (sd ': sds) (d ': ds) = SliceEndIndex' (si ': sis) (sd ': sds) (d ': ds) (1 <=? sd)
type family SliceEndIndex' (startIndex :: [Nat]) (sliceDims :: [Nat]) (dims :: [Nat]) (sliceDimPositive :: Bool) :: [Nat] where
SliceEndIndex' (si ': sis) (sd ': sds) (d ': ds) 'True = SliceEndIndex'' (si ': sis) (sd ': sds) (d ': ds) (si + sd <=? d)
SliceEndIndex' _ (sd ': _) _ 'False =
TypeError ('Text "SliceEndIndex: Slice has non-positive dimension: " ':<>: 'ShowType sd)
type family SliceEndIndex'' (startIndex :: [Nat]) (sliceDims :: [Nat]) (dims :: [Nat]) (sliceDimInside :: Bool) :: [Nat] where
SliceEndIndex'' (si ': sis) (sd ': sds) (d ': ds) 'True = (si + sd - 1 ': SliceEndIndex sis sds ds)
SliceEndIndex'' (si ': sis) (sd ': sds) (d ': ds) 'False =
(TypeError
( 'Text "SliceEndIndex: Slice dimension is outside of the tensor. It starts at "
':<>: 'ShowType si
':<>: 'Text " and ends at "
':<>: 'ShowType (si + sd - 1)
':<>: 'Text " which is outside of the range of the tensor's dimension [0.."
':<>: 'ShowType (d - 1)
':<>: 'Text "]."))
type ElemsInSlice (startIndex :: [Nat]) (sliceDims :: [Nat]) (dims :: [Nat]) =
ElemsInSlice' startIndex (SliceEndIndex startIndex sliceDims dims) (AllIndexes dims)
type family ElemsInSlice' (startIndex :: [Nat]) (endIndex :: [Nat]) (indexes :: [[Nat]]) :: [Bool] where
ElemsInSlice' _ _ '[] = '[]
ElemsInSlice' startIndex endIndex (i ': is) = ElemsInSlice'' i startIndex endIndex ': ElemsInSlice' startIndex endIndex is
type family ElemsInSlice'' (index :: [Nat]) (startIndex :: [Nat]) (endIndex :: [Nat]) :: Bool where
ElemsInSlice'' (i ': is) (s ': ss) (e ': es) = s <=? i && i <=? e && ElemsInSlice'' is ss es
ElemsInSlice'' '[] '[] '[] = 'True
slice :: forall startIndex sliceDims dims e.
(Slice startIndex sliceDims dims e)
=> Lens' (Tensor dims e) (Tensor (NormalizeDims sliceDims) e)
slice = lens (getSlice @startIndex @sliceDims @dims @e) (setSlice @startIndex @sliceDims @dims @e)
{-# INLINE slice #-}
type Slice startIndex sliceDims dims e =
( IsTensor dims e
, IsTensor (NormalizeDims sliceDims) e
, GetSliceElemsWrk (ElemsInSlice startIndex sliceDims dims)
, SetSliceElemsWrk (ElemsInSlice startIndex sliceDims dims)
)
getSlice :: forall startIndex sliceDims dims e.
(GetSlice startIndex sliceDims dims e)
=> Tensor dims e
-> Tensor (NormalizeDims sliceDims) e
getSlice = unsafeFromList . getSliceElems @startIndex @sliceDims @dims @e
{-# INLINE getSlice #-}
type GetSlice startIndex sliceDims dims e =
( IsTensor dims e
, IsTensor (NormalizeDims sliceDims) e
, GetSliceElemsWrk (ElemsInSlice startIndex sliceDims dims)
)
getSliceElems :: forall startIndex sliceDims dims e.
(GetSliceElems startIndex sliceDims dims e)
=> Tensor dims e
-> [e]
getSliceElems = getSliceElemsWrk @(ElemsInSlice startIndex sliceDims dims) . toList
{-# INLINE getSliceElems #-}
type GetSliceElems startIndex sliceDims dims e =
( IsTensor dims e
, GetSliceElemsWrk (ElemsInSlice startIndex sliceDims dims)
)
impossible_notEnoughTensorElems :: a
impossible_notEnoughTensorElems =
error "Impossible happend! Not enough elements in the tensor. Please report this bug."
{-# INLINE impossible_notEnoughTensorElems #-}
class GetSliceElemsWrk (elemsInSlice :: [Bool]) where
getSliceElemsWrk :: [e] -> [e]
instance GetSliceElemsWrk '[] where
getSliceElemsWrk _ = []
{-# INLINE getSliceElemsWrk #-}
instance (GetSliceElemsWrk xs) => GetSliceElemsWrk ('True ': xs) where
getSliceElemsWrk [] = impossible_notEnoughTensorElems
getSliceElemsWrk (x : xs) = x : getSliceElemsWrk @xs xs
{-# INLINE getSliceElemsWrk #-}
instance (GetSliceElemsWrk xs) => GetSliceElemsWrk ('False ': xs) where
getSliceElemsWrk [] = impossible_notEnoughTensorElems
getSliceElemsWrk (_ : xs) = getSliceElemsWrk @xs xs
{-# INLINE getSliceElemsWrk #-}
setSlice :: forall startIndex sliceDims dims e.
(SetSlice startIndex sliceDims dims e)
=> Tensor dims e
-> Tensor (NormalizeDims sliceDims) e
-> Tensor dims e
setSlice t st =
case setSliceElems @startIndex @sliceDims @dims @e t $ toList st of
Nothing -> impossible_notEnoughTensorElems
Just x -> x
{-# INLINE setSlice #-}
type SetSlice startIndex sliceDims dims e =
( IsTensor dims e
, IsTensor (NormalizeDims sliceDims) e
, SetSliceElemsWrk (ElemsInSlice startIndex sliceDims dims)
)
setSliceElems :: forall startIndex sliceDims dims e.
(SetSliceElems startIndex sliceDims dims e)
=> Tensor dims e
-> [e]
-> Maybe (Tensor dims e)
setSliceElems t xs = unsafeFromList <$> setSliceElemsWrk @(ElemsInSlice startIndex sliceDims dims) (toList t) xs
{-# INLINE setSliceElems #-}
type SetSliceElems startIndex sliceDims dims e =
( IsTensor dims e
, SetSliceElemsWrk (ElemsInSlice startIndex sliceDims dims)
)
class SetSliceElemsWrk (elemsInSlice :: [Bool]) where
setSliceElemsWrk :: [e] -> [e] -> Maybe [e]
instance SetSliceElemsWrk '[] where
setSliceElemsWrk _ _ = Just []
{-# INLINE setSliceElemsWrk #-}
instance (SetSliceElemsWrk xs) => SetSliceElemsWrk ('True ': xs) where
setSliceElemsWrk [] _ = impossible_notEnoughTensorElems
setSliceElemsWrk _ [] = Nothing
setSliceElemsWrk (_ : xs) (y : ys) = (y :) <$> setSliceElemsWrk @xs xs ys
{-# INLINE setSliceElemsWrk #-}
instance (SetSliceElemsWrk xs) => SetSliceElemsWrk ('False ': xs) where
setSliceElemsWrk [] _ = impossible_notEnoughTensorElems
setSliceElemsWrk (x : xs) yss = (x :) <$> setSliceElemsWrk @xs xs yss
{-# INLINE setSliceElemsWrk #-}
mapSliceElems :: forall startIndex sliceDims dims e.
(MapSliceElems startIndex sliceDims dims e)
=> Tensor dims e
-> (e -> e)
-> Tensor dims e
mapSliceElems t f =
case setSliceElems @startIndex @sliceDims @dims @e
t (U.map @(ElemsNumber sliceDims) f (getSliceElems @startIndex @sliceDims @dims @e t))
of
Nothing -> impossible_notEnoughTensorElems
Just x -> x
{-# INLINE mapSliceElems #-}
type MapSliceElems startIndex sliceDims dims e =
( IsTensor dims e
, GetSliceElemsWrk (ElemsInSlice startIndex sliceDims dims)
, SetSliceElemsWrk (ElemsInSlice startIndex sliceDims dims)
, U.Map (ElemsNumber sliceDims)
)
remove :: forall (axis :: Nat) (indexOnAxis :: Nat) (dims :: [Nat]) e.
(Remove axis indexOnAxis dims e)
=> Tensor dims e
-> Tensor (DimsAfterRemove axis indexOnAxis dims) e
remove = unsafeFromList . removeWrk @(ElemsInSlice (RemoveSliceStartIndex axis indexOnAxis dims) (RemoveSliceDims axis indexOnAxis dims) dims) . toList
{-# INLINE remove #-}
type Remove (axis :: Nat) (indexOnAxis :: Nat) (dims :: [Nat]) e =
( IsTensor dims e
, IsTensor (DimsAfterRemove axis indexOnAxis dims) e
, RemoveWrk (ElemsInSlice (RemoveSliceStartIndex axis indexOnAxis dims) (RemoveSliceDims axis indexOnAxis dims) dims)
)
type family DimsAfterRemove (axis :: Nat) (index :: Nat) (dims :: [Nat]) :: [Nat] where
DimsAfterRemove _ _ '[] = TypeError ('Text "DimsAfterRemove: axis must be in range [0..(number of dimensions in the tensor)].")
DimsAfterRemove 0 i (d ': ds) =
If (i <=? d - 1)
(d - 1 ': ds)
(TypeError (
'Text "DimsAfterRemove: Index "
':<>: 'ShowType i
':<>: 'Text " is outside of the range of dimension [0.."
':<>: 'ShowType (d - 1)
':<>: 'Text "]."))
DimsAfterRemove a i (d ': ds) = d ': DimsAfterRemove (a - 1) i ds
type RemoveSliceStartIndex (axis :: Nat) (indexOnAxis :: Nat) (dims :: [Nat]) = RemoveSliceStartIndex' axis indexOnAxis dims 0
type family RemoveSliceStartIndex' (axis :: Nat) (indexOnAxis :: Nat) (dims :: [Nat]) (n :: Nat) :: [Nat] where
RemoveSliceStartIndex' _ _ '[] _ = '[]
RemoveSliceStartIndex' a i (d ': ds) n =
If (a == n) i 0 ': RemoveSliceStartIndex' a i ds (n + 1)
type family RemoveSliceDims (axis :: Nat) (indexOnAxis :: Nat) (dims :: [Nat]) :: [Nat] where
RemoveSliceDims _ _ '[] = TypeError ('Text "RemoveSliceDims: axis must be in range [0..(number of dimensions in the tensor)].")
RemoveSliceDims 0 i (d ': ds) =
If (i <=? d - 1)
(1 ': ds)
(TypeError (
'Text "RemoveSliceDims: Index "
':<>: 'ShowType i
':<>: 'Text " is outside of the range of dimension [0.."
':<>: 'ShowType (d - 1)
':<>: 'Text "]."))
RemoveSliceDims a i (d ': ds) = d ': RemoveSliceDims (a - 1) i ds
class RemoveWrk (elemsInSlice :: [Bool]) where
removeWrk :: [e] -> [e]
instance RemoveWrk '[] where
removeWrk _ = []
{-# INLINE removeWrk #-}
instance (RemoveWrk xs) => RemoveWrk ('False ': xs) where
removeWrk [] = impossible_notEnoughTensorElems
removeWrk (x : xs) = x : removeWrk @xs xs
{-# INLINE removeWrk #-}
instance (RemoveWrk xs) => RemoveWrk ('True ': xs) where
removeWrk [] = impossible_notEnoughTensorElems
removeWrk (_ : xs) = removeWrk @xs xs
{-# INLINE removeWrk #-}
cons :: forall (axis :: Nat) (dims :: [Nat]) e.
(Cons axis dims e) =>
Tensor (NormalizeDims (ConsSubtensorDims axis dims)) e
-> Tensor dims e
-> Tensor (DimsAfterCons axis dims) e
cons st t =
setSlice @(ConsSubtensorStartingIndex dims) @(ConsSubtensorDims axis dims) @(DimsAfterCons axis dims) @e t' st
where
t' = setSlice @(ConsTensorStartingIndex axis dims) @dims @(DimsAfterCons axis dims) z t
z = fill @(DimsAfterCons axis dims) @e (head $ toList t)
{-# INLINE cons #-}
type Cons (axis :: Nat) (dims :: [Nat]) e =
( SetSlice (ConsSubtensorStartingIndex dims) (ConsSubtensorDims axis dims) (DimsAfterCons axis dims) e
, SetSlice (ConsTensorStartingIndex axis dims) dims (DimsAfterCons axis dims) e
, dims ~ NormalizeDims dims
, Fill (DimsAfterCons axis dims) e
)
type family ConsSubtensorStartingIndex (dims :: [Nat]) :: [Nat] where
ConsSubtensorStartingIndex '[] = '[]
ConsSubtensorStartingIndex (_ ': ds) = 0 ': ConsSubtensorStartingIndex ds
type ConsTensorStartingIndex (axis :: Nat) (dims :: [Nat]) = ConsTensorStartingIndex' axis dims 0
type family ConsTensorStartingIndex' (axis :: Nat) (dims :: [Nat]) (i :: Nat) :: [Nat] where
ConsTensorStartingIndex' _ '[] _ = '[]
ConsTensorStartingIndex' a (d ': ds) i =
If (a == i) 1 0 ': ConsTensorStartingIndex' a ds (i + 1)
type ConsSubtensorDims (axis :: Nat) (dims :: [Nat]) = ConsSubtensorDims' axis dims 0
type family ConsSubtensorDims' (axis :: Nat) (dims :: [Nat]) (i :: Nat) :: [Nat] where
ConsSubtensorDims' _ '[] _ = '[]
ConsSubtensorDims' a (d ': ds) i =
If (a == i) 1 d ': ConsSubtensorDims' a ds (i + 1)
type family DimsAfterCons (axis :: Nat) (dims :: [Nat]) :: [Nat] where
DimsAfterCons 0 (d ': ds) = d + 1 ': ds
DimsAfterCons a (d ': ds) = d ': DimsAfterCons (a - 1) ds
DimsAfterCons _ '[] = TypeError ('Text "DimsAfterCons: axis must be in range [0..(number of dimensions in the tensor)].")
snoc :: forall (axis :: Nat) (dims :: [Nat]) e.
(Snoc axis dims e) =>
Tensor dims e
-> Tensor (NormalizeDims (SnocSubtensorDims axis dims)) e
-> Tensor (DimsAfterSnoc axis dims) e
snoc t st =
setSlice @(SnocSubtensorStartingIndex axis dims) @(SnocSubtensorDims axis dims) @(DimsAfterSnoc axis dims) @e t' st
where
t' = setSlice @(SnocTensorStartingIndex dims) @dims @(DimsAfterSnoc axis dims) z t
z = fill @(DimsAfterSnoc axis dims) @e (head $ toList t)
{-# INLINE snoc #-}
type Snoc (axis :: Nat) (dims :: [Nat]) e =
( SetSlice (SnocSubtensorStartingIndex axis dims) (SnocSubtensorDims axis dims) (DimsAfterSnoc axis dims) e
, SetSlice (SnocTensorStartingIndex dims) dims (DimsAfterSnoc axis dims) e
, dims ~ NormalizeDims dims
, Fill (DimsAfterSnoc axis dims) e
)
type family SnocTensorStartingIndex (dims :: [Nat]) :: [Nat] where
SnocTensorStartingIndex '[] = '[]
SnocTensorStartingIndex (_ ': ds) = 0 ': SnocTensorStartingIndex ds
type SnocSubtensorStartingIndex (axis :: Nat) (dims :: [Nat]) = SnocSubtensorStartingIndex' axis dims 0
type family SnocSubtensorStartingIndex' (axis :: Nat) (dims :: [Nat]) (i :: Nat) :: [Nat] where
SnocSubtensorStartingIndex' _ '[] _ = '[]
SnocSubtensorStartingIndex' a (d ': ds) i =
If (a == i) d 0 ': SnocSubtensorStartingIndex' a ds (i + 1)
type SnocSubtensorDims (axis :: Nat) (dims :: [Nat]) = SnocSubtensorDims' axis dims 0
type family SnocSubtensorDims' (axis :: Nat) (dims :: [Nat]) (i :: Nat) :: [Nat] where
SnocSubtensorDims' _ '[] _ = '[]
SnocSubtensorDims' a (d ': ds) i =
If (a == i) 1 d ': SnocSubtensorDims' a ds (i + 1)
type family DimsAfterSnoc (axis :: Nat) (dims :: [Nat]) :: [Nat] where
DimsAfterSnoc 0 (d ': ds) = d + 1 ': ds
DimsAfterSnoc a (d ': ds) = d ': DimsAfterSnoc (a - 1) ds
DimsAfterSnoc _ '[] = TypeError ('Text "DimsAfterSnoc: axis must be in range [0..(number of dimensions in the tensor)].")
append :: forall (axis :: Nat) (dims0 :: [Nat]) (dims1 :: [Nat]) e.
(Append axis dims0 dims1 e)
=> Tensor dims0 e
-> Tensor dims1 e
-> Tensor (DimsAfterAppend axis dims0 dims1) e
append t0 t1 =
setSlice @(AppendSndTensorStartingIndex axis dims1) @dims1 @(DimsAfterAppend axis dims0 dims1) @e t0' t1
where
t0' = setSlice @(AppendFstTensorStartingIndex dims0) @dims0 @(DimsAfterAppend axis dims0 dims1) z t0
z = fill @(DimsAfterAppend axis dims0 dims1) @e (head $ toList t0)
{-# INLINE append #-}
type Append (axis :: Nat) (dims0 :: [Nat]) (dims1 :: [Nat]) e =
( SetSlice (AppendFstTensorStartingIndex dims0) dims0 (DimsAfterAppend axis dims0 dims1) e
, SetSlice (AppendSndTensorStartingIndex axis dims1) dims1 (DimsAfterAppend axis dims0 dims1) e
, dims0 ~ NormalizeDims dims0
, dims1 ~ NormalizeDims dims1
, Fill (DimsAfterAppend axis dims0 dims1) e
)
type family AppendFstTensorStartingIndex (dims :: [Nat]) :: [Nat] where
AppendFstTensorStartingIndex '[] = '[]
AppendFstTensorStartingIndex (_ ': ds) = 0 ': SnocTensorStartingIndex ds
type AppendSndTensorStartingIndex (axis :: Nat) (dims :: [Nat]) = AppendSndTensorStartingIndex' axis dims 0
type family AppendSndTensorStartingIndex' (axis :: Nat) (dims :: [Nat]) (i :: Nat) :: [Nat] where
AppendSndTensorStartingIndex' _ '[] _ = '[]
AppendSndTensorStartingIndex' a (d ': ds) i =
If (a == i) d 0 ': AppendSndTensorStartingIndex' a ds (i + 1)
type DimsAfterAppend (axis :: Nat) (dims0 :: [Nat]) (dims1 :: [Nat]) = DimsAfterAppend' axis dims0 dims1 0
type family DimsAfterAppend' (axis :: Nat) (dims0 :: [Nat]) (dims1 :: [Nat]) (i :: Nat) :: [Nat] where
DimsAfterAppend' _ '[] (d1 ': d1s) _ = TypeError ('Text "DimsAfterAppend: Tensors must have the same number of dimensions.")
DimsAfterAppend' _ (d0 ': d0s) '[] _ = TypeError ('Text "DimsAfterAppend: Tensors must have the same number of dimensions.")
DimsAfterAppend' a '[] '[] a = TypeError ('Text "DimsAfterAppend: axis must be in range [0..(number of dimensions in the tensor)].")
DimsAfterAppend' a '[] '[] i = '[]
DimsAfterAppend' a (d0 ': d0s) (d1 ': d1s) a = d0 + d1 ': DimsAfterAppend' a d0s d1s (a + 1)
DimsAfterAppend' a (d ': d0s) (d ': d1s) i = d ': DimsAfterAppend' a d0s d1s (i + 1)
DimsAfterAppend' a (d0 ': d0s) (d1 ': d1s) i = TypeError ('Text "DimsAfterAppend: Tensors have incompatible dimensions.")
instance (IsTensor dims a, IsTensor dims b) => Each (Tensor dims a) (Tensor dims b) a b where
{-# INLINE each #-}
each f t = unsafeFromList <$> traversed f (toList t)
type instance Element (Tensor dims e) = e
instance (MonoFunctorCtx dims e) => MonoFunctor (Tensor dims e) where
{-# INLINE omap #-}
omap f = unsafeFromList . U.map @(ElemsNumber dims) f . toList
type MonoFunctorCtx (dims :: [Nat]) e =
( IsTensor dims e
, U.Map (ElemsNumber dims)
)
instance (MonoFoldableCtx dims e) => MonoFoldable (Tensor dims e) where
{-# INLINE ofoldr #-}
{-# INLINE ofoldMap #-}
{-# INLINE ofoldl' #-}
{-# INLINE ofoldr1Ex #-}
{-# INLINE ofoldl1Ex' #-}
ofoldr f z = U.foldr @(ElemsNumber dims) f z . toList
ofoldMap f = U.foldMap @(ElemsNumber dims) f . toList
ofoldl' f z = U.foldl @(ElemsNumber dims) f z . toList
ofoldr1Ex f = U.foldr1 @(ElemsNumber dims) f . toList
ofoldl1Ex' f = U.foldl1 @(ElemsNumber dims) f . toList
type MonoFoldableCtx (dims :: [Nat]) e =
( IsTensor dims e
, U.Foldr (ElemsNumber dims)
, U.Foldl (ElemsNumber dims)
, U.Foldr1 (ElemsNumber dims)
, U.Foldl1 (ElemsNumber dims)
)
instance (MonoTraversableCtx dims e) => MonoTraversable (Tensor dims e) where
{-# INLINE otraverse #-}
otraverse f t = unsafeFromList <$> traverse f (toList t)
type MonoTraversableCtx (dims :: [Nat]) e =
( IsTensor dims e
, U.Map (ElemsNumber dims)
, U.Foldr (ElemsNumber dims)
, U.Foldl (ElemsNumber dims)
, U.Foldr1 (ElemsNumber dims)
, U.Foldl1 (ElemsNumber dims)
)
instance (MonoZipCtx dims e) => MonoZip (Tensor dims e) where
{-# INLINE ozipWith #-}
{-# INLINE ozip #-}
{-# INLINE ounzip #-}
ozipWith f = \t1 t2 -> unsafeFromList $ U.zipWith @(ElemsNumber dims) f (toList t1) (toList t2)
ozip t1 t2 = U.zip @(ElemsNumber dims) (toList t1) (toList t2)
ounzip ps =
let (es1, es2) = U.unzip @(ElemsNumber dims) ps
!t1 = unsafeFromList es1
!t2 = unsafeFromList es2
in (t1, t2)
type MonoZipCtx (dims :: [Nat]) e =
( IsTensor dims e
, U.Map (ElemsNumber dims)
, U.ZipWith (ElemsNumber dims)
, U.Zip (ElemsNumber dims)
, U.Unzip (ElemsNumber dims)
)
instance (IsTensor dims e, Storable e, KnownNat (ElemsNumber dims)) => Storable (Tensor dims e) where
{-# INLINE alignment #-}
{-# INLINE sizeOf #-}
{-# INLINE peek #-}
{-# INLINE poke #-}
alignment _ = alignment (undefined :: e)
sizeOf _ = elemsNumber @dims * offsetDiff (undefined :: e) (undefined :: e)
peek p = unsafeFromList <$> mapM (\x -> peekByteOff p (x * size)) [0 .. count - 1]
where
size = offsetDiff (undefined :: e) (undefined :: e)
count = elemsNumber @dims
poke p m = mapM_ (\(i, x) -> pokeByteOff p (size * i) x) $ zip [0 .. count - 1] $ toList m
where
size = offsetDiff (undefined :: e) (undefined :: e)
count = elemsNumber @dims
unsafeWithTensorPtr :: (IsTensor dims e, Storable e, KnownNat (ElemsNumber dims)) => Tensor dims e -> (Ptr e -> IO a) -> IO a
unsafeWithTensorPtr t f = with t (f . castPtr)
{-# INLINE unsafeWithTensorPtr #-}
padding :: (Storable a, Storable b) => a -> b -> Int
padding a b = (alignB - sizeOf a) `mod` alignB
where alignB = alignment b
{-# INLINE padding #-}
offsetDiff :: (Storable a, Storable b) => a -> b -> Int
offsetDiff a b = sizeOf a + padding a b
{-# INLINE offsetDiff #-}