from typing import Optional, List, Dict, Union, Callable
import tensorflow as tf
from tensorflow.keras.layers import Activation
import numpy as np
from .generic import TransformerLayer, MeshLayer, CompoundTransformerLayer, PermutationLayer
from ..meshmodel import RectangularMeshModel, TriangularMeshModel, PermutingRectangularMeshModel, ButterflyMeshModel
from ..helpers import rectangular_permutation, butterfly_layer_permutation
from ..config import DEFAULT_BASIS, TF_FLOAT, TF_COMPLEX
[docs]class RM(MeshLayer):
"""Rectangular mesh network layer for unitary operators implemented in tensorflow
Args:
units: The dimension of the unitary matrix (:math:`N`)
num_layers: The number of layers (:math:`L`) of the mesh
hadamard: Hadamard convention for the beamsplitters
basis: Phase basis to use
bs_error: Beamsplitter split ratio error
theta_init: Initializer for :code:`theta` (:math:`\\boldsymbol{\\theta}` or :math:`\\theta_{n\ell}`)
phi_init: Initializer for :code:`phi` (:math:`\\boldsymbol{\\phi}` or :math:`\\phi_{n\ell}`)
gamma_init: Initializer for :code:`gamma` (:math:`\\boldsymbol{\\gamma}` or :math:`\\gamma_{n}`)
phase_loss_fn: Phase loss function for the layer
activation: Nonlinear activation function (:code:`None` if there's no nonlinearity)
"""
def __init__(self, units: int, num_layers: int = None, hadamard: bool = False, incoherent: bool = False,
basis: str = DEFAULT_BASIS, bs_error: float = 0.0,
theta_init: Union[str, tuple, np.ndarray] = "haar_rect",
phi_init: Union[str, tuple, np.ndarray] = "random_phi",
gamma_init: Union[str, tuple, np.ndarray] = "random_gamma",
phase_loss_fn: Optional[Callable[[tf.Tensor], tf.Tensor]] = None,
activation: Activation = None, **kwargs):
super(RM, self).__init__(
RectangularMeshModel(units, num_layers, hadamard, bs_error, basis, theta_init, phi_init, gamma_init),
activation=activation, incoherent=incoherent,
phase_loss_fn=phase_loss_fn, **kwargs
)
[docs]class TM(MeshLayer):
"""Triangular mesh network layer for unitary operators implemented in tensorflow
Args:
units: The dimension of the unitary matrix (:math:`N`)
hadamard: Hadamard convention for the beamsplitters
basis: Phase basis to use
bs_error: Beamsplitter split ratio error
theta_init: Initializer for :code:`theta` (:math:`\\boldsymbol{\\theta}` or :math:`\\theta_{n\ell}`)
phi_init: Initializer for :code:`phi` (:math:`\\boldsymbol{\\phi}` or :math:`\\phi_{n\ell}`)
gamma_init: Initializer for :code:`gamma` (:math:`\\boldsymbol{\\gamma}` or :math:`\\gamma_{n}`)
phase_loss_fn: Phase loss function for the layer
activation: Nonlinear activation function (:code:`None` if there's no nonlinearity)
"""
def __init__(self, units: int, hadamard: bool = False, incoherent: bool = False, basis: str = DEFAULT_BASIS,
bs_error: float = 0.0, theta_init: Union[str, tuple, np.ndarray] = "haar_tri",
phi_init: Union[str, tuple, np.ndarray] = "random_phi",
gamma_init: Union[str, tuple, np.ndarray] = "random_gamma",
phase_loss_fn: Optional[Callable[[tf.Tensor], tf.Tensor]] = None,
activation: Activation = None, **kwargs):
super(TM, self).__init__(
TriangularMeshModel(units, hadamard, bs_error, basis, theta_init, phi_init, gamma_init),
activation=activation, incoherent=incoherent, phase_loss_fn=phase_loss_fn, **kwargs
)
[docs]class PRM(MeshLayer):
"""Permuting rectangular mesh unitary layer
Args:
units: The dimension of the unitary matrix (:math:`N`) to be modeled by this transformer
tunable_layers_per_block: The number of tunable layers per block (overrides :code:`num_tunable_layers_list`, :code:`sampling_frequencies`)
num_tunable_layers_list: Number of tunable layers in each block in order from left to right
sampling_frequencies: Frequencies of sampling frequencies between the tunable layers
is_trainable: Whether the parameters are trainable
bs_error: Photonic error in the beamsplitter
theta_init: Initializer for :code:`theta` (:math:`\\boldsymbol{\\theta}` or :math:`\\theta_{n\ell}`)
phi_init: Initializer for :code:`phi` (:math:`\\boldsymbol{\\phi}` or :math:`\\phi_{n\ell}`)
gamma_init: Initializer for :code:`gamma` (:math:`\\boldsymbol{\\gamma}` or :math:`\\gamma_{n}`)
phase_loss_fn: Phase loss function for the layer
activation: Nonlinear activation function (:code:`None` if there's no nonlinearity)
"""
def __init__(self, units: int, tunable_layers_per_block: int = None,
num_tunable_layers_list: Optional[List[int]] = None, sampling_frequencies: Optional[List[int]] = None,
bs_error: float = 0.0, hadamard: bool = False, incoherent: bool = False,
theta_init: Union[str, tuple, np.ndarray] = "haar_prm",
phi_init: Union[str, tuple, np.ndarray] = "random_phi",
gamma_init: Union[str, tuple, np.ndarray] = "random_gamma",
phase_loss_fn: Optional[Callable[[tf.Tensor], tf.Tensor]] = None,
activation: Activation = None, **kwargs):
if theta_init == 'haar_prm' and tunable_layers_per_block is not None:
raise NotImplementedError('haar_prm initializer is incompatible with setting tunable_layers_per_block.')
super(PRM, self).__init__(
PermutingRectangularMeshModel(units, tunable_layers_per_block, num_tunable_layers_list,
sampling_frequencies, bs_error, hadamard, theta_init, phi_init, gamma_init),
activation=activation, incoherent=incoherent, phase_loss_fn=phase_loss_fn, **kwargs)
[docs]class BM(MeshLayer):
"""Butterfly mesh unitary layer
Args:
units: The dimension of the unitary matrix (:math:`N`)
hadamard: Hadamard convention for the beamsplitters
basis: Phase basis to use
bs_error: Beamsplitter split ratio error
theta_init: Initializer for :code:`theta` (:math:`\\boldsymbol{\\theta}` or :math:`\\theta_{n\ell}`)
phi_init: Initializer for :code:`phi` (:math:`\\boldsymbol{\\phi}` or :math:`\\phi_{n\ell}`)
phase_loss_fn: Phase loss function for the layer
activation: Nonlinear activation function (:code:`None` if there's no nonlinearity)
"""
def __init__(self, num_layers: int, hadamard: bool = False, incoherent: bool = False, basis: str = DEFAULT_BASIS,
bs_error: float = 0.0, theta_init: Union[str, tuple, np.ndarray] = "random_theta",
phi_init: Union[str, tuple, np.ndarray] = "random_phi",
phase_loss_fn: Optional[Callable[[tf.Tensor], tf.Tensor]] = None,
activation: Activation = None, **kwargs):
super(BM, self).__init__(
ButterflyMeshModel(num_layers, hadamard, bs_error, basis, theta_init, phi_init),
activation=activation, incoherent=incoherent, phase_loss_fn=phase_loss_fn, **kwargs
)
[docs]class SVD(CompoundTransformerLayer):
"""Singular value decomposition transformer for implementing a matrix.
Notes:
SVD requires you specify the unitary transformers used to implement the SVD in `unitary_transformer_dict`,
specifying transformer name and arguments for that transformer.
Args:
units: The number of inputs (:math:`M`) of the :math:`M \\times N` matrix to be modelled by the SVD
mesh_dict: The name and properties of the mesh layer used for the SVD
output_units: The dimension of the output (:math:`N`) of the :math:`M \\times N` matrix to be modelled by the SVD
pos_singular_values: Whether to allow only positive singular values
activation: Nonlinear activation function (:code:`None` if there's no nonlinearity)
"""
def __init__(self, units: int, mesh_dict: Dict, output_units: Optional[int] = None,
pos_singular_values: bool = False, activation: Activation = None):
self.units = units
self.output_units = output_units if output_units is not None else units
if output_units != units and output_units is not None:
raise NotImplementedError("Still working out a clean implementation of non-square linear operators.")
self.mesh_name = mesh_dict['name']
self.mesh_properties = mesh_dict.get('properties', {})
self.pos = pos_singular_values
mesh_name2layer = {
'rm': RM,
'prm': PRM,
'tm': TM
}
self.v = mesh_name2layer[self.mesh_name](units=units, name="v", **self.mesh_properties)
self.diag = Diagonal(units, output_units=output_units, pos=self.pos)
self.u = mesh_name2layer[self.mesh_name](units=units, name="u", **self.mesh_properties)
self.activation = activation
super(SVD, self).__init__(
units=self.units,
transformer_list=[self.v, self.diag, self.u]
)
[docs]class DiagonalPhaseLayer(TransformerLayer):
"""Diagonal matrix of phase shifts
Args:
units: Dimension of the input (number of input waveguide ports), :math:`N`
"""
def __init__(self, units: int, **kwargs):
super(DiagonalPhaseLayer, self).__init__(units=units)
self.gamma = tf.Variable(
name="gamma",
initial_value=tf.constant(2 * np.pi * np.random.rand(units), dtype=TF_FLOAT),
dtype=TF_FLOAT,
**kwargs
)
self.diag_vec = tf.complex(tf.cos(self.gamma), tf.sin(self.gamma))
self.inv_diag_vec = tf.complex(tf.cos(-self.gamma), tf.sin(-self.gamma))
self.variables.append(self.gamma)
[docs]class Diagonal(TransformerLayer):
"""Diagonal matrix of gains and losses (not necessarily real)
Args:
units: Dimension of the input (number of input waveguide ports), :math:`N`
is_complex: Whether to use complex values or not
output_units: Dimension of the output (number of output waveguide ports), :math:`M`.
If :math:`M < N`, remove last :math:`N - M` elements.
If :math:`M > N`, pad with :math:`M - N` zeros.
pos: Enforce positive definite matrix (only positive singular values)
"""
def __init__(self, units: int, is_complex: bool = True, output_units: Optional[int] = None,
pos: bool = False, **kwargs):
super(Diagonal, self).__init__(units=units, **kwargs)
self.output_dim = output_units if output_units is not None else units
self.pos = pos
self.is_complex = is_complex
singular_value_dim = min(self.units, self.output_dim)
self.sigma = tf.Variable(
name="sigma",
initial_value=tf.constant(2 * np.pi * np.random.randn(singular_value_dim), dtype=TF_FLOAT),
dtype=TF_FLOAT
)
[docs]class RectangularPerm(PermutationLayer):
"""Rectangular permutation layer
The rectangular permutation layer for a frequency :math:`f` corresponds effectively is equivalent to adding
:math:`f` layers of cross state MZIs in a grid configuration to the existing mesh.
Args:
units: Dimension of the input (number of input waveguide ports), :math:`N`
frequency: Frequency of interacting mesh wires (waveguides)
"""
def __init__(self, units: int, frequency: int):
self.frequency = frequency
super(RectangularPerm, self).__init__(
permuted_indices=rectangular_permutation(units, frequency))
[docs]class ButterflyPerm(PermutationLayer):
"""Butterfly (FFT) permutation layer
The butterfly or FFT permutation for a frequency :math:`f` corresponds to switching all inputs
that are :math:`f` inputs apart. This works most cleanly in a butterfly mesh architecture where
the number of inputs, :math:`N`, and the frequencies, :math:`f` are powers of two.
Args:
units: Dimension of the input (number of input waveguide ports), :math:`N`
frequency: Frequency of interacting mesh wires (waveguides)
"""
def __init__(self, units: int, frequency: int):
self.frequency = frequency
super(ButterflyPerm, self).__init__(permuted_indices=butterfly_layer_permutation(units, frequency))