Source code for neurophox.helpers

from typing import Optional, Callable, Tuple

import numpy as np
import tensorflow as tf
try:
    import torch
except ImportError:
    # if the user did not install pytorch, just do tensorflow stuff
    pass

from scipy.stats import multivariate_normal

from .config import NP_FLOAT


[docs]def to_stripe_array(nparray: np.ndarray, units: int): """ Convert a numpy array of phase shifts of size (`num_layers`, `units`) or (`batch_size`, `num_layers`, `units`) into striped array for use in general feedforward mesh architectures. Args: nparray: phase shift values for all columns units: dimension the stripe array acts on (depends on parity) Returns: A general mesh stripe array arrangement that is of size (`units`, `num_layers`) or (`batch_size`, `units`, `num_layers`) """ if len(nparray.shape) == 2: num_layers, _ = nparray.shape stripe_array = np.zeros((units - 1, num_layers), dtype=nparray.dtype) stripe_array[::2] = nparray.T stripe_array = np.vstack([stripe_array, np.zeros(shape=(1, num_layers))]) else: num_layers, _, batch_size = nparray.shape stripe_array = np.zeros((units - 1, num_layers, batch_size), dtype=nparray.dtype) stripe_array[::2] = nparray.transpose((1, 0, 2)) stripe_array = np.vstack([stripe_array, np.zeros(shape=(1, num_layers, batch_size))]) return stripe_array
[docs]def to_absolute_theta(theta: np.ndarray) -> np.ndarray: theta = np.mod(theta, 2 * np.pi) theta[theta > np.pi] = 2 * np.pi - theta[theta > np.pi] return theta
[docs]def get_haar_diagonal_sequence(diagonal_length, parity_odd: bool = False): odd_nums = list(diagonal_length + 1 - np.flip(np.arange(1, diagonal_length + 1, 2), axis=0)) even_nums = list(diagonal_length + 1 - np.arange(2, 2 * (diagonal_length - len(odd_nums)) + 1, 2)) nums = np.asarray(odd_nums + even_nums) if parity_odd: nums = nums[::-1] return nums
[docs]def get_alpha_checkerboard(units: int, num_layers: int, include_off_mesh: bool = False, flipud=False): if units < num_layers: raise ValueError("Require units >= num_layers!") alpha_checkerboard = np.zeros((units - 1, num_layers)) diagonal_length_to_sequence = [get_haar_diagonal_sequence(i, bool(num_layers % 2)) for i in range(1, num_layers + 1)] for i in range(units - 1): for j in range(num_layers): if (i + j) % 2 == 0: if i < num_layers and j > i: diagonal_length = num_layers - np.abs(i - j) elif i > units - num_layers and j < i - units + num_layers: diagonal_length = num_layers - np.abs(i - j - units + num_layers) - 1 * (units == num_layers) else: diagonal_length = num_layers - 1 * (units == num_layers) alpha_checkerboard[i, j] = 1 if diagonal_length == 1 else \ diagonal_length_to_sequence[int(diagonal_length) - 1][min(i, j)] else: if include_off_mesh: alpha_checkerboard[i, j] = 1 # symmetrize the checkerboard if units != num_layers: alpha_checkerboard = (alpha_checkerboard + np.flipud(alpha_checkerboard)) / 2 return alpha_checkerboard if not flipud else np.flipud(alpha_checkerboard)
[docs]def get_alpha_checkerboard_general(units: int, num_layers: int): alpha_checkerboards = [get_alpha_checkerboard(units, units, flipud=bool(n % 2 and units % 2)) for n in range(num_layers // units)] extra_layers = num_layers - num_layers // units * units if extra_layers < units: # partial checkerboard alpha_checkerboards.append( get_alpha_checkerboard(units, extra_layers, flipud=not num_layers // units % 2 and units % 2)) return np.hstack(alpha_checkerboards)
[docs]def get_efficient_coarse_grain_block_sizes(units: int, tunable_layers_per_block: int = 2, use_cg_sequence: bool = True): num_blocks = int(np.rint(np.log2(units))) sampling_frequencies = [2 ** (block_num + 1) for block_num in range(num_blocks - 1)] if use_cg_sequence: sampling_frequencies = 2 ** get_haar_diagonal_sequence(num_blocks - 1) tunable_block_sizes = [tunable_layers_per_block for _ in range(num_blocks - 1)] return np.asarray(tunable_block_sizes, dtype=np.int32), np.asarray(sampling_frequencies, dtype=np.int32)
[docs]def get_default_coarse_grain_block_sizes(units: int, use_cg_sequence: bool = True): num_blocks = int(np.rint(np.log2(units))) sampling_frequencies = [2 ** (block_num + 1) for block_num in range(num_blocks - 1)] if use_cg_sequence: sampling_frequencies = 2 ** get_haar_diagonal_sequence(num_blocks - 1) tunable_layer_rank = int(np.floor(units / num_blocks)) tunable_layer_rank = tunable_layer_rank + 1 if tunable_layer_rank % 2 else tunable_layer_rank tunable_block_sizes = [tunable_layer_rank for _ in range(num_blocks - 1)] tunable_block_sizes.append(units - tunable_layer_rank * (num_blocks - 1)) return np.asarray(tunable_block_sizes, dtype=np.int32), np.asarray(sampling_frequencies, dtype=np.int32)
[docs]def prm_permutation(units: int, tunable_block_sizes: np.ndarray, sampling_frequencies: np.ndarray, butterfly: bool = False): grid_perms = [grid_permutation(units, tunable_block_size) for tunable_block_size in tunable_block_sizes] perms_to_concatenate = [grid_perms[0][0]] for idx, frequency in enumerate(sampling_frequencies): perm_prev = grid_perms[idx][-1] perm_next = grid_perms[idx + 1][0] perm = butterfly_layer_permutation(units, frequency) if butterfly else rectangular_permutation(units, frequency) glued_perm = glue_permutations(perm_prev, perm) glued_perm = glue_permutations(glued_perm, perm_next) perms_to_concatenate += [grid_perms[idx][1:-1], glued_perm] perms_to_concatenate.append(grid_perms[-1][1:]) return np.vstack(perms_to_concatenate)
[docs]def butterfly_layer_permutation(units: int, frequency: int): if units % 2: raise NotImplementedError('Odd input dimension case not yet implemented.') frequency = frequency unpermuted_indices = np.arange(units) num_splits = units // frequency total_num_indices = num_splits * frequency unpermuted_indices_remainder = unpermuted_indices[total_num_indices:] permuted_indices = np.hstack( [np.hstack([i, i + frequency] for i in range(frequency)) + 2 * frequency * split_num for split_num in range(num_splits // 2)] + [unpermuted_indices_remainder] ) return permuted_indices.astype(np.int32)
[docs]def rectangular_permutation(units: int, frequency: int): unpermuted_indices = np.arange(units) frequency_offset = np.empty((units,)) frequency_offset[::2] = -frequency frequency_offset[1::2] = frequency permuted_indices = unpermuted_indices + frequency_offset for idx in range(units): if permuted_indices[idx] < 0: permuted_indices[idx] = -1 - permuted_indices[idx] if permuted_indices[idx] > units - 1: permuted_indices[idx] = 2 * units - 1 - permuted_indices[idx] return permuted_indices.astype(np.int32)
[docs]def grid_permutation(units: int, num_layers: int): ordered_idx = np.arange(units) split_num_layers = (num_layers - num_layers // 2, num_layers // 2) left_shift = np.roll(ordered_idx, -1, axis=0) right_shift = np.roll(ordered_idx, 1, axis=0) permuted_indices = np.zeros((num_layers, units)) permuted_indices[::2] = np.ones((split_num_layers[0], 1)) @ left_shift[np.newaxis, :] permuted_indices[1::2] = np.ones((split_num_layers[1], 1)) @ right_shift[np.newaxis, :] if num_layers % 2: return np.vstack((ordered_idx.astype(np.int32), permuted_indices[:-1].astype(np.int32), ordered_idx.astype(np.int32))) return np.vstack((ordered_idx.astype(np.int32), permuted_indices.astype(np.int32)))
[docs]def grid_viz_permutation(units: int, num_layers: int, flip: bool = False): ordered_idx = np.arange(units) split_num_layers = (num_layers - num_layers // 2, num_layers // 2) right_shift = np.roll(ordered_idx, 1, axis=0) permuted_indices = np.zeros((num_layers, units)) if flip: permuted_indices[::2] = np.ones((split_num_layers[0], 1)) @ ordered_idx[np.newaxis, :] permuted_indices[1::2] = np.ones((split_num_layers[1], 1)) @ right_shift[np.newaxis, :] else: permuted_indices[::2] = np.ones((split_num_layers[0], 1)) @ right_shift[np.newaxis, :] permuted_indices[1::2] = np.ones((split_num_layers[1], 1)) @ ordered_idx[np.newaxis, :] return np.vstack((ordered_idx.astype(np.int32), permuted_indices[:-1].astype(np.int32), ordered_idx.astype(np.int32)))
[docs]def ordered_viz_permutation(units: int, num_layers: int): ordered_idx = np.arange(units) permuted_indices = np.ones((num_layers + 1, 1)) @ ordered_idx[np.newaxis, :] return permuted_indices.astype(np.int32)
[docs]def plot_complex_matrix(plt, matrix: np.ndarray): plt.figure(figsize=(15, 5), dpi=200) plt.subplot(131) plt.title('Absolute') plt.imshow(np.abs(matrix), cmap='hot') plt.colorbar(shrink=0.7) plt.subplot(132) plt.title('Real') plt.imshow(np.real(matrix), cmap='hot') plt.colorbar(shrink=0.7) plt.subplot(133) plt.title('Imag') plt.imshow(np.imag(matrix), cmap='hot') plt.colorbar(shrink=0.7)
[docs]def random_gaussian_batch(batch_size: int, units: int, covariance_matrix: Optional[np.ndarray] = None, seed: Optional[int] = None) -> np.ndarray: if seed is not None: np.random.seed(seed) input_matrix = multivariate_normal.rvs( mean=np.zeros(units), cov=1 if not covariance_matrix else covariance_matrix, size=batch_size ) random_phase = np.random.rand(batch_size, units).astype(dtype=NP_FLOAT) * 2 * np.pi return input_matrix * np.exp(1j * random_phase)
[docs]def glue_permutations(perm_idx_1: np.ndarray, perm_idx_2: np.ndarray): perm_idx = np.zeros_like(perm_idx_1) perm_idx[perm_idx_2] = perm_idx_1 return perm_idx.astype(np.int32)
[docs]def inverse_permutation(permuted_indices: np.ndarray): inv_permuted_indices = np.zeros_like(permuted_indices) for idx, perm_idx in enumerate(permuted_indices): inv_permuted_indices[perm_idx] = idx return inv_permuted_indices
[docs]def pairwise_off_diag_permutation(units: int): ordered_idx = np.arange(units) perm_idx = np.zeros_like(ordered_idx) if units % 2: perm_idx[:-1][::2] = ordered_idx[1::2] perm_idx[1::2] = ordered_idx[:-1][::2] perm_idx[-1] = ordered_idx[-1] else: perm_idx[::2] = ordered_idx[1::2] perm_idx[1::2] = ordered_idx[::2] return perm_idx.astype(np.int32)
[docs]def butterfly_permutation(num_layers: int): ordered_idx = np.arange(2 ** num_layers) permuted_idx = np.vstack( [butterfly_layer_permutation(2 ** num_layers, 2 ** layer) for layer in range(num_layers)] ).astype(np.int32) return np.vstack((ordered_idx.astype(np.int32), permuted_idx[1:].astype(np.int32), ordered_idx.astype(np.int32)))
[docs]def neurophox_matplotlib_setup(plt): plt.rc('text', usetex=True) plt.rc('font', family='serif') # plt.rc('text', usetex=True) # plt.rc('font', **{'family': 'serif', 'serif': ['Charter']}) # plt.rcParams['mathtext.fontset'] = 'dejavuserif' plt.rcParams.update({'text.latex.preamble': [r'\usepackage{siunitx}', r'\usepackage{amsmath}']})
# Phase functions
[docs]def fix_phase_tf(fixed, mask): return lambda tensor: mask * tensor + (1 - mask) * fixed
[docs]def fix_phase_torch(fixed: np.ndarray, mask: np.ndarray, device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.cfloat): mask = torch.as_tensor(mask, dtype=dtype, device=device) fixed = torch.as_tensor(fixed, dtype=dtype, device=device) return lambda tensor: tensor * mask + (1 - mask) * fixed
[docs]def tri_phase_tf(phase_range: float): def pcf(phase): phase = tf.math.mod(phase, 2 * phase_range) phase = tf.where(tf.greater(phase, phase_range), 2 * phase_range * tf.ones_like(phase) - phase, phase) return phase return pcf
[docs]def tri_phase_torch(phase_range: float): def pcf(phase): phase = torch.fmod(phase, 2 * phase_range) phase[phase > phase_range] = 2 * phase_range - phase[phase > phase_range] return phase return pcf