from typing import List, Tuple, Optional, Callable
import tensorflow as tf
from tensorflow.keras.layers import Layer, Activation
import numpy as np
from ..numpy.generic import MeshPhases
from ..meshmodel import MeshModel
from ..helpers import pairwise_off_diag_permutation, plot_complex_matrix, inverse_permutation
from ..config import TF_COMPLEX, BLOCH, SINGLEMODE
[docs]class PermutationLayer(TransformerLayer):
"""Permutation layer
Args:
permuted_indices: order of indices for the permutation matrix (efficient permutation representation)
"""
def __init__(self, permuted_indices: np.ndarray):
super(PermutationLayer, self).__init__(units=permuted_indices.shape[0])
self.permuted_indices = np.asarray(permuted_indices, dtype=np.int32)
self.inv_permuted_indices = inverse_permutation(self.permuted_indices)
[docs]class MeshVerticalLayer(TransformerLayer):
"""
Args:
diag: the diagonal terms to multiply
off_diag: the off-diagonal terms to multiply
left_perm: the permutation for the mesh vertical layer (prior to the coupling operation)
right_perm: the right permutation for the mesh vertical layer
(usually for the final layer and after the coupling operation)
"""
def __init__(self, pairwise_perm_idx: np.ndarray, diag: tf.Tensor, off_diag: tf.Tensor,
right_perm: PermutationLayer = None, left_perm: PermutationLayer = None):
self.diag = diag
self.off_diag = off_diag
self.left_perm = left_perm
self.right_perm = right_perm
self.pairwise_perm_idx = pairwise_perm_idx
super(MeshVerticalLayer, self).__init__(pairwise_perm_idx.shape[0])
[docs]class MeshParamTensorflow:
"""A class that cleanly arranges parameters into a specific arrangement that can be used to simulate any mesh
Args:
param: parameter to arrange in mesh
units: number of inputs/outputs of the mesh
"""
def __init__(self, param: tf.Tensor, units: int):
self.param = param
self.units = units
@property
def single_mode_arrangement(self):
"""
The single-mode arrangement based on the :math:`L(\\theta)` transfer matrix for :code:`PhaseShiftUpper`
is one where elements of `param` are on the even rows and all odd rows are zero.
In particular, given the :code:`param` array
:math:`\\boldsymbol{\\theta} = [\\boldsymbol{\\theta}_1, \\boldsymbol{\\theta}_2, \ldots \\boldsymbol{\\theta}_M]^T`,
where :math:`\\boldsymbol{\\theta}_m` represent row vectors and :math:`M = \\lfloor\\frac{N}{2}\\rfloor`, the single-mode arrangement has the stripe array form
:math:`\widetilde{\\boldsymbol{\\theta}} = [\\boldsymbol{\\theta}_1, \\boldsymbol{0}, \\boldsymbol{\\theta}_2, \\boldsymbol{0}, \ldots \\boldsymbol{\\theta}_N, \\boldsymbol{0}]^T`.
where :math:`\widetilde{\\boldsymbol{\\theta}} \in \mathbb{R}^{N \\times L}` defines the :math:`\\boldsymbol{\\theta}` of the final mesh
and :math:`\\boldsymbol{0}` represents an array of zeros of the same size as :math:`\\boldsymbol{\\theta}_n`.
Returns:
Single-mode arrangement array of phases
"""
tensor_t = tf.transpose(self.param)
stripe_tensor = tf.reshape(tf.concat((tensor_t, tf.zeros_like(tensor_t)), 1),
shape=(tensor_t.shape[0] * 2, tensor_t.shape[1]))
if self.units % 2:
return tf.concat([stripe_tensor, tf.zeros(shape=(1, tensor_t.shape[1]))], axis=0)
else:
return stripe_tensor
@property
def common_mode_arrangement(self) -> tf.Tensor:
"""
The common-mode arrangement based on the :math:`C(\\theta)` transfer matrix for :code:`PhaseShiftCommonMode`
is one where elements of `param` are on the even rows and repeated on respective odd rows.
In particular, given the :code:`param` array
:math:`\\boldsymbol{\\theta} = [\\boldsymbol{\\theta}_1, \\boldsymbol{\\theta}_2, \ldots \\boldsymbol{\\theta}_M]^T`,
where :math:`\\boldsymbol{\\theta}_n` represent row vectors and :math:`M = \\lfloor\\frac{N}{2}\\rfloor`,
the common-mode arrangement has the stripe array form
:math:`\\widetilde{\\boldsymbol{\\theta}} = [\\boldsymbol{\\theta}_1, \\boldsymbol{\\theta}_1,\\boldsymbol{\\theta}_2, \\boldsymbol{\\theta}_2, \ldots \\boldsymbol{\\theta}_N, \\boldsymbol{\\theta}_N]^T`.
where :math:`\widetilde{\\boldsymbol{\\theta}} \in \mathbb{R}^{N \\times L}` defines the :math:`\\boldsymbol{\\theta}` of the final mesh.
Returns:
Common-mode arrangement array of phases
"""
phases = self.single_mode_arrangement
return phases + _roll_tensor(phases)
@property
def differential_mode_arrangement(self) -> tf.Tensor:
"""
The differential-mode arrangement is based on the :math:`D(\\theta)` transfer matrix
for :code:`PhaseShiftDifferentialMode`.
Given the :code:`param` array
:math:`\\boldsymbol{\\theta} = [\cdots \\boldsymbol{\\theta}_m \cdots]^T`,
where :math:`\\boldsymbol{\\theta}_n` represent row vectors and :math:`M = \\lfloor\\frac{N}{2}\\rfloor`, the differential-mode arrangement has the form
:math:`\\widetilde{\\boldsymbol{\\theta}} = \\left[\cdots \\frac{\\boldsymbol{\\theta}_m}{2}, -\\frac{\\boldsymbol{\\theta}_m}{2} \cdots \\right]^T`.
where :math:`\widetilde{\\boldsymbol{\\theta}} \in \mathbb{R}^{N \\times L}` defines the :math:`\\boldsymbol{\\theta}` of the final mesh.
Returns:
Differential-mode arrangement array of phases
"""
phases = self.single_mode_arrangement
return phases / 2 - _roll_tensor(phases / 2)
def __add__(self, other):
return MeshParamTensorflow(self.param + other.param, self.units)
def __sub__(self, other):
return MeshParamTensorflow(self.param - other.param, self.units)
def __mul__(self, other):
return MeshParamTensorflow(self.param * other.param, self.units)
[docs]class MeshPhasesTensorflow:
"""Organizes the phases in the mesh into appropriate arrangements
Args:
theta: Array to be converted to :math:`\\boldsymbol{\\theta}`
phi: Array to be converted to :math:`\\boldsymbol{\\phi}`
gamma: Array to be converted to :math:`\\boldsymbol{\gamma}`
mask: Mask over values of :code:`theta` and :code:`phi` that are not in bar state
basis: Phase basis to use
hadamard: Whether to use Hadamard convention
theta_fn: TF-friendly phi function call to reparametrize phi (example use cases: see `neurophox.helpers`).
By default, use identity function.
phi_fn: TF-friendly phi function call to reparametrize phi (example use cases: see `neurophox.helpers`).
By default, use identity function.
gamma_fn: TF-friendly gamma function call to reparametrize gamma (example use cases: see `neurophox.helpers`).
phase_loss_fn: Incorporate phase shift-dependent loss into the model.
The function is of the form phase_loss_fn(phases),
which returns the loss
"""
def __init__(self, theta: tf.Variable, phi: tf.Variable, mask: np.ndarray, gamma: tf.Variable, units: int,
basis: str = SINGLEMODE, hadamard: bool = False,
theta_fn: Optional[Callable] = None, phi_fn: Optional[Callable] = None,
gamma_fn: Optional[Callable] = None,
phase_loss_fn: Optional[Callable[[tf.Tensor], tf.Tensor]] = None):
self.mask = mask if mask is not None else np.ones_like(theta)
self.theta_fn = (lambda x: x) if theta_fn is None else theta_fn
self.phi_fn = (lambda x: x) if phi_fn is None else phi_fn
self.gamma_fn = (lambda x: x) if gamma_fn is None else gamma_fn
self.theta = MeshParamTensorflow(self.theta_fn(theta) * mask + (1 - mask) * (1 - hadamard) * np.pi, units=units)
self.phi = MeshParamTensorflow(self.phi_fn(phi) * mask + (1 - mask) * (1 - hadamard) * np.pi, units=units)
self.gamma = self.gamma_fn(gamma)
self.basis = basis
self.phase_loss_fn = (lambda x: 0) if phase_loss_fn is None else phase_loss_fn
self.phase_fn = lambda phase: tf.complex(tf.cos(phase), tf.sin(phase)) * (1 - _to_complex(self.phase_loss_fn(phase)))
self.input_phase_shift_layer = self.phase_fn(self.gamma)
if self.theta.param.shape != self.phi.param.shape:
raise ValueError("Internal phases (theta) and external phases (phi) need to have the same shape.")
@property
def internal_phase_shifts(self):
"""
The internal phase shift matrix of the mesh corresponds to an `L \\times N` array of phase shifts
(in between beamsplitters, thus internal) where :math:`L` is number of layers and :math:`N` is number of inputs/outputs
Returns:
Internal phase shift matrix corresponding to :math:`\\boldsymbol{\\theta}`
"""
if self.basis == BLOCH:
return self.theta.differential_mode_arrangement
elif self.basis == SINGLEMODE:
return self.theta.single_mode_arrangement
else:
raise NotImplementedError(f"{self.basis} is not yet supported or invalid.")
@property
def external_phase_shifts(self):
"""The external phase shift matrix of the mesh corresponds to an `L \\times N` array of phase shifts
(outside of beamsplitters, thus external) where :math:`L` is number of layers and :math:`N` is number of inputs/outputs
Returns:
External phase shift matrix corresponding to :math:`\\boldsymbol{\\phi}`
"""
if self.basis == BLOCH or self.basis == SINGLEMODE:
return self.phi.single_mode_arrangement
else:
raise NotImplementedError(f"{self.basis} is not yet supported or invalid.")
@property
def internal_phase_shift_layers(self):
"""Elementwise applying complex exponential to :code:`internal_phase_shifts`.
Returns:
Internal phase shift layers corresponding to :math:`\\boldsymbol{\\theta}`
"""
return self.phase_fn(self.internal_phase_shifts)
@property
def external_phase_shift_layers(self):
"""Elementwise applying complex exponential to :code:`external_phase_shifts`.
Returns:
External phase shift layers corresponding to :math:`\\boldsymbol{\\phi}`
"""
return self.phase_fn(self.external_phase_shifts)
[docs]class Mesh:
def __init__(self, model: MeshModel):
"""General mesh network layer defined by `neurophox.meshmodel.MeshModel`
Args:
model: The `MeshModel` model of the mesh network (e.g., rectangular, triangular, custom, etc.)
"""
self.model = model
self.units, self.num_layers = self.model.units, self.model.num_layers
self.pairwise_perm_idx = pairwise_off_diag_permutation(self.units)
ss, cs, sc, cc = self.model.mzi_error_tensors
self.ss, self.cs, self.sc, self.cc = tf.constant(ss, dtype=TF_COMPLEX), tf.constant(cs, dtype=TF_COMPLEX), \
tf.constant(sc, dtype=TF_COMPLEX), tf.constant(cc, dtype=TF_COMPLEX)
self.perm_layers = [PermutationLayer(self.model.perm_idx[layer]) for layer in range(self.num_layers + 1)]
[docs] def mesh_layers(self, phases: MeshPhasesTensorflow) -> List[MeshVerticalLayer]:
"""
Args:
phases: The :code:`MeshPhasesTensorflow` object containing :math:`\\boldsymbol{\\theta}, \\boldsymbol{\\phi}, \\boldsymbol{\\gamma}`
Returns:
List of mesh layers to be used by any instance of :code:`MeshLayer`
"""
internal_psl = phases.internal_phase_shift_layers
external_psl = phases.external_phase_shift_layers
# smooth trick to efficiently perform the layerwise coupling computation
if self.model.hadamard:
s11 = self.cc * internal_psl + self.ss * _roll_tensor(internal_psl, up=True)
s22 = _roll_tensor(self.ss * internal_psl + self.cc * _roll_tensor(internal_psl, up=True))
s12 = _roll_tensor(self.cs * internal_psl - self.sc * _roll_tensor(internal_psl, up=True))
s21 = self.sc * internal_psl - self.cs * _roll_tensor(internal_psl, up=True)
else:
s11 = self.cc * internal_psl - self.ss * _roll_tensor(internal_psl, up=True)
s22 = _roll_tensor(-self.ss * internal_psl + self.cc * _roll_tensor(internal_psl, up=True))
s12 = 1j * _roll_tensor(self.cs * internal_psl + self.sc * _roll_tensor(internal_psl, up=True))
s21 = 1j * (self.sc * internal_psl + self.cs * _roll_tensor(internal_psl, up=True))
diag_layers = external_psl * (s11 + s22) / 2
off_diag_layers = _roll_tensor(external_psl) * (s21 + s12) / 2
if self.units % 2:
diag_layers = tf.concat((diag_layers[:-1], tf.ones_like(diag_layers[-1:])), axis=0)
diag_layers, off_diag_layers = tf.transpose(diag_layers), tf.transpose(off_diag_layers)
mesh_layers = [MeshVerticalLayer(self.pairwise_perm_idx, diag_layers[0], off_diag_layers[0],
self.perm_layers[1], self.perm_layers[0])]
for layer in range(1, self.num_layers):
mesh_layers.append(MeshVerticalLayer(self.pairwise_perm_idx, diag_layers[layer], off_diag_layers[layer],
self.perm_layers[layer + 1]))
return mesh_layers
[docs]class MeshLayer(TransformerLayer):
"""Mesh network layer for unitary operators implemented in numpy
Args:
mesh_model: The `MeshModel` model of the mesh network (e.g., rectangular, triangular, custom, etc.)
activation: Nonlinear activation function (:code:`None` if there's no nonlinearity)
incoherent: Use an incoherent representation for the layer (no phase coherent between respective inputs...)
phases: Initialize with phases (overrides mesh model initialization)
"""
def __init__(self, mesh_model: MeshModel, activation: Activation = None,
incoherent: bool = False,
phase_loss_fn: Optional[Callable[[tf.Tensor], tf.Tensor]] = None,
**kwargs):
self.mesh = Mesh(mesh_model)
self.units, self.num_layers = self.mesh.units, self.mesh.num_layers
self.incoherent = incoherent
self.phase_loss_fn = phase_loss_fn
super(MeshLayer, self).__init__(self.units, activation=activation, **kwargs)
theta_init, phi_init, gamma_init = self.mesh.model.init
self.theta, self.phi, self.gamma = theta_init.to_tf("theta"), phi_init.to_tf("phi"), gamma_init.to_tf("gamma")
self.theta_fn, self.phi_fn, self.gamma_fn = self.mesh.model.theta_fn, self.mesh.model.phi_fn, self.mesh.model.gamma_fn
@property
def phases_and_layers(self) -> Tuple[MeshPhasesTensorflow, List[MeshVerticalLayer]]:
"""
Returns:
Phases and layers for this mesh layer
"""
mesh_phases = MeshPhasesTensorflow(
theta=self.theta, phi=self.phi, gamma=self.gamma,
theta_fn=self.theta_fn, phi_fn=self.phi_fn, gamma_fn=self.gamma_fn,
mask=self.mesh.model.mask, hadamard=self.mesh.model.hadamard,
units=self.units, basis=self.mesh.model.basis, phase_loss_fn=self.phase_loss_fn,
)
mesh_layers = self.mesh.mesh_layers(mesh_phases)
return mesh_phases, mesh_layers
@property
def phases(self) -> MeshPhases:
"""
Returns:
The :code:`MeshPhases` object for this layer
"""
return MeshPhases(
theta=self.theta.numpy() * self.mesh.model.mask,
phi=self.phi.numpy() * self.mesh.model.mask,
mask=self.mesh.model.mask,
gamma=self.gamma.numpy()
)
def _roll_tensor(tensor: tf.Tensor, up=False):
# a complex number-friendly roll that works on gpu
if up:
return tf.concat([tensor[1:], tensor[tf.newaxis, 0]], axis=0)
return tf.concat([tensor[tf.newaxis, -1], tensor[:-1]], axis=0)
def _to_complex(tensor: tf.Tensor):
if isinstance(tensor, tf.Tensor) and tensor.dtype == tf.float32:
return tf.complex(tensor, tf.zeros_like(tensor))
else:
return tensor