{ "cells": [ { "cell_type": "markdown", "id": "22662945", "metadata": {}, "source": [ "# Advanced Usage \n", "\n", "As the title suggests in this section a series of rarely required but available use-cases are described." ] }, { "cell_type": "markdown", "id": "112b1f29", "metadata": {}, "source": [ "## Real-Fields \n", "\n", "Generally in ultrafast pulse retrieval it is sufficient (and convenient) to describe the pulses as complex fields. However this comes at a drawback. Complex-valued fields are not hermitian. Thus the pulses they describe do not possess negative frequencies and are thus not able to describe nonlinear processes involving difference-frequency generation. This includes explicit DFG-Traces or measurement techniques like [TREX](https://github.com/matillda123/Pulse-Retrieval-with-JAX/blob/main/examples/6_simulate_and_retrieve_TREX.py) {cite}`trexpaper`, which simultaneously measure multiple nonlinear signals. \n", "In order to solve this the `pulsedjax.real_fields` module explicitely uses real-valued fields to calculate nonlinear signals. Unfortunately the presence of negative frequencies in real-valued signals requires the usage of a large frequency axis, which increases computational demand. In order to avoid convergence issues, the pulses are only defined on a user-specified frequency range. However repetitive interpolation between different frequency axis utilizes additional computational ressources. \n", "\n", "Below is a usage example of an SHG-FROG with real fields and the AutoDiff-Solver. The only actual difference is the user-specified input `f_range_pulse = (fmin, fmax)` as well as `f_range_fields = (fmin, fmax)` and optionally `f_max_all_fields = fmax`. Where `f_range_pulse` describes the frequency interval in which the pulse (and gate-pulse) are located, `f_range_fields` describes the frequency interval in which the relevant nonlinear signals are located, and `f_max_all_fields` is the largest frequency to be considered when the nonlinear signals are calculated.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a0bdf0db", "metadata": {}, "outputs": [], "source": [ "from pulsedjax.real_fields import frog\n", "import optax\n", "\n", "ad = frog.AutoDiff(delay, frequency, trace, \"shg\", f_range_pulse=(0.1,0.25), f_range_fields=(0.25,0.5), \n", " solver=optax.adam(learning_rate=0.1))\n", "\n", "population = ad.create_initial_population(5, \"continuous\", \"continuous\")\n", "final_result = ad.run(population, 50)" ] }, { "cell_type": "markdown", "id": "0998588f", "metadata": {}, "source": [ "## Mixing Algorithms \n", "\n", "All algorithms work based on an `algorithm.step(descent_state, measurement_info, descent_info)` function, which always takes the stated inputs. Thus in principle composite algorithms can easily be created by chaining step functions of different algorithms. However `descent_state` and `descent_info` have different structures and contents between different algorithms. In a composite algorithm these algorithm dependent containers need to be accounted for. In the case of `descent_info` this is not an issue since it is static. For `descent_state` the user has to specify how the population or population individuals are transfered between different variants. \n", "Below is an example of an algorithm composed of `DifferentialEvolution` and `AutoDiff`. The goal is to refine the fittest individual at each iteration of the `DifferentialEvolution`." ] }, { "cell_type": "code", "execution_count": null, "id": "ca7cf14d", "metadata": {}, "outputs": [], "source": [ "from pulsedjax.frog import DifferentialEvolution, AutoDiff\n", "import optax\n", "\n", "from pulsedjax.utilities import scan_helper, run_scan\n", "from equinox import tree_at\n", "import jax.numpy as jnp\n", "import jax\n", "from jax.tree_util import Partial\n", "\n", "\n", "# instantiate and initialize DifferentialEvolution and AutoDiff\n", "# for population: types and number of basis funcs must be the same\n", "# otherwise transfering the population between descent_states becomes \"difficult\"\n", "\n", "de = DifferentialEvolution(delay, frequency_trace, trace, \"shg\", strategy=\"best2_exp\")\n", "population_de = de.create_initial_population(10, \n", " amp_type=\"bsplines_5\", phase_type=\"bsplines_5\", \n", " no_funcs_amp=15, no_funcs_phase=15)\n", "descent_state_de, step_de = de.initialize_run(population_de)\n", "\n", "\n", "ad = AutoDiff(delay, frequency_trace, trace, \"shg\", solver=optax.lbfgs(learning_rate=0.1))\n", "population_ad = ad.create_initial_population(1, \n", " amp_type=\"bsplines_5\", phase_type=\"bsplines_5\", \n", " no_funcs_amp=15, no_funcs_phase=15)\n", "descent_state_ad, step_ad = ad.initialize_run(population_ad)\n", "\n", "\n", "\n", "# descent_state_ad is kept static in order to avoid cross-talk between optimization runs of different individuals\n", "# so there is no need to pass it into _step_composite explicitely\n", "def _step_composite(descent_state_de):\n", "\n", " # make one DifferentialEvolution-Step\n", " descent_state_de, error_de = step_de(descent_state_de, None)\n", "\n", " # Extract the fittest individual\n", " idx = de.get_idx_best_individual(descent_state_de)\n", " fittest_individual = de.get_individual_from_idx(idx, descent_state_de.population)\n", "\n", " # insert the fittest individual into the AutoDiff descent_state\n", " descent_state_ad_new = tree_at(lambda x: x.population, descent_state_ad, fittest_individual)\n", "\n", " # run an AutoDiff optimization of that individual with 50 iterations\n", " descent_state_ad_new, error_ad = jax.lax.scan(step_ad, descent_state_ad_new, length=50)\n", "\n", " # return this individual into the population of DifferentialEvolution\n", " population_de = jax.tree.map(lambda x,y: x.at[idx].set(y[0]), descent_state_de.population, descent_state_ad_new.population)\n", " descent_state_de = tree_at(lambda x: x.population, descent_state_de, population_de)\n", " return descent_state_de, jnp.concatenate([error_de, error_ad[-1]])\n", "\n", "\n", "# convert _composite_step into a lax.scan compatible form\n", "step_composite = Partial(scan_helper, actual_function=_step_composite, number_of_args=1, number_of_xs=0)\n", "\n", "# run\n", "descent_state_de, error_arr = run_scan(step_composite, descent_state_de, 100)\n", "final_result = de.post_process(descent_state_de, error_arr)" ] }, { "cell_type": "markdown", "id": "5bb0361a", "metadata": {}, "source": [ "## Differentiability and JAX/Equinox-Transformations\n", "\n", "Even though the algorithms were not meant/built to be differentiable, they appear to be so, if they are part of a pure function. This fact may in principle be used to optimize the parameters of an algorithm via automatic-differentiation. An example for this can be found [here](https://github.com/matillda123/Pulse-Retrieval-with-JAX/blob/main/examples/9_optimize_stepsize.py). \n", "\n", "Furthermore this means that if the algorithms are part of pure functions, they should be compatible with all JAX/Equinox-transformations." ] }, { "cell_type": "markdown", "id": "5e98601d", "metadata": {}, "source": [ "## Adding new Methods \n", "\n", "One idea behind `pulsedjax` is modularity. This should make it relatively easy to add new methods as well as new algorithms. All methods are implemented as `RetrievePulsesMETHOD(RetrievePulses)` in `pulsedjax.core.base_classes_methods.py`. Methods define how data is preprocessed and added into `measurement_info`. In some cases specific post-processing is needd as well. Most importantly a method-class needs to define how the nonlinear signal is calculated via the class-method `RetrievePulsesMETHOD.calculate_signal_t(self, individual, transform_arr, measurement_info)`.\n", "Below is a bare-bone implementation of a `RetrievePulsesMETHOD` class." ] }, { "cell_type": "code", "execution_count": null, "id": "544c0743", "metadata": {}, "outputs": [], "source": [ "class RetrievePulsesMETHOD(RetrievePulses):\n", "\n", " def __init__(self, theta, frequency, measured_trace, nonlinear_method, *args, **kwargs):\n", " super().__init__(nonlinear_method, *args, **kwargs)\n", "\n", " self.theta, self.time, self.frequency, self.measured_trace = self.get_data(theta, frequency, measured_trace)\n", "\n", " self.measurement_info = self.measurement_info.expand(theta = self.theta,\n", " measured_trace = self.measured_trace)\n", " \n", " # in delay based methods theta and tau_arr are the same and the same as transform_arr\n", " # for chirp_scans transform_arr is the same as phase_matrix, and tau_arr doesnt exist\n", " self.measurement_info = self.measurement_info.expand(transform_arr = self.transform_arr)\n", "\n", "\n", "\n", " def calculate_signal_t(self, individual, transform_arr, measurement_info):\n", " signal_t = ... \n", "\n", " signal_f = self.fft(signal_t, measurement_info.sk, measurement_info.rn)\n", " signal_t = MyNamespace(signal_t=signal_t, signal_f=signal_f, ... )\n", " return signal_t\n", " \n", "\n", "\n", " def post_process_get_pulse_and_gate(self, descent_state, measurement_info, descent_info, idx=None):\n", " ...\n", " return pulse_t, gate_t, pulse_f, gate_f" ] }, { "cell_type": "markdown", "id": "2b11091a", "metadata": {}, "source": [ "In order to create a new final algorithm multiple inheritance is used via `Algorithm(AlgorithmBASE, RetrievePulsesMETHOD)`. Usually some final method specific class-methods have to be defined. " ] }, { "cell_type": "markdown", "id": "43c67e18", "metadata": {}, "source": [ "## Adding new Algorithms \n", "\n", "As stated above it should be relatively easy to add new algorithms. Usually all algorithms inherit from `ClassicalAlgorithmsBASE` or `GeneralAlgorithmsBASE` and are either defined in `pulsedjax.core.base_classic_algorithms.py` or `pulsedjax.core.base_general_optimization.py`. \n", "All algorithm classes need to possess an `algorithm.step(self, descent_state, measurement_info, descent_info)` function as well as an `algorithm.initialize_run(self, population)` function, where `step()` performs one iteration of the algorithm and `initialize_run()` prepares all provided data for the retrieval. An exception to this are `PytchographicIterativeEngine` and `COPRA` who have a `local_step()` and `global_step()` instead of `step()`. \n", "Below is a bare-bone implementation of an `AlgorithmBASE` class." ] }, { "cell_type": "code", "execution_count": null, "id": "30fdc3ac", "metadata": {}, "outputs": [], "source": [ "class MyAlgorithmBASE(Classic_or_GeneralAlgorithmsBASE):\n", "\n", " def __init__(self, *args, **kwargs):\n", " super().__init__(*args, **kwargs)\n", "\n", " self._name = \"MyAlgorithm\"\n", "\n", "\n", "\n", " def step(self, descent_state, measurement_info, descent_info):\n", " \"\"\"\n", " Performs one iteration of the Generalized Projection Algorithm.\n", " \n", " Args:\n", " descent_state (Pytree):\n", " measurement_info (Pytree):\n", " descent_info (Pytree):\n", "\n", " Returns:\n", " tuple[Pytree, jnp.array], the updated descent state and the current trace errors of the population.\n", " \"\"\"\n", " \n", " ... \n", "\n", " return descent_state, trace_error.reshape(-1,1)\n", " \n", "\n", "\n", " def initialize_run(self, population):\n", " \"\"\"\n", " Prepares all provided data and parameters for the reconstruction. \n", " Here the final shape/structure of descent_state, measurement_info and descent_info are determined. \n", "\n", " Args:\n", " population (Pytree): the initial guess as created by self.create_initial_population()\n", " \n", " Returns:\n", " tuple[Pytree, Callable], the initial descent state and the step-function of the algorithm.\n", "\n", " \"\"\"\n", "\n", " # usually the final structure of measurement info is defined through RetrievePulsesMETHOD\n", " measurement_info = self.measurement_info\n", "\n", " # add the setting attributes into descent_info\n", " self.descent_info = self.descent_info.expand( ... )\n", " descent_info = self.descent_info\n", "\n", " # add the initial guess and other changing variables like prng keys into descent_state\n", " self.descent_state = self.descent_state.expand(population = population, ... )\n", " descent_state = self.descent_state\n", "\n", " do_step = Partial(self.step, measurement_info=measurement_info, descent_info=descent_info)\n", " do_step = Partial(scan_helper, actual_function=do_step, number_of_args=1, number_of_xs=0)\n", " return descent_state, do_step" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }