Skip to content

Comments

Add ITG turbulence proxy and neural network objectives#2080

Open
byoungj wants to merge 31 commits intoPlasmaControl:masterfrom
byoungj:cj/itg-proxy
Open

Add ITG turbulence proxy and neural network objectives#2080
byoungj wants to merge 31 commits intoPlasmaControl:masterfrom
byoungj:cj/itg-proxy

Conversation

@byoungj
Copy link
Collaborator

@byoungj byoungj commented Feb 6, 2026

Related to #1196

Adds two new objective functions for ITG turbulence optimization: an analytical proxy (ITGProxy) and a neural network ensemble proxy (NNITGProxy), based on Landreman et al. 2025 ("How does ITG turbulence depend on magnetic geometry?").

GX Geometric Features

  • Reference quantities (gx_B_reference, gx_L_reference), magnetic field strength (gx_bmag), gradient features (gx_gds2, gx_gds21_over_shat, gx_gds22_over_shat_squared), drift features (gx_gbdrift, gx_cvdrift, gx_gbdrift0_over_shat), and parallel gradient (gx_gradpar)
  • Sign convention corrections for toroidal flux direction and pressure gradient terms

ITGProxy Objective

  • Analytical proxy: f_Q = mean([Theta(cvdrift) + 0.2] * |grad(x)|^3 / B)
  • Uses smooth sigmoid approximation for differentiability during optimization

Flux Tube Geometry Utilities

  • Arclength computation via gradpar with uniform resampling to z in [-pi, pi)
  • Brent's method solver for poloidal turns matching a target flux tube length (I tried making it differentiable via desc.backend.root_scalar within .compute())
  • Dynamic poloidal_turns(fixed toroidal turns) that adapts as iota changes during optimization

NNITGProxy Objective

  • Pure-JAX re-implementation of the CyclicInvariantNet CNN architecture (circular convolutions, BatchNorm inference, max/global-avg pooling)
  • Loads pre-trained PyTorch .pth weights and converts to JAX arrays (TODO: will have to make this not Torch dependent.
  • Ensemble inference with configurable top_k model selection from DeepHyper results
  • JIT-compiled forward pass
  • Batched grid processing: single map_coordinates + compute call for all (rho, alpha, zeta) points
  • Optional features: return_per_alpha for per-field-line analysis, return_std for ensemble uncertainty, return_signals for intermediate GX feature debugging

Known Limitations/TODO

  • The quantities calculated here are the same as in _stability with normalization / sign factors, so this can be consolidated. But for debugging purposes, I have created a separate compute() quantities for now.
  • PyTorch dependency for weight loading: Currently requires torch to deserialize .pth files. Plan is to serialize weights to a torch-free format (e.g. .npz) in a separate conversion repo, so DESC only needs JAX at runtime.
  • TEM metrics (Proll 2015, Mackenbach 2022/2023) are not included in this PR — they depend on bounce integral infrastructure and will be addressed separately under Turbulence metrics for ITG/TEM #1196.

Chris J and others added 29 commits February 5, 2026 03:15
Implement gx_B_reference, gx_L_reference, and gx_bmag following
GX gyrokinetics conventions for ITG turbulence prediction.
- Implement gx_gds2, gx_gds21_over_shat, gx_gds22_over_shat_squared
- Follow GX conventions for ITG turbulence prediction
- Implement gx_gbdrift, gx_cvdrift, gx_gbdrift0_over_shat
- Simplify tests to focus on finiteness and physical constraints
- Implement gx_gradpar for flux tube arclength resampling
- Implement ITG proxy integrand and scalar proxy from Landreman et al.
- Use smooth sigmoid for differentiability in optimization
- Implement ITGProxy for optimizing ITG turbulence proxy
- Use field-aligned grid to evaluate proxy on flux surfaces
- Implement compute_arclength_via_gradpar utility for NNITGProxy
- Integrate 1/|gradpar| along field line using trapezoidal rule
- Support both single (1D) and multiple (2D) field lines
- Implement resample_to_uniform_arclength utility for NNITGProxy
- Map non-uniform arclength to uniform z in [-pi, pi) using cubic interpolation
- Support single and multiple field lines via interpax
- Implement compute_flux_tube_length (wrapper for total length)
- Implement solve_poloidal_turns_for_length (Brent's method solver)
- Match GX conventions for target flux tube length
- Implement _circular_pad_1d for periodic boundary padding
- Implement _conv1d_circular for 1D convolution with circular padding
- Implement _max_pool_1d and _global_avg_pool_1d
- Match PyTorch CyclicInvariantNet architecture from Landreman et al. 2025
- Implement _cyclic_invariant_forward (5 conv blocks + FC layers)
- Implement _ensemble_forward (average predictions over ensemble)
- Match Landreman et al. 2025 ensemble models (no BatchNorm)
- Add test verifying JAX matches PyTorch output within 1e-4 tolerance
- Support both BatchNorm and non-BatchNorm CyclicInvariantNet models
- Auto-detect model type from weight keys
- Apply BatchNorm in inference mode using running statistics
- Extract helper functions for cleaner code
- Load PyTorch model weights and convert to JAX arrays for inference
- Support both single model and ensemble loading from DeepHyper results
- Use module-level caching to avoid reloading weights from disk
Implement NNITGProxy objective class for NN-based ITG turbulence
prediction. The build() method solves for poloidal_turns to achieve
target flux tube length using uniform theta_PEST parameterization with
explicit PEST->DESC coordinate mapping.

- Add __init__ with ensemble/single model support
- Add build() with length solver using map_coordinates
- Add _load_model_weights helper for ensemble/single loading
- Refactor duplicate GX coefficient tests into parameterized test
- Remove compute_flux_tube_length (trivial wrapper, now unused)
- Compute 7 GX features along field lines using uniform theta_PEST
- Resample to uniform arclength spacing (96 points)
- Run CNN forward pass (single model or ensemble)
- Average Q over field lines (alphas)
- Add return_signals parameter for debugging/validation
- Add tests for compute and return_signals
- Replace double loop over (rho, alpha) with batched approach
- Single map_coordinates call for all nz x num_rho x num_alpha points
- Single compute_fun call for all GX features
- Batch alpha dimension for CNN inference (still loop over rho for resampling)
- Add PyTorch dependency check in __init__ with clear error message
- Skip NNITGProxy tests when torch unavailable
Store toroidal_turns instead of poloidal_turns so flux tube length
adapts as iota changes during optimization:
- build(): convert poloidal_turns to toroidal_turns = poloidal_turns / |iota|
- compute(): parameterize by zeta (fixed toroidal extent) instead of theta_pest
- theta_pest = alpha + iota x zeta now varies by rho
- Per-rho theta_pest_offset for correct arclength computation
Add optional exact flux tube length solving at compute time:
- New solve_length_at_compute parameter (default False)
- When True, re-solve for exact poloidal_turns per rho using root_scalar
- More expensive but maintains exact target length during optimization

Refactor to reduce code duplication with 4 helper methods:
- _map_pest_to_desc_coords: PEST to DESC coordinate transform
- _compute_gx_features_on_grid: Grid creation and compute_fun
- _run_cnn_inference_for_rho: Arclength, resample, and CNN
- _build_return_value: Return value formatting
- Add return_per_alpha parameter to NNITGProxy.compute()
- Return Q shape (num_rho, num_alpha) for per-field-line heat flux analysis
- Enable robustness analysis across field lines
- Fix asymmetric padding in _conv1d_circular
- Match PyTorch padding='same' behavior for even kernel sizes
- Remove 8 redundant CNN primitive tests (covered by JAX vs PyTorch test)
- Add BatchNorm coverage
- Merge return_per_alpha tests
Add three features to improve ensemble model handling in NNITGProxy:
- Verbose progress reporting prints progress every 10 models during
  ensemble loading (activated when verbose > 0).
- Auto pre_method detection finds the most common pre_method among
  top-k models from the DeepHyper results CSV.
- Ensemble std output via new return_std parameter in compute() returns
  ensemble uncertainty (Q_std). Uses delta method to convert log-space
  std to Q-space.
Add JIT-compiled closures for CNN inference to improve performance.
- _make_jit_forward() creates JIT-compiled forward with weights captured
- Weight loading functions now return (weights, jit_forward) tuples
- _run_cnn_inference_for_rho uses JIT functions when available
- Backward compatible: falls back to non-JIT if functions not present
- Expected 3-10x speedup after first compilation
- Add _SIGMA_BXY module constant for GX sign convention
- Use gx_B_reference and gx_L_reference instead of recalculating
- Add Timer to ITGProxy.build() and NNITGProxy.build()
- Update description fields to plain English (formulas in label)
- Simplify module docstring
- Consolidate 4-tuple (nn_weights, jit_forward, ensemble_weights,
  jit_forwards) into single "models" list of JIT-compiled callables
- Remove dead non-JIT fallback code paths in _run_cnn_inference_for_rho
- Infer conv/FC layer counts from weight dict keys in
  _cyclic_invariant_forward instead of hardcoding range(5) and indices
- Redesign NNITGProxy tests with module-scoped fixtures and
  nz_internal=101 for speed (15 min -> 8 min on CPU)
- Move JIT-compiled forward functions from _constants to self._models
- Add _models to _static_attrs so JAX treats them as static aux_data
- Fixes TypeError when PjitFunction objects hit pytree flattening
  during optimization with use_jit=True
- Use csv.DictReader instead of pd.read_csv in both files
- pandas was not used anywhere else in the DESC codebase
- gds21/shat: use unsigned B_reference matching gx_geometry
- cvdrift: remove extra psi_sign from pressure term
- gbdrift0/shat: add toroidal_flux_sign and document formula
- Apply sign(iota) correction in batched compute path for
  anti-symmetric features (gbdrift0/shat, gds21/shat) when
  zeta parameterization reverses field line direction
- Update test assertions for _models attribute
@byoungj byoungj requested a review from rahulgaur104 February 6, 2026 22:17
@github-actions
Copy link
Contributor

github-actions bot commented Feb 6, 2026

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |    1.05 %    |     4.015e+03      |     4.057e+03      |    42.08     |       37.55        |       35.35        |
  test_proximal_jac_w7x_with_eq_update   |   -2.55 %    |     6.575e+03      |     6.407e+03      |   -167.96    |       159.05       |       158.44       |
  test_proximal_freeb_jac                |    0.40 %    |     1.317e+04      |     1.322e+04      |    52.09     |       80.60        |       82.12        |
  test_proximal_freeb_jac_blocked        |    1.21 %    |     7.482e+03      |     7.573e+03      |    90.89     |       70.88        |       70.77        |
  test_proximal_freeb_jac_batched        |    0.91 %    |     7.414e+03      |     7.482e+03      |    67.38     |       71.07        |       70.77        |
  test_proximal_jac_ripple               |    1.13 %    |     3.474e+03      |     3.513e+03      |    39.30     |       63.57        |       63.58        |
  test_proximal_jac_ripple_bounce1d      |    1.24 %    |     3.525e+03      |     3.569e+03      |    43.63     |       74.30        |       73.92        |
  test_eq_solve                          |    4.73 %    |     1.967e+03      |     2.060e+03      |    93.04     |       91.85        |       91.65        |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

Chris J added 2 commits February 6, 2026 14:39
- Black, isort, flake8 fixes across all 3 files
- Remove test_jax_vs_pytorch_forward_pass (moving away from PyTorch deps)
- Add @pytest.mark.unit marker
- Add CHANGELOG.md entry
- Use safediv for gx_gds22_over_shat_squared to handle rho=0
- Default rho=0.5 in ITGProxy and NNITGProxy for generic test compatibility
Copy link
Collaborator

@YigitElma YigitElma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some questions before reviewing the rest


# Check for torch dependency (required for loading .pth model weights)
try:
import torch # noqa: F401
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't check the whole code yet, but as I understand it, user needs to give the model weights as an extra file to be able to use this objective, right? If the objective requires this extra information which DESC doesn't offer, I would vote to move this objective to desc.external. Currently, this extra dependency causes some generic tests to fail, but it is easier to prevent it for things in desc.external

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, it looks like the model parameters are publicly available on https://zenodo.org/records/14867777, is it feasible to have them in DESC in a way that JAX can import directly? The data .tar file is pretty large, but how much the model parameters take in total?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally! my plan is to remove this before merge to avoid adding more dependencies within DESC.


f_Q = mean((sigmoid(cvdrift) + 0.2) * |grad_x|^3 / B)
"""
data["ITG proxy"] = jnp.mean(data["ITG proxy integrand"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is a simple mean the correct thing here? IE, should it really be a flux surface average? and should this return a value per surface, or per fieldline or 1 for the whole volume etc?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in matts paper its the mean over a fieldline right? but in theory it could also be computed on a surface, which would eliminate the need to map to field aligned coordinates. I think that's what's usually done for the other "grad(r) in bad curvature" type objectives

# Handle both 1D and 2D cases
if dl_dtheta.ndim == 1:
integrand_half = 0.5 * (dl_dtheta[1:] + dl_dtheta[:-1])
arclength = dtheta * jnp.concatenate(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can use https://quadax.readthedocs.io/en/latest/_api/quadax.cumulative_trapezoid.html#quadax.cumulative_trapezoid for this, or cumulative simpson if you want a bit more accuracy

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can also vmap to handle the multiple fieldline case

nfeatures = data.shape[0]
data_uniform = jnp.stack(
[
interp1d(z_uniform, z_orig, data[i], method="cubic")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interp1d can handle extra trailing dimensions the data, so you don't need this loop here

for i in range(nfeatures)
]
)
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should also be possible to use vmap here to get rid of these loops and conditionals

target_length : float
Target flux tube length in units of L_ref.
x0_guess : float, optional
Initial guess for poloidal turns. Default 1.0. The solver uses a
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wasn't most of matts original data using ~2 turns? would that be a better default?

Q : ndarray
Predicted heat flux for each flux surface.
Shape (num_rho,) if return_per_alpha=False (default).
Shape (num_rho, num_alpha) if return_per_alpha=True.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should always return a 1d array for it to play nice with other desc objectives and optimizers, so this may need to be flattened

-jnp.pi * toroidal_turns, jnp.pi * toroidal_turns, nz
)

# Create meshgrid: shapes will be (nz, num_rho, num_alpha)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recommend just using jnp.meshgrid for this

Q_std_all_per_alpha = [] if return_std else None
all_signals = [] if return_signals else None

for i_rho in range(num_rho):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be vmapped?

)
theta_desc_flat = desc_coords[:, 1]
# Fix theta wrapping for continuity
theta_desc_flat = theta_desc_flat + 2 * jnp.pi * jnp.round(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just do theta_desc_flat = theta_desc_flat % (2*jnp.pi)?

return_per_alpha,
return_std=False,
):
"""Compute with exact flux tube length solving per rho.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there might be a more clever way of doing this.

We can probably get a reasonable upper bound for the number of turns for a given length just based on eg a, R0 etc, or by doing this root finding once during build and then taking 2x that or something. We compute all the data along this longer fieldline, then when we resample to a uniform grid we only use up to the length we need. So no extra rootfinding necessary every time the objective gets called. The only downside would be, for a fixed resolution, slightly worse spacing between points, but assuming the bound is reasonably tight that shouldn't be too bad, plus extra points are relatively cheap compared to iterative root finding.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants