utilities

utilities.flatten_MyNamespace(MyNamespace)[source]

Flattening the pytree. Needed for registering MyNamespace as a pytree.

utilities.unflatten_MyNamespace(aux_data, leaves)[source]

Unflattening the pytree. Needed for registering MyNamespace as a pytree.

class utilities.MyNamespace(**kwargs)[source]

Bases: object

The central Pytree. Supports basic arithmetic if shapes/structures are consistent. Does not have a fixed structure at initialization.

expand(**kwargs)[source]

Returns a new MyNamespace object containg all previous attributes as well as the **kwargs. Can be used to build an arbitrary pytree.

utilities.run_scan(do_scan, carry, no_iterations)[source]

Run a solver iteratively using lax.scan with jax.jit.

Parameters:
  • do_scan (Callable) – the callable needs to take carry its argument

  • carry (Pytree) – the initial state of the iteration

  • no_iterations (int) – the number of iterations

Returns:

tuple[Carry, Y], the output of jax.lax.scan

utilities.scan_helper(carry, xs, actual_function, number_of_args, number_of_xs)[source]

jax.lax.scan expects the provided callable to accept two arguments carry and possibly xs. This wraps around the function to be iterated by lax.scan, such that its inputs do not have to conform to lax.scan’s requirements. The provided carry and xs are unpacked and provided to actual_function. All arguments except carry and xs have to be fixed via partial. The resulting callable is then provided to lax.scan. The output of actual_function needs to be of the same structure as carry.

Parameters:
  • carry (any, tuple) – the initial state of the iteration

  • xs (any, tuple) – the xs used by jax.lax.scan

  • actual_function (Callable) – the function that is to be iterated over

  • number_of_args (int) – the number of individual arguments in carry

  • number_of_xs (int) – the number of individual arguments in xs

Returns:

Any, the output of actual_function

utilities.while_loop_helper(carry, actual_function, number_of_args)[source]

Similar to scan_helper. Unpacks carry, such that the input of actual_function does not have to conform to lax.while_loop’s requirements.

Parameters:
  • carry (any, tuple) – the initial state of the iteration

  • actual_function (Callable) – the function to be iterated over

  • number_of_args (int) – the number of individual arguments in carry

Returns:

Any, the output of actual_function

utilities.optimistix_helper_loss_function(input, args, function, no_of_args)[source]

Optimistix’s interactive solver API expects loss-functions which take two variables and returns a tuple with the error and auxilary information. This wraps around function to adhere to this. function and no_of_args have to be fixed via partial.

Parameters:
  • input (any) – the input the function

  • args (any) – the args of function

  • function (Callable) – the actual loss function

  • no_of_args (int) – the number of extra arguments

Returns:

tuple, a tuple which contains the calculated error twice, since there is no auxilary information

utilities.scan_helper_equinox(carry, xs, step, static)[source]

This function wraps around step, which is to be iterated over via lax.scan. In some cases the carry contains static not jax compatible parts. (e.g. some of the optimistix solvers contain jaxpr). These need to be filtered out to be jax compatible which can be done through equinox. The function takes carry merges the static part and removes the static part once the iteration is done. step and static have to be fixed via partial.

Parameters:
  • carry (any, tuple) – the carry to be iterated over

  • xs (any, tuple) – unused but required by lax.scan

  • step (Callable) – the function to be iterated over

  • static (any) – a static non-jax-compatible object which is to be merged before calling step and removed afterwards

Returns:

tuple, the output of step

utilities.do_fft(signal, sk, rn, axis=-1)[source]

Do a complex-valued 1D-FFT. Does not use fftshift. Instead sk and rn obtained from get_sk_rn are applied which have the same effect and make the fft work any frequency range.

Parameters:
  • signal (jnp.array) – the signal on which the fft is applied

  • sk (jnp.array) – corrective values which “shift” the signal to the correct frequencies

  • rn (jnp.array) – corrective values which “shift” the signal to the correct frequencies

  • axis (int) – the axis over which the fft is applied (Default is -1)

Returns:

jnp.array, the fourier transformed signal

utilities.do_ifft(signal, sk, rn, axis=-1)[source]

Do a complex-valued 1D-IFFT. Does not use fftshift. Instead sk and rn obtained from get_sk_rn are applied which have the same effect and make the fft work any frequency range.

Parameters:
  • signal (jnp.array) – the signal on which the ifft is applied

  • sk (jnp.array) – corrective values which “shift” the signal to the correct positions

  • rn (jnp.array) – corrective values which “shift” the signal to the correct positions

  • axis (int) – the axis over which the ifft is applied (Default is -1)

Returns:

jnp.array, the inverse fourier transformed signal

utilities.get_sk_rn(time, frequency)[source]

The definition of the FFT differs from the discrete fourier transform. In order to correct for this the input and result of fft/ifft can be multiplied by the values calculated here. This essentially results in the fourier shift theorem. time and frequency have to fullfill N=1/(df*dt).

Parameters:
  • time (jnp.array) – the time axis

  • frequency (jnp.array) – the frequency axis

Returns:

tuple[jnp.array, jnp.array], the corrections used by do_fft/do_ifft

utilities.do_interpolation_1d(x_new, x, y, method='cubic', extrap=1e-12, axis=-1)[source]

Wraps around interpax.interp1d and jnp.interp

utilities.integrate_signal_1D(signal, x, integration_method, integration_order)[source]

Calculates the indefinite integral of a signal using the Riemann sum or the Euler-Maclaurin formula.

utilities.calculate_gate(pulse_t, method)[source]

Calculate the gate field/signal for the nonlinear process.

utilities.calculate_gate_with_Real_Fields(pulse_t, method)[source]

Calculate the gate field/signal for the nonlinear process using real input fields. This allows for the description of difference frequency generation.

utilities.project_onto_intensity(signal_f, measured_intensity)[source]

Project the current complex guess signal onto the measured intensity.

utilities.project_onto_amplitude(signal_f, measured_amplitude)[source]

Project the current complex guess signal onto the measured amplitude.

utilities.initialize_mu(optimizer, measurement_info, descent_info)[source]
utilities.calculate_mu(trace, measured_trace, measurement_info, descent_info, local_or_global)[source]
utilities.calculate_trace(signal_f, measured_trace, measurement_info, descent_info, local_or_global)[source]

Calculates intensity from a complex signal. As well as the calibration factor/curve. Needs to be vmapped in order to apply to a population.

utilities.calculate_trace_error(mu, trace, measured_trace)[source]

Calculates the mean of the squared L2-Norm between the measured intensity and intensity of the current guess. With the current guess being scaled by mu.

utilities.calculate_Z_error(signal_t, signal_t_new)[source]

Calculates the squared L2-Norm between the complex signal fields in the time domain before and after projection onto the measured signal.

utilities.generate_random_continuous_function(key, no_points, x, minval, maxval, distribution, forced_vals=False, **kwargs)[source]

Generates a 1D-array with random but continuous values. Uses a cubic inter/extrapolation of random values.

Parameters:
  • key (jnp.array) – a jax.random.PRNGKey

  • no_points (int) – the number of random points to use for the interpolation

  • x (jnp.array) – the x-values from which to choose the location of random values

  • minval (int, float) – the minimal random y-value, the interpolation may lead to lower values

  • maxval (int, float) – the maximal random y-value, the interpolation may lead to higher values

  • distribution (jnp.array) – a probability distribution for the x-location of the random values.

Returns:

jnp.array, the interpolated random y-values.

utilities.solve_linear_system(A, b, x_prev, solver)[source]

Solve a stack of linear equation Ax=b using scipy or lineax.

Parameters:
  • A (jnp.array) – stack of 2D-arrays

  • b (jnp.array) – stack of 1D-arrays

  • x_prev (jnp.array) – stack of 1D-arrays with approximate solutions.

  • solver (str,lineax-solver) – which library/method to use

Returns:

jnp.array, stack of 1D-arrays with the solution to Ax=b

utilities.calculate_newton_direction(grad_m, hessian_m, lambda_lm, newton_direction_prev, solver, full_or_diagonal)[source]

Calculates the newton-direction give a gradient and a hessian.

Parameters:
  • grad_m (jnp.array)

  • hessian_m (jnp.array)

  • lambda_lm (float)

  • newton_direction_prev (jnp.array)

  • solver – (str, lineax-solver):

  • full_or_diagonal (str)

Returns:

tuple[jnp.array, Pytree]

utilities.get_idx_arr(N, M, key)[source]

Create a stack of size M with randomized arrays with indices with range 0, N.

Parameters:
  • N (int) – the maximum index

  • M (int) – the number of randomizations

  • key (jnp.array) – a jax.random.PRNGKey

Returns:

jnp.array, a stack of 1D-arrays with randomized indices

utilities.get_com(signal, idx_arr)[source]

Calculate the center of mass of a signal.

utilities.center_signal(signal)[source]

Center a signal to the middle of an array via its center of mass. Is done in two stages since periodic boundaries distort the actual center of mass.

Parameters:

signal (jnp.array) – the signal to be centered.

Returns:

jnp.array, the signal with its center of mass located at index N/2

utilities.center_signal_to_max(signal)[source]

Center a signal to the middle of an array via jnp.argmax.

utilities.remove_phase_jumps(phase)[source]

Checks for jumps of 2*pi in phase, subtracts accordingly to get a smooth phase.

utilities.get_score_values(output_pulses, input_pulses, gate=False, factor=-1)[source]

Computes different error-metrics for a reconstructed pulse given the exact pulse is known and provided. The error metrics are the maximum cross-correlation between reconstructed and exact pulse in the time and frequency domain. The cross-correlation in the frequency domain without any shifts. This evaluates the efficacy of the retrived central freqeuncy. A weighted and normalized L2-Norm of the GDD difference of reconstructed and exact pulse.