neurophox.ml package¶
neurophox.ml.linear module¶
-
class
neurophox.ml.linear.LinearMultiModelRunner(experiment_name, layer_names, layers, optimizer, batch_size, iterations_per_epoch=50, iterations_per_tb_update=5, logdir=None, train_on_test=False, store_params=True)[source]¶ Bases:
objectComplex mean square error linear optimization experiment that can run and track multiple model optimizations in parallel.
- Parameters
experiment_name (
str) – Name of the experimentlayer_names (
List[str]) – List of layer nameslayers (
List[MeshLayer]) – List of transformer layersoptimizer (
Union[OptimizerV2,List[OptimizerV2]]) – Optimizer for all layers or list of optimizers for each layerbatch_size (
int) – Batch size for the optimizationiterations_per_epoch (
int) – Iterations per epochiterations_per_tb_update (
int) – Iterations per update of TensorBoardlogdir (
Optional[str]) – Logging directory for TensorBoard to track losses of each layer (default to None for no logging)train_on_test (
bool) – Use same training and testing setstore_params (
bool) – Store params during the training for visualization later
-
iterate(target_unitary, cost_fn=<function complex_mse>)[source]¶ Run gradient update toward a target unitary \(U\).
- Parameters
target_unitary (
ndarray) – Target unitary, \(U\).cost_fn (
Callable) – Cost function for linear model (default to complex mean square error)
-
run(num_epochs, target_unitary, pbar=None)[source]¶ - Parameters
num_epochs (
int) – Number of epochs (defined in terms of iterations_per_epoch)target_unitary (
ndarray) – Target unitary, \(U\).pbar (
Optional[Callable]) – Progress bar (tqdm recommended)
-
neurophox.ml.linear.complex_mse(y_true, y_pred)[source]¶ - Parameters
y_true (
Tensor) – The true labels, \(V \in \mathbb{C}^{B \times N}\)y_pred (
Tensor) – The true labels, \(\widehat{V} \in \mathbb{C}^{B \times N}\)
- Returns
The complex mean squared error \(\boldsymbol{e} \in \mathbb{R}^B\), where given example \(\widehat{V}_i \in \mathbb{C}^N\), we have \(e_i = \frac{\|V_i - \widehat{V}_i\|^2}{N}\).