module Quantum.Synthesis.Matrix where
import Quantum.Synthesis.Ring
data Zero
data Succ a
type One = Succ Zero
type Two = Succ One
type Three = Succ Two
type Four = Succ Three
type Five = Succ Four
type Six = Succ Five
type Seven = Succ Six
type Eight = Succ Seven
type Nine = Succ Eight
type Ten = Succ Nine
type Ten_and a = Succ (Succ (Succ (Succ (Succ (Succ (Succ (Succ (Succ (Succ a)))))))))
data NNat :: * -> * where
Zero :: NNat Zero
Succ :: (Nat n) => NNat n -> NNat (Succ n)
fromNNat :: NNat n -> Integer
fromNNat Zero = 0
fromNNat (Succ n) = 1 + fromNNat n
instance Show (NNat n) where
show = show . fromNNat
class Nat n where
nnat :: NNat n
nat :: n -> Integer
instance Nat Zero where
nnat = Zero
nat n = 0
instance (Nat a) => Nat (Succ a) where
nnat = Succ nnat
nat n = 1 + nat (un n) where
un :: Succ a -> a
un = undefined
type family Plus n m
type instance Zero `Plus` m = m
type instance (Succ n) `Plus` m = Succ (n `Plus` m)
type family Times n m
type instance Zero `Times` m = Zero
type instance (Succ n) `Times` m = m `Plus` (n `Times` m)
data Vector :: * -> * -> * where
Nil :: Vector Zero a
Cons :: !a -> !(Vector n a) -> Vector (Succ n) a
infixr 5 `Cons`
instance (Eq a) => Eq (Vector n a) where
Nil == Nil = True
Cons a as == Cons b bs = a == b && as == bs
instance (Show a) => Show (Vector n a) where
showsPrec d x = showParen (d >= 11) $ showString ("vector " ++ show (list_of_vector x))
instance (ToDyadic a b) => ToDyadic (Vector n a) (Vector n b) where
maybe_dyadic as = vector_sequence (vector_map maybe_dyadic as)
instance (WholePart a b) => WholePart (Vector n a) (Vector n b) where
from_whole = vector_map from_whole
to_whole = vector_map to_whole
instance (DenomExp a) => DenomExp (Vector n a) where
denomexp as = denomexp (list_of_vector as)
denomexp_factor as k = vector_map (\a -> denomexp_factor a k) as
vector_singleton :: a -> Vector One a
vector_singleton x = x `Cons` Nil
vector_length :: (Nat n) => Vector n a -> Integer
vector_length = nat . un where
un :: Vector n a -> n
un = undefined
list_of_vector :: Vector n a -> [a]
list_of_vector Nil = []
list_of_vector (Cons h t) = h : list_of_vector t
vector_zipwith :: (a -> b -> c) -> Vector n a -> Vector n b -> Vector n c
vector_zipwith f Nil Nil = Nil
vector_zipwith f (Cons a as) (Cons b bs) = Cons c cs where
c = f a b
cs = vector_zipwith f as bs
vector_map :: (a -> b) -> Vector n a -> Vector n b
vector_map f Nil = Nil
vector_map f (Cons a as) = Cons (f a) (vector_map f as)
vector_enum :: (Num a, Nat n) => Vector n a
vector_enum = aux nnat 0 where
aux :: (Num a) => NNat n -> a -> Vector n a
aux Zero a = Nil
aux (Succ n) a = Cons a (aux n (a+1))
vector_of_function :: (Num a, Nat n) => (a -> b) -> Vector n b
vector_of_function f = vector_map f vector_enum
vector :: (Nat n) => [a] -> Vector n a
vector = aux nnat where
aux :: NNat n -> [a] -> Vector n a
aux Zero [] = Nil
aux (Succ n) (h:t) = Cons h (aux n t)
aux _ _ = error "vector: length mismatch"
vector_index :: (Integral i) => Vector n a -> i -> a
vector_index v i = list_of_vector v !! fromIntegral i
vector_repeat :: (Nat n) => a -> Vector n a
vector_repeat x = vector_of_function (const x)
vector_transpose :: (Nat m) => Vector n (Vector m a) -> Vector m (Vector n a)
vector_transpose Nil = vector_repeat Nil
vector_transpose (Cons a as) = vector_zipwith Cons a (vector_transpose as)
vector_foldl :: (a -> b -> a) -> a -> Vector n b -> a
vector_foldl f x l = foldl f x (list_of_vector l)
vector_foldr :: (a -> b -> b) -> b -> Vector n a -> b
vector_foldr f x l = foldr f x (list_of_vector l)
vector_tail :: Vector (Succ n) a -> Vector n a
vector_tail (Cons h t) = t
vector_head :: Vector (Succ n) a -> a
vector_head (Cons h t) = h
vector_append :: Vector n a -> Vector m a -> Vector (n `Plus` m) a
vector_append Nil v = v
vector_append (Cons h t) v = Cons h (vector_append t v)
vector_sequence :: (Monad m) => Vector n (m a) -> m (Vector n a)
vector_sequence Nil = return Nil
vector_sequence (Cons a as) = do
a' <- a
as' <- vector_sequence as
return (Cons a' as')
data Matrix m n a = Matrix !(Vector n (Vector m a))
deriving (Eq)
instance (Nat m, Show a) => Show (Matrix m n a) where
showsPrec d m = showParen (d >= 11) $ showString ("matrix " ++ show (rows_of_matrix m))
instance (Nat m) => Show (Matrix m n DRootTwo) where
showsPrec = showsPrec_DenomExp
instance (Nat m) => Show (Matrix m n DRComplex) where
showsPrec = showsPrec_DenomExp
instance (Nat m) => Show (Matrix m n DOmega) where
showsPrec = showsPrec_DenomExp
instance (ToDyadic a b) => ToDyadic (Matrix m n a) (Matrix m n b) where
maybe_dyadic (Matrix a) = do
b <- maybe_dyadic a
return (Matrix b)
instance (WholePart a b) => WholePart (Matrix m n a) (Matrix m n b) where
from_whole (Matrix m) = Matrix (from_whole m)
to_whole (Matrix m) = Matrix (to_whole m)
instance (DenomExp a) => DenomExp (Matrix m n a) where
denomexp (Matrix m) = denomexp m
denomexp_factor (Matrix m) k = Matrix (denomexp_factor m k)
unMatrix :: Matrix m n a -> (Vector n (Vector m a))
unMatrix (Matrix m) = m
matrix_size :: (Nat m, Nat n) => Matrix m n a -> (Integer, Integer)
matrix_size op = (nat (m op), nat (n op)) where
m :: Matrix m n a -> m
m = undefined
n :: Matrix m n a -> n
n = undefined
(.+.) :: (Num a) => Matrix m n a -> Matrix m n a -> Matrix m n a
Matrix a .+. Matrix b = Matrix c where
c = vector_zipwith (vector_zipwith (+)) a b
infixl 6 .+.
(.-.) :: (Num a) => Matrix m n a -> Matrix m n a -> Matrix m n a
Matrix a .-. Matrix b = Matrix c where
c = vector_zipwith (vector_zipwith ()) a b
infixl 6 .-.
matrix_map :: (a -> b) -> Matrix m n a -> Matrix m n b
matrix_map f (Matrix a) = Matrix b where
b = vector_map (vector_map f) a
matrix_enum :: (Num a, Nat n, Nat m) => Matrix m n (a,a)
matrix_enum = Matrix (vector_of_function f) where
f i = vector_of_function (\j -> (j,i))
matrix_of_function :: (Num a, Nat n, Nat m) => (a -> a -> b) -> Matrix m n b
matrix_of_function f = matrix_map (uncurry f) matrix_enum
scalarmult :: (Num a) => a -> Matrix m n a -> Matrix m n a
scalarmult x m = matrix_map (x *) m
infixl 7 `scalarmult`
scalardiv :: (Fractional a) => Matrix m n a -> a -> Matrix m n a
scalardiv m x = matrix_map (/ x) m
infixl 7 `scalardiv`
(.*.) :: (Num a, Nat m) => Matrix m n a -> Matrix n p a -> Matrix m p a
Matrix a .*. Matrix b = Matrix c where
c = vector_map (a `mmv`) b
mmv :: (Num a, Nat m) => Vector n (Vector m a) -> Vector n a -> Vector m a
Nil `mmv` Nil = vector_repeat 0
(Cons h Nil) `mmv` (Cons k Nil) = k `msv` h
(Cons h t) `mmv` (Cons k s) = (k `msv` h) `avv` (t `mmv` s)
msv :: (Num b) => b -> Vector n b -> Vector n b
k `msv` h = vector_map (k*) h
avv :: (Num c) => Vector n c -> Vector n c -> Vector n c
v `avv` w = vector_zipwith (+) v w
infixl 7 .*.
null_matrix :: (Num a, Nat n, Nat m) => Matrix m n a
null_matrix = Matrix (vector_repeat (vector_repeat 0))
matrix_transpose :: (Nat m) => Matrix m n a -> Matrix n m a
matrix_transpose (Matrix a) = Matrix b where
b = vector_transpose a
adjoint :: (Nat m, Adjoint a) => Matrix m n a -> Matrix n m a
adjoint (Matrix a) = Matrix c where
b = vector_map (vector_map adj) a
c = vector_transpose b
matrix_index :: (Integral i) => Matrix m n a -> i -> i -> a
matrix_index (Matrix a) i j = a `vector_index` j `vector_index` i
matrix_entries :: Matrix m n a -> [a]
matrix_entries (Matrix m) =
concat $ map list_of_vector $ list_of_vector m
matrix_sequence :: (Monad m) => Matrix n p (m a) -> m (Matrix n p a)
matrix_sequence (Matrix m) = do
m' <- vector_sequence (vector_map vector_sequence m)
return (Matrix m')
tr :: (Ring a) => Matrix n n a -> a
tr (Matrix a) = aux a where
aux :: (Num a) => Vector n (Vector n a) -> a
aux Nil = 0
aux ((h `Cons` t) `Cons` s) = h + aux (vector_map vector_tail s)
hs_sqnorm :: (Ring a, Adjoint a, Nat n) => Matrix n m a -> a
hs_sqnorm m = tr (m .*. adjoint m)
instance (Num a, Nat n) => Num (Matrix n n a) where
(+) = (.+.)
(*) = (.*.)
negate = scalarmult (1)
() = (.-.)
fromInteger x = matrix_of_function (\i j -> if i == j then fromInteger x else 0)
abs a = a
signum a = 1
instance (Nat n, Adjoint a) => Adjoint (Matrix n n a) where
adj (Matrix a) = Matrix c where
b = vector_map (vector_map adj) a
c = vector_transpose b
instance (Nat n, Adjoint2 a) => Adjoint2 (Matrix n n a) where
adj2 (Matrix a) = Matrix b where
b = vector_map (vector_map adj2) a
instance (HalfRing a, Nat n) => HalfRing (Matrix n n a) where
half = scalarmult half 1
instance (RootHalfRing a, Nat n) => RootHalfRing (Matrix n n a) where
roothalf = scalarmult roothalf 1
instance (RootTwoRing a, Nat n) => RootTwoRing (Matrix n n a) where
roottwo = scalarmult roottwo 1
instance (ComplexRing a, Nat n) => ComplexRing (Matrix n n a) where
i = scalarmult i 1
stack_vertical :: Matrix m n a -> Matrix p n a -> Matrix (m `Plus` p) n a
stack_vertical (Matrix a) (Matrix b) = (Matrix c) where
c = vector_zipwith vector_append a b
stack_horizontal :: Matrix m n a -> Matrix m p a -> Matrix m (n `Plus` p) a
stack_horizontal (Matrix a) (Matrix b) = (Matrix c) where
c = vector_append a b
tensor_vertical :: (Num a, Nat n) => Vector p a -> Matrix m n a -> Matrix (p `Times` m) n a
tensor_vertical v m = concat_vertical (vector_map (`scalarmult` m) v)
concat_vertical :: (Num a, Nat n) => Vector p (Matrix m n a) -> Matrix (p `Times` m) n a
concat_vertical Nil = null_matrix
concat_vertical (Cons h t) = stack_vertical h (concat_vertical t)
tensor_horizontal :: (Num a, Nat m) => Vector p a -> Matrix m n a -> Matrix m (p `Times` n) a
tensor_horizontal v m = concat_horizontal (vector_map (`scalarmult` m) v)
concat_horizontal :: (Num a, Nat m) => Vector p (Matrix m n a) -> Matrix m (p `Times` n) a
concat_horizontal Nil = null_matrix
concat_horizontal (Cons h t) = stack_horizontal h (concat_horizontal t)
tensor :: (Num a, Nat n, Nat (p `Times` m)) => Matrix p q a -> Matrix m n a -> Matrix (p `Times` m) (q `Times` n) a
tensor a b = ab3 where
Matrix ab1 = matrix_map (`scalarmult` b) a
ab2 = vector_map concat_vertical ab1
ab3 = concat_horizontal ab2
oplus :: (Num a, Nat m, Nat q, Nat n, Nat p) => Matrix p q a -> Matrix m n a -> Matrix (p `Plus` m) (q `Plus` n) a
oplus (a :: Matrix p q a) (b :: Matrix m n a) =
(a `stack_vertical` (null_matrix :: Matrix m q a)) `stack_horizontal` ((null_matrix :: Matrix p n a) `stack_vertical` b)
matrix_controlled :: (Eq a, Num a, Nat n) => Matrix n n a -> Matrix (n `Plus` n) (n `Plus` n) a
matrix_controlled (m :: Matrix n n a) = oplus (1 :: Matrix n n a) m
type U2 a = Matrix Two Two a
type SO3 a = Matrix Three Three a
matrix_of_columns :: (Nat n, Nat m) => [[a]] -> Matrix n m a
matrix_of_columns columns = Matrix m where
m = vector $ map vector columns
matrix_of_rows :: (Nat n, Nat m) => [[a]] -> Matrix n m a
matrix_of_rows = matrix_transpose . matrix_of_columns
matrix :: (Nat n, Nat m) => [[a]] -> Matrix n m a
matrix = matrix_of_rows
columns_of_matrix :: Matrix n m a -> [[a]]
columns_of_matrix (Matrix m) =
map list_of_vector (list_of_vector m)
rows_of_matrix :: (Nat n) => Matrix n m a -> [[a]]
rows_of_matrix = columns_of_matrix . matrix_transpose
matrix2x2 :: (a, a) -> (a, a) -> Matrix Two Two a
matrix2x2 (a, b) (c, d) = matrix_of_columns [[a,c], [b,d]]
from_matrix2x2 :: Matrix Two Two a -> ((a, a), (a, a))
from_matrix2x2 (Matrix ((a `Cons` c `Cons` Nil) `Cons` (b `Cons` d `Cons` Nil) `Cons` Nil)) = ((a, b), (c, d))
matrix3x3 :: (a, a, a) -> (a, a, a) -> (a, a, a) -> Matrix Three Three a
matrix3x3 (a0, a1, a2) (b0, b1, b2) (c0, c1, c2) =
matrix_of_columns [[a0, b0, c0], [a1, b1, c1], [a2, b2, c2]]
matrix4x4 :: (a, a, a, a) -> (a, a, a, a) -> (a, a, a, a) -> (a, a, a, a) -> Matrix Four Four a
matrix4x4 (a0, a1, a2, a3) (b0, b1, b2, b3) (c0, c1, c2, c3) (d0, d1, d2, d3) =
matrix_of_columns [[a0, b0, c0, d0], [a1, b1, c1, d1], [a2, b2, c2, d2], [a3, b3, c3, d3]]
column3 :: (a, a, a) -> Matrix Three One a
column3 (a, b, c) = matrix_of_columns [[a, b, c]]
from_column3 :: Matrix Three One a -> (a, a, a)
from_column3 (Matrix ((a `Cons` b `Cons` c `Cons` Nil) `Cons` Nil)) = (a, b, c)
column_matrix :: Vector n a -> Matrix n One a
column_matrix v = Matrix (vector_singleton v)
cnot :: (Num a) => Matrix Four Four a
cnot = matrix4x4 (1,0,0,0)
(0,1,0,0)
(0,0,0,1)
(0,0,1,0)
swap :: (Num a) => Matrix Four Four a
swap = matrix4x4 (1,0,0,0)
(0,0,1,0)
(0,1,0,0)
(0,0,0,1)
zrot :: (Eq r, Floating r, Adjoint r) => r -> Matrix Two Two (Cplx r)
zrot theta = matrix2x2 (u, 0)
(0, adj u)
where
u = Cplx (cos (theta/2)) (sin (theta/2))