module Data.Basis (HasBasis(..), linearCombo, recompose) where
import Control.Arrow (first)
import Data.Ratio
import Foreign.C.Types (CFloat, CDouble)
import Data.VectorSpace
import Data.VectorSpace.Generic
import qualified GHC.Generics as Gnrx
import GHC.Generics (Generic, (:*:)(..))
class VectorSpace v => HasBasis v where
type Basis v :: *
type Basis v = Basis (VRep v)
basisValue :: Basis v -> v
default basisValue :: (Generic v, HasBasis (VRep v), Basis (VRep v) ~ Basis v)
=> Basis v -> v
basisValue b = Gnrx.to (basisValue b :: VRep v)
decompose :: v -> [(Basis v, Scalar v)]
default decompose :: ( Generic v, HasBasis (VRep v)
, Scalar (VRep v) ~ Scalar v, Basis (VRep v) ~ Basis v )
=> v -> [(Basis v, Scalar v)]
decompose v = decompose (Gnrx.from v :: VRep v)
decompose' :: v -> (Basis v -> Scalar v)
default decompose' :: ( Generic v, HasBasis (VRep v)
, Scalar (VRep v) ~ Scalar v, Basis (VRep v) ~ Basis v )
=> v -> Basis v -> Scalar v
decompose' v = decompose' (Gnrx.from v :: VRep v)
recompose :: HasBasis v => [(Basis v, Scalar v)] -> v
recompose = linearCombo . fmap (first basisValue)
#define ScalarTypeCon(con,t) \
instance con => HasBasis (t) where \
{ type Basis (t) = () \
; basisValue () = 1 \
; decompose s = [((),s)] \
; decompose' s = const s }
#define ScalarType(t) ScalarTypeCon((),t)
ScalarType(Float)
ScalarType(CFloat)
ScalarType(Double)
ScalarType(CDouble)
ScalarTypeCon(Integral a, Ratio a)
instance ( HasBasis u, s ~ Scalar u
, HasBasis v, s ~ Scalar v )
=> HasBasis (u,v) where
type Basis (u,v) = Basis u `Either` Basis v
basisValue (Left a) = (basisValue a, zeroV)
basisValue (Right b) = (zeroV, basisValue b)
decompose (u,v) = decomp2 Left u ++ decomp2 Right v
decompose' (u,v) = decompose' u `either` decompose' v
decomp2 :: HasBasis w => (Basis w -> b) -> w -> [(b, Scalar w)]
decomp2 inject = fmap (first inject) . decompose
instance ( HasBasis u, s ~ Scalar u
, HasBasis v, s ~ Scalar v
, HasBasis w, s ~ Scalar w )
=> HasBasis (u,v,w) where
type Basis (u,v,w) = Basis (u,(v,w))
basisValue = unnest3 . basisValue
decompose = decompose . nest3
decompose' = decompose' . nest3
unnest3 :: (a,(b,c)) -> (a,b,c)
unnest3 (a,(b,c)) = (a,b,c)
nest3 :: (a,b,c) -> (a,(b,c))
nest3 (a,b,c) = (a,(b,c))
instance HasBasis a => HasBasis (Gnrx.Rec0 a s) where
type Basis (Gnrx.Rec0 a s) = Basis a
basisValue = Gnrx.K1 . basisValue
decompose = decompose . Gnrx.unK1
decompose' = decompose' . Gnrx.unK1
instance HasBasis (f p) => HasBasis (Gnrx.M1 i c f p) where
type Basis (Gnrx.M1 i c f p) = Basis (f p)
basisValue = Gnrx.M1 . basisValue
decompose = decompose . Gnrx.unM1
decompose' = decompose' . Gnrx.unM1
instance (HasBasis (f p), HasBasis (g p), Scalar (f p) ~ Scalar (g p))
=> HasBasis ((f :*: g) p) where
type Basis ((f:*:g) p) = Either (Basis (f p)) (Basis (g p))
basisValue (Left bf) = basisValue bf :*: zeroV
basisValue (Right bg) = zeroV :*: basisValue bg
decompose (u:*:v) = decomp2 Left u ++ decomp2 Right v
decompose' (u:*:v) = decompose' u `either` decompose' v