{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE TypeInType #-}

{-# LANGUAGE OverloadedStrings #-}


module MathFlow.Core where

import GHC.TypeLits
import Data.Singletons
import Data.Singletons.TH
import Data.Promotion.Prelude

-- |IsSubSamp // Subsampling constraint
--
-- * (f :: [Nat]) // strides for subsampling
-- * (m :: [Nat]) // dimensions of original tensor 
-- * (n :: [Nat]) // dimensions of subsampled tensor 
-- * :: Bool
type family IsSubSamp (f :: [Nat]) (m :: [Nat]) (n :: [Nat]) :: Bool where
  IsSubSamp (1:fs) (m:ms) (n:ns) = IsSubSamp fs ms ns
  IsSubSamp (f:fs) (m:ms) (n:ns) = ((n * f) :== m) :&& (IsSubSamp fs ms ns)
  IsSubSamp '[] '[] '[] = 'True
  IsSubSamp _ _ _ = 'False

-- |IsMatMul // A constraint for matrix multiplication
--
-- * (m :: [Nat]) // dimensions of a[..., i, k] 
-- * (o :: [Nat]) // dimensions of b[..., k, j]
-- * (n :: [Nat]) // dimensions of output[..., i, j] = sum_k (a[..., i, k] * b[..., k, j]), for all indices i, j.
-- * :: Bool
type family IsMatMul (m :: [Nat]) (o :: [Nat]) (n :: [Nat]) :: Bool where
  IsMatMul m o n =
    Last n :== Last o :&&
    Last m :== Head (Tail (Reverse o)) :&&
    (Tail (Reverse n)) :== (Tail (Reverse m)) :&&
    (Tail (Tail (Reverse n))) :== (Tail (Tail (Reverse o)))

-- |IsConcat // A constraint for concatination of tensor 
--
-- * (m :: [Nat]) // dimensions of a[..., i, ...] 
-- * (o :: [Nat]) // dimensions of b[..., k, ...]
-- * (n :: [Nat]) // dimensions of output[..., i+k, ...] = concat (a,b) 
-- * :: Bool
type family IsConcat (m :: [Nat]) (o :: [Nat]) (n :: [Nat]) :: Bool where
  IsConcat (m:mx) (o:ox) (n:nx) = (m :== o :&& m:== n :|| m + o :== n) :&& IsConcat mx ox nx
  IsConcat '[] '[] '[] = 'True
  IsConcat _ _ _ = 'False

-- |IsSameProduct // A constraint for reshaping tensor
--
-- * (m :: [Nat]) // dimensions of original tensor
-- * (n :: [Nat]) // dimensions of reshaped tensor
-- * :: Bool
type family IsSameProduct (m :: [Nat]) (n :: [Nat]) :: Bool where
  IsSameProduct (m:mx) (n:nx) = m :== n :&& (Product mx :== Product nx)
  IsSameProduct mx nx = Product mx :== Product nx


-- |Dependently typed tensor model
--
-- This model includes basic arithmetic operators and tensorflow functions.
data Tensor (n::[Nat]) t a =
    (Num t) => TScalar t -- ^ Scalar value
  | Tensor a -- ^ Transform a value to dependently typed value
  | TAdd (Tensor n t a) (Tensor n t a) -- ^ + of Num
  | TSub (Tensor n t a) (Tensor n t a) -- ^ - of Num
  | TMul (Tensor n t a) (Tensor n t a) -- ^ * of Num
  | TAbs (Tensor n t a) -- ^ abs of Num
  | TSign (Tensor n t a) -- ^ signum of Num
  | TRep (Tensor (Tail n) t a) -- ^ vector wise operator
  | TTr (Tensor (Reverse n) t a) -- ^ tensor tansporse operator
  | forall o m. (SingI o,SingI m,SingI n,IsMatMul m o n ~ 'True) => TMatMul (Tensor m t a) (Tensor o t a) -- ^ matrix multiply
  | forall o m. (SingI o,SingI m,SingI n,IsConcat m o n ~ 'True) => TConcat (Tensor m t a) (Tensor o t a) -- ^ concat operator
  | forall m. (SingI m,IsSameProduct m n ~ 'True) => TReshape (Tensor m t a) -- ^ reshape function
  | forall o m.
    (SingI o,SingI m,
     Last n ~ Last o,
     Last m ~ Head (Tail (Reverse o)),
     (Tail (Reverse n)) ~ (Tail (Reverse m))
    ) =>
    TConv2d (Tensor m t a) (Tensor o t a) -- ^ conv2d function
  | forall f m. (SingI f, SingI m,IsSubSamp f m n ~ 'True) => TMaxPool (Sing f) (Tensor m t a) -- ^ max pool
  | TSoftMax (Tensor n t a)
  | TReLu (Tensor n t a)
  | TNorm (Tensor n t a)
  | forall f m. (SingI f,SingI m,IsSubSamp f m n ~ 'True) => TSubSamp (Sing f) (Tensor m t a) -- ^ subsampling function
  | forall m t2. TApp (Tensor n t a) (Tensor m t2 a)
  | TFunc String (Tensor n t a)
  | TSym String
  | TArgT String (Tensor n t a)
  | TArgS String String
  | TArgI String Integer
  | TArgF String Float
  | TArgD String Double
  | forall f. (SingI f) => TArgSing String (Sing (f::[Nat]))
  | TLabel String (Tensor n t a) -- ^ When generating code, this label is used.

(<+>) :: forall n t a m t2. (Tensor n t a) -> (Tensor m t2 a) -> (Tensor n t a)
(<+>) = TApp

infixr 4 <+>

instance (Num t) => Num (Tensor n t a) where
  (+) = TAdd
  (-) = TSub
  (*) = TMul
  abs = TAbs
  signum = TSign
  fromInteger = TScalar . fromInteger


-- | get dimension from tensor
-- 
-- >>> dim (Tensor 1 :: Tensor '[192,10] Float Int)
-- [192,10]
class Dimension a where
  dim :: a -> [Integer]

instance (SingI n) => Dimension (Tensor n t a) where
  dim t = dim $ ty t
    where
      ty :: (SingI n) => Tensor n t a -> Sing n
      ty _ = sing

instance Dimension (Sing (n::[Nat])) where
  dim t = fromSing t

toValue :: forall n t a. Sing (n::[Nat]) -> a -> Tensor n t a
toValue _ a = Tensor a

(%*) :: forall o m n t a. (SingI o,SingI m,SingI n,IsMatMul m o n ~ 'True)
     => Tensor m t a -> Tensor o t a -> Tensor n t a
(%*) a b = TMatMul a b

(<--) :: SingI n => String -> Tensor n t a  -> Tensor n t a 
(<--) = TLabel


class FromTensor a where
  fromTensor :: Tensor n t a -> a
  toString :: Tensor n t a -> String
  run :: Tensor n t a -> IO (Int,String,String)