module MathFlow.Core where
import GHC.TypeLits
import Data.Singletons
import Data.Singletons.TH
import Data.Promotion.Prelude
type family IsSubSamp (f :: [Nat]) (m :: [Nat]) (n :: [Nat]) :: Bool where
IsSubSamp (1:fs) (m:ms) (n:ns) = IsSubSamp fs ms ns
IsSubSamp (f:fs) (m:ms) (n:ns) = ((n * f) :== m) :&& (IsSubSamp fs ms ns)
IsSubSamp '[] '[] '[] = 'True
IsSubSamp _ _ _ = 'False
type family IsMatMul (m :: [Nat]) (o :: [Nat]) (n :: [Nat]) :: Bool where
IsMatMul m o n =
Last n :== Last o :&&
Last m :== Head (Tail (Reverse o)) :&&
(Tail (Reverse n)) :== (Tail (Reverse m)) :&&
(Tail (Tail (Reverse n))) :== (Tail (Tail (Reverse o)))
type family IsConcat (m :: [Nat]) (o :: [Nat]) (n :: [Nat]) :: Bool where
IsConcat (m:mx) (o:ox) (n:nx) = (m :== o :&& m:== n :|| m + o :== n) :&& IsConcat mx ox nx
IsConcat '[] '[] '[] = 'True
IsConcat _ _ _ = 'False
type family IsSameProduct (m :: [Nat]) (n :: [Nat]) :: Bool where
IsSameProduct (m:mx) (n:nx) = m :== n :&& (Product mx :== Product nx)
IsSameProduct mx nx = Product mx :== Product nx
data Tensor (n::[Nat]) t a =
(Num t) => TScalar t
| Tensor a
| TAdd (Tensor n t a) (Tensor n t a)
| TSub (Tensor n t a) (Tensor n t a)
| TMul (Tensor n t a) (Tensor n t a)
| TAbs (Tensor n t a)
| TSign (Tensor n t a)
| TRep (Tensor (Tail n) t a)
| TTr (Tensor (Reverse n) t a)
| forall o m. (SingI o,SingI m,SingI n,IsMatMul m o n ~ 'True) => TMatMul (Tensor m t a) (Tensor o t a)
| forall o m. (SingI o,SingI m,SingI n,IsConcat m o n ~ 'True) => TConcat (Tensor m t a) (Tensor o t a)
| forall m. (SingI m,IsSameProduct m n ~ 'True) => TReshape (Tensor m t a)
| forall o m.
(SingI o,SingI m,
Last n ~ Last o,
Last m ~ Head (Tail (Reverse o)),
(Tail (Reverse n)) ~ (Tail (Reverse m))
) =>
TConv2d (Tensor m t a) (Tensor o t a)
| forall f m. (SingI f, SingI m,IsSubSamp f m n ~ 'True) => TMaxPool (Sing f) (Tensor m t a)
| TSoftMax (Tensor n t a)
| TReLu (Tensor n t a)
| TNorm (Tensor n t a)
| forall f m. (SingI f,SingI m,IsSubSamp f m n ~ 'True) => TSubSamp (Sing f) (Tensor m t a)
| forall m t2. TApp (Tensor n t a) (Tensor m t2 a)
| TFunc String (Tensor n t a)
| TSym String
| TArgT String (Tensor n t a)
| TArgS String String
| TArgI String Integer
| TArgF String Float
| TArgD String Double
| forall f. (SingI f) => TArgSing String (Sing (f::[Nat]))
| TLabel String (Tensor n t a)
(<+>) :: forall n t a m t2. (Tensor n t a) -> (Tensor m t2 a) -> (Tensor n t a)
(<+>) = TApp
infixr 4 <+>
instance (Num t) => Num (Tensor n t a) where
(+) = TAdd
() = TSub
(*) = TMul
abs = TAbs
signum = TSign
fromInteger = TScalar . fromInteger
class Dimension a where
dim :: a -> [Integer]
instance (SingI n) => Dimension (Tensor n t a) where
dim t = dim $ ty t
where
ty :: (SingI n) => Tensor n t a -> Sing n
ty _ = sing
instance Dimension (Sing (n::[Nat])) where
dim t = fromSing t
toValue :: forall n t a. Sing (n::[Nat]) -> a -> Tensor n t a
toValue _ a = Tensor a
(%*) :: forall o m n t a. (SingI o,SingI m,SingI n,IsMatMul m o n ~ 'True)
=> Tensor m t a -> Tensor o t a -> Tensor n t a
(%*) a b = TMatMul a b
(<--) :: SingI n => String -> Tensor n t a -> Tensor n t a
(<--) = TLabel
class FromTensor a where
fromTensor :: Tensor n t a -> a
toString :: Tensor n t a -> String
run :: Tensor n t a -> IO (Int,String,String)