Add ITG turbulence proxy and neural network objectives#2080
Add ITG turbulence proxy and neural network objectives#2080byoungj wants to merge 31 commits intoPlasmaControl:masterfrom
Conversation
Implement gx_B_reference, gx_L_reference, and gx_bmag following GX gyrokinetics conventions for ITG turbulence prediction.
- Implement gx_gds2, gx_gds21_over_shat, gx_gds22_over_shat_squared - Follow GX conventions for ITG turbulence prediction
- Implement gx_gbdrift, gx_cvdrift, gx_gbdrift0_over_shat - Simplify tests to focus on finiteness and physical constraints
- Implement gx_gradpar for flux tube arclength resampling
- Implement ITG proxy integrand and scalar proxy from Landreman et al. - Use smooth sigmoid for differentiability in optimization
- Implement ITGProxy for optimizing ITG turbulence proxy - Use field-aligned grid to evaluate proxy on flux surfaces
- Implement compute_arclength_via_gradpar utility for NNITGProxy - Integrate 1/|gradpar| along field line using trapezoidal rule - Support both single (1D) and multiple (2D) field lines
- Implement resample_to_uniform_arclength utility for NNITGProxy - Map non-uniform arclength to uniform z in [-pi, pi) using cubic interpolation - Support single and multiple field lines via interpax
- Implement compute_flux_tube_length (wrapper for total length) - Implement solve_poloidal_turns_for_length (Brent's method solver) - Match GX conventions for target flux tube length
- Implement _circular_pad_1d for periodic boundary padding - Implement _conv1d_circular for 1D convolution with circular padding - Implement _max_pool_1d and _global_avg_pool_1d - Match PyTorch CyclicInvariantNet architecture from Landreman et al. 2025
- Implement _cyclic_invariant_forward (5 conv blocks + FC layers) - Implement _ensemble_forward (average predictions over ensemble) - Match Landreman et al. 2025 ensemble models (no BatchNorm) - Add test verifying JAX matches PyTorch output within 1e-4 tolerance
- Support both BatchNorm and non-BatchNorm CyclicInvariantNet models - Auto-detect model type from weight keys - Apply BatchNorm in inference mode using running statistics - Extract helper functions for cleaner code
- Load PyTorch model weights and convert to JAX arrays for inference - Support both single model and ensemble loading from DeepHyper results - Use module-level caching to avoid reloading weights from disk
Implement NNITGProxy objective class for NN-based ITG turbulence prediction. The build() method solves for poloidal_turns to achieve target flux tube length using uniform theta_PEST parameterization with explicit PEST->DESC coordinate mapping. - Add __init__ with ensemble/single model support - Add build() with length solver using map_coordinates - Add _load_model_weights helper for ensemble/single loading - Refactor duplicate GX coefficient tests into parameterized test - Remove compute_flux_tube_length (trivial wrapper, now unused)
- Compute 7 GX features along field lines using uniform theta_PEST - Resample to uniform arclength spacing (96 points) - Run CNN forward pass (single model or ensemble) - Average Q over field lines (alphas) - Add return_signals parameter for debugging/validation - Add tests for compute and return_signals
- Replace double loop over (rho, alpha) with batched approach - Single map_coordinates call for all nz x num_rho x num_alpha points - Single compute_fun call for all GX features - Batch alpha dimension for CNN inference (still loop over rho for resampling) - Add PyTorch dependency check in __init__ with clear error message - Skip NNITGProxy tests when torch unavailable
Store toroidal_turns instead of poloidal_turns so flux tube length adapts as iota changes during optimization: - build(): convert poloidal_turns to toroidal_turns = poloidal_turns / |iota| - compute(): parameterize by zeta (fixed toroidal extent) instead of theta_pest - theta_pest = alpha + iota x zeta now varies by rho - Per-rho theta_pest_offset for correct arclength computation
Add optional exact flux tube length solving at compute time: - New solve_length_at_compute parameter (default False) - When True, re-solve for exact poloidal_turns per rho using root_scalar - More expensive but maintains exact target length during optimization Refactor to reduce code duplication with 4 helper methods: - _map_pest_to_desc_coords: PEST to DESC coordinate transform - _compute_gx_features_on_grid: Grid creation and compute_fun - _run_cnn_inference_for_rho: Arclength, resample, and CNN - _build_return_value: Return value formatting
- Add return_per_alpha parameter to NNITGProxy.compute() - Return Q shape (num_rho, num_alpha) for per-field-line heat flux analysis - Enable robustness analysis across field lines - Fix asymmetric padding in _conv1d_circular - Match PyTorch padding='same' behavior for even kernel sizes
- Remove 8 redundant CNN primitive tests (covered by JAX vs PyTorch test) - Add BatchNorm coverage - Merge return_per_alpha tests
Add three features to improve ensemble model handling in NNITGProxy: - Verbose progress reporting prints progress every 10 models during ensemble loading (activated when verbose > 0). - Auto pre_method detection finds the most common pre_method among top-k models from the DeepHyper results CSV. - Ensemble std output via new return_std parameter in compute() returns ensemble uncertainty (Q_std). Uses delta method to convert log-space std to Q-space.
Add JIT-compiled closures for CNN inference to improve performance. - _make_jit_forward() creates JIT-compiled forward with weights captured - Weight loading functions now return (weights, jit_forward) tuples - _run_cnn_inference_for_rho uses JIT functions when available - Backward compatible: falls back to non-JIT if functions not present - Expected 3-10x speedup after first compilation
- Add _SIGMA_BXY module constant for GX sign convention - Use gx_B_reference and gx_L_reference instead of recalculating - Add Timer to ITGProxy.build() and NNITGProxy.build() - Update description fields to plain English (formulas in label) - Simplify module docstring
- Consolidate 4-tuple (nn_weights, jit_forward, ensemble_weights, jit_forwards) into single "models" list of JIT-compiled callables - Remove dead non-JIT fallback code paths in _run_cnn_inference_for_rho - Infer conv/FC layer counts from weight dict keys in _cyclic_invariant_forward instead of hardcoding range(5) and indices - Redesign NNITGProxy tests with module-scoped fixtures and nz_internal=101 for speed (15 min -> 8 min on CPU)
- Move JIT-compiled forward functions from _constants to self._models - Add _models to _static_attrs so JAX treats them as static aux_data - Fixes TypeError when PjitFunction objects hit pytree flattening during optimization with use_jit=True
- Use csv.DictReader instead of pd.read_csv in both files - pandas was not used anywhere else in the DESC codebase
- gds21/shat: use unsigned B_reference matching gx_geometry - cvdrift: remove extra psi_sign from pressure term - gbdrift0/shat: add toroidal_flux_sign and document formula - Apply sign(iota) correction in batched compute path for anti-symmetric features (gbdrift0/shat, gds21/shat) when zeta parameterization reverses field line direction - Update test assertions for _models attribute
Memory benchmark result| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | 1.05 % | 4.015e+03 | 4.057e+03 | 42.08 | 37.55 | 35.35 |
test_proximal_jac_w7x_with_eq_update | -2.55 % | 6.575e+03 | 6.407e+03 | -167.96 | 159.05 | 158.44 |
test_proximal_freeb_jac | 0.40 % | 1.317e+04 | 1.322e+04 | 52.09 | 80.60 | 82.12 |
test_proximal_freeb_jac_blocked | 1.21 % | 7.482e+03 | 7.573e+03 | 90.89 | 70.88 | 70.77 |
test_proximal_freeb_jac_batched | 0.91 % | 7.414e+03 | 7.482e+03 | 67.38 | 71.07 | 70.77 |
test_proximal_jac_ripple | 1.13 % | 3.474e+03 | 3.513e+03 | 39.30 | 63.57 | 63.58 |
test_proximal_jac_ripple_bounce1d | 1.24 % | 3.525e+03 | 3.569e+03 | 43.63 | 74.30 | 73.92 |
test_eq_solve | 4.73 % | 1.967e+03 | 2.060e+03 | 93.04 | 91.85 | 91.65 |For the memory plots, go to the summary of |
- Black, isort, flake8 fixes across all 3 files - Remove test_jax_vs_pytorch_forward_pass (moving away from PyTorch deps) - Add @pytest.mark.unit marker - Add CHANGELOG.md entry
- Use safediv for gx_gds22_over_shat_squared to handle rho=0 - Default rho=0.5 in ITGProxy and NNITGProxy for generic test compatibility
YigitElma
left a comment
There was a problem hiding this comment.
Some questions before reviewing the rest
|
|
||
| # Check for torch dependency (required for loading .pth model weights) | ||
| try: | ||
| import torch # noqa: F401 |
There was a problem hiding this comment.
I couldn't check the whole code yet, but as I understand it, user needs to give the model weights as an extra file to be able to use this objective, right? If the objective requires this extra information which DESC doesn't offer, I would vote to move this objective to desc.external. Currently, this extra dependency causes some generic tests to fail, but it is easier to prevent it for things in desc.external
There was a problem hiding this comment.
Alternatively, it looks like the model parameters are publicly available on https://zenodo.org/records/14867777, is it feasible to have them in DESC in a way that JAX can import directly? The data .tar file is pretty large, but how much the model parameters take in total?
There was a problem hiding this comment.
Totally! my plan is to remove this before merge to avoid adding more dependencies within DESC.
|
|
||
| f_Q = mean((sigmoid(cvdrift) + 0.2) * |grad_x|^3 / B) | ||
| """ | ||
| data["ITG proxy"] = jnp.mean(data["ITG proxy integrand"]) |
There was a problem hiding this comment.
is a simple mean the correct thing here? IE, should it really be a flux surface average? and should this return a value per surface, or per fieldline or 1 for the whole volume etc?
There was a problem hiding this comment.
I think in matts paper its the mean over a fieldline right? but in theory it could also be computed on a surface, which would eliminate the need to map to field aligned coordinates. I think that's what's usually done for the other "grad(r) in bad curvature" type objectives
| # Handle both 1D and 2D cases | ||
| if dl_dtheta.ndim == 1: | ||
| integrand_half = 0.5 * (dl_dtheta[1:] + dl_dtheta[:-1]) | ||
| arclength = dtheta * jnp.concatenate( |
There was a problem hiding this comment.
can use https://quadax.readthedocs.io/en/latest/_api/quadax.cumulative_trapezoid.html#quadax.cumulative_trapezoid for this, or cumulative simpson if you want a bit more accuracy
There was a problem hiding this comment.
can also vmap to handle the multiple fieldline case
| nfeatures = data.shape[0] | ||
| data_uniform = jnp.stack( | ||
| [ | ||
| interp1d(z_uniform, z_orig, data[i], method="cubic") |
There was a problem hiding this comment.
interp1d can handle extra trailing dimensions the data, so you don't need this loop here
| for i in range(nfeatures) | ||
| ] | ||
| ) | ||
| else: |
There was a problem hiding this comment.
I think it should also be possible to use vmap here to get rid of these loops and conditionals
| target_length : float | ||
| Target flux tube length in units of L_ref. | ||
| x0_guess : float, optional | ||
| Initial guess for poloidal turns. Default 1.0. The solver uses a |
There was a problem hiding this comment.
wasn't most of matts original data using ~2 turns? would that be a better default?
| Q : ndarray | ||
| Predicted heat flux for each flux surface. | ||
| Shape (num_rho,) if return_per_alpha=False (default). | ||
| Shape (num_rho, num_alpha) if return_per_alpha=True. |
There was a problem hiding this comment.
This should always return a 1d array for it to play nice with other desc objectives and optimizers, so this may need to be flattened
| -jnp.pi * toroidal_turns, jnp.pi * toroidal_turns, nz | ||
| ) | ||
|
|
||
| # Create meshgrid: shapes will be (nz, num_rho, num_alpha) |
There was a problem hiding this comment.
recommend just using jnp.meshgrid for this
| Q_std_all_per_alpha = [] if return_std else None | ||
| all_signals = [] if return_signals else None | ||
|
|
||
| for i_rho in range(num_rho): |
| ) | ||
| theta_desc_flat = desc_coords[:, 1] | ||
| # Fix theta wrapping for continuity | ||
| theta_desc_flat = theta_desc_flat + 2 * jnp.pi * jnp.round( |
There was a problem hiding this comment.
I think you can just do theta_desc_flat = theta_desc_flat % (2*jnp.pi)?
| return_per_alpha, | ||
| return_std=False, | ||
| ): | ||
| """Compute with exact flux tube length solving per rho. |
There was a problem hiding this comment.
I think there might be a more clever way of doing this.
We can probably get a reasonable upper bound for the number of turns for a given length just based on eg a, R0 etc, or by doing this root finding once during build and then taking 2x that or something. We compute all the data along this longer fieldline, then when we resample to a uniform grid we only use up to the length we need. So no extra rootfinding necessary every time the objective gets called. The only downside would be, for a fixed resolution, slightly worse spacing between points, but assuming the bound is reasonably tight that shouldn't be too bad, plus extra points are relatively cheap compared to iterative root finding.
Related to #1196
Adds two new objective functions for ITG turbulence optimization: an analytical proxy (
ITGProxy) and a neural network ensemble proxy (NNITGProxy), based on Landreman et al. 2025 ("How does ITG turbulence depend on magnetic geometry?").GX Geometric Features
gx_B_reference,gx_L_reference), magnetic field strength (gx_bmag), gradient features (gx_gds2,gx_gds21_over_shat,gx_gds22_over_shat_squared), drift features (gx_gbdrift,gx_cvdrift,gx_gbdrift0_over_shat), and parallel gradient (gx_gradpar)ITGProxy Objective
f_Q = mean([Theta(cvdrift) + 0.2] * |grad(x)|^3 / B)Flux Tube Geometry Utilities
gradparwith uniform resampling toz in [-pi, pi)poloidal_turns(fixed toroidal turns) that adapts as iota changes during optimizationNNITGProxy Objective
.pthweights and converts to JAX arrays (TODO: will have to make this not Torch dependent.top_kmodel selection from DeepHyper resultsmap_coordinates+computecall for all(rho, alpha, zeta)pointsreturn_per_alphafor per-field-line analysis,return_stdfor ensemble uncertainty,return_signalsfor intermediate GX feature debuggingKnown Limitations/TODO
torchto deserialize.pthfiles. Plan is to serialize weights to a torch-free format (e.g..npz) in a separate conversion repo, so DESC only needs JAX at runtime.