{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Quipper.Internal.Transformer where
import Quipper.Internal.Circuit
import Quipper.Utils.Auxiliary
import Control.Monad
import Control.Monad.State
import Data.Map (Map)
import qualified Data.Map as Map
import qualified Data.IntMap as IntMap
import Data.Typeable
data B_Endpoint a b =
Endpoint_Qubit a
| Endpoint_Bit b
deriving (Eq, Ord, Typeable, Show)
type Bindings a b = Map Wire (B_Endpoint a b)
wires_of_bindings :: Bindings a b -> [Wire]
wires_of_bindings = Map.keys
bindings_empty :: Bindings a b
bindings_empty = Map.empty
bind :: Wire -> B_Endpoint a b -> Bindings a b -> Bindings a b
bind r x bindings = Map.insert r x bindings
bind_qubit_wire :: Wire -> a -> Bindings a b -> Bindings a b
bind_qubit_wire r x bindings = bind r (Endpoint_Qubit x) bindings
bind_bit_wire :: Wire -> b -> Bindings a b -> Bindings a b
bind_bit_wire r x bindings = bind r (Endpoint_Bit x) bindings
unbind :: Bindings a b -> Wire -> B_Endpoint a b
unbind bindings w = case Map.lookup w bindings of
Nothing -> error ("unbind: wire (" ++ show w ++ ") not in bindings: " ++ show (wires_of_bindings bindings))
Just a -> a
unbind_qubit_wire :: Bindings a b -> Wire -> a
unbind_qubit_wire bindings w =
case unbind bindings w of
Endpoint_Qubit x -> x
Endpoint_Bit x -> error "Transformer error: expected a qubit, got a bit"
unbind_bit_wire :: Bindings a b -> Wire -> b
unbind_bit_wire bindings w =
case unbind bindings w of
Endpoint_Bit x -> x
Endpoint_Qubit x -> error "Transformer error: expected a bit, got a qubit"
bind_delete :: Wire -> Bindings a b -> Bindings a b
bind_delete r bindings = Map.delete r bindings
bind_list :: [Wire] -> [B_Endpoint a b] -> Bindings a b -> Bindings a b
bind_list ws xs bindings =
foldr (\ (w, x) -> bind w x) bindings (zip ws xs)
bind_qubit_wire_list :: [Wire] -> [a] -> Bindings a b -> Bindings a b
bind_qubit_wire_list ws xs bindings =
foldr (\ (w, x) -> bind_qubit_wire w x) bindings (zip ws xs)
bind_bit_wire_list :: [Wire] -> [b] -> Bindings a b -> Bindings a b
bind_bit_wire_list ws xs bindings =
foldr (\ (w, x) -> bind_bit_wire w x) bindings (zip ws xs)
unbind_list :: Bindings a b -> [Wire] -> [B_Endpoint a b]
unbind_list bindings ws =
map (unbind bindings) ws
unbind_qubit_wire_list :: Bindings a b -> [Wire] -> [a]
unbind_qubit_wire_list bindings ws =
map (unbind_qubit_wire bindings) ws
unbind_bit_wire_list :: Bindings a b -> [Wire] -> [b]
unbind_bit_wire_list bindings ws =
map (unbind_bit_wire bindings) ws
type Ctrls a b = [Signed (B_Endpoint a b)]
bind_controls :: Controls -> Ctrls a b -> Bindings a b -> Bindings a b
bind_controls controls xs bindings =
bind_list (map from_signed controls) (map from_signed xs) bindings
unbind_controls :: Bindings a b -> Controls -> Ctrls a b
unbind_controls bindings c =
[Signed (unbind bindings w) b | Signed w b <- c ]
data T_Gate m a b x =
T_QGate String Int Int InverseFlag NoControlFlag (([a] -> [a] -> Ctrls a b -> m ([a], [a], Ctrls a b)) -> x)
| T_QRot String Int Int InverseFlag Timestep NoControlFlag (([a] -> [a] -> Ctrls a b -> m ([a], [a], Ctrls a b)) -> x)
| T_GPhase Double NoControlFlag (([B_Endpoint a b] -> Ctrls a b -> m (Ctrls a b)) -> x)
| T_CNot NoControlFlag ((b -> Ctrls a b -> m (b, Ctrls a b)) -> x)
| T_CGate String NoControlFlag (([b] -> m (b, [b])) -> x)
| T_CGateInv String NoControlFlag ((b -> [b] -> m [b]) -> x)
| T_CSwap NoControlFlag ((b -> b -> Ctrls a b -> m (b, b, Ctrls a b)) -> x)
| T_QPrep NoControlFlag ((b -> m a) -> x)
| T_QUnprep NoControlFlag ((a -> m b) -> x)
| T_QInit Bool NoControlFlag (m a -> x)
| T_CInit Bool NoControlFlag (m b -> x)
| T_QTerm Bool NoControlFlag ((a -> m ()) -> x)
| T_CTerm Bool NoControlFlag ((b -> m ()) -> x)
| T_QMeas ((a -> m b) -> x)
| T_QDiscard ((a -> m ()) -> x)
| T_CDiscard ((b -> m ()) -> x)
| T_DTerm Bool ((b -> m ()) -> x)
| T_Subroutine BoxId InverseFlag NoControlFlag ControllableFlag [Wire] Arity [Wire] Arity RepeatFlag ((Namespace -> [B_Endpoint a b] -> Ctrls a b -> m ([B_Endpoint a b], Ctrls a b)) -> x)
| T_Comment String InverseFlag (([(B_Endpoint a b, String)] -> m ()) -> x)
instance Show (T_Gate m a b x) where
show (T_QGate name n m inv ncf f) = "QGate[" ++ name ++ "," ++ show n ++ "," ++ show m ++ "]" ++ optional inv "*"
show (T_QRot name n m inv t ncf f) = "QRot[" ++ name ++ "," ++ show t ++ "," ++ show n ++ "," ++ show m ++ "]" ++ optional inv "*"
show (T_GPhase t ncf f) = "GPhase[" ++ show t ++ "]"
show (T_CNot ncf f) = "CNot"
show (T_CGate n ncf f) = "CGate[" ++ n ++ "]"
show (T_CGateInv n ncf f) = "CGate[" ++ n ++ "]*"
show (T_CSwap ncf f) = "CSwap"
show (T_QPrep ncf f) = "QPrep"
show (T_QUnprep ncf f) = "QUnprep"
show (T_QInit b ncf f) = "QInit" ++ if b then "1" else "0"
show (T_CInit b ncf f) = "CInit" ++ if b then "1" else "0"
show (T_QTerm b ncf f) = "QTerm" ++ if b then "1" else "0"
show (T_CTerm b ncf f) = "CTerm" ++ if b then "1" else "0"
show (T_QMeas f) = "QMeas"
show (T_QDiscard f) = "QDiscard"
show (T_CDiscard f) = "CDiscard"
show (T_DTerm b f) = "DTerm" ++ if b then "1" else "0"
show (T_Subroutine n inv ncf scf ws a1 vs a2 rep f) = "Subroutine(x" ++ (show rep) ++ ")[" ++ show n ++ "]" ++ optional inv "*"
show (T_Comment n inv f) = "Comment[" ++ n ++ "]" ++ optional inv "*"
type Transformer m a b = forall x . T_Gate m a b x -> x
type BT m a b = Bindings a b -> m (Bindings a b)
bind_gate :: Monad m => Namespace -> Gate -> T_Gate m a b (BT m a b)
bind_gate namespace gate = case gate of
QGate name inv ws vs c ncf -> T_QGate name n m inv ncf (list_binary ws vs c)
where
n = length ws
m = length vs
QRot name inv t ws vs c ncf -> T_QRot name n m inv t ncf (list_binary ws vs c)
where
n = length ws
m = length vs
GPhase t w c ncf -> T_GPhase t ncf (phase_ary w c)
CNot w c ncf -> T_CNot ncf (cunary w c)
CGate n w vs ncf -> T_CGate n ncf (cgate_ary w vs)
CGateInv n w vs ncf -> T_CGateInv n ncf (cgateinv_ary w vs)
CSwap w v c ncf -> T_CSwap ncf (binary_c w v c)
QPrep w ncf -> T_QPrep ncf (qprep_ary w)
QUnprep w ncf -> T_QUnprep ncf (qunprep_ary w)
QInit b w ncf -> T_QInit b ncf (qinit_ary w)
CInit b w ncf -> T_CInit b ncf (cinit_ary w)
QTerm b w ncf -> T_QTerm b ncf (qterm_ary w)
CTerm b w ncf -> T_CTerm b ncf (cterm_ary w)
QMeas w -> T_QMeas (qunprep_ary w)
QDiscard w -> T_QDiscard (qterm_ary w)
CDiscard w -> T_CDiscard (cterm_ary w)
DTerm b w -> T_DTerm b (cterm_ary w)
Subroutine n inv ws a1 vs a2 c ncf scf rep
-> T_Subroutine n inv ncf scf ws a1 vs a2 rep
(\f -> subroutine_ary ws vs c (f namespace))
Comment s inv ws -> T_Comment s inv (comment_ary ws)
where
unary :: Monad m => Wire -> Controls -> (a -> Ctrls a b -> m (a, Ctrls a b)) -> BT m a b
unary w c f bindings = do
let w' = unbind_qubit_wire bindings w
let c' = unbind_controls bindings c
(w'', c'') <- f w' c'
let bindings1 = bind_qubit_wire w w'' bindings
let bindings2 = bind_controls c c'' bindings1
return bindings2
binary :: Monad m => Wire -> Wire -> Controls -> (a -> a -> Ctrls a b -> m (a, a, Ctrls a b)) -> BT m a b
binary w v c f bindings = do
let w' = unbind_qubit_wire bindings w
let v' = unbind_qubit_wire bindings v
let c' = unbind_controls bindings c
(w'', v'', c'') <- f w' v' c'
let bindings1 = bind_qubit_wire w w'' bindings
let bindings2 = bind_qubit_wire v v'' bindings1
let bindings3 = bind_controls c c'' bindings2
return bindings3
binary_c :: Monad m => Wire -> Wire -> Controls -> (b -> b -> Ctrls a b -> m (b, b, Ctrls a b)) -> BT m a b
binary_c w v c f bindings = do
let w' = unbind_bit_wire bindings w
let v' = unbind_bit_wire bindings v
let c' = unbind_controls bindings c
(w'', v'', c'') <- f w' v' c'
let bindings1 = bind_bit_wire w w'' bindings
let bindings2 = bind_bit_wire v v'' bindings1
let bindings3 = bind_controls c c'' bindings2
return bindings3
list_unary :: Monad m => [Wire] -> Controls -> ([a] -> Ctrls a b -> m ([a], Ctrls a b)) -> BT m a b
list_unary ws c f bindings = do
let ws' = unbind_qubit_wire_list bindings ws
let c' = unbind_controls bindings c
(ws'', c'') <- f ws' c'
let bindings1 = bind_qubit_wire_list ws ws'' bindings
let bindings2 = bind_controls c c'' bindings1
return bindings2
list_binary :: Monad m => [Wire] -> [Wire] -> Controls -> ([a] -> [a] -> Ctrls a b -> m ([a], [a], Ctrls a b)) -> BT m a b
list_binary ws vs c f bindings = do
let ws' = unbind_qubit_wire_list bindings ws
let vs' = unbind_qubit_wire_list bindings vs
let c' = unbind_controls bindings c
(ws'', vs'', c'') <- f ws' vs' c'
let bindings1 = bind_qubit_wire_list ws ws'' bindings
let bindings2 = bind_qubit_wire_list vs vs'' bindings1
let bindings3 = bind_controls c c'' bindings2
return bindings3
qprep_ary :: Monad m => Wire -> (b -> m a) -> BT m a b
qprep_ary w f bindings = do
let w' = unbind_bit_wire bindings w
w'' <- f w'
let bindings1 = bind_qubit_wire w w'' bindings
return bindings1
qunprep_ary :: Monad m => Wire -> (a -> m b) -> BT m a b
qunprep_ary w f bindings = do
let w' = unbind_qubit_wire bindings w
w'' <- f w'
let bindings1 = bind_bit_wire w w'' bindings
return bindings1
cunary :: Monad m => Wire -> Controls -> (b -> Ctrls a b -> m (b, Ctrls a b)) -> BT m a b
cunary w c f bindings = do
let w' = unbind_bit_wire bindings w
let c' = unbind_controls bindings c
(w'', c'') <- f w' c'
let bindings1 = bind_bit_wire w w'' bindings
let bindings2 = bind_controls c c'' bindings1
return bindings2
qinit_ary :: Monad m => Wire -> m a -> BT m a b
qinit_ary w f bindings = do
w'' <- f
let bindings1 = bind_qubit_wire w w'' bindings
return bindings1
cinit_ary :: Monad m => Wire -> m b -> BT m a b
cinit_ary w f bindings = do
w'' <- f
let bindings1 = bind_bit_wire w w'' bindings
return bindings1
qterm_ary :: Monad m => Wire -> (a -> m ()) -> BT m a b
qterm_ary w f bindings = do
let w' = unbind_qubit_wire bindings w
() <- f w'
let bindings1 = bind_delete w bindings
return bindings1
cterm_ary :: Monad m => Wire -> (b -> m ()) -> BT m a b
cterm_ary w f bindings = do
let w' = unbind_bit_wire bindings w
() <- f w'
let bindings1 = bind_delete w bindings
return bindings1
cgate_ary :: Monad m => Wire -> [Wire] -> ([b] -> m (b, [b])) -> BT m a b
cgate_ary w vs f bindings = do
let vs' = unbind_bit_wire_list bindings vs
(w'', vs'') <- f vs'
let bindings1 = bind_bit_wire w w'' bindings
let bindings2 = bind_bit_wire_list vs vs'' bindings1
return bindings2
cgateinv_ary :: Monad m => Wire -> [Wire] -> (b -> [b] -> m [b]) -> BT m a b
cgateinv_ary w vs f bindings = do
let vs' = unbind_bit_wire_list bindings vs
let w' = unbind_bit_wire bindings w
vs'' <- f w' vs'
let bindings1 = bind_bit_wire_list vs vs'' bindings
return bindings1
subroutine_ary :: Monad m => [Wire] -> [Wire] -> Controls
-> ([B_Endpoint a b] -> Ctrls a b -> m ([B_Endpoint a b], Ctrls a b))
-> BT m a b
subroutine_ary ws vs c f bindings = do
let c' = unbind_controls bindings c
let ws' = unbind_list bindings ws
(vs'',c'') <- f ws' c'
let bindings1 = bind_list vs vs'' bindings
let bindings2 = bind_controls c c'' bindings1
return bindings2
phase_ary :: Monad m => [Wire] -> Controls -> ([B_Endpoint a b] -> Ctrls a b -> m (Ctrls a b)) -> BT m a b
phase_ary w c f bindings = do
let w' = map (unbind bindings) w
let c' = unbind_controls bindings c
c'' <- f w' c'
let bindings1 = bind_controls c c'' bindings
return bindings1
comment_ary :: Monad m => [(Wire, String)] -> (([(B_Endpoint a b, String)] -> m ()) -> BT m a b)
comment_ary ws f bindings = do
let ws' = zip (unbind_list bindings $ map fst ws) (map snd ws)
f ws'
return bindings
transform_circuit :: Monad m => Transformer m a b -> Circuit -> Bindings a b -> m (Bindings a b)
transform_circuit transformer c bindings =
foldM apply bindings gs
where
(_,gs,_,_) = c
apply bindings g = transformer (bind_gate namespace_empty g) bindings
transform_bcircuit_rec :: Monad m => Transformer m a b -> BCircuit -> Bindings a b -> m (Bindings a b)
transform_bcircuit_rec transformer (c,namespace) bindings =
foldM apply bindings gs
where
(_,gs,_,_) = c
apply bindings g = transformer (bind_gate namespace g) bindings
transform_bcircuit_id :: Transformer Id a b -> BCircuit -> Bindings a b -> Bindings a b
transform_bcircuit_id t c b = getId (transform_bcircuit_rec t c b)
data DynamicTransformer m a b = DT {
transformer :: Transformer m a b,
define_subroutine :: BoxId -> TypedSubroutine -> m (),
lifting_function :: b -> m Bool
}
transform_dbcircuit :: Monad m => DynamicTransformer m a b -> DBCircuit x -> Bindings a b -> m (x,Bindings a b)
transform_dbcircuit dt (a0,rw) bindings = evalStateT (inner_transform dt (a0,rw) bindings) namespace_empty where
inner_transform :: Monad m => DynamicTransformer m a b -> DBCircuit x -> Bindings a b -> (StateT Namespace m) (x,Bindings a b)
inner_transform dt (a0,rw) bindings =
case rw of
(RW_Return (_,_,x)) -> return (x,bindings)
(RW_Write gate rw') -> do
namespace <- get
bindings' <- lift $ (transformer dt) (bind_gate namespace gate) bindings
inner_transform dt (a0,rw') bindings'
(RW_Read wire rw_cont) -> do
let bit = unbind_bit_wire bindings wire
bool <- lift $ (lifting_function dt) bit
let rw' = rw_cont bool
inner_transform dt (a0,rw') bindings
(RW_Subroutine name subroutine rw') -> do
lift $ (define_subroutine dt) name subroutine
namespace <- get
let namespace' = map_provide name subroutine namespace
put namespace'
inner_transform dt (a0,rw') bindings