{-# LANGUAGE TypeFamilies          #-}

{-# LANGUAGE MultiParamTypeClasses #-}

{-# LANGUAGE TypeSynonymInstances  #-}           

{-# LANGUAGE DataKinds             #-}

{-# LANGUAGE KindSignatures        #-}

{-# LANGUAGE TypeOperators         #-}

{-# LANGUAGE FlexibleInstances     #-}

{-# LANGUAGE FlexibleContexts      #-}           

{-# LANGUAGE ScopedTypeVariables   #-}

{-# LANGUAGE DeriveGeneric         #-}

{-# LANGUAGE DeriveAnyClass        #-}

{-# LANGUAGE TemplateHaskell       #-}

{-# LANGUAGE StandaloneDeriving    #-}           

{-# LANGUAGE UndecidableInstances  #-}

{-# LANGUAGE TypeApplications      #-}

{-# LANGUAGE TypeInType            #-}

{-# LANGUAGE AllowAmbiguousTypes   #-}

{-# LANGUAGE ConstraintKinds       #-}



-----------------------------------------------------------------------------

-- |

-- Module      :  Data.Vector.Static

-- Copyright   :  (C) 2017 Alexey Vagarenko

-- License     :  BSD-style (see LICENSE)

-- Maintainer  :  Alexey Vagarenko (vagarenko@gmail.com)

-- Stability   :  experimental

-- Portability :  non-portable

--

----------------------------------------------------------------------------



module Data.Vector.Static (

    -- * Vector types

      Vector

    , VectorConstructor

    , IsVector

    -- ** Construction

    , vector

    -- ** Vector Operations

    , normalize

    , NormalizedVector

    , Normalize

    , unNormalizedVector

    , norm

    , vectorLenSquare

    , VectorLenSquare

    , vectorLen

    , VectorLen

    , dot

    , Dot

    , cross

    , toHomogenous

    , fromHomogenous

    -- ** Generating vector instances

    , genVectorInstance

) where

  

import Data.Containers       (MonoZip(..))

import Data.MonoTraversable  (omap, osum)

import Data.Tensor.Static    ( IsTensor(..), scale, Scale, TensorConstructor, withTensor

                             , MonoFunctorCtx, MonoFoldableCtx, MonoZipCtx)

import Data.Tensor.Static.TH (genTensorInstance)

import GHC.Generics          (Generic)

import GHC.TypeLits          (Nat)

import Language.Haskell.TH   (Q, Name, Dec)

import qualified Data.List.NonEmpty as N



---------------------------------------------------------------------------------------------------

-- | N-dimensional vector.

type Vector n e = Tensor '[n] e



-- | Type of vector data constructor.

type VectorConstructor n e = TensorConstructor '[n] e



-- | Vector constraint.

type IsVector n e = IsTensor '[n] e



-- | Normalized vector.

newtype NormalizedVector n e =

    NormalizedVector

        { unNormalizedVector :: Vector n e

            -- ^ unwrap 'NormalizedVector'. Note: this does not give you original vector back.

            --

            -- @unNormalizedVector . normalize /= id@

            --

        }

    deriving (Generic)

    

deriving instance (Eq   (Vector n e)) => Eq   (NormalizedVector n e)

deriving instance (Show (Vector n e)) => Show (NormalizedVector n e)



---------------------------------------------------------------------------------------------------

-- | Alias for a conrete vector data constructor.

vector :: forall n e. (IsVector n e) => VectorConstructor n e

vector = tensor @'[n] @e

{-# INLINE vector #-}



-- | Get square of length of a vector.

vectorLenSquare :: (VectorLenSquare n e) => Vector n e -> e

vectorLenSquare = osum . omap (\x -> x * x)

{-# INLINE vectorLenSquare #-}



-- | Constraints for 'vectorLenSquare' function.

type VectorLenSquare (n :: Nat) e =

    ( Num e

    , IsVector n e

    , MonoFunctorCtx '[n] e

    , MonoFoldableCtx '[n] e

    )



-- | Get length of a vector.    

vectorLen :: (VectorLen n e) => Vector n e -> e

vectorLen = sqrt . vectorLenSquare

{-# INLINE vectorLen #-}



-- | Constraints for 'vectorLen' function.

type VectorLen (n :: Nat) e =

    ( Floating e

    , VectorLenSquare n e

    )



-- | Normalize vector.

normalize :: (Normalize n e) => Vector n e -> NormalizedVector n e

normalize v = NormalizedVector $ scale v (1 / vectorLen v)

{-# INLINE normalize #-}



-- | Constraints for 'normalize' function.

type Normalize (n :: Nat) e =

    ( VectorLen n e

    , Scale '[n] e

    )



-- | Normalize vector but don't wrap it in 'NormalizedVector'.

norm :: (Normalize n e) => Vector n e -> Vector n e

norm = unNormalizedVector . normalize

{-# INLINE norm #-}



-- | Dot product of two vectors.

dot :: (Dot n e) => Vector n e -> Vector n e -> e

dot v1 v2 = osum $ ozipWith (*) v1 v2

{-# INLINE dot #-}



-- | Constraints for 'dot' function.

type Dot (n :: Nat) e =

    ( Num e

    , IsVector n e

    , MonoFunctorCtx '[n] e

    , MonoFoldableCtx '[n] e

    , MonoZipCtx '[n] e

    )



---------------------------------------------------------------------------------------------------

-- | Cross product is only defined for 3-dimensional vectors.

cross :: (Num e, IsVector 3 e) => Vector 3 e -> Vector 3 e -> Vector 3 e

cross v0 v1 =

    withTensor v0 $ \x0 y0 z0 ->

        withTensor v1 $ \x1 y1 z1 ->

            vector @3 (y0 * z1 - z0 * y1) (z0 * x1 - x0 * z1) (x0 * y1 - y0 * x1)

{-# INLINE cross #-}



---------------------------------------------------------------------------------------------------

-- | Convert 3-dimensional vector to 4-dimensional vector by setting the last element to @1@.

toHomogenous :: (Num e, IsVector 3 e, IsVector 4 e) => Vector 3 e -> Vector 4 e

toHomogenous v = withTensor v $ \x y z -> vector @4 x y z 1

{-# INLINE toHomogenous #-}



-- | Convert 4-dimensional vector to 3-dimensional vector by dividing first 3 coords by the last.

--   The last element must not be zero!

fromHomogenous :: (Fractional e, IsVector 3 e, IsVector 4 e) => Vector 4 e -> Vector 3 e

fromHomogenous v = withTensor v $ \x y z w -> scale (vector @3 x y z) (1 / w)

{-# INLINE fromHomogenous #-}



---------------------------------------------------------------------------------------------------

-- | Generate instance of a vector.

genVectorInstance :: Int       -- ^ Size of the vector.

                  -> Name      -- ^ Type of elements.

                  -> Q [Dec]

genVectorInstance size elemTypeName = genTensorInstance (N.fromList [size]) elemTypeName