Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions config/config_jepa.yml
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,39 @@ data_loading :

# config for training
training_config:

# training_mode: "masking", "student_teacher", "latent_loss"
training_mode: ["student_teacher"]

# Collapse monitoring for SSL training (JEPA/DINO/iBOT)
# Detects representation collapse via various metrics
# For forecasting, supports sequences of latents with per-step and aggregate metrics
collapse_monitoring:
enabled: true
compute_frequency: 100 # batches between metric computations
log_frequency: 100 # batches between metric logging
metrics:
effective_rank:
enabled: true
tensor_source: "both" # "student", "teacher", or "both"
sample_size: 2048 # max samples for SVD (0 = no sampling)
# For forecasting sequences: "all" (per-step + aggregates),
# "aggregate_only" (mean/min/max/degradation), "per_step_only"
forecast_aggregation: "all"
singular_values:
enabled: true
tensor_source: "both"
sample_size: 2048
forecast_aggregation: "all"
dimension_variance:
enabled: true
tensor_source: "both" # cheap to compute, good early indicator
forecast_aggregation: "all"
prototype_entropy:
enabled: true # only applies to DINO
ema_beta:
enabled: true

num_mini_epochs: 32
samples_per_mini_epoch: 4096
shuffle: True
Expand Down Expand Up @@ -182,8 +211,8 @@ training_config:
},
},
target_and_aux_calc: { "EMATeacher" :
{ ema_ramp_up_ratio : 0.09,
ema_halflife_in_thousands: 1e-3,
{ ema_ramp_up_ratio : null,
ema_halflife_in_thousands: 1e-1,
model_param_overrides : {
training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}}
},
Expand Down
31 changes: 30 additions & 1 deletion config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,39 @@ data_loading :

# config for training
training_config:

# training_mode: "masking", "student_teacher", "latent_loss"
training_mode: ["masking"]

# Collapse monitoring for detecting representation collapse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default config is on forecasting so we shouldn't need this in there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be interesting to measure anyway, but if you feel strongly. I can remove it

# Works with SSL training (JEPA/DINO) and forecasting modes
# For forecasting, monitors latent_state.patch_tokens at each forecast step
collapse_monitoring:
enabled: true
compute_frequency: 100 # batches between metric computations
log_frequency: 100 # batches between metric logging
metrics:
effective_rank:
enabled: true
tensor_source: "student" # "student", "teacher", or "both"
sample_size: 2048 # max samples for SVD (0 = no sampling)
# For forecasting sequences: "all" (per-step + aggregates),
# "aggregate_only" (mean/min/max/degradation), "per_step_only"
forecast_aggregation: "all"
singular_values:
enabled: true
tensor_source: "student"
sample_size: 2048
forecast_aggregation: "all"
dimension_variance:
enabled: true
tensor_source: "student"
forecast_aggregation: "all"
prototype_entropy:
enabled: false # only relevant for DINO
ema_beta:
enabled: false # only relevant for SSL with EMA teacher

num_mini_epochs: 32
samples_per_mini_epoch: 4096
shuffle: True
Expand Down
26 changes: 22 additions & 4 deletions src/weathergen/model/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
self.rampup_ratio = rampup_ratio
self.ema_model = empty_model
self.is_model_sharded = is_model_sharded
self.batch_size = 1
# Build a name → param map once
self.src_params = dict(self.original_model.named_parameters())

Expand All @@ -55,16 +56,33 @@ def requires_grad_(self, flag: bool):
for p in self.ema_model.parameters():
p.requires_grad = flag

def get_current_beta(self, cur_step: int) -> float:
"""
Get current EMA beta value for monitoring.

The beta value determines how much the teacher model is updated towards
the student model at each step. Higher beta means slower teacher updates.

Args:
cur_step: Current training step (typically istep * batch_size).

Returns:
Current EMA beta value.
"""
halflife_steps = self.halflife_steps
if self.rampup_ratio is not None:
halflife_steps = min(halflife_steps, cur_step / self.rampup_ratio)
beta = 0.5 ** (self.batch_size / max(halflife_steps, 1e-6))
return beta
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we continuously update beta and store it as a class member variable, and just return it in the function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is what is happening


@torch.no_grad()
def update(self, cur_step, batch_size):
# ensure model remains sharded
if self.is_model_sharded:
self.ema_model.reshard()
# determine correct interpolation params
halflife_steps = self.halflife_steps
if self.rampup_ratio is not None:
halflife_steps = min(halflife_steps, cur_step / 1e3 * self.rampup_ratio)
beta = 0.5 ** (batch_size / max(halflife_steps * 1e3, 1e-6))
self.batch_size = batch_size
beta = self.get_current_beta(cur_step)

for name, p_ema in self.ema_model.named_parameters():
p_src = self.src_params.get(name, None)
Expand Down
Loading