diff --git a/adept/_vlasov1d/datamodel.py b/adept/_vlasov1d/datamodel.py index b735c49..f60005d 100644 --- a/adept/_vlasov1d/datamodel.py +++ b/adept/_vlasov1d/datamodel.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class SpaceProfileModel(BaseModel): @@ -23,8 +23,12 @@ class SpeciesBackgroundModel(BaseModel): class DensityModel(BaseModel): + model_config = ConfigDict(extra="allow") + quasineutrality: bool - species_background: SpeciesBackgroundModel + # Note: Allow arbitrary species-* fields to be defined dynamically + # e.g., species_background, species_beam, etc. + # These are validated as SpeciesBackgroundModel when accessed class UnitsModel(BaseModel): @@ -37,11 +41,11 @@ class UnitsModel(BaseModel): class GridModel(BaseModel): dt: float - nv: int + nv: int | None = None # Optional: for backward compatibility with single-species config files nx: int tmin: float tmax: float - vmax: float + vmax: float | None = None # Optional: for backward compatibility with single-species config files xmax: float xmin: float @@ -53,8 +57,11 @@ class TimeSaveModel(BaseModel): class SaveModel(BaseModel): + model_config = ConfigDict(extra="allow") + fields: dict[str, TimeSaveModel] - electron: dict[str, TimeSaveModel] + # Note: Allow arbitrary species save sections (electron, ion, etc.) + # These are validated as dict[str, TimeSaveModel] when accessed class ExDriverModel(BaseModel): @@ -112,10 +119,35 @@ class KrookModel(BaseModel): space: SpaceTermModel +class SpeciesConfig(BaseModel): + """Configuration for a physical species in multi-species simulations. + + Each species has its own charge-to-mass ratio and velocity grid, allowing + for electron-ion physics and other multi-species interactions. + + Attributes: + name: Species name (e.g., 'electron', 'ion') + charge: Charge in units of fundamental charge (e.g., -1.0 for electrons, 10.0 for Z=10 ions) + mass: Mass in units of electron mass (e.g., 1.0 for electrons, 1836.0 for protons) + vmax: Velocity grid maximum for this species + nv: Number of velocity grid points for this species + density_components: List of density component names from the 'density:' section + that contribute to this species' distribution function + """ + + name: str + charge: float + mass: float + vmax: float + nv: int + density_components: list[str] + + class TermsModel(BaseModel): field: str edfdv: str time: str + species: list[SpeciesConfig] | None = None fokker_planck: FokkerPlanckModel krook: KrookModel diff --git a/adept/_vlasov1d/helpers.py b/adept/_vlasov1d/helpers.py index afb9c44..f6ec574 100644 --- a/adept/_vlasov1d/helpers.py +++ b/adept/_vlasov1d/helpers.py @@ -1,6 +1,7 @@ # Copyright (c) Ergodic LLC 2023 # research@ergodic.io import os +import warnings from time import time import numpy as np @@ -92,101 +93,147 @@ def _initialize_distribution_( def _initialize_total_distribution_(cfg, cfg_grid): - params = cfg["density"] - n_prof_total = np.zeros([cfg_grid["nx"]]) - f = np.zeros([cfg_grid["nx"], cfg_grid["nv"]]) - species_found = False - for name, species_params in cfg["density"].items(): - if name.startswith("species-"): - v0 = species_params["v0"] - T0 = species_params["T0"] - m = species_params["m"] - if name in params: - if "v0" in params[name]: - v0 = params[name]["v0"] + """ + Initialize distribution functions for all species. - if "T0" in params[name]: - T0 = params[name]["T0"] + The species config is normalized in modules.py:get_derived_quantities() so that + a species config always exists (for backward compatibility with single-species + config files, a default electron species is generated). - if "m" in params[name]: - m = params[name]["m"] + Returns: + dict mapping species_name -> (n_prof, f_s, v_ax) + """ + species_configs = cfg["terms"]["species"] + species_distributions = {} + species_found = False - if species_params["basis"] == "uniform": - nprof = np.ones_like(n_prof_total) + for species_cfg in species_configs: + species_name = species_cfg["name"] + density_components = species_cfg["density_components"] + vmax = species_cfg["vmax"] + nv = species_cfg["nv"] - elif species_params["basis"] == "linear": - left = species_params["center"] - species_params["width"] * 0.5 - right = species_params["center"] + species_params["width"] * 0.5 - rise = species_params["rise"] - mask = get_envelope(rise, rise, left, right, cfg_grid["x"]) + # Initialize arrays for this species + n_prof_species = np.zeros([cfg_grid["nx"]]) + dv = 2.0 * vmax / nv + vax = np.linspace(-vmax + dv / 2.0, vmax - dv / 2.0, nv) + f_species = np.zeros([cfg_grid["nx"], nv]) - ureg = pint.UnitRegistry() - _Q = ureg.Quantity + # Sum contributions from all density components + for component_name in density_components: + if component_name not in cfg["density"]: + raise ValueError(f"Density component '{component_name}' not found in config") - L = ( - _Q(species_params["gradient scale length"]).to("nm").magnitude - / cfg["units"]["derived"]["x0"].to("nm").magnitude - ) - nprof = species_params["val at center"] + (cfg_grid["x"] - species_params["center"]) / L - nprof = mask * nprof - elif species_params["basis"] == "exponential": - left = species_params["center"] - species_params["width"] * 0.5 - right = species_params["center"] + species_params["width"] * 0.5 - rise = species_params["rise"] - mask = get_envelope(rise, rise, left, right, cfg_grid["x"]) - - ureg = pint.UnitRegistry() - _Q = ureg.Quantity - - L = ( - _Q(species_params["gradient scale length"]).to("nm").magnitude - / cfg["units"]["derived"]["x0"].to("nm").magnitude + species_params = cfg["density"][component_name] + v0 = species_params["v0"] + T0 = species_params["T0"] + + # Handle mass parameter with deprecation + if "m" in species_params: + warnings.warn( + f"Density component '{component_name}': The 'm' parameter in density config is deprecated. " + f"Please use the mass from the species config instead.", + DeprecationWarning, + stacklevel=2, ) - nprof = species_params["val at center"] * np.exp((cfg_grid["x"] - species_params["center"]) / L) - nprof = mask * nprof - - elif species_params["basis"] == "tanh": - left = species_params["center"] - species_params["width"] * 0.5 - right = species_params["center"] + species_params["width"] * 0.5 - rise = species_params["rise"] - nprof = get_envelope(rise, rise, left, right, cfg_grid["x"]) - - if species_params["bump_or_trough"] == "trough": - nprof = 1 - nprof - nprof = species_params["baseline"] + species_params["bump_height"] * nprof - - elif species_params["basis"] == "sine": - baseline = species_params["baseline"] - amp = species_params["amplitude"] - kk = species_params["wavenumber"] - nprof = baseline * (1.0 + amp * jnp.sin(kk * cfg_grid["x"])) + m = species_params["m"] + # Check if it matches the species config mass + if abs(m - species_cfg["mass"]) > 1e-10: + warnings.warn( + f"Density component '{component_name}': Mass mismatch! " + f"Using m={m} from density config, but species '{species_name}' " + f"has mass={species_cfg['mass']}. This may lead to inconsistent physics.", + UserWarning, + stacklevel=2, + ) else: - raise NotImplementedError + # Use mass from species config + m = species_cfg["mass"] - n_prof_total += nprof + # Get density profile + nprof = _get_density_profile(species_params, cfg, cfg_grid) + n_prof_species += nprof - # Distribution function + # Distribution function for this component temp_f, _ = _initialize_distribution_( nx=int(cfg_grid["nx"]), - nv=int(cfg_grid["nv"]), + nv=nv, v0=v0, m=m, T0=T0, - vmax=cfg_grid["vmax"], + vmax=vmax, n_prof=nprof, noise_val=species_params["noise_val"], noise_seed=int(species_params["noise_seed"]), noise_type=species_params["noise_type"], ) - f += temp_f + f_species += temp_f species_found = True - else: - pass + + species_distributions[species_name] = (n_prof_species, f_species, vax) if not species_found: raise ValueError("No species found! Check the config") - return n_prof_total, f + return species_distributions + + +def _get_density_profile(species_params, cfg, cfg_grid): + """Extract density profile generation logic into a helper function.""" + if species_params["basis"] == "uniform": + nprof = np.ones([cfg_grid["nx"]]) + + elif species_params["basis"] == "linear": + left = species_params["center"] - species_params["width"] * 0.5 + right = species_params["center"] + species_params["width"] * 0.5 + rise = species_params["rise"] + mask = get_envelope(rise, rise, left, right, cfg_grid["x"]) + + ureg = pint.UnitRegistry() + _Q = ureg.Quantity + + L = ( + _Q(species_params["gradient scale length"]).to("nm").magnitude + / cfg["units"]["derived"]["x0"].to("nm").magnitude + ) + nprof = species_params["val at center"] + (cfg_grid["x"] - species_params["center"]) / L + nprof = mask * nprof + + elif species_params["basis"] == "exponential": + left = species_params["center"] - species_params["width"] * 0.5 + right = species_params["center"] + species_params["width"] * 0.5 + rise = species_params["rise"] + mask = get_envelope(rise, rise, left, right, cfg_grid["x"]) + + ureg = pint.UnitRegistry() + _Q = ureg.Quantity + + L = ( + _Q(species_params["gradient scale length"]).to("nm").magnitude + / cfg["units"]["derived"]["x0"].to("nm").magnitude + ) + nprof = species_params["val at center"] * np.exp((cfg_grid["x"] - species_params["center"]) / L) + nprof = mask * nprof + + elif species_params["basis"] == "tanh": + left = species_params["center"] - species_params["width"] * 0.5 + right = species_params["center"] + species_params["width"] * 0.5 + rise = species_params["rise"] + nprof = get_envelope(rise, rise, left, right, cfg_grid["x"]) + + if species_params["bump_or_trough"] == "trough": + nprof = 1 - nprof + nprof = species_params["baseline"] + species_params["bump_height"] * nprof + + elif species_params["basis"] == "sine": + baseline = species_params["baseline"] + amp = species_params["amplitude"] + kk = species_params["wavenumber"] + nprof = baseline * (1.0 + amp * jnp.sin(kk * cfg_grid["x"])) + else: + raise NotImplementedError + + return nprof def post_process(result: Solution, cfg: dict, td: str, args: dict): diff --git a/adept/_vlasov1d/modules.py b/adept/_vlasov1d/modules.py index adb305d..cbbff16 100644 --- a/adept/_vlasov1d/modules.py +++ b/adept/_vlasov1d/modules.py @@ -82,7 +82,25 @@ def get_derived_quantities(self): cfg_grid = self.cfg["grid"] cfg_grid["dx"] = cfg_grid["xmax"] / cfg_grid["nx"] - cfg_grid["dv"] = 2.0 * cfg_grid["vmax"] / cfg_grid["nv"] + + # Normalize species config: if not provided, generate a default electron species + if self.cfg["terms"].get("species", None) is None: + # Collect all density components (keys starting with "species-") + density_components = [name for name in self.cfg["density"].keys() if name.startswith("species-")] + if not density_components: + raise ValueError("No density components found (expected keys starting with 'species-')") + + # Generate default electron species config + self.cfg["terms"]["species"] = [ + { + "name": "electron", + "charge": -1.0, + "mass": 1.0, + "vmax": cfg_grid["vmax"], + "nv": cfg_grid["nv"], + "density_components": density_components, + } + ] if len(self.cfg["drivers"]["ey"].keys()) > 0: print("overriding dt to ensure wave solver stability") @@ -117,13 +135,8 @@ def get_solver_quantities(self) -> dict: cfg_grid["xmin"] + cfg_grid["dx"] / 2, cfg_grid["xmax"] - cfg_grid["dx"] / 2, cfg_grid["nx"] ), "t": jnp.linspace(0, cfg_grid["tmax"], cfg_grid["nt"]), - "v": jnp.linspace( - -cfg_grid["vmax"] + cfg_grid["dv"] / 2, cfg_grid["vmax"] - cfg_grid["dv"] / 2, cfg_grid["nv"] - ), "kx": jnp.fft.fftfreq(cfg_grid["nx"], d=cfg_grid["dx"]) * 2.0 * np.pi, "kxr": jnp.fft.rfftfreq(cfg_grid["nx"], d=cfg_grid["dx"]) * 2.0 * np.pi, - "kv": jnp.fft.fftfreq(cfg_grid["nv"], d=cfg_grid["dv"]) * 2.0 * np.pi, - "kvr": jnp.fft.rfftfreq(cfg_grid["nv"], d=cfg_grid["dv"]) * 2.0 * np.pi, }, } @@ -136,26 +149,82 @@ def get_solver_quantities(self) -> dict: one_over_kxr[1:] = 1.0 / cfg_grid["kxr"][1:] cfg_grid["one_over_kxr"] = jnp.array(one_over_kxr) - # velocity axes - one_over_kv = np.zeros_like(cfg_grid["kv"]) - one_over_kv[1:] = 1.0 / cfg_grid["kv"][1:] - cfg_grid["one_over_kv"] = jnp.array(one_over_kv) - - one_over_kvr = np.zeros_like(cfg_grid["kvr"]) - one_over_kvr[1:] = 1.0 / cfg_grid["kvr"][1:] - cfg_grid["one_over_kvr"] = jnp.array(one_over_kvr) - cfg_grid["nuprof"] = 1.0 # get_profile_with_mask(cfg["nu"]["time-profile"], t, cfg["nu"]["time-profile"]["bump_or_trough"]) cfg_grid["ktprof"] = 1.0 # get_profile_with_mask(cfg["krook"]["time-profile"], t, cfg["krook"]["time-profile"]["bump_or_trough"]) - cfg_grid["n_prof_total"], cfg_grid["starting_f"] = _initialize_total_distribution_(self.cfg, cfg_grid) + + # Initialize distributions (always returns dict format) + dist_result = _initialize_total_distribution_(self.cfg, cfg_grid) + cfg_grid["species_distributions"] = dist_result + + # Build species_grids and species_params + cfg_grid["species_grids"] = {} + cfg_grid["species_params"] = {} + n_prof_total = np.zeros([cfg_grid["nx"]]) + + for species_name, (n_prof, f_s, v_ax) in dist_result.items(): + n_prof_total += n_prof + + # Find the species config (always exists due to normalization in get_derived_quantities) + species_cfg = next((s for s in self.cfg["terms"]["species"] if s["name"] == species_name), None) + if species_cfg is None: + raise ValueError(f"Species '{species_name}' not found in config['terms']['species']") + + nv = species_cfg["nv"] + vmax = species_cfg["vmax"] + + dv = 2.0 * vmax / nv + + # Build velocity grid parameters for this species + cfg_grid["species_grids"][species_name] = { + "v": jnp.array(v_ax), + "dv": dv, + "nv": nv, + "vmax": vmax, + "kv": jnp.fft.fftfreq(nv, d=dv) * 2.0 * np.pi, + "kvr": jnp.fft.rfftfreq(nv, d=dv) * 2.0 * np.pi, + } + + # one_over_kv for this species (size is length of kvr for real FFT) + kvr_len = len(cfg_grid["species_grids"][species_name]["kvr"]) + one_over_kv = np.zeros(nv) + one_over_kv[1:] = 1.0 / cfg_grid["species_grids"][species_name]["kv"][1:] + cfg_grid["species_grids"][species_name]["one_over_kv"] = jnp.array(one_over_kv) + + one_over_kvr = np.zeros(kvr_len) + one_over_kvr[1:] = 1.0 / cfg_grid["species_grids"][species_name]["kvr"][1:] + cfg_grid["species_grids"][species_name]["one_over_kvr"] = jnp.array(one_over_kvr) + + # Build species parameters (charge, mass, charge-to-mass ratio) + cfg_grid["species_params"][species_name] = { + "charge": species_cfg["charge"], + "mass": species_cfg["mass"], + "charge_to_mass": species_cfg["charge"] / species_cfg["mass"], + } + + cfg_grid["n_prof_total"] = n_prof_total + + # Quasineutrality handling + # For single-species electron-only sims, assume static ion background + # For multi-species, quasineutrality is handled by the species themselves + has_multiple_species = len(self.cfg["terms"]["species"]) > 1 + if has_multiple_species: + cfg_grid["ion_charge"] = np.zeros_like(n_prof_total) + else: + cfg_grid["ion_charge"] = n_prof_total.copy() + + # For single-species configs, also store velocity grid at grid level for backward compatibility + if not has_multiple_species and "electron" in cfg_grid["species_grids"]: + cfg_grid["v"] = jnp.array(dist_result["electron"][2]) + cfg_grid["kv"] = cfg_grid["species_grids"]["electron"]["kv"] + cfg_grid["kvr"] = cfg_grid["species_grids"]["electron"]["kvr"] + cfg_grid["one_over_kv"] = cfg_grid["species_grids"]["electron"]["one_over_kv"] + cfg_grid["one_over_kvr"] = cfg_grid["species_grids"]["electron"]["one_over_kvr"] cfg_grid["kprof"] = np.ones_like(cfg_grid["n_prof_total"]) # get_profile_with_mask(cfg["krook"]["space-profile"], xs, cfg["krook"]["space-profile"]["bump_or_trough"]) - cfg_grid["ion_charge"] = np.zeros_like(cfg_grid["n_prof_total"]) + cfg_grid["n_prof_total"] - cfg_grid["x_a"] = np.concatenate( [ [cfg_grid["x"][0] - cfg_grid["dx"]], @@ -173,21 +242,31 @@ def init_state_and_args(self) -> dict: :param cfg: :return: """ - n_prof_total, f = _initialize_total_distribution_(self.cfg, self.cfg["grid"]) + # Initialize distributions (always returns dict format) + dist_result = _initialize_total_distribution_(self.cfg, self.cfg["grid"]) state = {} - for species in ["electron"]: - state[species] = f + # Build state dict with all species distributions + for species_name, (n_prof, f_s, v_ax) in dist_result.items(): + state[species_name] = jnp.array(f_s) + + # Reference distribution for diagnostics (use first species) + # TODO(gh-174): Store species distributions separately for multi-species diagnostics + first_species_name = next(iter(dist_result.keys())) + f_ref = dist_result[first_species_name][1] + + # Field quantities (same for all modes) for field in ["e", "de"]: state[field] = jnp.zeros(self.cfg["grid"]["nx"]) for field in ["a", "da", "prev_a"]: state[field] = jnp.zeros(self.cfg["grid"]["nx"] + 2) # need boundary cells + # Diagnostics (use reference distribution shape) for k in ["diag-vlasov-dfdt", "diag-fp-dfdt"]: if self.cfg["diagnostics"][k]: - state[k] = jnp.zeros_like(f) + state[k] = jnp.zeros_like(f_ref) self.state = state self.args = {"drivers": self.cfg["drivers"], "terms": self.cfg["terms"]} diff --git a/adept/_vlasov1d/solvers/pushers/field.py b/adept/_vlasov1d/solvers/pushers/field.py index 459202e..a50bdfd 100644 --- a/adept/_vlasov1d/solvers/pushers/field.py +++ b/adept/_vlasov1d/solvers/pushers/field.py @@ -102,67 +102,223 @@ def __call__(self, a: jnp.ndarray, aold: jnp.ndarray, djy_array: jnp.ndarray, el class SpectralPoissonSolver: - def __init__(self, ion_charge, one_over_kx, dv): + """Spectral Poisson solver for electrostatic field. + + Solves Poisson equation: div^2(phi) = -rho, E = -grad(phi) + where rho = sum_s q_s * integral(f_s dv_s) is the total charge density from all species. + + For quasineutral plasmas, the total charge should sum to zero at initialization. + """ + + def __init__(self, one_over_kx, species_grids, species_params): + """Initialize the Poisson solver. + + Args: + one_over_kx: 1/kx array for spectral solve (with 0 for k=0 mode) + species_grids: dict mapping species_name -> {"dv": dv, "v": v, ...} + species_params: dict mapping species_name -> {"charge": q, "mass": m, ...} + """ super().__init__() - self.ion_charge = ion_charge self.one_over_kx = one_over_kx - self.dv = dv + self.species_grids = species_grids + self.species_params = species_params + + def compute_charge_density(self, f_dict): + """Compute total charge density from all species. + + rho = sum_s q_s * integral(f_s dv_s) = sum_s q_s * n_s + + Args: + f_dict: dict mapping species_name -> f[nx, nv] distribution + + Returns: + Total charge density array[nx] + """ + rho = jnp.zeros_like(next(iter(f_dict.values()))[:, 0]) + + for species_name, f_s in f_dict.items(): + q_s = self.species_params[species_name]["charge"] + dv_s = self.species_grids[species_name]["dv"] + n_s = jnp.sum(f_s, axis=1) * dv_s + rho = rho + q_s * n_s + + return rho - def compute_charges(self, f): - return jnp.sum(f, axis=1) * self.dv + def __call__(self, f_dict: dict, prev_ex: jnp.ndarray, dt: jnp.float64): + """Solve Poisson equation for electric field. - def __call__(self, f: jnp.ndarray, prev_ex: jnp.ndarray, dt: jnp.float64): - return jnp.real(jnp.fft.ifft(1j * self.one_over_kx * jnp.fft.fft(self.ion_charge - self.compute_charges(f)))) + Args: + f_dict: dict mapping species_name -> f[nx, nv] distribution + prev_ex: previous electric field (unused for Poisson, kept for interface) + dt: time step (unused for Poisson, kept for interface) + + Returns: + Electric field E[nx] from Poisson solve + """ + rho = self.compute_charge_density(f_dict) + # Poisson equation: div^2(phi) = -rho => -k^2 phi_k = -rho_k => phi_k = rho_k / k^2 + # E = -grad(phi) => E_k = -ik phi_k = -ik rho_k / k^2 = -i rho_k / k + return jnp.real(jnp.fft.ifft(-1j * self.one_over_kx * jnp.fft.fft(rho))) class AmpereSolver: - def __init__(self, cfg): + """Ampere solver using current density to evolve electric field. + + Solves: a single Euler step of ∂E/∂t = -j, where j = Σ_s (q_s/m_s) ∫v f_s dv + is the total current density from all species. + """ + + def __init__(self, species_grids, species_params): + """Initialize the Ampere solver. + + Args: + species_grids: dict mapping species_name -> {"dv": dv, "v": v[nv], ...} + species_params: dict mapping species_name -> {"charge": q, "mass": m, ...} + """ super().__init__() - self.vx = cfg["grid"]["v"] - self.dv = cfg["grid"]["dv"] + self.species_grids = species_grids + self.species_params = species_params + + def compute_current_density(self, f_dict): + """Compute total current density from all species. + + j = Σ_s q_s ∫v f_s dv_s + + Args: + f_dict: dict mapping species_name -> f[nx, nv] distribution + + Returns: + Total current density array[nx] + """ + j = jnp.zeros_like(next(iter(f_dict.values()))[:, 0]) - def vx_moment(self, f): - return jnp.sum(f, axis=1) * self.dv + for species_name, f_s in f_dict.items(): + q_s = self.species_params[species_name]["charge"] + v_s = self.species_grids[species_name]["v"] + dv_s = self.species_grids[species_name]["dv"] + # Current from this species: q_s * ∫v f_s dv + j_s = jnp.sum(v_s[None, :] * f_s, axis=1) * dv_s + j = j + q_s * j_s - def __call__(self, f: jnp.ndarray, prev_ex: jnp.ndarray, dt: jnp.float64): - return prev_ex - dt * self.vx_moment(self.vx[None, :] * f) + return j + + def __call__(self, f_dict: dict, prev_ex: jnp.ndarray, dt: jnp.float64): + """Evolve electric field using Ampere's law. + + Args: + f_dict: dict mapping species_name -> f[nx, nv] distribution + prev_ex: previous electric field E[nx] + dt: time step + + Returns: + Updated electric field E[nx] + """ + j = self.compute_current_density(f_dict) + return prev_ex - dt * j class HampereSolver: - def __init__(self, cfg): - self.vx = cfg["grid"]["v"][None, :] - self.dv = cfg["grid"]["dv"] - self.kx = cfg["grid"]["kx"][:, None] - self.one_over_ikx = cfg["grid"]["one_over_kx"] / 1j + """Hamiltonian Ampere solver using spectral integration in time. + + This solver uses the characteristic method to integrate Ampere's law + exactly along particle trajectories in Fourier space. + + Note: Currently only supports single-species due to the complexity of + the spectral integration with different velocity grids. For multi-species, + use the standard AmpereSolver instead. + """ + + def __init__(self, kx, one_over_kx, species_grids, species_params): + """Initialize the Hamiltonian Ampere solver. + + Args: + kx: wavenumber array kx[nx] + one_over_kx: 1/kx array (with 0 for k=0 mode) + species_grids: dict mapping species_name -> {"dv": dv, "v": v[nv], ...} + species_params: dict mapping species_name -> {"charge": q, "mass": m, ...} + """ + self.kx = kx[:, None] + self.one_over_ikx = one_over_kx / 1j + self.species_grids = species_grids + self.species_params = species_params + + # For now, validate that we have exactly one species (limitation documented above) + if len(species_grids) > 1: + raise NotImplementedError( + "HampereSolver currently only supports single-species simulations. " + "For multi-species, use 'ampere' or 'poisson' field solver instead." + ) + + # Cache the single species' grid parameters + species_name = next(iter(species_grids.keys())) + self.vx = species_grids[species_name]["v"][None, :] + self.dv = species_grids[species_name]["dv"] + self.charge = species_params[species_name]["charge"] + + def __call__(self, f_dict: dict, prev_ex: jnp.ndarray, dt: jnp.float64): + """Evolve electric field using Hamiltonian Ampere method. + + Args: + f_dict: dict mapping species_name -> f[nx, nv] distribution + prev_ex: previous electric field E[nx] + dt: time step + + Returns: + Updated electric field E[nx] + """ + # Extract the single species distribution + f = next(iter(f_dict.values())) - def __call__(self, f: jnp.ndarray, prev_ex: jnp.ndarray, dt: jnp.float64): prev_ek = jnp.fft.fft(prev_ex, axis=0) fk = jnp.fft.fft(f, axis=0) new_ek = ( - prev_ek + self.one_over_ikx * jnp.sum(fk * (jnp.exp(-1j * self.kx * dt * self.vx) - 1), axis=1) * self.dv + prev_ek + + self.charge + * self.one_over_ikx + * jnp.sum(fk * (jnp.exp(-1j * self.kx * dt * self.vx) - 1), axis=1) + * self.dv ) return jnp.real(jnp.fft.ifft(new_ek)) class ElectricFieldSolver: + """Wrapper for electrostatic field solvers. + + Combines the self-consistent electrostatic field from Poisson/Ampere solve + with the ponderomotive force from laser fields. + """ + def __init__(self, cfg): super().__init__() + species_grids = cfg["grid"]["species_grids"] + species_params = cfg["grid"]["species_params"] + if cfg["terms"]["field"] == "poisson": self.es_field_solver = SpectralPoissonSolver( - ion_charge=cfg["grid"]["ion_charge"], one_over_kx=cfg["grid"]["one_over_kx"], dv=cfg["grid"]["dv"] + one_over_kx=cfg["grid"]["one_over_kx"], + species_grids=species_grids, + species_params=species_params, ) self.hampere = False elif cfg["terms"]["field"] == "ampere": if cfg["terms"]["time"] == "leapfrog": - self.es_field_solver = AmpereSolver(cfg) + self.es_field_solver = AmpereSolver( + species_grids=species_grids, + species_params=species_params, + ) self.hampere = False else: raise NotImplementedError(f"ampere + {cfg['terms']['time']} has not yet been implemented") elif cfg["terms"]["field"] == "hampere": if cfg["terms"]["time"] == "leapfrog": - self.es_field_solver = HampereSolver(cfg) + self.es_field_solver = HampereSolver( + kx=cfg["grid"]["kx"], + one_over_kx=cfg["grid"]["one_over_kx"], + species_grids=species_grids, + species_params=species_params, + ) self.hampere = True else: raise NotImplementedError(f"ampere + {cfg['terms']['time']} has not yet been implemented") @@ -170,16 +326,22 @@ def __init__(self, cfg): raise NotImplementedError("Field Solver: <" + cfg["solver"]["field"] + "> has not yet been implemented") self.dx = cfg["grid"]["dx"] - def __call__(self, f: jnp.ndarray, a: jnp.ndarray, prev_ex: jnp.ndarray, dt: jnp.float64): - """ - This returns the total electrostatic field that is used in the Vlasov equation - The total field is a sum of the ponderomotive force from `E_y`, the driver field, and the - self-consistent electrostatic field from a Poisson or Ampere solve + def __call__(self, f_dict: dict, a: jnp.ndarray, prev_ex: jnp.ndarray, dt: jnp.float64): + """Compute total electrostatic field for the Vlasov equation. - :param f: distribution function - :param a: - :return: + The total field is a sum of: + - Ponderomotive force from E_y (laser field) + - Self-consistent electrostatic field from Poisson or Ampere solve + + Args: + f_dict: dict mapping species_name -> f[nx, nv] distribution + a: vector potential a[nx+2] (with boundary cells) + prev_ex: previous electric field E[nx] + dt: time step + + Returns: + Tuple of (ponderomotive_force[nx], self_consistent_ex[nx]) """ ponderomotive_force = -0.5 * jnp.gradient(a**2.0, self.dx)[1:-1] - self_consistent_ex = self.es_field_solver(f, prev_ex, dt) + self_consistent_ex = self.es_field_solver(f_dict, prev_ex, dt) return ponderomotive_force, self_consistent_ex diff --git a/adept/_vlasov1d/solvers/pushers/fokker_planck.py b/adept/_vlasov1d/solvers/pushers/fokker_planck.py index 8b48293..b354a75 100644 --- a/adept/_vlasov1d/solvers/pushers/fokker_planck.py +++ b/adept/_vlasov1d/solvers/pushers/fokker_planck.py @@ -42,16 +42,32 @@ def __init_fp_operator__(self) -> "LenardBernstein | ChangCooperLenardBernstein else: raise NotImplementedError - def __call__(self, nu_fp: jnp.ndarray, nu_K: jnp.ndarray, f: jnp.ndarray, dt: jnp.float64) -> jnp.ndarray: + def __call__(self, nu_fp: jnp.ndarray, nu_K: jnp.ndarray, f, dt: jnp.float64): """ Apply configured collision operators to the distribution function. :param nu_fp: Collision frequencies for the Fokker-Planck operator (shape: nx). :param nu_K: Krook collision frequencies (shape: nx). - :param f: Distribution function f(x, v) (shape: nx x nv). + :param f: Distribution function (dict or array). :param dt: Time step size. :return: Updated distribution function after collisions. """ + # TODO(gh-173): Properly handle multi-species collisions + # For now, only apply to electron distribution for backward compatibility + if isinstance(f, dict): + result = {} + for species_name, f_species in f.items(): + if species_name == "electron": + result[species_name] = self._apply_collisions(nu_fp, nu_K, f_species, dt) + else: + # For non-electron species, just pass through unchanged for now + result[species_name] = f_species + return result + else: + return self._apply_collisions(nu_fp, nu_K, f, dt) + + def _apply_collisions(self, nu_fp: jnp.ndarray, nu_K: jnp.ndarray, f: jnp.ndarray, dt: jnp.float64) -> jnp.ndarray: + """Apply collision operators to a single species distribution.""" if self.cfg["terms"]["fokker_planck"]["is_on"]: # The three diagonals representing collision operator for all x cee_a, cee_b, cee_c = self.fp(nu=nu_fp, f_xv=f, dt=dt) @@ -74,9 +90,11 @@ def __init__(self, cfg: Mapping[str, Any]): :param cfg: Simulation configuration containing grid spacing and velocity grid. """ self.cfg = cfg - f_mx = np.exp(-(self.cfg["grid"]["v"][None, :] ** 2.0) / 2.0) - self.f_mx = f_mx / np.trapz(f_mx, dx=self.cfg["grid"]["dv"], axis=1)[:, None] - self.dv = self.cfg["grid"]["dv"] + v = cfg["grid"]["species_grids"]["electron"]["v"] + dv = cfg["grid"]["species_grids"]["electron"]["dv"] + f_mx = np.exp(-(v[None, :] ** 2.0) / 2.0) + self.f_mx = f_mx / np.trapz(f_mx, dx=dv, axis=1)[:, None] + self.dv = dv def vx_moment(self, f_xv: jnp.ndarray) -> jnp.ndarray: """Compute density n(x) by integrating over velocity.""" @@ -103,9 +121,11 @@ class _DriftDiffusionBase: def __init__(self, cfg: Mapping[str, Any]): self.cfg = cfg - self.v = self.cfg["grid"]["v"] - self.dv = self.cfg["grid"]["dv"] - self.ones = jnp.ones((self.cfg["grid"]["nx"], self.cfg["grid"]["nv"])) + # TODO(gh-173): For multi-species, use electron grid for FP for now + self.v = cfg["grid"]["species_grids"]["electron"]["v"] + self.dv = cfg["grid"]["species_grids"]["electron"]["dv"] + nv = cfg["grid"]["species_grids"]["electron"]["nv"] + self.ones = jnp.ones((self.cfg["grid"]["nx"], nv)) def vx_moment(self, f_xv: jnp.ndarray) -> jnp.ndarray: """Compute density n(x) by integrating over velocity.""" @@ -169,10 +189,11 @@ def __init__(self, cfg: Mapping[str, Any]): :param cfg: Simulation configuration providing grid metadata. """ self.cfg = cfg - self.v = self.cfg["grid"]["v"] - self.dv = self.cfg["grid"]["dv"] + self.v = cfg["grid"]["species_grids"]["electron"]["v"] + self.dv = cfg["grid"]["species_grids"]["electron"]["dv"] + nv = cfg["grid"]["species_grids"]["electron"]["nv"] self.v_edge = 0.5 * (self.v[1:] + self.v[:-1]) - self.ones = jnp.ones((self.cfg["grid"]["nx"], self.cfg["grid"]["nv"])) + self.ones = jnp.ones((self.cfg["grid"]["nx"], nv)) def vx_moment(self, f_xv: jnp.ndarray) -> jnp.ndarray: """Compute density n(x) by integrating over velocity.""" @@ -236,10 +257,11 @@ def __init__(self, cfg: Mapping[str, Any]): :param cfg: Simulation configuration providing grid metadata. """ self.cfg = cfg - self.v = self.cfg["grid"]["v"] - self.dv = self.cfg["grid"]["dv"] + self.v = cfg["grid"]["species_grids"]["electron"]["v"] + self.dv = cfg["grid"]["species_grids"]["electron"]["dv"] + nv = cfg["grid"]["species_grids"]["electron"]["nv"] self.v_edge = 0.5 * (self.v[1:] + self.v[:-1]) - self.ones = jnp.ones((self.cfg["grid"]["nx"], self.cfg["grid"]["nv"])) + self.ones = jnp.ones((self.cfg["grid"]["nx"], nv)) def vx_moment(self, f_xv: jnp.ndarray) -> jnp.ndarray: """Compute density n(x) by integrating over velocity.""" diff --git a/adept/_vlasov1d/solvers/pushers/vlasov.py b/adept/_vlasov1d/solvers/pushers/vlasov.py index a6608b4..19a1d07 100644 --- a/adept/_vlasov1d/solvers/pushers/vlasov.py +++ b/adept/_vlasov1d/solvers/pushers/vlasov.py @@ -49,31 +49,49 @@ def __call__(self, t, y, args): class VelocityExponential: - def __init__(self, cfg): - self.kv_real = cfg["grid"]["kvr"] + def __init__(self, species_grids, species_params): + self.species_grids = species_grids + self.species_params = species_params - def __call__(self, f, e, dt): - return jnp.real( - jnp.fft.irfft(jnp.exp(-1j * self.kv_real[None, :] * dt * e[:, None]) * jnp.fft.rfft(f, axis=1), axis=1) - ) + def __call__(self, f_dict, e, dt): + result = {} + for species_name, f in f_dict.items(): + kv_real = self.species_grids[species_name]["kvr"] + qm = self.species_params[species_name]["charge_to_mass"] + result[species_name] = jnp.real( + jnp.fft.irfft(jnp.exp(-1j * kv_real[None, :] * dt * qm * e[:, None]) * jnp.fft.rfft(f, axis=1), axis=1) + ) + return result class VelocityCubicSpline: - def __init__(self, cfg): - self.v = jnp.repeat(cfg["grid"]["v"][None, :], repeats=cfg["grid"]["nx"], axis=0) - self.interp = vmap(partial(interp1d, extrap=True), in_axes=0) # {"xq": 0, "f": 0, "x": None}) - - def __call__(self, f, e, dt): - vq = self.v - e[:, None] * dt - return self.interp(xq=vq, x=self.v, f=f) + def __init__(self, species_grids, species_params): + self.species_grids = species_grids + self.species_params = species_params + self.interp = vmap(partial(interp1d, extrap=True), in_axes=0) + + def __call__(self, f_dict, e, dt): + result = {} + for species_name, f in f_dict.items(): + v = self.species_grids[species_name]["v"] + qm = self.species_params[species_name]["charge_to_mass"] + nx = f.shape[0] + v_repeated = jnp.repeat(v[None, :], repeats=nx, axis=0) + vq = v_repeated - qm * e[:, None] * dt + result[species_name] = self.interp(xq=vq, x=v_repeated, f=f) + return result class SpaceExponential: - def __init__(self, cfg): - self.kx_real = cfg["grid"]["kxr"] - self.v = cfg["grid"]["v"] - - def __call__(self, f, dt): - return jnp.real( - jnp.fft.irfft(jnp.exp(-1j * self.kx_real[:, None] * dt * self.v[None, :]) * jnp.fft.rfft(f, axis=0), axis=0) - ) + def __init__(self, x, species_grids): + self.kx_real = jnp.fft.rfftfreq(len(x), d=x[1] - x[0]) * 2 * jnp.pi + self.species_grids = species_grids + + def __call__(self, f_dict, dt): + result = {} + for species_name, f in f_dict.items(): + v = self.species_grids[species_name]["v"] + result[species_name] = jnp.real( + jnp.fft.irfft(jnp.exp(-1j * self.kx_real[:, None] * dt * v[None, :]) * jnp.fft.rfft(f, axis=0), axis=0) + ) + return result diff --git a/adept/_vlasov1d/solvers/vector_field.py b/adept/_vlasov1d/solvers/vector_field.py index 1087df7..b06d897 100644 --- a/adept/_vlasov1d/solvers/vector_field.py +++ b/adept/_vlasov1d/solvers/vector_field.py @@ -19,23 +19,31 @@ class TimeIntegrator: def __init__(self, cfg: dict): self.field_solve = field.ElectricFieldSolver(cfg) + self.species_grids = cfg["grid"]["species_grids"] + self.species_params = cfg["grid"]["species_params"] self.edfdv = self.get_edfdv(cfg) - self.vdfdx = vlasov.SpaceExponential(cfg) + self.vdfdx = vlasov.SpaceExponential(cfg["grid"]["x"], self.species_grids) def get_edfdv(self, cfg: dict): if cfg["terms"]["edfdv"] == "exponential": - return vlasov.VelocityExponential(cfg) + return vlasov.VelocityExponential(self.species_grids, self.species_params) elif cfg["terms"]["edfdv"] == "cubic-spline": - return vlasov.VelocityCubicSpline(cfg) + return vlasov.VelocityCubicSpline(self.species_grids, self.species_params) else: raise NotImplementedError(f"{cfg['terms']['edfdv']} has not been implemented") class LeapfrogIntegrator(TimeIntegrator): - """ - This is a leapfrog integrator + """Leapfrog integrator for Vlasov-Poisson system. - :param cfg: + Implements a standard leapfrog scheme where all species are pushed + synchronously under the same self-consistent electric field. + + The field solve uses the total charge density from all species: + rho = sum_s q_s * integral(f_s dv_s) + + Args: + cfg: Configuration dictionary """ def __init__(self, cfg: dict): @@ -43,23 +51,41 @@ def __init__(self, cfg: dict): self.dt = cfg["grid"]["dt"] self.dt_array = self.dt * jnp.array([0.0, 1.0]) - def __call__(self, f: Array, a: Array, dex_array: Array, prev_ex: Array) -> tuple[Array, Array]: - f_after_v = self.vdfdx(f=f, dt=self.dt) + def __call__(self, f_dict: dict, a: Array, dex_array: Array, prev_ex: Array) -> tuple[Array, dict]: + """Perform one leapfrog timestep for all species. + + Args: + f_dict: dict mapping species_name -> f[nx, nv] distribution + a: vector potential a[nx+2] + dex_array: external driver field at substep times + prev_ex: previous electric field E[nx] + + Returns: + Tuple of (electric_field[nx], updated_f_dict) + """ + f_after_v = self.vdfdx(f_dict, dt=self.dt) if self.field_solve.hampere: - f_for_field = f + f_for_field = f_dict else: f_for_field = f_after_v - pond, e = self.field_solve(f=f_for_field, a=a, prev_ex=prev_ex, dt=self.dt) - f = self.edfdv(f=f_after_v, e=pond + e + dex_array[0], dt=self.dt) + pond, e = self.field_solve(f_dict=f_for_field, a=a, prev_ex=prev_ex, dt=self.dt) + f_dict = self.edfdv(f_after_v, e=pond + e + dex_array[0], dt=self.dt) - return e, f + return e, f_dict class SixthOrderHamIntegrator(TimeIntegrator): - """ - This class contains the 6th order Hamiltonian integrator from Crousseilles + """6th-order Hamiltonian integrator for Vlasov-Poisson system. - :param cfg: + Implements the 6th-order symplectic integrator from Crouseilles et al. + All species are pushed synchronously under the same self-consistent + electric field at each substep. + + The field solve uses the total charge density from all species: + rho = sum_s q_s * integral(f_s dv_s) + + Args: + cfg: Configuration dictionary """ def __init__(self, cfg): @@ -94,56 +120,72 @@ def __init__(self, cfg): ] ) - def __call__(self, f: Array, a: Array, dex_array: Array, prev_ex: Array) -> tuple[Array, Array]: - ponderomotive_force, self_consistent_ex = self.field_solve(f=f, a=a, prev_ex=None, dt=None) + def __call__(self, f_dict: dict, a: Array, dex_array: Array, prev_ex: Array) -> tuple[Array, dict]: + """Perform one 6th-order timestep for all species. + + Args: + f_dict: dict mapping species_name -> f[nx, nv] distribution + a: vector potential a[nx+2] + dex_array: external driver field at substep times + prev_ex: previous electric field (unused, kept for interface consistency) + + Returns: + Tuple of (electric_field[nx], updated_f_dict) + """ + ponderomotive_force, self_consistent_ex = self.field_solve(f_dict=f_dict, a=a, prev_ex=None, dt=None) force = ponderomotive_force + dex_array[0] + self_consistent_ex - f = self.edfdv(f=f, e=force, dt=self.D1 * self.dt) + f_dict = self.edfdv(f_dict, e=force, dt=self.D1 * self.dt) - f = self.vdfdx(f=f, dt=self.a1 * self.dt) - ponderomotive_force, self_consistent_ex = self.field_solve(f=f, a=a, prev_ex=None, dt=None) + f_dict = self.vdfdx(f_dict, dt=self.a1 * self.dt) + ponderomotive_force, self_consistent_ex = self.field_solve(f_dict=f_dict, a=a, prev_ex=None, dt=None) force = ponderomotive_force + dex_array[1] + self_consistent_ex - f = self.edfdv(f=f, e=force, dt=self.D2 * self.dt) + f_dict = self.edfdv(f_dict, e=force, dt=self.D2 * self.dt) - f = self.vdfdx(f=f, dt=self.a2 * self.dt) - ponderomotive_force, self_consistent_ex = self.field_solve(f=f, a=a, prev_ex=None, dt=None) + f_dict = self.vdfdx(f_dict, dt=self.a2 * self.dt) + ponderomotive_force, self_consistent_ex = self.field_solve(f_dict=f_dict, a=a, prev_ex=None, dt=None) force = ponderomotive_force + dex_array[2] + self_consistent_ex - f = self.edfdv(f=f, e=force, dt=self.D3 * self.dt) + f_dict = self.edfdv(f_dict, e=force, dt=self.D3 * self.dt) - f = self.vdfdx(f=f, dt=self.a3 * self.dt) - ponderomotive_force, self_consistent_ex = self.field_solve(f=f, a=a, prev_ex=None, dt=None) + f_dict = self.vdfdx(f_dict, dt=self.a3 * self.dt) + ponderomotive_force, self_consistent_ex = self.field_solve(f_dict=f_dict, a=a, prev_ex=None, dt=None) force = ponderomotive_force + dex_array[3] + self_consistent_ex - f = self.edfdv(f=f, e=force, dt=self.D3 * self.dt) + f_dict = self.edfdv(f_dict, e=force, dt=self.D3 * self.dt) - f = self.vdfdx(f=f, dt=self.a2 * self.dt) - ponderomotive_force, self_consistent_ex = self.field_solve(f=f, a=a, prev_ex=None, dt=None) + f_dict = self.vdfdx(f_dict, dt=self.a2 * self.dt) + ponderomotive_force, self_consistent_ex = self.field_solve(f_dict=f_dict, a=a, prev_ex=None, dt=None) force = ponderomotive_force + dex_array[4] + self_consistent_ex - f = self.edfdv(f=f, e=force, dt=self.D2 * self.dt) + f_dict = self.edfdv(f_dict, e=force, dt=self.D2 * self.dt) - f = self.vdfdx(f=f, dt=self.a1 * self.dt) - ponderomotive_force, self_consistent_ex = self.field_solve(f=f, a=a, prev_ex=None, dt=None) + f_dict = self.vdfdx(f_dict, dt=self.a1 * self.dt) + ponderomotive_force, self_consistent_ex = self.field_solve(f_dict=f_dict, a=a, prev_ex=None, dt=None) force = ponderomotive_force + dex_array[5] + self_consistent_ex - f = self.edfdv(f=f, e=force, dt=self.D1 * self.dt) + f_dict = self.edfdv(f_dict, e=force, dt=self.D1 * self.dt) - return self_consistent_ex, f + return self_consistent_ex, f_dict class VlasovPoissonFokkerPlanck: - """ - This class contains the Vlasov-Poisson + Fokker-Planck timestep + """Vlasov-Poisson + Fokker-Planck timestep for multi-species simulations. - :param cfg: Configuration dictionary + Combines a Vlasov-Poisson integrator (leapfrog or 6th-order Hamiltonian) + with optional Fokker-Planck collisions. Handles dict-based multi-species + distributions where each species evolves under the same self-consistent + electric field. - :return: Tuple of the electric field and the distribution function + Args: + cfg: Configuration dictionary + + Returns: + Tuple of (electric_field, f_dict, diagnostics_dict) """ def __init__(self, cfg: dict): self.dt = cfg["grid"]["dt"] - self.v = cfg["grid"]["v"] if cfg["terms"]["time"] == "sixth": self.vlasov_poisson = SixthOrderHamIntegrator(cfg) self.dex_save = 3 @@ -157,17 +199,21 @@ def __init__(self, cfg: dict): self.fp_dfdt = cfg["diagnostics"]["diag-fp-dfdt"] def __call__( - self, f: Array, a: Array, prev_ex: Array, dex_array: Array, nu_fp: Array, nu_K: Array - ) -> tuple[Array, Array]: - e, f_vlasov = self.vlasov_poisson(f, a, dex_array, prev_ex) + self, f_dict: dict, a: Array, prev_ex: Array, dex_array: Array, nu_fp: Array, nu_K: Array + ) -> tuple[Array, dict, dict]: + e, f_vlasov = self.vlasov_poisson(f_dict, a, dex_array, prev_ex) f_fp = self.fp(nu_fp, nu_K, f_vlasov, dt=self.dt) diags = {} if self.vlasov_dfdt: - diags["diag-vlasov-dfdt"] = (f_vlasov - f) / self.dt + # Compute diagnostics for each species + for species_name in f_dict.keys(): + diags[f"diag-vlasov-dfdt-{species_name}"] = (f_vlasov[species_name] - f_dict[species_name]) / self.dt if self.fp_dfdt: - diags["diag-fp-dfdt"] = (f_fp - f_vlasov) / self.dt + # Compute diagnostics for each species + for species_name in f_dict.keys(): + diags[f"diag-fp-dfdt-{species_name}"] = (f_fp[species_name] - f_vlasov[species_name]) / self.dt return e, f_fp, diags @@ -188,8 +234,18 @@ def __init__(self, cfg: dict): self.ey_driver = field.Driver(cfg["grid"]["x_a"], driver_key="ey") self.ex_driver = field.Driver(cfg["grid"]["x"], driver_key="ex") - def compute_charges(self, f): - return jnp.sum(f, axis=1) * self.cfg["grid"]["dv"] + def compute_charges(self, f_dict): + """Compute charge density from distribution functions. + + For a dict of species distributions, sum over all species with their charge-to-mass ratios. + """ + charge_density = jnp.zeros_like(self.cfg["grid"]["x"]) + for species_name, f in f_dict.items(): + dv = self.cfg["grid"]["species_grids"][species_name]["dv"] + charge = self.cfg["grid"]["species_params"][species_name]["charge"] + # Sum over velocity axis (axis=1) to get spatial density, then multiply by charge + charge_density += charge * jnp.sum(f, axis=1) * dv + return charge_density def nu_prof(self, t, nu_args): t_L = nu_args["time"]["center"] - nu_args["time"]["width"] * 0.5 @@ -237,21 +293,29 @@ def __call__(self, t, y, args): else: nu_K_prof = None - electron_density_n = self.compute_charges(y["electron"]) - e, f, diags = self.vpfp( - f=y["electron"], a=y["a"], prev_ex=y["e"], dex_array=dex, nu_fp=nu_fp_prof, nu_K=nu_K_prof + # Extract all species distributions from state + f_dict = {k: v for k, v in y.items() if k in self.cfg["grid"]["species_grids"]} + + electron_density_n = self.compute_charges(f_dict) + e, f_dict_new, diags = self.vpfp( + f_dict=f_dict, a=y["a"], prev_ex=y["e"], dex_array=dex, nu_fp=nu_fp_prof, nu_K=nu_K_prof ) - electron_density_np1 = self.compute_charges(f) + electron_density_np1 = self.compute_charges(f_dict_new) a = self.wave_solver( a=y["a"], aold=y["prev_a"], djy_array=djy, electron_charge=0.5 * (electron_density_n + electron_density_np1) ) - return { - "electron": f, + # Build result dict with all species distributions + result = { "a": a["a"], "prev_a": a["prev_a"], "da": djy, "de": dex[self.vpfp.dex_save], "e": e, - } | diags + } + # Add all species distributions + result.update(f_dict_new) + result.update(diags) + + return result diff --git a/adept/_vlasov1d/storage.py b/adept/_vlasov1d/storage.py index f698b2c..dd30a81 100644 --- a/adept/_vlasov1d/storage.py +++ b/adept/_vlasov1d/storage.py @@ -63,12 +63,16 @@ def store_f(cfg: dict, this_t: dict, td: str, ys: dict) -> xr.Dataset: :param ys: :return: """ - f_store = xr.Dataset( - { - spc: xr.DataArray(ys[spc], coords=(("t", this_t[spc]), ("x", cfg["grid"]["x"]), ("v", cfg["grid"]["v"]))) - for spc in ["electron"] - } - ) + # Find which species distributions were saved + species_names = list(cfg["grid"]["species_grids"].keys()) + species_to_save = [spc for spc in species_names if spc in ys] + + data_vars = {} + for spc in species_to_save: + v = cfg["grid"]["species_grids"][spc]["v"] + data_vars[spc] = xr.DataArray(ys[spc], coords=(("t", this_t[spc]), ("x", cfg["grid"]["x"]), ("v", v))) + + f_store = xr.Dataset(data_vars) f_store.to_netcdf(os.path.join(td, "binary", "dist.nc")) return f_store @@ -107,13 +111,15 @@ def store_diags(cfg: dict, this_t: dict, td: str, ys: dict) -> xr.Dataset: def get_field_save_func(cfg): if {"t"} == set(cfg["save"]["fields"].keys()): + v = cfg["grid"]["species_grids"]["electron"]["v"] + dv = cfg["grid"]["species_grids"]["electron"]["dv"] def _calc_moment_(inp): - return jnp.sum(inp, axis=1) * cfg["grid"]["dv"] + return jnp.sum(inp, axis=1) * dv def fields_save_func(t, y, args): - temp = {"n": _calc_moment_(y["electron"]), "v": _calc_moment_(y["electron"] * cfg["grid"]["v"][None, :])} - v_m_vbar = cfg["grid"]["v"][None, :] - temp["v"][:, None] + temp = {"n": _calc_moment_(y["electron"]), "v": _calc_moment_(y["electron"] * v[None, :])} + v_m_vbar = v[None, :] - temp["v"][:, None] temp["p"] = _calc_moment_(y["electron"] * v_m_vbar**2.0) temp["q"] = _calc_moment_(y["electron"] * v_m_vbar**3.0) temp["-flogf"] = _calc_moment_(y["electron"] * jnp.log(jnp.abs(y["electron"]))) @@ -213,8 +219,8 @@ def get_save_quantities(cfg: dict) -> dict: def get_default_save_func(cfg): - v = cfg["grid"]["v"][None, :] - dv = cfg["grid"]["dv"] + v = cfg["grid"]["species_grids"]["electron"]["v"][None, :] + dv = cfg["grid"]["species_grids"]["electron"]["dv"] def _calc_mean_moment_(inp): return jnp.mean(jnp.sum(inp, axis=1) * dv) diff --git a/tests/test_vlasov1d/configs/multispecies_ion_acoustic.yaml b/tests/test_vlasov1d/configs/multispecies_ion_acoustic.yaml new file mode 100644 index 0000000..581f4a0 --- /dev/null +++ b/tests/test_vlasov1d/configs/multispecies_ion_acoustic.yaml @@ -0,0 +1,118 @@ +units: + laser_wavelength: 351nm + normalizing_temperature: 2000eV + normalizing_density: 1.5e21/cc + Z: 10 + Zp: 10 + + +density: + quasineutrality: true + species-electron-background: + noise_seed: 420 + noise_type: gaussian + noise_val: 0.0 + v0: 0.0 + T0: 1.0 + m: 2.0 + basis: sine + baseline: 1.0 + amplitude: 1.0e-3 + wavenumber: 0.1 + species-ion-background: + noise_seed: 421 + noise_type: gaussian + noise_val: 0.0 + v0: 0.0 + T0: 0.01 + m: 2.0 + basis: sine + baseline: 1.0 + amplitude: 1.0e-3 + wavenumber: 0.1 + +grid: + dt: 0.5 + nx: 32 + tmin: 0. + tmax: 5000.0 + xmax: 314.159 + xmin: 0.0 + +save: + fields: + t: + tmin: 0.0 + tmax: 5000.0 + nt: 1001 + +solver: vlasov-1d + +mlflow: + experiment: multispecies-test + run: ion-acoustic-wave + +drivers: + ex: {} + ey: {} + +diagnostics: + diag-vlasov-dfdt: False + diag-fp-dfdt: False + +terms: + field: poisson + edfdv: exponential + time: sixth + species: + - name: electron + charge: -1.0 + mass: 1.0 + vmax: 6.4 + nv: 512 + density_components: + - species-electron-background + - name: ion + charge: 10.0 + mass: 18360.0 + vmax: 0.15 + nv: 256 + density_components: + - species-ion-background + fokker_planck: + is_on: False + type: Dougherty + time: + baseline: 1.0 + bump_or_trough: bump + center: 0.0 + rise: 25.0 + slope: 0.0 + bump_height: 0.0 + width: 100000.0 + space: + baseline: 1.0 + bump_or_trough: bump + center: 0.0 + rise: 25.0 + slope: 0.0 + bump_height: 0.0 + width: 100000.0 + krook: + is_on: False + time: + baseline: 1.0 + bump_or_trough: bump + center: 0.0 + rise: 25.0 + slope: 0.0 + bump_height: 0.0 + width: 100000.0 + space: + baseline: 1.0 + bump_or_trough: bump + center: 0.0 + rise: 25.0 + slope: 0.0 + bump_height: 0.0 + width: 100000.0 diff --git a/tests/test_vlasov1d/test_ion_acoustic_wave.py b/tests/test_vlasov1d/test_ion_acoustic_wave.py new file mode 100644 index 0000000..5394d5e --- /dev/null +++ b/tests/test_vlasov1d/test_ion_acoustic_wave.py @@ -0,0 +1,213 @@ +# Copyright (c) Ergodic LLC 2023 +# research@ergodic.io +""" +Integration test for ion acoustic waves in multi-species Vlasov-Poisson. + +Ion acoustic waves are a fundamental multi-species phenomenon where: +- Electrons provide the pressure (restoring force) +- Ions provide the inertia + +Dispersion relation (long wavelength limit, T_e >> T_i): + ω = k * c_s where c_s = sqrt(Z * T_e / m_i) + +Full dispersion relation: + ω² = k² * c_s² / (1 + k²λ_D²) + +Reference: https://farside.ph.utexas.edu/teaching/plasma/Plasma/node112.html +""" + +import numpy as np +import pytest +import yaml + +from adept import ergoExo + + +def ion_acoustic_frequency(k, Z, T_e, m_i, lambda_D=1.0): + """ + Compute ion acoustic wave frequency from dispersion relation. + + ω² = k² * c_s² / (1 + k²λ_D²) + where c_s² = Z * T_e / m_i + + Args: + k: wavenumber (normalized to 1/λ_D) + Z: ion charge number + T_e: electron temperature (normalized) + m_i: ion mass (normalized to electron mass) + lambda_D: Debye length (normalized, typically 1) + + Returns: + Ion acoustic wave frequency ω + """ + c_s_squared = Z * T_e / m_i + omega_squared = k**2 * c_s_squared / (1 + (k * lambda_D) ** 2) + return np.sqrt(omega_squared) + + +def measure_frequency_from_density(density, time_axis, expected_omega=None): + """ + Measure oscillation frequency from density time series. + + Uses FFT of the k=1 Fourier mode to find the dominant frequency. + + Args: + density: Density array [nt, nx] (ion or electron density) + time_axis: Time points [nt] + expected_omega: If provided, search near this frequency + + Returns: + Measured frequency + """ + # Get the k=1 Fourier mode (the driven mode) + nx = density.shape[1] + nk1 = 2.0 / nx * np.fft.fft(density, axis=1)[:, 1] + + # Use latter half of simulation (after transients settle) + nt = len(time_axis) + start_idx = nt // 2 + + # Compute frequency from FFT of the time series + dt = time_axis[1] - time_axis[0] + nk1_late = nk1[start_idx:] + freq_axis = np.fft.fftfreq(len(nk1_late), dt) + spectrum = np.abs(np.fft.fft(nk1_late)) + + # Convert to angular frequency + omega_axis = 2 * np.pi * freq_axis + positive_mask = omega_axis > 0 + + # If expected frequency is provided, search in a window around it + if expected_omega is not None: + # Search within factor of 5 of expected + search_mask = positive_mask & (omega_axis < 5 * expected_omega) & (omega_axis > expected_omega / 5) + if np.any(search_mask): + peak_idx = np.argmax(spectrum[search_mask]) + return omega_axis[search_mask][peak_idx] + + # Otherwise find global peak + peak_idx = np.argmax(spectrum[positive_mask]) + return omega_axis[positive_mask][peak_idx] + + +def compute_ion_density(f_ion, dv): + """Compute ion density by integrating distribution over velocity.""" + return np.sum(f_ion, axis=-1) * dv + + +@pytest.mark.parametrize("time_integrator", ["sixth"]) +def test_ion_acoustic_simulation_runs(time_integrator): + """ + Smoke test that multi-species ion acoustic simulation runs end-to-end. + + This test verifies: + 1. Multi-species config is parsed correctly + 2. Electrons and ions are initialized with different velocity grids + 3. The Vlasov-Poisson solver handles the multi-species dict structure + 4. Field solves compute total charge density from all species + 5. Simulation completes without errors + + Note: Quantitative frequency verification requires careful physics calibration + and is tracked separately. This test ensures the infrastructure works. + """ + with open("tests/test_vlasov1d/configs/multispecies_ion_acoustic.yaml") as file: + config = yaml.safe_load(file) + + # Use shorter simulation for smoke test + config["grid"]["nx"] = 64 # Override for better resolution + config["grid"]["tmax"] = 100.0 + config["grid"]["dt"] = 0.1 + config["save"]["fields"]["t"]["tmax"] = 100.0 + config["save"]["fields"]["t"]["nt"] = 101 + + # Modify config for this test + config["terms"]["time"] = time_integrator + config["mlflow"]["experiment"] = "vlasov1d-test-ion-acoustic" + config["mlflow"]["run"] = f"ion-acoustic-smoke-{time_integrator}" + + # Run simulation + exo = ergoExo() + exo.setup(config) + result, datasets, run_id = exo(None) + solver_result = result["solver result"] + + # Verify we got output + e_field = solver_result.ys["fields"]["e"] + time_axis = solver_result.ts["fields"] + + assert e_field.shape[0] == len(time_axis), "Time axis mismatch" + assert e_field.shape[1] == config["grid"]["nx"], "Spatial grid mismatch" + + # Verify the field has non-trivial dynamics (not just zeros) + assert np.std(e_field) > 0, "Electric field should have dynamics" + + print(f"\nIon Acoustic Smoke Test ({time_integrator}):") + print(" Simulation completed successfully") + print(f" Time steps: {len(time_axis)}") + print(f" E-field std: {np.std(e_field):.2e}") + + +@pytest.mark.parametrize("time_integrator", ["sixth"]) +def test_ion_acoustic_dispersion(time_integrator): + """ + Test that ion acoustic wave frequency matches theoretical prediction. + + Uses ion density oscillations to measure the frequency, which directly + reflects the ion acoustic mode without contamination from fast electron + plasma oscillations. + + With k=0.1 and λ_D=1, we have kλ_D=0.1 << 1, so we're in the + long-wavelength regime where ω ≈ k * c_s. + """ + with open("tests/test_vlasov1d/configs/multispecies_ion_acoustic.yaml") as file: + config = yaml.safe_load(file) + + # Modify config for this test + config["grid"]["nx"] = 64 # Override for better resolution + config["terms"]["time"] = time_integrator + config["mlflow"]["experiment"] = "vlasov1d-test-ion-acoustic" + config["mlflow"]["run"] = f"ion-acoustic-dispersion-{time_integrator}" + + # Extract physical parameters + k = config["density"]["species-electron-background"]["wavenumber"] + Z = config["terms"]["species"][1]["charge"] # ion charge + T_e = config["density"]["species-electron-background"]["T0"] + m_i = config["terms"]["species"][1]["mass"] + + # Compute expected frequency + expected_omega = ion_acoustic_frequency(k, Z, T_e, m_i) + + # Run simulation + exo = ergoExo() + exo.setup(config) + result, datasets, run_id = exo(None) + solver_result = result["solver result"] + + # Get ion density from the "n" field (computed from electron, but we want ion) + # The fields save includes "n" which is electron density + # We need to compute ion density from the default scalars or use a different approach + # + # For now, use the electron density "n" as a proxy since both oscillate together + # in an ion acoustic wave (quasineutral oscillation) + n_field = solver_result.ys["fields"]["n"] + time_axis = solver_result.ts["fields"] + + # Measure frequency from density oscillations + measured_omega = measure_frequency_from_density(n_field, time_axis, expected_omega) + + print(f"\nIon Acoustic Wave Test ({time_integrator}):") + print(f" Wavenumber k = {k}") + print(f" Ion charge Z = {Z}") + print(f" Ion mass m_i = {m_i}") + print(f" Sound speed c_s = {np.sqrt(Z * T_e / m_i):.6f}") + print(f" Expected ω = {expected_omega:.6f}") + print(f" Measured ω = {measured_omega:.6f}") + print(f" Relative error = {abs(measured_omega - expected_omega) / expected_omega * 100:.2f}%") + + # Assert frequency matches within 15% + # (Some error expected from numerical dispersion and finite grid effects) + np.testing.assert_allclose(measured_omega, expected_omega, rtol=0.15) + + +if __name__ == "__main__": + test_ion_acoustic_dispersion("sixth") diff --git a/tests/test_vlasov1d/test_multispecies_config.py b/tests/test_vlasov1d/test_multispecies_config.py new file mode 100644 index 0000000..d131fb1 --- /dev/null +++ b/tests/test_vlasov1d/test_multispecies_config.py @@ -0,0 +1,102 @@ +"""Test multi-species configuration parsing and validation.""" + +from pathlib import Path + +import pytest +import yaml +from pydantic import ValidationError + +from adept._vlasov1d.datamodel import ConfigModel, SpeciesConfig + + +def test_multispecies_config_parsing(): + """Test that multi-species config files parse correctly.""" + config_path = Path(__file__).parent / "configs" / "multispecies_ion_acoustic.yaml" + + with open(config_path) as f: + config_dict = yaml.safe_load(f) + + # Validate with Pydantic model + config = ConfigModel(**config_dict) + + # Check that species are defined + assert config.terms.species is not None + assert len(config.terms.species) == 2 + + # Check electron species + electron = config.terms.species[0] + assert electron.name == "electron" + assert electron.charge == -1.0 + assert electron.mass == 1.0 + assert electron.vmax == 6.4 + assert electron.nv == 512 + assert "species-electron-background" in electron.density_components + + # Check ion species + ion = config.terms.species[1] + assert ion.name == "ion" + assert ion.charge == 10.0 + assert ion.mass == 18360.0 + assert ion.vmax == 0.15 + assert ion.nv == 256 + assert "species-ion-background" in ion.density_components + + +def test_backward_compatible_config(): + """Test that old configs without terms.species still work.""" + config_path = Path(__file__).parent / "configs" / "resonance.yaml" + + with open(config_path) as f: + config_dict = yaml.safe_load(f) + + # Validate with Pydantic model + config = ConfigModel(**config_dict) + + # Check that species is None (not provided in old configs) + assert config.terms.species is None + + +def test_species_config_validation(): + """Test SpeciesConfig validation.""" + # Valid species config + species = SpeciesConfig( + name="electron", + charge=-1.0, + mass=1.0, + vmax=6.4, + nv=512, + density_components=["species-background"], + ) + assert species.name == "electron" + + # Test that all required fields must be present + with pytest.raises(ValidationError): + SpeciesConfig( + name="electron", + charge=-1.0, + # Missing mass, vmax, nv, density_components + ) + + +def test_multiple_density_components(): + """Test species with multiple density components.""" + species = SpeciesConfig( + name="electron", + charge=-1.0, + mass=1.0, + vmax=6.4, + nv=512, + density_components=["species-background", "species-beam"], + ) + assert len(species.density_components) == 2 + assert "species-background" in species.density_components + assert "species-beam" in species.density_components + + +if __name__ == "__main__": + # Run tests + test_multispecies_config_parsing() + test_backward_compatible_config() + test_species_config_validation() + test_multiple_density_components() + print("All tests passed!") diff --git a/tests/test_vlasov1d/test_multispecies_init.py b/tests/test_vlasov1d/test_multispecies_init.py new file mode 100644 index 0000000..274a719 --- /dev/null +++ b/tests/test_vlasov1d/test_multispecies_init.py @@ -0,0 +1,116 @@ +"""Test multi-species initialization and state structure.""" + +from pathlib import Path + +import numpy as np +import yaml + +from adept._vlasov1d.modules import BaseVlasov1D + + +def test_multispecies_state_initialization(): + """Test that multi-species state initialization creates correct structures.""" + config_path = Path(__file__).parent / "configs" / "multispecies_ion_acoustic.yaml" + + with open(config_path) as f: + config_dict = yaml.safe_load(f) + + # Create module + module = BaseVlasov1D(config_dict) + + # Initialize (following the pattern from the module) + module.write_units() + module.get_derived_quantities() + module.get_solver_quantities() + module.init_state_and_args() + + # Check state structure + assert "electron" in module.state + assert "ion" in module.state + assert "e" in module.state + assert "de" in module.state + + # Check electron distribution shape (nx=32, nv=512 from config) + assert module.state["electron"].shape == (32, 512) + + # Check ion distribution shape (nx=32, nv=256 from config) + assert module.state["ion"].shape == (32, 256) + + # Check species_grids + assert "species_grids" in module.cfg["grid"] + assert "electron" in module.cfg["grid"]["species_grids"] + assert "ion" in module.cfg["grid"]["species_grids"] + + # Check electron velocity grid + electron_grid = module.cfg["grid"]["species_grids"]["electron"] + assert electron_grid["nv"] == 512 + assert electron_grid["vmax"] == 6.4 + assert len(electron_grid["v"]) == 512 + + # Check ion velocity grid + ion_grid = module.cfg["grid"]["species_grids"]["ion"] + assert ion_grid["nv"] == 256 + assert ion_grid["vmax"] == 0.15 + assert len(ion_grid["v"]) == 256 + + # Check species_params + assert "species_params" in module.cfg["grid"] + assert "electron" in module.cfg["grid"]["species_params"] + assert "ion" in module.cfg["grid"]["species_params"] + + # Check electron params + electron_params = module.cfg["grid"]["species_params"]["electron"] + assert electron_params["charge"] == -1.0 + assert electron_params["mass"] == 1.0 + assert electron_params["charge_to_mass"] == -1.0 + + # Check ion params + ion_params = module.cfg["grid"]["species_params"]["ion"] + assert ion_params["charge"] == 10.0 + assert ion_params["mass"] == 18360.0 + assert np.isclose(ion_params["charge_to_mass"], 10.0 / 18360.0) + + # Check quasineutrality handling for multi-species + # For multi-species, ion_charge should be zeros + assert np.allclose(module.cfg["grid"]["ion_charge"], 0.0) + + +def test_backward_compatible_state_initialization(): + """Test that single-species configs still work correctly.""" + config_path = Path(__file__).parent / "configs" / "resonance.yaml" + + with open(config_path) as f: + config_dict = yaml.safe_load(f) + + # Create module + module = BaseVlasov1D(config_dict) + + # Initialize + module.write_units() + module.get_derived_quantities() + module.get_solver_quantities() + module.init_state_and_args() + + # Check state structure (should have electron only) + assert "electron" in module.state + assert "e" in module.state + assert "de" in module.state + + # For backward compatibility, electron should have shape (nx, nv) with grid-level nv + nx = module.cfg["grid"]["nx"] + nv = module.cfg["grid"]["nv"] + assert module.state["electron"].shape == (nx, nv) + + # Check that single-species still has grid-level velocity grid + assert "v" in module.cfg["grid"] + assert len(module.cfg["grid"]["v"]) == nv + + # Check quasineutrality handling for single-species + # For single-species, ion_charge should equal n_prof_total + assert np.allclose(module.cfg["grid"]["ion_charge"], module.cfg["grid"]["n_prof_total"]) + + +if __name__ == "__main__": + test_multispecies_state_initialization() + test_backward_compatible_state_initialization() + print("All initialization tests passed!") diff --git a/tests/test_vlasov1d/test_multispecies_pushers.py b/tests/test_vlasov1d/test_multispecies_pushers.py new file mode 100644 index 0000000..d970a61 --- /dev/null +++ b/tests/test_vlasov1d/test_multispecies_pushers.py @@ -0,0 +1,146 @@ +"""Test multi-species pushers using method of manufactured solutions. + +These tests verify convergence to exact solutions obtained by characteristic tracing, +since the pushers implement linear advection operators in space. +""" + +from pathlib import Path + +import jax.numpy as jnp +import numpy as np +import yaml + +from adept._vlasov1d.solvers.pushers.vlasov import SpaceExponential, VelocityCubicSpline, VelocityExponential + + +def test_space_exponential_convergence(): + """Test SpaceExponential achieves machine precision using method of manufactured solutions. + + The space pusher solves: ∂f/∂t + v ∂f/∂x = 0 + Exact characteristic solution: f(x, v, t+dt) = f(x - v*dt, v, t) + + The spectral method achieves machine precision even at low resolution. + """ + Lx = 2 * np.pi + nv = 1 # Only testing spatial dimension + v = jnp.array([0.5]) # Single velocity value + + # Test at two resolutions to verify machine precision + nx_values = [16, 32] + errors = [] + + for nx in nx_values: + dx = Lx / nx + x = jnp.linspace(0, Lx - dx, nx) + + species_grids = { + "electron": {"v": v, "nv": nv}, + } + + pusher = SpaceExponential(x, species_grids) + + # Manufactured solution: sinusoidal in space with single velocity + k = 2 # wavenumber + f_init = jnp.sin(k * x)[:, None] # Shape (nx, 1) + + dt = 0.01 + + # Apply pusher + f_dict = {"electron": f_init} + result = pusher(f_dict, dt) + f_numerical = result["electron"] + + # Exact solution: shift in x by -v*dt + # For sinusoidal initial condition: sin(k*x) -> sin(k*(x - v*dt)) = sin(k*x - k*v*dt) + f_exact = jnp.sin(k * x - k * v[0] * dt)[:, None] # Shape (nx, 1) + + # Compute error (L2 norm) + error = jnp.sqrt(jnp.mean((f_numerical - f_exact) ** 2)) + errors.append(float(error)) + + # Spectral method should achieve machine precision + assert all(err < 1e-12 for err in errors), ( + f"SpaceExponential should achieve machine precision. Errors: {[f'{e:.2e}' for e in errors]}" + ) + + +def test_velocity_exponential_convergence(): + """Test VelocityExponential achieves machine precision for multiple species. + + The velocity pusher solves: ∂f/∂t + (q/m)E ∂f/∂v = 0 + Exact characteristic solution: f(x, v, t+dt) = f(x, v - (q/m)*E*dt, t) + + Tests that each species is pushed by its respective q/m factor. + The spectral method achieves machine precision even at low resolution. + """ + nx = 1 # Only testing velocity dimension + vmax = 2 * np.pi # Periodic domain in velocity + qm_electron = -1.0 # electron charge-to-mass ratio + qm_ion = 1.0 / 1836.0 # ion charge-to-mass ratio (proton) + e = jnp.array([0.5]) # Constant electric field + dt = 0.01 + + # Test at two resolutions to verify machine precision + nv_values = [16, 32] + errors_electron = [] + errors_ion = [] + + for nv in nv_values: + dv = 2.0 * vmax / nv + v = jnp.linspace(-vmax + dv / 2, vmax - dv / 2, nv) + kvr = jnp.fft.rfftfreq(nv, d=dv) * 2.0 * np.pi + + species_grids = { + "electron": {"kvr": kvr, "nv": nv, "v": v}, + "ion": {"kvr": kvr, "nv": nv, "v": v}, + } + + species_params = { + "electron": {"charge": -1.0, "mass": 1.0, "charge_to_mass": qm_electron}, + "ion": {"charge": 1.0, "mass": 1836.0, "charge_to_mass": qm_ion}, + } + + pusher = VelocityExponential(species_grids, species_params) + + # Manufactured solution: sinusoidal in velocity + # k must be chosen so function is periodic over [-vmax, vmax] + # Period = 2*vmax, so k = 2*pi*n / (2*vmax) = pi*n/vmax + n_mode = 1 # Choose mode number + k = np.pi * n_mode / vmax # This ensures periodicity + f_init = jnp.sin(k * v)[None, :] # Shape (1, nv) + + # Apply pusher to both species with same initial condition + f_dict = {"electron": f_init, "ion": f_init} + result = pusher(f_dict, e, dt) + f_electron_numerical = result["electron"] + f_ion_numerical = result["ion"] + + # Exact solution using characteristic solution: f(v, t+dt) = f(v - (q/m)*E*dt, t) + # Each species should be shifted by its own q/m factor + v_shift_electron = qm_electron * e[0] * dt + v_shift_ion = qm_ion * e[0] * dt + + f_electron_exact = jnp.sin(k * (v - v_shift_electron))[None, :] + f_ion_exact = jnp.sin(k * (v - v_shift_ion))[None, :] + + # Compute errors (L2 norm) + error_electron = jnp.sqrt(jnp.mean((f_electron_numerical - f_electron_exact) ** 2)) + error_ion = jnp.sqrt(jnp.mean((f_ion_numerical - f_ion_exact) ** 2)) + + errors_electron.append(float(error_electron)) + errors_ion.append(float(error_ion)) + + # Spectral method should achieve machine precision for both species + assert all(err < 1e-12 for err in errors_electron), ( + f"VelocityExponential should achieve machine precision for electrons. " + f"Errors: {[f'{e:.2e}' for e in errors_electron]}" + ) + assert all(err < 1e-12 for err in errors_ion), ( + f"VelocityExponential should achieve machine precision for ions. Errors: {[f'{e:.2e}' for e in errors_ion]}" + ) + + +if __name__ == "__main__": + test_space_exponential_convergence() + test_velocity_exponential_convergence() + print("All pusher tests passed!")