utilities
- utilities.flatten_MyNamespace(MyNamespace)[source]
Flattening the pytree. Needed for registering MyNamespace as a pytree.
- utilities.unflatten_MyNamespace(aux_data, leaves)[source]
Unflattening the pytree. Needed for registering MyNamespace as a pytree.
- class utilities.MyNamespace(**kwargs)[source]
Bases:
objectThe central Pytree. Supports basic arithmetic if shapes/structures are consistent. Does not have a fixed structure at initialization.
- utilities.run_scan(do_scan, carry, no_iterations)[source]
Run a solver iteratively using lax.scan with jax.jit.
- Parameters:
do_scan (Callable) – the callable needs to take carry its argument
carry (Pytree) – the initial state of the iteration
no_iterations (int) – the number of iterations
- Returns:
tuple[Carry, Y], the output of jax.lax.scan
- utilities.scan_helper(carry, xs, actual_function, number_of_args, number_of_xs)[source]
jax.lax.scan expects the provided callable to accept two arguments carry and possibly xs. This wraps around the function to be iterated by lax.scan, such that its inputs do not have to conform to lax.scan’s requirements. The provided carry and xs are unpacked and provided to actual_function. All arguments except carry and xs have to be fixed via partial. The resulting callable is then provided to lax.scan. The output of actual_function needs to be of the same structure as carry.
- Parameters:
carry (any, tuple) – the initial state of the iteration
xs (any, tuple) – the xs used by jax.lax.scan
actual_function (Callable) – the function that is to be iterated over
number_of_args (int) – the number of individual arguments in carry
number_of_xs (int) – the number of individual arguments in xs
- Returns:
Any, the output of actual_function
- utilities.while_loop_helper(carry, actual_function, number_of_args)[source]
Similar to scan_helper. Unpacks carry, such that the input of actual_function does not have to conform to lax.while_loop’s requirements.
- Parameters:
carry (any, tuple) – the initial state of the iteration
actual_function (Callable) – the function to be iterated over
number_of_args (int) – the number of individual arguments in carry
- Returns:
Any, the output of actual_function
- utilities.optimistix_helper_loss_function(input, args, function, no_of_args)[source]
Optimistix’s interactive solver API expects loss-functions which take two variables and returns a tuple with the error and auxilary information. This wraps around function to adhere to this. function and no_of_args have to be fixed via partial.
- Parameters:
input (any) – the input the function
args (any) – the args of function
function (Callable) – the actual loss function
no_of_args (int) – the number of extra arguments
- Returns:
tuple, a tuple which contains the calculated error twice, since there is no auxilary information
- utilities.scan_helper_equinox(carry, xs, step, static)[source]
This function wraps around step, which is to be iterated over via lax.scan. In some cases the carry contains static not jax compatible parts. (e.g. some of the optimistix solvers contain jaxpr). These need to be filtered out to be jax compatible which can be done through equinox. The function takes carry merges the static part and removes the static part once the iteration is done. step and static have to be fixed via partial.
- Parameters:
carry (any, tuple) – the carry to be iterated over
xs (any, tuple) – unused but required by lax.scan
step (Callable) – the function to be iterated over
static (any) – a static non-jax-compatible object which is to be merged before calling step and removed afterwards
- Returns:
tuple, the output of step
- utilities.do_fft(signal, sk, rn, axis=-1)[source]
Do a complex-valued 1D-FFT. Does not use fftshift. Instead sk and rn obtained from get_sk_rn are applied which have the same effect and make the fft work any frequency range.
- Parameters:
signal (jnp.array) – the signal on which the fft is applied
sk (jnp.array) – corrective values which “shift” the signal to the correct frequencies
rn (jnp.array) – corrective values which “shift” the signal to the correct frequencies
axis (int) – the axis over which the fft is applied (Default is -1)
- Returns:
jnp.array, the fourier transformed signal
- utilities.do_ifft(signal, sk, rn, axis=-1)[source]
Do a complex-valued 1D-IFFT. Does not use fftshift. Instead sk and rn obtained from get_sk_rn are applied which have the same effect and make the fft work any frequency range.
- Parameters:
signal (jnp.array) – the signal on which the ifft is applied
sk (jnp.array) – corrective values which “shift” the signal to the correct positions
rn (jnp.array) – corrective values which “shift” the signal to the correct positions
axis (int) – the axis over which the ifft is applied (Default is -1)
- Returns:
jnp.array, the inverse fourier transformed signal
- utilities.get_sk_rn(time, frequency)[source]
The definition of the FFT differs from the discrete fourier transform. In order to correct for this the input and result of fft/ifft can be multiplied by the values calculated here. This essentially results in the fourier shift theorem. time and frequency have to fullfill N=1/(df*dt).
- Parameters:
time (jnp.array) – the time axis
frequency (jnp.array) – the frequency axis
- Returns:
tuple[jnp.array, jnp.array], the corrections used by do_fft/do_ifft
- utilities.do_interpolation_1d(x_new, x, y, method='cubic', extrap=1e-12, axis=-1)[source]
Wraps around interpax.interp1d and jnp.interp
- utilities.integrate_signal_1D(signal, x, integration_method, integration_order)[source]
Calculates the indefinite integral of a signal using the Riemann sum or the Euler-Maclaurin formula.
- utilities.calculate_gate(pulse_t, method)[source]
Calculate the gate field/signal for the nonlinear process.
- utilities.calculate_gate_with_Real_Fields(pulse_t, method)[source]
Calculate the gate field/signal for the nonlinear process using real input fields. This allows for the description of difference frequency generation.
- utilities.project_onto_intensity(signal_f, measured_intensity)[source]
Project the current complex guess signal onto the measured intensity.
- utilities.project_onto_amplitude(signal_f, measured_amplitude)[source]
Project the current complex guess signal onto the measured amplitude.
- utilities.calculate_mu(trace, measured_trace, measurement_info, descent_info, local_or_global)[source]
- utilities.calculate_trace(signal_f, measured_trace, measurement_info, descent_info, local_or_global)[source]
Calculates intensity from a complex signal. As well as the calibration factor/curve. Needs to be vmapped in order to apply to a population.
- utilities.calculate_trace_error(mu, trace, measured_trace)[source]
Calculates the mean of the squared L2-Norm between the measured intensity and intensity of the current guess. With the current guess being scaled by mu.
- utilities.calculate_Z_error(signal_t, signal_t_new)[source]
Calculates the squared L2-Norm between the complex signal fields in the time domain before and after projection onto the measured signal.
- utilities.generate_random_continuous_function(key, no_points, x, minval, maxval, distribution, forced_vals=False, **kwargs)[source]
Generates a 1D-array with random but continuous values. Uses a cubic inter/extrapolation of random values.
- Parameters:
key (jnp.array) – a jax.random.PRNGKey
no_points (int) – the number of random points to use for the interpolation
x (jnp.array) – the x-values from which to choose the location of random values
minval (int, float) – the minimal random y-value, the interpolation may lead to lower values
maxval (int, float) – the maximal random y-value, the interpolation may lead to higher values
distribution (jnp.array) – a probability distribution for the x-location of the random values.
- Returns:
jnp.array, the interpolated random y-values.
- utilities.solve_linear_system(A, b, x_prev, solver)[source]
Solve a stack of linear equation Ax=b using scipy or lineax.
- Parameters:
A (jnp.array) – stack of 2D-arrays
b (jnp.array) – stack of 1D-arrays
x_prev (jnp.array) – stack of 1D-arrays with approximate solutions.
solver (str,lineax-solver) – which library/method to use
- Returns:
jnp.array, stack of 1D-arrays with the solution to Ax=b
- utilities.calculate_newton_direction(grad_m, hessian_m, lambda_lm, newton_direction_prev, solver, full_or_diagonal)[source]
Calculates the newton-direction give a gradient and a hessian.
- Parameters:
grad_m (jnp.array)
hessian_m (jnp.array)
lambda_lm (float)
newton_direction_prev (jnp.array)
solver – (str, lineax-solver):
full_or_diagonal (str)
- Returns:
tuple[jnp.array, Pytree]
- utilities.get_idx_arr(N, M, key)[source]
Create a stack of size M with randomized arrays with indices with range 0, N.
- Parameters:
N (int) – the maximum index
M (int) – the number of randomizations
key (jnp.array) – a jax.random.PRNGKey
- Returns:
jnp.array, a stack of 1D-arrays with randomized indices
- utilities.center_signal(signal)[source]
Center a signal to the middle of an array via its center of mass. Is done in two stages since periodic boundaries distort the actual center of mass.
- Parameters:
signal (jnp.array) – the signal to be centered.
- Returns:
jnp.array, the signal with its center of mass located at index N/2
- utilities.center_signal_to_max(signal)[source]
Center a signal to the middle of an array via jnp.argmax.
- utilities.remove_phase_jumps(phase)[source]
Checks for jumps of 2*pi in phase, subtracts accordingly to get a smooth phase.
- utilities.get_score_values(output_pulses, input_pulses, gate=False, factor=-1)[source]
Computes different error-metrics for a reconstructed pulse given the exact pulse is known and provided. The error metrics are the maximum cross-correlation between reconstructed and exact pulse in the time and frequency domain. The cross-correlation in the frequency domain without any shifts. This evaluates the efficacy of the retrived central freqeuncy. A weighted and normalized L2-Norm of the GDD difference of reconstructed and exact pulse.