{-# language GADTs #-} module Data.Tensor.Compiler.MultiStage where {-| * inner product of two vectors * matrix-vector action * matrix-matrix product -} -- | "User-facing" syntax data E1 a = K1 a | Dot a a -- (E1 a) (E1 a) deriving (Eq, Show) -- | Evaluate user-facing syntax into internal one evalE1 :: E1 a -> E2 a evalE1 expr = case expr of K1 x -> K2 x Dot a b -> Fold Add (ZipWith Mul (K2 a) (K2 b)) -- k1 :: a -> E1 a -- k1 = K1 -- dot :: a -> a -> E1 a -- dot = Dot -- | Internal representation data E2 a = K2 a | ZipWith BinOp (E2 a) (E2 a) | Fold BinOp (E2 a) deriving (Eq, Show) -- | Evaluate the internal representation into a concrete result. NB: the internal and returned type, [b], is constrained by the evaluation functions `zipWith`/`foldr`. In general this would be an opaque tensor data structure evalE2 :: Fractional b => b -> E2 [b] -> [b] evalE2 z expr = case expr of K2 x -> x ZipWith op v1 v2 -> zipWith (evalBinOp op) (evalE2 z v1) (evalE2 z v2) Fold op v -> [foldr (evalBinOp op) z (evalE2 z v)] -- | E2 combinators k2 :: a -> E2 a k2 = K2 addE2 :: a -> a -> E2 a addE2 a b = ZipWith Add (k2 a) (k2 b) mulE2 :: a -> a -> E2 a mulE2 a b = ZipWith Mul (k2 a) (k2 b) -- | Chained E1 -> E2 evaluation evalE :: Fractional b => b -> E1 [b] -> [b] evalE z = evalE2 z . evalE1 -- data UnOp = Sqrt deriving (Eq, Show) data BinOp = Add | Sub | Mul | Div deriving (Eq, Show) evalBinOp :: Fractional a => BinOp -> a -> a -> a evalBinOp op = case op of Add -> (+) Sub -> (-) Mul -> (*) Div -> (/) -- v1, v2 :: E1 [Int] -- v1 = k [1..5] -- v2 = k [3..7]