{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_GHC -fno-warn-unused-top-binds #-}
module NumHask.Array.Shape
( Shape (..),
HasShape (..),
type (++),
type (!!),
Take,
Drop,
Reverse,
ReverseGo,
Filter,
rank,
Rank,
ranks,
Ranks,
size,
Size,
dimension,
Dimension,
flatten,
shapen,
minimum,
Minimum,
checkIndex,
CheckIndex,
checkIndexes,
CheckIndexes,
addIndex,
AddIndex,
dropIndex,
DropIndex,
posRelative,
PosRelative,
PosRelativeGo,
DecMap,
addIndexes,
AddIndexes,
AddIndexesGo,
dropIndexes,
DropIndexes,
DropIndexesGo,
takeIndexes,
TakeIndexes,
exclude,
Exclude,
Enumerate,
EnumerateGo,
concatenate',
Concatenate,
CheckConcatenate,
Insert,
CheckInsert,
reorder',
Reorder,
CheckReorder,
squeeze',
Squeeze,
incAt,
decAt,
KnownNats (..),
KnownNatss (..),
)
where
import Data.Proxy
import Data.Type.Bool
import Data.Type.Equality
import GHC.TypeLits as L
import NumHask.Prelude as P hiding (Last, minimum)
newtype Shape (s :: [Nat]) = Shape {forall (s :: [Nat]). Shape s -> [Int]
shapeVal :: [Int]} deriving (Int -> Shape s -> ShowS
forall (s :: [Nat]). Int -> Shape s -> ShowS
forall (s :: [Nat]). [Shape s] -> ShowS
forall (s :: [Nat]). Shape s -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Shape s] -> ShowS
$cshowList :: forall (s :: [Nat]). [Shape s] -> ShowS
show :: Shape s -> String
$cshow :: forall (s :: [Nat]). Shape s -> String
showsPrec :: Int -> Shape s -> ShowS
$cshowsPrec :: forall (s :: [Nat]). Int -> Shape s -> ShowS
Show)
class HasShape s where
toShape :: Shape s
instance HasShape '[] where
toShape :: Shape '[]
toShape = forall (s :: [Nat]). [Int] -> Shape s
Shape []
instance (KnownNat n, HasShape s) => HasShape (n : s) where
toShape :: Shape (n : s)
toShape = forall (s :: [Nat]). [Int] -> Shape s
Shape forall a b. (a -> b) -> a -> b
$ forall a. FromInteger a => Integer -> a
fromInteger (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy :: Proxy n)) forall a. a -> [a] -> [a]
: forall (s :: [Nat]). Shape s -> [Int]
shapeVal (forall (s :: [Nat]). HasShape s => Shape s
toShape :: Shape s)
rank :: [a] -> Int
rank :: forall a. [a] -> Int
rank = forall (t :: * -> *) a. Foldable t => t a -> Int
length
{-# INLINE rank #-}
type family Rank (s :: [a]) :: Nat where
Rank '[] = 0
Rank (_ : s) = Rank s + 1
ranks :: [[a]] -> [Int]
ranks :: forall a. [[a]] -> [Int]
ranks = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [a] -> Int
rank
{-# INLINE ranks #-}
type family Ranks (s :: [[a]]) :: [Nat] where
Ranks '[] = '[]
Ranks (x : xs) = Rank x : Ranks xs
size :: [Int] -> Int
size :: [Int] -> Int
size [] = Int
1
size [Int
x] = Int
x
size [Int]
xs = forall a (f :: * -> *). (Multiplicative a, Foldable f) => f a -> a
P.product [Int]
xs
{-# INLINE size #-}
type family Size (s :: [Nat]) :: Nat where
Size '[] = 1
Size (n : s) = n L.* Size s
flatten :: [Int] -> [Int] -> Int
flatten :: [Int] -> [Int] -> Int
flatten [] [Int]
_ = Int
0
flatten [Int]
_ [Int
x'] = Int
x'
flatten [Int]
ns [Int]
xs = forall a (f :: * -> *). (Additive a, Foldable f) => f a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Multiplicative a => a -> a -> a
(*) [Int]
xs (forall a. Int -> [a] -> [a]
drop Int
1 forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr forall a. Multiplicative a => a -> a -> a
(*) forall a. Multiplicative a => a
one [Int]
ns)
{-# INLINE flatten #-}
shapen :: [Int] -> Int -> [Int]
shapen :: [Int] -> Int -> [Int]
shapen [] Int
_ = []
shapen [Int
_] Int
x' = [Int
x']
shapen [Int
_, Int
y] Int
x' = let (Int
i, Int
j) = forall a. Integral a => a -> a -> (a, a)
divMod Int
x' Int
y in [Int
i, Int
j]
shapen [Int]
ns Int
x =
forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
( \Int
a ([Int]
acc, Int
r) ->
let (Int
d, Int
m) = forall a. Integral a => a -> a -> (a, a)
divMod Int
r Int
a
in (Int
m forall a. a -> [a] -> [a]
: [Int]
acc, Int
d)
)
([], Int
x)
[Int]
ns
{-# INLINE shapen #-}
checkIndex :: Int -> Int -> Bool
checkIndex :: Int -> Int -> Bool
checkIndex Int
i Int
n = forall a. Additive a => a
zero forall a. Ord a => a -> a -> Bool
<= Int
i Bool -> Bool -> Bool
&& Int
i forall a. Additive a => a -> a -> a
+ forall a. Multiplicative a => a
one forall a. Ord a => a -> a -> Bool
<= Int
n
type family CheckIndex (i :: Nat) (n :: Nat) :: Bool where
CheckIndex i n =
If ((0 <=? i) && (i + 1 <=? n)) 'True (L.TypeError ('Text "index outside range"))
checkIndexes :: [Int] -> Int -> Bool
checkIndexes :: [Int] -> Int -> Bool
checkIndexes [Int]
is Int
n = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> Int -> Bool
`checkIndex` Int
n) [Int]
is
type family CheckIndexes (i :: [Nat]) (n :: Nat) :: Bool where
CheckIndexes '[] _ = 'True
CheckIndexes (i : is) n = CheckIndex i n && CheckIndexes is n
dimension :: [Int] -> Int -> Int
dimension :: [Int] -> Int -> Int
dimension (Int
s : [Int]
_) Int
0 = Int
s
dimension (Int
_ : [Int]
s) Int
n = [Int] -> Int -> Int
dimension [Int]
s (Int
n forall a. Subtractive a => a -> a -> a
- Int
1)
dimension [Int]
_ Int
_ = forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"dimension overflow")
type family Dimension (s :: [Nat]) (i :: Nat) :: Nat where
Dimension (s : _) 0 = s
Dimension (_ : s) n = Dimension s (n - 1)
Dimension _ _ = L.TypeError ('Text "dimension overflow")
minimum :: [Int] -> Int
minimum :: [Int] -> Int
minimum [] = forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"dimension underflow")
minimum [Int
x] = Int
x
minimum (Int
x : [Int]
xs) = forall a. Ord a => a -> a -> a
P.min Int
x ([Int] -> Int
minimum [Int]
xs)
type family Minimum (s :: [Nat]) :: Nat where
Minimum '[] = L.TypeError ('Text "zero dimension")
Minimum '[x] = x
Minimum (x : xs) = If (x <=? Minimum xs) x (Minimum xs)
type family Take (n :: Nat) (a :: [k]) :: [k] where
Take 0 _ = '[]
Take n (x : xs) = x : Take (n - 1) xs
type family Drop (n :: Nat) (a :: [k]) :: [k] where
Drop 0 xs = xs
Drop n (_ : xs) = Drop (n - 1) xs
type family Tail (a :: [k]) :: [k] where
Tail '[] = L.TypeError ('Text "No tail")
Tail (_ : xs) = xs
type family Init (a :: [k]) :: [k] where
Init '[] = L.TypeError ('Text "No init")
Init '[_] = '[]
Init (x : xs) = x : Init xs
type family Head (a :: [k]) :: k where
Head '[] = L.TypeError ('Text "No head")
Head (x : _) = x
type family Last (a :: [k]) :: k where
Last '[] = L.TypeError ('Text "No last")
Last '[x] = x
Last (_ : xs) = Last xs
type family (a :: [k]) ++ (b :: [k]) :: [k] where
'[] ++ b = b
(a : as) ++ b = a : (as ++ b)
dropIndex :: [Int] -> Int -> [Int]
dropIndex :: [Int] -> Int -> [Int]
dropIndex [Int]
s Int
i = forall a. Int -> [a] -> [a]
take Int
i [Int]
s forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
drop (Int
i forall a. Additive a => a -> a -> a
+ Int
1) [Int]
s
type DropIndex s i = Take i s ++ Drop (i + 1) s
addIndex :: [Int] -> Int -> Int -> [Int]
addIndex :: [Int] -> Int -> Int -> [Int]
addIndex [Int]
s Int
i Int
d = forall a. Int -> [a] -> [a]
take Int
i [Int]
s forall a. [a] -> [a] -> [a]
++ (Int
d forall a. a -> [a] -> [a]
: forall a. Int -> [a] -> [a]
drop Int
i [Int]
s)
type AddIndex s i d = Take i s ++ (d : Drop i s)
type Reverse (a :: [k]) = ReverseGo a '[]
type family ReverseGo (a :: [k]) (b :: [k]) :: [k] where
ReverseGo '[] b = b
ReverseGo (a : as) b = ReverseGo as (a : b)
posRelative :: [Int] -> [Int]
posRelative :: [Int] -> [Int]
posRelative [Int]
as = forall a. [a] -> [a]
reverse (forall {a}.
(Subtractive a, Multiplicative a, Ord a) =>
[a] -> [a] -> [a]
go [] [Int]
as)
where
go :: [a] -> [a] -> [a]
go [a]
r [] = [a]
r
go [a]
r (a
x : [a]
xs) = [a] -> [a] -> [a]
go (a
x forall a. a -> [a] -> [a]
: [a]
r) ((\a
y -> forall a. a -> a -> Bool -> a
bool (a
y forall a. Subtractive a => a -> a -> a
- forall a. Multiplicative a => a
one) a
y (a
y forall a. Ord a => a -> a -> Bool
< a
x)) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [a]
xs)
type family PosRelative (s :: [Nat]) where
PosRelative s = PosRelativeGo s '[]
type family PosRelativeGo (r :: [Nat]) (s :: [Nat]) where
PosRelativeGo '[] r = Reverse r
PosRelativeGo (x : xs) r = PosRelativeGo (DecMap x xs) (x : r)
type family DecMap (x :: Nat) (ys :: [Nat]) :: [Nat] where
DecMap _ '[] = '[]
DecMap x (y : ys) = If (y + 1 <=? x) y (y - 1) : DecMap x ys
dropIndexes :: [Int] -> [Int] -> [Int]
dropIndexes :: [Int] -> [Int] -> [Int]
dropIndexes [Int]
s [Int]
i = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' [Int] -> Int -> [Int]
dropIndex [Int]
s ([Int] -> [Int]
posRelative [Int]
i)
type family DropIndexes (s :: [Nat]) (i :: [Nat]) where
DropIndexes s i = DropIndexesGo s (PosRelative i)
type family DropIndexesGo (s :: [Nat]) (i :: [Nat]) where
DropIndexesGo s '[] = s
DropIndexesGo s (i : is) = DropIndexesGo (DropIndex s i) is
addIndexes :: () => [Int] -> [Int] -> [Int] -> [Int]
addIndexes :: [Int] -> [Int] -> [Int] -> [Int]
addIndexes [Int]
as [Int]
xs = [Int] -> [Int] -> [Int] -> [Int]
addIndexesGo [Int]
as (forall a. [a] -> [a]
reverse ([Int] -> [Int]
posRelative (forall a. [a] -> [a]
reverse [Int]
xs)))
where
addIndexesGo :: [Int] -> [Int] -> [Int] -> [Int]
addIndexesGo [Int]
as' [] [Int]
_ = [Int]
as'
addIndexesGo [Int]
as' (Int
x : [Int]
xs') (Int
y : [Int]
ys') = [Int] -> [Int] -> [Int] -> [Int]
addIndexesGo ([Int] -> Int -> Int -> [Int]
addIndex [Int]
as' Int
x Int
y) [Int]
xs' [Int]
ys'
addIndexesGo [Int]
_ [Int]
_ [Int]
_ = forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"mismatched ranks")
type family AddIndexes (as :: [Nat]) (xs :: [Nat]) (ys :: [Nat]) where
AddIndexes as xs ys = AddIndexesGo as (Reverse (PosRelative (Reverse xs))) ys
type family AddIndexesGo (as :: [Nat]) (xs :: [Nat]) (ys :: [Nat]) where
AddIndexesGo as' '[] _ = as'
AddIndexesGo as' (x : xs') (y : ys') = AddIndexesGo (AddIndex as' x y) xs' ys'
AddIndexesGo _ _ _ = L.TypeError ('Text "mismatched ranks")
takeIndexes :: [Int] -> [Int] -> [Int]
takeIndexes :: [Int] -> [Int] -> [Int]
takeIndexes [Int]
s [Int]
i = ([Int]
s !!) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int]
i
type family TakeIndexes (s :: [Nat]) (i :: [Nat]) where
TakeIndexes '[] _ = '[]
TakeIndexes _ '[] = '[]
TakeIndexes s (i : is) =
(s !! i) ': TakeIndexes s is
type family (a :: [k]) !! (b :: Nat) :: k where
(!!) '[] _ = L.TypeError ('Text "Index Underflow")
(!!) (x : _) 0 = x
(!!) (_ : xs) i = (!!) xs (i - 1)
type family Enumerate (n :: Nat) where
Enumerate n = Reverse (EnumerateGo n)
type family EnumerateGo (n :: Nat) where
EnumerateGo 0 = '[]
EnumerateGo n = (n - 1) : EnumerateGo (n - 1)
exclude :: Int -> [Int] -> [Int]
exclude :: Int -> [Int] -> [Int]
exclude Int
r = [Int] -> [Int] -> [Int]
dropIndexes [Int
0 .. (Int
r forall a. Subtractive a => a -> a -> a
- Int
1)]
type family Exclude (r :: Nat) (i :: [Nat]) where
Exclude r i = DropIndexes (EnumerateGo r) i
concatenate' :: Int -> [Int] -> [Int] -> [Int]
concatenate' :: Int -> [Int] -> [Int] -> [Int]
concatenate' Int
i [Int]
s0 [Int]
s1 = forall a. Int -> [a] -> [a]
take Int
i [Int]
s0 forall a. [a] -> [a] -> [a]
++ ([Int] -> Int -> Int
dimension [Int]
s0 Int
i forall a. Additive a => a -> a -> a
+ [Int] -> Int -> Int
dimension [Int]
s1 Int
i forall a. a -> [a] -> [a]
: forall a. Int -> [a] -> [a]
drop (Int
i forall a. Additive a => a -> a -> a
+ Int
1) [Int]
s0)
type Concatenate i s0 s1 = Take i s0 ++ (Dimension s0 i + Dimension s1 i : Drop (i + 1) s0)
type CheckConcatenate i s0 s1 s =
( CheckIndex i (Rank s0)
&& DropIndex s0 i == DropIndex s1 i
&& Rank s0 == Rank s1
)
~ 'True
type CheckInsert d i s =
(CheckIndex d (Rank s) && CheckIndex i (Dimension s d)) ~ 'True
type Insert d s = Take d s ++ (Dimension s d + 1 : Drop (d + 1) s)
incAt :: Int -> [Int] -> [Int]
incAt :: Int -> [Int] -> [Int]
incAt Int
d [Int]
s = forall a. Int -> [a] -> [a]
take Int
d [Int]
s forall a. [a] -> [a] -> [a]
++ ([Int] -> Int -> Int
dimension [Int]
s Int
d forall a. Additive a => a -> a -> a
+ Int
1 forall a. a -> [a] -> [a]
: forall a. Int -> [a] -> [a]
drop (Int
d forall a. Additive a => a -> a -> a
+ Int
1) [Int]
s)
decAt :: Int -> [Int] -> [Int]
decAt :: Int -> [Int] -> [Int]
decAt Int
d [Int]
s = forall a. Int -> [a] -> [a]
take Int
d [Int]
s forall a. [a] -> [a] -> [a]
++ ([Int] -> Int -> Int
dimension [Int]
s Int
d forall a. Subtractive a => a -> a -> a
- Int
1 forall a. a -> [a] -> [a]
: forall a. Int -> [a] -> [a]
drop (Int
d forall a. Additive a => a -> a -> a
+ Int
1) [Int]
s)
reorder' :: [Int] -> [Int] -> [Int]
reorder' :: [Int] -> [Int] -> [Int]
reorder' [] [Int]
_ = []
reorder' [Int]
_ [] = []
reorder' [Int]
s (Int
d : [Int]
ds) = [Int] -> Int -> Int
dimension [Int]
s Int
d forall a. a -> [a] -> [a]
: [Int] -> [Int] -> [Int]
reorder' [Int]
s [Int]
ds
type family Reorder (s :: [Nat]) (ds :: [Nat]) :: [Nat] where
Reorder '[] _ = '[]
Reorder _ '[] = '[]
Reorder s (d : ds) = Dimension s d : Reorder s ds
type family CheckReorder (ds :: [Nat]) (s :: [Nat]) where
CheckReorder ds s =
If
( Rank ds == Rank s
&& CheckIndexes ds (Rank s)
)
'True
(L.TypeError ('Text "bad dimensions"))
~ 'True
squeeze' :: (Eq a, Multiplicative a) => [a] -> [a]
squeeze' :: forall a. (Eq a, Multiplicative a) => [a] -> [a]
squeeze' = forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Eq a => a -> a -> Bool
/= forall a. Multiplicative a => a
one)
type family Squeeze (a :: [Nat]) where
Squeeze '[] = '[]
Squeeze a = Filter '[] a 1
type family Filter (r :: [Nat]) (xs :: [Nat]) (i :: Nat) where
Filter r '[] _ = Reverse r
Filter r (x : xs) i = Filter (If (x == i) r (x : r)) xs i
type family Sort (xs :: [k]) :: [k] where
Sort '[] = '[]
Sort (x ': xs) = (Sort (SFilter 'FMin x xs) ++ '[x]) ++ Sort (SFilter 'FMax x xs)
data Flag = FMin | FMax
type family Cmp (a :: k) (b :: k) :: Ordering
type family SFilter (f :: Flag) (p :: k) (xs :: [k]) :: [k] where
SFilter f p '[] = '[]
SFilter 'FMin p (x ': xs) = If (Cmp x p == 'LT) (x ': SFilter 'FMin p xs) (SFilter 'FMin p xs)
SFilter 'FMax p (x ': xs) = If (Cmp x p == 'GT || Cmp x p == 'EQ) (x ': SFilter 'FMax p xs) (SFilter 'FMax p xs)
type family Zip lst lst' where
Zip lst lst' = ZipWith '(,) lst lst'
type family ZipWith f lst lst' where
ZipWith f '[] lst = '[]
ZipWith f lst '[] = '[]
ZipWith f (l ': ls) (n ': ns) = f l n ': ZipWith f ls ns
type family Fst a where
Fst '(a, _) = a
type family Snd a where
Snd '(_, a) = a
type family FMap f lst where
FMap f '[] = '[]
FMap f (l ': ls) = f l ': FMap f ls
class KnownNats (ns :: [Nat]) where
natVals :: Proxy ns -> [Int]
instance KnownNats '[] where
natVals :: Proxy '[] -> [Int]
natVals Proxy '[]
_ = []
instance (KnownNat n, KnownNats ns) => KnownNats (n : ns) where
natVals :: Proxy (n : ns) -> [Int]
natVals Proxy (n : ns)
_ = forall a. FromInteger a => Integer -> a
fromInteger (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n)) forall a. a -> [a] -> [a]
: forall (ns :: [Nat]). KnownNats ns => Proxy ns -> [Int]
natVals (forall {k} (t :: k). Proxy t
Proxy @ns)
class KnownNatss (ns :: [[Nat]]) where
natValss :: Proxy ns -> [[Int]]
instance KnownNatss '[] where
natValss :: Proxy '[] -> [[Int]]
natValss Proxy '[]
_ = []
instance (KnownNats n, KnownNatss ns) => KnownNatss (n : ns) where
natValss :: Proxy (n : ns) -> [[Int]]
natValss Proxy (n : ns)
_ = forall (ns :: [Nat]). KnownNats ns => Proxy ns -> [Int]
natVals (forall {k} (t :: k). Proxy t
Proxy @n) forall a. a -> [a] -> [a]
: forall (ns :: [[Nat]]). KnownNatss ns => Proxy ns -> [[Int]]
natValss (forall {k} (t :: k). Proxy t
Proxy @ns)