core.hessians
Submodules
core.hessians.chirpscan_z_error_pseudo_hessian module
- core.hessians.chirpscan_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_shg(pulse_t_dispersed, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.chirpscan_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_thg(pulse_t_dispersed, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.chirpscan_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_pg(pulse_t_dispersed, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.chirpscan_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_nhg(pulse_t_dispersed, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.chirpscan_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_element(exp_arr_mp, exp_arr_mn, omega_p, omega_n, time, pulse_t_dispersed, deltaS_m, nonlinear_method)[source]
Sum over time axis via vmap.
- core.hessians.chirpscan_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_one_m(dummy, exp_arr_m, pulse_t_dispersed, deltaS_m, time, omega, nonlinear_method, full_or_diagonal)[source]
jax.vmap over frequency axis
- core.hessians.chirpscan_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_all_m(pulse_t_dispersed, deltaS, phase_matrix, measurement_info, full_or_diagonal)[source]
Loop over shifts to get hessian for each. Does not use jax.vmap because of memory limits.
- core.hessians.chirpscan_z_error_pseudo_hessian.get_pseudo_newton_direction_Z_error(grad_m, pulse_t_dispersed, signal_t, signal_t_new, phase_matrix, descent_state, measurement_info, descent_info, full_or_diagonal)[source]
Calculates the pseudo-newton direction for the Z-error of a chirp-scan. The direction is calculated in the frequency domain.
- Parameters:
grad_m (jnp.array) – the current Z-error gradient
pulse_t_dispersed (jnp.array) – the current guess after phase_matrix was applied
signal_t (jnp.array) – the current signal field
signal_t_new (jnp.array) – the current signal field projected onto the measured intensity
phase_matrix (jnp.array) – the applied phases
descent_state (pytree)
measurement_info (pytree)
descent_info (pytree)
full_or_diagonal (str) – calculate using the full or diagonal pseudo hessian?
- Returns:
tuple[jnp.array, Pytree], the pseudo-newton direction and the updated newton_state
core.hessians.frog_z_error_pseudo_hessian module
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_shg(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_thg(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_pg(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_sd(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_nhg(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_cross_correlation_pulse(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_shg_cross_correlation_gate(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_thg_cross_correlation_gate(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_pg_cross_correlation_gate(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_sd_cross_correlation_gate(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_nhg_cross_correlation_gate(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_shg_interferometric(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_thg_interferometric(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_pg_interferometric(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_nhg_interferometric(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_shg_interferometric_cross_correlation_pulse(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_thg_interferometric_cross_correlation_pulse(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_pg_interferometric_cross_correlation_pulse(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_nhg_interferometric_cross_correlation_pulse(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_shg_interferometric_cross_correlation_gate(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_thg_interferometric_cross_correlation_gate(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_pg_interferometric_cross_correlation_gate(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_nhg_interferometric_cross_correlation_gate(pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_element_pulse(exp_arr_mp, exp_arr_mn, omega_p, omega_n, time_k, pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, frogmethod, cross_correlation, interferometric)[source]
Sum over time axis via jax.lax.scan. Does not use jax.vmap because of memory limits.
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_element_gate(exp_arr_mp, exp_arr_mn, omega_p, omega_n, time_k, pulse_t, pulse_t_shifted_m, gate_shifted_m, deltaS_m, frogmethod, cross_correlation, interferometric)[source]
Sum over time axis via jax.lax.scan.
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_one_m(dummy, exp_arr_m, pulse_t_shifted_m, gate_shifted_m, deltaS_m, pulse_t, time, omega, frogmethod, cross_correlation, interferometric, full_or_diagonal, pulse_or_gate)[source]
jax.vmap over frequency axis
- core.hessians.frog_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_all_m(pulse_t, pulse_t_shifted, gate_shifted, deltaS, tau_arr, measurement_info, full_or_diagonal, pulse_or_gate, is_tdp)[source]
Loop over shifts to get hessian for each. Does not use jax.vmap because of memory limits.
- core.hessians.frog_z_error_pseudo_hessian.get_pseudo_newton_direction_Z_error(grad_m, pulse_t, pulse_t_shifted, gate_shifted, signal_t, signal_t_new, tau_arr, descent_state, measurement_info, descent_info, full_or_diagonal, pulse_or_gate)[source]
Calculates the pseudo-newton direction for the Z-error of a FROG measurement. The direction is calculated in the frequency domain.
- Parameters:
grad_m (jnp.array) – the current Z-error gradient
pulse_t (jnp.array) – the current guess
pulse_t_shifted (jnp.array) – the current guess shifted along the time axis
gate_shifted (jnp.array) – the current gate guess shifted along the time axis
signal_t (jnp.array) – the current signal field
signal_t_new (jnp.array) – the current signal field projected onto the measured intensity
tau_arr (jnp.array) – the time delays
descent_state (pytree)
measurement_info (pytree)
descent_info (pytree)
full_or_diagonal (str) – calculate using the full or diagonal pseudo hessian?
pulse_or_gate (str) – whether the direction is calculated for the pulse or the gate-pulse
- Returns:
tuple[jnp.array, Pytree], the pseudo-newton direction and the updated newton_state
core.hessians.pie_pseudo_hessian module
- core.hessians.pie_pseudo_hessian.PIE_get_full_pseudo_hessian_all_m(probe, subelement, transform_arr, time, omega, Dkn, measurement_info, pulse_or_gate)[source]
Calculates the full pseudo hessian through jnp.einsum().
- core.hessians.pie_pseudo_hessian.PIE_get_diagonal_pseudo_hessian_all_m(probe, subelement, transform_arr, time, omega, Dkn, measurement_info, pulse_or_gate)[source]
Calculates the diagonal pseudo hessian through jnp.einsum().
- core.hessians.pie_pseudo_hessian.PIE_get_pseudo_hessian_all_m(probe, signal_f, transform_arr, measured_trace, measurement_info, descent_info, full_or_diagonal, pulse_or_gate)[source]
Just an intermediary to call full or diagonal hessian.
- core.hessians.pie_pseudo_hessian.PIE_get_pseudo_newton_direction(grad, probe, signal_f, transform_arr, measured_trace, descent_state, measurement_info, descent_info, pulse_or_gate, local_or_global)[source]
Calculates the pseudo-newton direction for the PIE loss function. Is the same for all methods. Except for those which cannot be used with PIE, which are not available. The direction is calculated in the time domain.
- Parameters:
grad (jnp.array) – the current (weighted) gradient
probe (jnp.array) – the PIE probe or modified probe/object for hessian
signal_f (jnp.array) – the signal field in the frequency domain
transform_arr (jnp.array) – the delays or phase matrix
measured_trace (jnp.array) – the measured intensity
descent_state (Pytree)
measurement_info (Pytree) – holds measurement data and parameters
descent_info (Pytree) – holds algorithm parameters
pulse_or_gate (str) – pulse or gate, (or chirpscan)
local_or_global (str) – local or global iteration?
- Returns:
tuple[jnp.array, Pytree]
core.hessians.streaking_z_error_pseudo_hessian module
- core.hessians.streaking_z_error_pseudo_hessian.Z_pseudo_hessian_diagonal_EUV_pulse(signal_t, signal_t_new, tau_arr, measurement_info)[source]
- core.hessians.streaking_z_error_pseudo_hessian.Z_pseudo_hessian_diagonal_vectorpotential(signal_t, signal_t_new, tau_arr, measurement_info)[source]
Loads of product and chain rule in this one :D
- core.hessians.streaking_z_error_pseudo_hessian.Z_pseudo_hessian_diagonal_DTME(signal_t, signal_t_new, tau_arr, measurement_info)[source]
- core.hessians.streaking_z_error_pseudo_hessian.get_pseudo_newton_direction_Z_error(grad_m, signal_t, signal_t_new, tau_arr, descent_state, measurement_info, descent_info, full_or_diagonal, pulse_or_gate)[source]
Calculates the pseudo-newton direction for the Z-error of a FROG measurement. The direction is calculated in the frequency domain.
- Parameters:
grad_m (jnp.array) – the current Z-error gradient
signal_t (jnp.array) – the current signal field
signal_t_new (jnp.array) – the current signal field projected onto the measured intensity
tau_arr (jnp.array) – the time delays
descent_state (pytree)
measurement_info (pytree)
descent_info (pytree)
full_or_diagonal (str) – calculate using the full or diagonal pseudo hessian?
pulse_or_gate (str) – whether the direction is calculated for the pulse or the gate-pulse
- Returns:
tuple[jnp.array, Pytree], the pseudo-newton direction and the updated newton_state
core.hessians.tdp_z_error_pseudo_hessian module
- core.hessians.tdp_z_error_pseudo_hessian.get_pseudo_newton_direction_Z_error(grad_m, pulse_t, pulse_t_shifted, gate_shifted, signal_t, signal_t_new, tau_arr, descent_state, measurement_info, descent_info, full_or_diagonal, pulse_or_gate)[source]
Calculates the pseudo-newton direction for the Z-error of a Time-Domain-Ptychography measurement. The direction is calculated in the frequency domain.
- Parameters:
grad_m (jnp.array) – the current Z-error gradient
pulse_t (jnp.array) – the current guess
pulse_t_shifted (jnp.array) – the current guess shifted along the time axis
gate_shifted (jnp.array) – the current gate guess shifted along the time axis
signal_t (jnp.array) – the current signal field
signal_t_new (jnp.array) – the current signal field projected onto the measured intensity
tau_arr (jnp.array) – the time delays
descent_state (pytree)
measurement_info (pytree)
descent_info (pytree)
full_or_diagonal (str) – calculate using the full or diagonal pseudo hessian?
pulse_or_gate (str) – whether the direction is calculated for the pulse or the gate-pulse
- Returns:
tuple[jnp.array, Pytree], the pseudo-newton direction and the updated newton_state
core.hessians.twodsi_z_error_pseudo_hessian module
- core.hessians.twodsi_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_shg_pulse(pulse_t, gate_pulses_m, gate_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.twodsi_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_thg_pulse(pulse_t, gate_pulses_m, gate_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.twodsi_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_pg_pulse(pulse_t, gate_pulses_m, gate_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.twodsi_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_sd_pulse(pulse_t, gate_pulses_m, gate_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.twodsi_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_nhg_pulse(pulse_t, gate_pulses_m, gate_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.twodsi_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_cross_correlation_pulse(pulse_t, gate_pulses_m, gate_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.twodsi_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_shg_cross_correlation_gate(pulse_t, gate_pulses_m, gate_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.twodsi_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_thg_cross_correlation_gate(pulse_t, gate_pulses_m, gate_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.twodsi_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_pg_cross_correlation_gate(pulse_t, gate_pulses_m, gate_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.twodsi_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_sd_cross_correlation_gate(pulse_t, gate_pulses_m, gate_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.twodsi_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_subelement_nhg_cross_correlation_gate(pulse_t, gate_pulses_m, gate_m, deltaS_m, D_arr_pn, exp_arr_mn, exp_arr_mp, n)[source]
- core.hessians.twodsi_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_element_pulse(exp_arr_mp, exp_arr_mn, omega_p, omega_n, time_k, pulse_t, gate_pulses_m, gate_m, deltaS_m, nonlinear_method, cross_correlation)[source]
Sum over time axis via jax.vmap.
- core.hessians.twodsi_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_element_gate(exp_arr_mp, exp_arr_mn, omega_p, omega_n, time_k, pulse_t, gate_pulses_m, gate_m, deltaS_m, nonlinear_method, cross_correlation)[source]
Sum over time axis via jax.vmap
- core.hessians.twodsi_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_one_m(dummy, exp_arr_m, gate_pulses_m, gate_m, deltaS_m, pulse_t, time, omega, nonlinear_method, cross_correlation, full_or_diagonal, pulse_or_gate)[source]
jax.vmap ovet the frequency axis
- core.hessians.twodsi_z_error_pseudo_hessian.calc_Z_error_pseudo_hessian_all_m(pulse_t, gate_pulses, gate, deltaS, tau_arr, measurement_info, full_or_diagonal, pulse_or_gate, is_vampire)[source]
Loop over shifts to get hessian for each. Does not use jax.vmap because of memory limits.
- core.hessians.twodsi_z_error_pseudo_hessian.get_pseudo_newton_direction_Z_error(grad_m, pulse_t, gate_pulses, gate, signal_t, signal_t_new, tau_arr, descent_state, measurement_info, descent_info, full_or_diagonal, pulse_or_gate)[source]
Calculates the pseudo-newton direction for the Z-error of a 2DSI measurement. The direction is calculated in the frequency domain.
- Parameters:
grad_m (jnp.array) – the current Z-error gradient
pulse_t (jnp.array) – the current guess
gate_pulses (jnp.array) – the currently guessed gate-pulses
gate (jnp.array) – the current gate
signal_t (jnp.array) – the current signal field
signal_t_new (jnp.array) – the current signal field projected onto the measured intensity
tau_arr (jnp.array) – the applied delays
descent_state (pytree)
measurement_info (pytree)
descent_info (pytree)
full_or_diagonal (str) – calculate using the full or diagonal pseudo hessian?
pulse_or_gate (str) – whether the direction is calculated for the pulse or the gate-pulse
- Returns:
tuple[jnp.array, Pytree], the pseudo-newton direction and the updated newton_state
core.hessians.vampire_z_error_pseudo_hessian module
- core.hessians.vampire_z_error_pseudo_hessian.get_pseudo_newton_direction_Z_error(grad_m, pulse_t, gate_pulses, gate, signal_t, signal_t_new, tau_arr, descent_state, measurement_info, descent_info, full_or_diagonal, pulse_or_gate)[source]
Calculates the pseudo-newton direction for the Z-error of a VAMPIRE measurement. The direction is calculated in the frequency domain.
- Parameters:
grad_m (jnp.array) – the current Z-error gradient
pulse_t (jnp.array) – the current guess
gate_pulses (jnp.array) – the currently guessed gate-pulses
gate (jnp.array) – the current gate
signal_t (jnp.array) – the current signal field
signal_t_new (jnp.array) – the current signal field projected onto the measured intensity
tau_arr (jnp.array) – the applied delays
descent_state (pytree)
measurement_info (pytree)
descent_info (pytree)
full_or_diagonal (str) – calculate using the full or diagonal pseudo hessian?
pulse_or_gate (str) – whether the direction is calculated for the pulse or the gate-pulse
- Returns:
tuple[jnp.array, Pytree], the pseudo-newton direction and the updated newton_state