Advanced Usage

As the title suggests in this section a series of rarely required but available use-cases are described.

Real-Fields

Generally in ultrafast pulse retrieval it is sufficient (and convenient) to describe the pulses as complex fields. However this comes at a drawback. Complex-valued fields are not hermitian. Thus the pulses they describe do not possess negative frequencies and are thus not able to describe nonlinear processes involving difference-frequency generation. This includes explicit DFG-Traces or measurement techniques like TREX [35], which simultaneously measure multiple nonlinear signals.
In order to solve this the pulsedjax.real_fields module explicitely uses real-valued fields to calculate nonlinear signals. Unfortunately the presence of negative frequencies in real-valued signals requires the usage of a large frequency axis, which increases computational demand. In order to avoid convergence issues, the pulses are only defined on a user-specified frequency range. However repetitive interpolation between different frequency axis utilizes additional computational ressources.

Below is a usage example of an SHG-FROG with real fields and the AutoDiff-Solver. The only actual difference is the user-specified input f_range_pulse = (fmin, fmax) as well as f_range_fields = (fmin, fmax) and optionally f_max_all_fields = fmax. Where f_range_pulse describes the frequency interval in which the pulse (and gate-pulse) are located, f_range_fields describes the frequency interval in which the relevant nonlinear signals are located, and f_max_all_fields is the largest frequency to be considered when the nonlinear signals are calculated.

from pulsedjax.real_fields import frog
import optax

ad = frog.AutoDiff(delay, frequency, trace, "shg", f_range_pulse=(0.1,0.25), f_range_fields=(0.25,0.5), 
                   solver=optax.adam(learning_rate=0.1))

population = ad.create_initial_population(5, "continuous", "continuous")
final_result = ad.run(population, 50)

Mixing Algorithms

All algorithms work based on an algorithm.step(descent_state, measurement_info, descent_info) function, which always takes the stated inputs. Thus in principle composite algorithms can easily be created by chaining step functions of different algorithms. However descent_state and descent_info have different structures and contents between different algorithms. In a composite algorithm these algorithm dependent containers need to be accounted for. In the case of descent_info this is not an issue since it is static. For descent_state the user has to specify how the population or population individuals are transfered between different variants.
Below is an example of an algorithm composed of DifferentialEvolution and AutoDiff. The goal is to refine the fittest individual at each iteration of the DifferentialEvolution.

from pulsedjax.frog import DifferentialEvolution, AutoDiff
import optax

from pulsedjax.utilities import scan_helper, run_scan
from equinox import tree_at
import jax.numpy as jnp
import jax
from jax.tree_util import Partial


# instantiate and initialize DifferentialEvolution and AutoDiff
# for population: types and number of basis funcs must be the same
# otherwise transfering the population between descent_states becomes "difficult"

de = DifferentialEvolution(delay, frequency_trace, trace, "shg", strategy="best2_exp")
population_de = de.create_initial_population(10, 
                                             amp_type="bsplines_5", phase_type="bsplines_5", 
                                             no_funcs_amp=15, no_funcs_phase=15)
descent_state_de, step_de = de.initialize_run(population_de)


ad = AutoDiff(delay, frequency_trace, trace, "shg", solver=optax.lbfgs(learning_rate=0.1))
population_ad = ad.create_initial_population(1, 
                                             amp_type="bsplines_5", phase_type="bsplines_5", 
                                             no_funcs_amp=15, no_funcs_phase=15)
descent_state_ad, step_ad = ad.initialize_run(population_ad)



# descent_state_ad is kept static in order to avoid cross-talk between optimization runs of different individuals
# so there is no need to pass it into _step_composite explicitely
def _step_composite(descent_state_de):

    # make one DifferentialEvolution-Step
    descent_state_de, error_de = step_de(descent_state_de, None)

    # Extract the fittest individual
    idx = de.get_idx_best_individual(descent_state_de)
    fittest_individual = de.get_individual_from_idx(idx, descent_state_de.population)

    # insert the fittest individual into the AutoDiff descent_state
    descent_state_ad_new = tree_at(lambda x: x.population, descent_state_ad, fittest_individual)

    # run an AutoDiff optimization of that individual with 50 iterations
    descent_state_ad_new, error_ad = jax.lax.scan(step_ad, descent_state_ad_new, length=50)

    # return this individual into the population of DifferentialEvolution
    population_de = jax.tree.map(lambda x,y: x.at[idx].set(y[0]), descent_state_de.population, descent_state_ad_new.population)
    descent_state_de = tree_at(lambda x: x.population, descent_state_de, population_de)
    return descent_state_de, jnp.concatenate([error_de, error_ad[-1]])


# convert _composite_step into a lax.scan compatible form
step_composite = Partial(scan_helper, actual_function=_step_composite, number_of_args=1, number_of_xs=0)

# run
descent_state_de, error_arr = run_scan(step_composite, descent_state_de, 100)
final_result = de.post_process(descent_state_de, error_arr)

Differentiability and JAX/Equinox-Transformations

Even though the algorithms were not meant/built to be differentiable, they appear to be so, if they are part of a pure function. This fact may in principle be used to optimize the parameters of an algorithm via automatic-differentiation. An example for this can be found here.

Furthermore this means that if the algorithms are part of pure functions, they should be compatible with all JAX/Equinox-transformations.

Adding new Methods

One idea behind pulsedjax is modularity. This should make it relatively easy to add new methods as well as new algorithms. All methods are implemented as RetrievePulsesMETHOD(RetrievePulses) in pulsedjax.core.base_classes_methods.py. Methods define how data is preprocessed and added into measurement_info. In some cases specific post-processing is needd as well. Most importantly a method-class needs to define how the nonlinear signal is calculated via the class-method RetrievePulsesMETHOD.calculate_signal_t(self, individual, transform_arr, measurement_info). Below is a bare-bone implementation of a RetrievePulsesMETHOD class.

class RetrievePulsesMETHOD(RetrievePulses):

    def __init__(self, theta, frequency, measured_trace, nonlinear_method, *args, **kwargs):
        super().__init__(nonlinear_method, *args, **kwargs)

        self.theta, self.time, self.frequency, self.measured_trace = self.get_data(theta, frequency, measured_trace)

        self.measurement_info = self.measurement_info.expand(theta = self.theta,
                                                             measured_trace = self.measured_trace)
        
        # in delay based methods theta and tau_arr are the same and the same as transform_arr
        # for chirp_scans transform_arr is the same as phase_matrix, and tau_arr doesnt exist
        self.measurement_info = self.measurement_info.expand(transform_arr = self.transform_arr)



    def calculate_signal_t(self, individual, transform_arr, measurement_info):
        signal_t = ... 

        signal_f = self.fft(signal_t, measurement_info.sk, measurement_info.rn)
        signal_t = MyNamespace(signal_t=signal_t, signal_f=signal_f, ... )
        return signal_t
    


    def post_process_get_pulse_and_gate(self, descent_state, measurement_info, descent_info, idx=None):
        ...
        return pulse_t, gate_t, pulse_f, gate_f

In order to create a new final algorithm multiple inheritance is used via Algorithm(AlgorithmBASE, RetrievePulsesMETHOD). Usually some final method specific class-methods have to be defined.

Adding new Algorithms

As stated above it should be relatively easy to add new algorithms. Usually all algorithms inherit from ClassicalAlgorithmsBASE or GeneralAlgorithmsBASE and are either defined in pulsedjax.core.base_classic_algorithms.py or pulsedjax.core.base_general_optimization.py.
All algorithm classes need to possess an algorithm.step(self, descent_state, measurement_info, descent_info) function as well as an algorithm.initialize_run(self, population) function, where step() performs one iteration of the algorithm and initialize_run() prepares all provided data for the retrieval. An exception to this are PytchographicIterativeEngine and COPRA who have a local_step() and global_step() instead of step().
Below is a bare-bone implementation of an AlgorithmBASE class.

class MyAlgorithmBASE(Classic_or_GeneralAlgorithmsBASE):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self._name = "MyAlgorithm"



    def step(self, descent_state, measurement_info, descent_info):
        """
        Performs one iteration of the Generalized Projection Algorithm.
        
        Args:
            descent_state (Pytree):
            measurement_info (Pytree):
            descent_info (Pytree):

        Returns:
            tuple[Pytree, jnp.array], the updated descent state and the current trace errors of the population.
        """
        
        ... 

        return descent_state, trace_error.reshape(-1,1)
    


    def initialize_run(self, population):
        """
        Prepares all provided data and parameters for the reconstruction. 
        Here the final shape/structure of descent_state, measurement_info and descent_info are determined. 

        Args:
            population (Pytree): the initial guess as created by self.create_initial_population()
        
        Returns:
            tuple[Pytree, Callable], the initial descent state and the step-function of the algorithm.

        """

        # usually the final structure of measurement info is defined through RetrievePulsesMETHOD
        measurement_info = self.measurement_info

        # add the setting attributes into descent_info
        self.descent_info = self.descent_info.expand( ... )
        descent_info = self.descent_info

        # add the initial guess and other changing variables like prng keys into descent_state
        self.descent_state = self.descent_state.expand(population = population, ... )
        descent_state = self.descent_state

        do_step = Partial(self.step, measurement_info=measurement_info, descent_info=descent_info)
        do_step = Partial(scan_helper, actual_function=do_step, number_of_args=1, number_of_xs=0)
        return descent_state, do_step