Advanced Usage
As the title suggests in this section a series of rarely required but available use-cases are described.
Real-Fields
Generally in ultrafast pulse retrieval it is sufficient (and convenient) to describe the pulses as complex fields. However this comes at a drawback. Complex-valued fields are not hermitian. Thus the pulses they describe do not possess negative frequencies and are thus not able to describe nonlinear processes involving difference-frequency generation. This includes explicit DFG-Traces or measurement techniques like TREX [35], which simultaneously measure multiple nonlinear signals.
In order to solve this the pulsedjax.real_fields module explicitely uses real-valued fields to calculate nonlinear signals. Unfortunately the presence of negative frequencies in real-valued signals requires the usage of a large frequency axis, which increases computational demand. In order to avoid convergence issues, the pulses are only defined on a user-specified frequency range. However repetitive interpolation between different frequency axis utilizes additional computational ressources.
Below is a usage example of an SHG-FROG with real fields and the AutoDiff-Solver. The only actual difference is the user-specified input f_range_pulse = (fmin, fmax) as well as f_range_fields = (fmin, fmax) and optionally f_max_all_fields = fmax. Where f_range_pulse describes the frequency interval in which the pulse (and gate-pulse) are located, f_range_fields describes the frequency interval in which the relevant nonlinear signals are located, and f_max_all_fields is the largest frequency to be considered when the nonlinear signals are calculated.
from pulsedjax.real_fields import frog
import optax
ad = frog.AutoDiff(delay, frequency, trace, "shg", f_range_pulse=(0.1,0.25), f_range_fields=(0.25,0.5),
solver=optax.adam(learning_rate=0.1))
population = ad.create_initial_population(5, "continuous", "continuous")
final_result = ad.run(population, 50)
Mixing Algorithms
All algorithms work based on an algorithm.step(descent_state, measurement_info, descent_info) function, which always takes the stated inputs. Thus in principle composite algorithms can easily be created by chaining step functions of different algorithms. However descent_state and descent_info have different structures and contents between different algorithms. In a composite algorithm these algorithm dependent containers need to be accounted for. In the case of descent_info this is not an issue since it is static. For descent_state the user has to specify how the population or population individuals are transfered between different variants.
Below is an example of an algorithm composed of DifferentialEvolution and AutoDiff. The goal is to refine the fittest individual at each iteration of the DifferentialEvolution.
from pulsedjax.frog import DifferentialEvolution, AutoDiff
import optax
from pulsedjax.utilities import scan_helper, run_scan
from equinox import tree_at
import jax.numpy as jnp
import jax
from jax.tree_util import Partial
# instantiate and initialize DifferentialEvolution and AutoDiff
# for population: types and number of basis funcs must be the same
# otherwise transfering the population between descent_states becomes "difficult"
de = DifferentialEvolution(delay, frequency_trace, trace, "shg", strategy="best2_exp")
population_de = de.create_initial_population(10,
amp_type="bsplines_5", phase_type="bsplines_5",
no_funcs_amp=15, no_funcs_phase=15)
descent_state_de, step_de = de.initialize_run(population_de)
ad = AutoDiff(delay, frequency_trace, trace, "shg", solver=optax.lbfgs(learning_rate=0.1))
population_ad = ad.create_initial_population(1,
amp_type="bsplines_5", phase_type="bsplines_5",
no_funcs_amp=15, no_funcs_phase=15)
descent_state_ad, step_ad = ad.initialize_run(population_ad)
# descent_state_ad is kept static in order to avoid cross-talk between optimization runs of different individuals
# so there is no need to pass it into _step_composite explicitely
def _step_composite(descent_state_de):
# make one DifferentialEvolution-Step
descent_state_de, error_de = step_de(descent_state_de, None)
# Extract the fittest individual
idx = de.get_idx_best_individual(descent_state_de)
fittest_individual = de.get_individual_from_idx(idx, descent_state_de.population)
# insert the fittest individual into the AutoDiff descent_state
descent_state_ad_new = tree_at(lambda x: x.population, descent_state_ad, fittest_individual)
# run an AutoDiff optimization of that individual with 50 iterations
descent_state_ad_new, error_ad = jax.lax.scan(step_ad, descent_state_ad_new, length=50)
# return this individual into the population of DifferentialEvolution
population_de = jax.tree.map(lambda x,y: x.at[idx].set(y[0]), descent_state_de.population, descent_state_ad_new.population)
descent_state_de = tree_at(lambda x: x.population, descent_state_de, population_de)
return descent_state_de, jnp.concatenate([error_de, error_ad[-1]])
# convert _composite_step into a lax.scan compatible form
step_composite = Partial(scan_helper, actual_function=_step_composite, number_of_args=1, number_of_xs=0)
# run
descent_state_de, error_arr = run_scan(step_composite, descent_state_de, 100)
final_result = de.post_process(descent_state_de, error_arr)
Differentiability and JAX/Equinox-Transformations
Even though the algorithms were not meant/built to be differentiable, they appear to be so, if they are part of a pure function. This fact may in principle be used to optimize the parameters of an algorithm via automatic-differentiation. An example for this can be found here.
Furthermore this means that if the algorithms are part of pure functions, they should be compatible with all JAX/Equinox-transformations.
Adding new Methods
One idea behind pulsedjax is modularity. This should make it relatively easy to add new methods as well as new algorithms. All methods are implemented as RetrievePulsesMETHOD(RetrievePulses) in pulsedjax.core.base_classes_methods.py. Methods define how data is preprocessed and added into measurement_info. In some cases specific post-processing is needd as well. Most importantly a method-class needs to define how the nonlinear signal is calculated via the class-method RetrievePulsesMETHOD.calculate_signal_t(self, individual, transform_arr, measurement_info).
Below is a bare-bone implementation of a RetrievePulsesMETHOD class.
class RetrievePulsesMETHOD(RetrievePulses):
def __init__(self, theta, frequency, measured_trace, nonlinear_method, *args, **kwargs):
super().__init__(nonlinear_method, *args, **kwargs)
self.theta, self.time, self.frequency, self.measured_trace = self.get_data(theta, frequency, measured_trace)
self.measurement_info = self.measurement_info.expand(theta = self.theta,
measured_trace = self.measured_trace)
# in delay based methods theta and tau_arr are the same and the same as transform_arr
# for chirp_scans transform_arr is the same as phase_matrix, and tau_arr doesnt exist
self.measurement_info = self.measurement_info.expand(transform_arr = self.transform_arr)
def calculate_signal_t(self, individual, transform_arr, measurement_info):
signal_t = ...
signal_f = self.fft(signal_t, measurement_info.sk, measurement_info.rn)
signal_t = MyNamespace(signal_t=signal_t, signal_f=signal_f, ... )
return signal_t
def post_process_get_pulse_and_gate(self, descent_state, measurement_info, descent_info, idx=None):
...
return pulse_t, gate_t, pulse_f, gate_f
In order to create a new final algorithm multiple inheritance is used via Algorithm(AlgorithmBASE, RetrievePulsesMETHOD). Usually some final method specific class-methods have to be defined.
Adding new Algorithms
As stated above it should be relatively easy to add new algorithms. Usually all algorithms inherit from ClassicalAlgorithmsBASE or GeneralAlgorithmsBASE and are either defined in pulsedjax.core.base_classic_algorithms.py or pulsedjax.core.base_general_optimization.py.
All algorithm classes need to possess an algorithm.step(self, descent_state, measurement_info, descent_info) function as well as an algorithm.initialize_run(self, population) function, where step() performs one iteration of the algorithm and initialize_run() prepares all provided data for the retrieval. An exception to this are PytchographicIterativeEngine and COPRA who have a local_step() and global_step() instead of step().
Below is a bare-bone implementation of an AlgorithmBASE class.
class MyAlgorithmBASE(Classic_or_GeneralAlgorithmsBASE):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._name = "MyAlgorithm"
def step(self, descent_state, measurement_info, descent_info):
"""
Performs one iteration of the Generalized Projection Algorithm.
Args:
descent_state (Pytree):
measurement_info (Pytree):
descent_info (Pytree):
Returns:
tuple[Pytree, jnp.array], the updated descent state and the current trace errors of the population.
"""
...
return descent_state, trace_error.reshape(-1,1)
def initialize_run(self, population):
"""
Prepares all provided data and parameters for the reconstruction.
Here the final shape/structure of descent_state, measurement_info and descent_info are determined.
Args:
population (Pytree): the initial guess as created by self.create_initial_population()
Returns:
tuple[Pytree, Callable], the initial descent state and the step-function of the algorithm.
"""
# usually the final structure of measurement info is defined through RetrievePulsesMETHOD
measurement_info = self.measurement_info
# add the setting attributes into descent_info
self.descent_info = self.descent_info.expand( ... )
descent_info = self.descent_info
# add the initial guess and other changing variables like prng keys into descent_state
self.descent_state = self.descent_state.expand(population = population, ... )
descent_state = self.descent_state
do_step = Partial(self.step, measurement_info=measurement_info, descent_info=descent_info)
do_step = Partial(scan_helper, actual_function=do_step, number_of_args=1, number_of_xs=0)
return descent_state, do_step