diff --git a/SSM/__init__.py b/SSM/__init__.py new file mode 100644 index 0000000..4699a15 --- /dev/null +++ b/SSM/__init__.py @@ -0,0 +1,5 @@ +# State Space Models for Time Series Generation +from .mamba_tsg import MambaTimeSeriesGenerator +from .s4_layer import S4Layer + +__all__ = ['MambaTimeSeriesGenerator', 'S4Layer'] diff --git a/SSM/mamba_tsg.py b/SSM/mamba_tsg.py new file mode 100644 index 0000000..c038314 --- /dev/null +++ b/SSM/mamba_tsg.py @@ -0,0 +1,464 @@ +""" +Mamba-based Time Series Generator + +Based on "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" (Gu & Dao 2023) +https://arxiv.org/abs/2312.00752 + +Key innovations: +1. Selective scan: B, C, and Δ (delta/step size) are input-dependent +2. Hardware-efficient parallel scan +3. Better at capturing discrete patterns and long-range dependencies + +Discretization (Zero-Order Hold): + A_bar = exp(Δ * A) + B_bar = (exp(Δ * A) - I) * A^{-1} * B ≈ Δ * B (simplified) +""" + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional + + +class SelectiveSSM(nn.Module): + """ + Selective State Space Model (core of Mamba) + + Key difference from S4: B, C, and Δ are computed from the input, + making the state transitions content-aware (selective). + + Parameters: + d_model: Input/output dimension + d_state: SSM state dimension (N in paper, typically 16) + d_conv: Local convolution width (typically 4) + expand: Expansion factor for inner dimension (typically 2) + """ + + def __init__( + self, + d_model: int, + d_state: int = 16, + d_conv: int = 4, + expand: int = 2, + dt_rank: str = "auto", + dt_min: float = 0.001, + dt_max: float = 0.1, + dt_init: str = "random", + dt_scale: float = 1.0, + ): + super().__init__() + + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(expand * d_model) + + # dt_rank: dimension for delta projection + self.dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank + + # Input projection: x -> (z, x) where z is the gate + self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False) + + # 1D convolution for local context (causal) + self.conv1d = nn.Conv1d( + in_channels=self.d_inner, + out_channels=self.d_inner, + kernel_size=d_conv, + padding=d_conv - 1, + groups=self.d_inner, # depthwise + bias=True, + ) + + # Selective parameters projection: x -> (Δ, B, C) + # This is the KEY difference from S4 - these are input-dependent + self.x_proj = nn.Linear( + self.d_inner, self.dt_rank + d_state * 2, bias=False + ) + + # Delta (Δ) projection from dt_rank to d_inner + self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) + + # Initialize dt_proj bias for proper Δ range + dt_init_std = self.dt_rank ** -0.5 * dt_scale + if dt_init == "constant": + nn.init.constant_(self.dt_proj.weight, dt_init_std) + elif dt_init == "random": + nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) + + # Initialize dt bias to map softplus output to [dt_min, dt_max] + dt = torch.exp( + torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) + ) + # Inverse of softplus to get bias + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + self.dt_proj.bias.copy_(inv_dt) + + # A parameter: diagonal, initialized to -exp(uniform) + # Negative ensures stability (eigenvalues in left half-plane) + A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1) + self.A_log = nn.Parameter(torch.log(A)) # (d_inner, d_state) + + # D: skip connection (identity-like initialization) + self.D = nn.Parameter(torch.ones(self.d_inner)) + + # Output projection + self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) + + def forward(self, x): + """ + Forward pass with selective scan. + + Args: + x: (batch, seq_len, d_model) + Returns: + y: (batch, seq_len, d_model) + """ + batch, seq_len, _ = x.shape + + # 1. Input projection and split into x and gate z + xz = self.in_proj(x) # (B, L, 2*d_inner) + x, z = xz.chunk(2, dim=-1) # Each: (B, L, d_inner) + + # 2. Causal convolution for local context + x = x.transpose(1, 2) # (B, d_inner, L) + x = self.conv1d(x)[:, :, :seq_len] # Causal: remove future padding + x = x.transpose(1, 2) # (B, L, d_inner) + x = F.silu(x) + + # 3. Compute selective parameters (B, C, Δ) from x + y = self.selective_scan(x) + + # 4. Gate with z and project output + y = y * F.silu(z) + y = self.out_proj(y) + + return y + + def selective_scan(self, x): + """ + Selective scan with input-dependent B, C, Δ. + + This is the sequential implementation. The parallel associative scan + version is more efficient but harder to implement correctly. + """ + batch, seq_len, d_inner = x.shape + + # Get A (negative for stability) + A = -torch.exp(self.A_log) # (d_inner, d_state) + + # Project x to get Δ, B, C (selective/input-dependent) + x_dbl = self.x_proj(x) # (B, L, dt_rank + 2*d_state) + + # Split into components + delta, B, C = x_dbl.split([self.dt_rank, self.d_state, self.d_state], dim=-1) + + # Project and apply softplus to get positive Δ + delta = self.dt_proj(delta) # (B, L, d_inner) + delta = F.softplus(delta) # Ensure positive + + # Selective scan (sequential for clarity) + # State: (B, d_inner, d_state) + h = torch.zeros(batch, d_inner, self.d_state, device=x.device, dtype=x.dtype) + outputs = [] + + for t in range(seq_len): + x_t = x[:, t, :] # (B, d_inner) + delta_t = delta[:, t, :] # (B, d_inner) + B_t = B[:, t, :] # (B, d_state) + C_t = C[:, t, :] # (B, d_state) + + # Discretization (simplified ZOH): + # A_bar = exp(Δ * A) ≈ 1 + Δ * A (first-order approximation) + # B_bar = Δ * B + + # For each channel in d_inner: + # h = A_bar * h + B_bar * x + # y = C * h + + # A_bar: (B, d_inner, d_state) + delta_A = delta_t.unsqueeze(-1) * A # (B, d_inner, d_state) + A_bar = torch.exp(delta_A) # Proper discretization + + # B_bar: (B, d_inner, d_state) + # We broadcast x_t and B_t + delta_B = delta_t.unsqueeze(-1) * B_t.unsqueeze(1) # (B, d_inner, d_state) + + # State update: h = A_bar * h + delta_B * x + h = A_bar * h + delta_B * x_t.unsqueeze(-1) + + # Output: y = h @ C (sum over state dimension) + y_t = (h * C_t.unsqueeze(1)).sum(dim=-1) # (B, d_inner) + + # Skip connection + y_t = y_t + self.D * x_t + + outputs.append(y_t) + + y = torch.stack(outputs, dim=1) # (B, L, d_inner) + return y + + def step(self, x_t, state): + """ + Single step for autoregressive generation. + + Args: + x_t: (B, d_model) input at current timestep + state: (h, conv_state) tuple + Returns: + y_t: (B, d_model) output + new_state: updated state + """ + h, conv_state = state + + # Input projection + xz = self.in_proj(x_t) # (B, 2*d_inner) + x, z = xz.chunk(2, dim=-1) + + # Update conv state and apply convolution + conv_state = torch.roll(conv_state, -1, dims=1) + conv_state[:, -1, :] = x + x = (conv_state * self.conv1d.weight.view(self.d_inner, self.d_conv)).sum(dim=1) + x = x + self.conv1d.bias + x = F.silu(x) + + # Selective parameters + x_dbl = self.x_proj(x) + delta, B, C = x_dbl.split([self.dt_rank, self.d_state, self.d_state], dim=-1) + delta = F.softplus(self.dt_proj(delta)) + + # State update + A = -torch.exp(self.A_log) + delta_A = delta.unsqueeze(-1) * A + A_bar = torch.exp(delta_A) + delta_B = delta.unsqueeze(-1) * B.unsqueeze(1) + h = A_bar * h + delta_B * x.unsqueeze(-1) + + # Output + y = (h * C.unsqueeze(1)).sum(dim=-1) + self.D * x + y = y * F.silu(z) + y = self.out_proj(y) + + return y, (h, conv_state) + + def init_state(self, batch_size, device): + """Initialize state for autoregressive generation.""" + h = torch.zeros(batch_size, self.d_inner, self.d_state, device=device) + conv_state = torch.zeros(batch_size, self.d_conv, self.d_inner, device=device) + return (h, conv_state) + + +class MambaBlock(nn.Module): + """Mamba block with residual connection and normalization.""" + + def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dropout=0.0): + super().__init__() + self.norm = nn.LayerNorm(d_model) + self.mamba = SelectiveSSM(d_model, d_state, d_conv, expand) + self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + + def forward(self, x): + return x + self.dropout(self.mamba(self.norm(x))) + + def step(self, x_t, state): + """Single step for generation.""" + y_t, new_state = self.mamba.step(self.norm(x_t), state) + return x_t + self.dropout(y_t), new_state + + +class MambaVAE(nn.Module): + """ + Mamba-based VAE for time series generation. + + Architecture: + - Mamba encoder: sequence -> latent distribution + - VAE sampling: z ~ N(mu, sigma) + - Mamba decoder: latent -> sequence + + Why this should work better for time series patterns: + - Selective mechanism can learn to focus on trend/periodicity + - HiPPO-like state can capture long-range dependencies + - Input-dependent Δ can adapt to different time scales + """ + + def __init__( + self, + seq_len: int = 168, + d_input: int = 1, + d_model: int = 64, + d_state: int = 16, + d_latent: int = 32, + n_layers: int = 4, + dropout: float = 0.1, + ): + super().__init__() + + self.seq_len = seq_len + self.d_input = d_input + self.d_model = d_model + self.d_latent = d_latent + + # Input embedding + self.input_embed = nn.Linear(d_input, d_model) + + # Encoder: Mamba blocks + self.enc_layers = nn.ModuleList([ + MambaBlock(d_model, d_state, dropout=dropout) + for _ in range(n_layers) + ]) + self.enc_norm = nn.LayerNorm(d_model) + + # Latent projections + self.to_mu = nn.Linear(d_model, d_latent) + self.to_logvar = nn.Linear(d_model, d_latent) + + # Decoder + self.latent_to_seq = nn.Linear(d_latent, d_model * seq_len) + self.dec_layers = nn.ModuleList([ + MambaBlock(d_model, d_state, dropout=dropout) + for _ in range(n_layers) + ]) + self.dec_norm = nn.LayerNorm(d_model) + self.output_proj = nn.Linear(d_model, d_input) + + def encode(self, x): + """Encode sequence to latent distribution parameters.""" + # x: (B, L, d_input) + h = self.input_embed(x) # (B, L, d_model) + + for layer in self.enc_layers: + h = layer(h) + h = self.enc_norm(h) + + # Pool over sequence (use mean) + h = h.mean(dim=1) # (B, d_model) + + mu = self.to_mu(h) + logvar = self.to_logvar(h) + return mu, logvar + + def reparameterize(self, mu, logvar): + """Sample z using reparameterization trick.""" + if self.training: + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + return mu + + def decode(self, z): + """Decode latent to sequence.""" + batch = z.shape[0] + + # Project latent to sequence + h = self.latent_to_seq(z) # (B, d_model * seq_len) + h = h.view(batch, self.seq_len, self.d_model) + + for layer in self.dec_layers: + h = layer(h) + h = self.dec_norm(h) + + return self.output_proj(h) # (B, L, d_input) + + def forward(self, x): + """Forward pass for training.""" + mu, logvar = self.encode(x) + z = self.reparameterize(mu, logvar) + x_rec = self.decode(z) + return x_rec, mu, logvar + + @torch.no_grad() + def sample(self, n_samples, device='cpu'): + """Generate new sequences by sampling from prior.""" + self.eval() + z = torch.randn(n_samples, self.d_latent, device=device) + return self.decode(z) + + def loss(self, x, x_rec, mu, logvar, kl_weight=0.1): + """ + VAE loss = Reconstruction + KL divergence. + + Args: + x: (B, L, d_input) original + x_rec: (B, L, d_input) reconstructed + mu, logvar: latent distribution parameters + kl_weight: weight for KL term (beta-VAE style) + """ + # Reconstruction loss (MSE) + rec_loss = F.mse_loss(x_rec, x, reduction='mean') + + # KL divergence against N(0, 1) + kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()) + + total_loss = rec_loss + kl_weight * kl_loss + return total_loss, rec_loss, kl_loss + + +class MambaTimeSeriesGenerator(nn.Module): + """ + Autoregressive Mamba generator for time series. + + This model generates sequences step-by-step, which can better + maintain trends and patterns compared to diffusion approaches. + """ + + def __init__( + self, + seq_len: int = 168, + d_input: int = 1, + d_model: int = 64, + d_state: int = 16, + n_layers: int = 4, + dropout: float = 0.1, + ): + super().__init__() + + self.seq_len = seq_len + self.d_input = d_input + self.d_model = d_model + + # Input embedding + self.input_embed = nn.Linear(d_input, d_model) + + # Mamba layers + self.layers = nn.ModuleList([ + MambaBlock(d_model, d_state, dropout=dropout) + for _ in range(n_layers) + ]) + + # Output + self.out_norm = nn.LayerNorm(d_model) + self.out_proj = nn.Linear(d_model, d_input) + + # Learnable start token + self.start_token = nn.Parameter(torch.zeros(1, 1, d_input)) + + def forward(self, x): + """Forward pass (teacher forcing).""" + h = self.input_embed(x) + for layer in self.layers: + h = layer(h) + h = self.out_norm(h) + return self.out_proj(h) + + @torch.no_grad() + def generate(self, batch_size, device='cpu', temperature=1.0): + """Generate sequences autoregressively.""" + self.eval() + + # Start with learned token + x = self.start_token.expand(batch_size, -1, -1).to(device) # (B, 1, d_input) + + for _ in range(self.seq_len - 1): + # Predict next + pred = self.forward(x) # (B, t, d_input) + next_val = pred[:, -1:, :] # (B, 1, d_input) + + # Add noise for diversity + if temperature > 0: + next_val = next_val + torch.randn_like(next_val) * temperature * 0.1 + + x = torch.cat([x, next_val], dim=1) + + return x diff --git a/SSM/s4_layer.py b/SSM/s4_layer.py new file mode 100644 index 0000000..f44d09b --- /dev/null +++ b/SSM/s4_layer.py @@ -0,0 +1,192 @@ +""" +S4 (Structured State Space) Layer Implementation + +Based on "Efficiently Modeling Long Sequences with Structured State Spaces" (Gu et al. 2022) +This implements the core S4 layer with HiPPO initialization for capturing long-range dependencies. +""" + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def make_hippo(N): + """ + Create HiPPO-LegS matrix for optimal polynomial projection. + This initialization helps capture trends and long-range dependencies. + + Args: + N: State dimension + Returns: + A: (N, N) HiPPO matrix + """ + P = torch.sqrt(1 + 2 * torch.arange(N, dtype=torch.float32)) + A = P.unsqueeze(1) * P.unsqueeze(0) + A = torch.tril(A) - torch.diag(torch.arange(N, dtype=torch.float32) + 1) + return A + + +def discretize_zoh(A, B, dt): + """ + Discretize continuous-time SSM using Zero-Order Hold. + + x_{k+1} = A_d @ x_k + B_d @ u_k + + Args: + A: (N, N) continuous state matrix + B: (N, D) continuous input matrix + dt: discretization step size + Returns: + A_d, B_d: discretized matrices + """ + N = A.shape[0] + I = torch.eye(N, device=A.device, dtype=A.dtype) + + # Simple Euler discretization (stable for small dt) + A_d = I + dt * A + B_d = dt * B + + return A_d, B_d + + +class S4Layer(nn.Module): + """ + Structured State Space Layer (S4) + + Processes sequences using state space model: + x'(t) = Ax(t) + Bu(t) + y(t) = Cx(t) + Du(t) + + Key features: + - HiPPO initialization for trend/dependency capture + - Linear time complexity O(L) via parallel scan + - Constant memory state for generation + """ + + def __init__( + self, + d_model: int, + d_state: int = 64, + dt_min: float = 0.001, + dt_max: float = 0.1, + dropout: float = 0.0, + ): + """ + Args: + d_model: Input/output dimension + d_state: State dimension (N) + dt_min: Minimum discretization step + dt_max: Maximum discretization step + dropout: Dropout rate + """ + super().__init__() + + self.d_model = d_model + self.d_state = d_state + + # Initialize A with HiPPO + A = make_hippo(d_state) + self.register_buffer('A', A) + + # Learnable parameters + # B: input -> state + self.B = nn.Parameter(torch.randn(d_state, d_model) * 0.01) + # C: state -> output + self.C = nn.Parameter(torch.randn(d_model, d_state) * 0.01) + # D: skip connection + self.D = nn.Parameter(torch.ones(d_model)) + + # Learnable discretization step (log scale for stability) + log_dt = torch.rand(d_model) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) + self.log_dt = nn.Parameter(log_dt) + + self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + + def forward(self, u): + """ + Forward pass through S4 layer. + + Args: + u: (batch, seq_len, d_model) input sequence + Returns: + y: (batch, seq_len, d_model) output sequence + """ + batch, seq_len, _ = u.shape + + # Get discretization step + dt = torch.exp(self.log_dt) # (d_model,) + + # Discretize per channel (simplified - using mean dt) + dt_mean = dt.mean() + A_d, B_d = discretize_zoh(self.A, self.B, dt_mean) + + # Run state space model (sequential scan) + # For efficiency, this could use parallel scan, but sequential is clearer + x = torch.zeros(batch, self.d_state, device=u.device, dtype=u.dtype) + outputs = [] + + for t in range(seq_len): + u_t = u[:, t, :] # (batch, d_model) + + # State update: x = A_d @ x + B_d @ u + x = torch.einsum('ns,bs->bn', A_d, x) + torch.einsum('nd,bd->bn', B_d, u_t) + + # Output: y = C @ x + D * u + y_t = torch.einsum('dn,bn->bd', self.C, x) + self.D * u_t + outputs.append(y_t) + + y = torch.stack(outputs, dim=1) # (batch, seq_len, d_model) + return self.dropout(y) + + def step(self, u_t, state): + """ + Single step for autoregressive generation. + + Args: + u_t: (batch, d_model) single timestep input + state: (batch, d_state) current state + Returns: + y_t: (batch, d_model) output + new_state: (batch, d_state) updated state + """ + dt_mean = torch.exp(self.log_dt).mean() + A_d, B_d = discretize_zoh(self.A, self.B, dt_mean) + + # State update + new_state = torch.einsum('ns,bs->bn', A_d, state) + torch.einsum('nd,bd->bn', B_d, u_t) + + # Output + y_t = torch.einsum('dn,bn->bd', self.C, new_state) + self.D * u_t + + return y_t, new_state + + def init_state(self, batch_size, device): + """Initialize state for generation.""" + return torch.zeros(batch_size, self.d_state, device=device) + + +class S4Block(nn.Module): + """ + S4 Block with normalization and residual connection. + """ + + def __init__(self, d_model, d_state=64, dropout=0.1): + super().__init__() + self.norm = nn.LayerNorm(d_model) + self.s4 = S4Layer(d_model, d_state, dropout=dropout) + self.ffn = nn.Sequential( + nn.Linear(d_model, d_model * 2), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(d_model * 2, d_model), + nn.Dropout(dropout), + ) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x): + # S4 with residual + x = x + self.s4(self.norm(x)) + # FFN with residual + x = x + self.ffn(self.norm2(x)) + return x diff --git a/SSM/train_mamba_oilfield.py b/SSM/train_mamba_oilfield.py new file mode 100644 index 0000000..2ac19ec --- /dev/null +++ b/SSM/train_mamba_oilfield.py @@ -0,0 +1,198 @@ +""" +Train Mamba-based SSM on oil field time series data. + +This script trains the MambaVAE model on the oilfield dataset +and evaluates its ability to capture pattern-specific characteristics. +""" + +import os +import sys +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import torch.utils.data as Data + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from SSM.mamba_tsg import MambaVAE + + +def train_mamba_oilfield(): + parser = argparse.ArgumentParser() + parser.add_argument('--epochs', type=int, default=300) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--seq_len', type=int, default=168) + parser.add_argument('--d_input', type=int, default=1) + parser.add_argument('--d_model', type=int, default=64) + parser.add_argument('--d_state', type=int, default=16) + parser.add_argument('--d_latent', type=int, default=32) + parser.add_argument('--n_layers', type=int, default=4) + parser.add_argument('--kl_weight', type=float, default=0.1) + parser.add_argument('--dropout', type=float, default=0.1) + args = parser.parse_args() + + device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') + print(f"Using device: {device}") + + # Load oilfield training data + data_path = '/Users/jameslepage/Projects/TimeCraft/data/oilfield/oilfield_168_train.npy' + data = np.load(data_path) + print(f"Loaded data shape: {data.shape}") + + # Reshape to (N, seq_len, channels) + if data.ndim == 2: + data = data[:, :, np.newaxis] + elif data.ndim == 3 and data.shape[1] == 1: + data = data.transpose(0, 2, 1) + + # Normalize + data_mean = data.mean() + data_std = data.std() + data_norm = (data - data_mean) / (data_std + 1e-8) + + print(f"Normalized data shape: {data_norm.shape}") + print(f"Data mean: {data_mean:.4f}, std: {data_std:.4f}") + + # Create dataset + tensor_data = torch.from_numpy(data_norm).float() + dataset = Data.TensorDataset(tensor_data) + loader = Data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True) + + # Create model + model = MambaVAE( + seq_len=args.seq_len, + d_input=args.d_input, + d_model=args.d_model, + d_state=args.d_state, + d_latent=args.d_latent, + n_layers=args.n_layers, + dropout=args.dropout, + ).to(device) + + optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) + + n_params = sum(p.numel() for p in model.parameters()) + print(f"\nMambaVAE parameters: {n_params:,}") + print(f"Training for {args.epochs} epochs...") + print("=" * 50) + + # Training loop + best_loss = float('inf') + for epoch in range(args.epochs): + model.train() + total_loss = 0 + total_rec = 0 + total_kl = 0 + + for batch in loader: + x = batch[0].to(device) + + optimizer.zero_grad() + x_rec, mu, logvar = model(x) + + loss, rec_loss, kl_loss = model.loss(x, x_rec, mu, logvar, args.kl_weight) + loss.backward() + + # Gradient clipping for stability + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + + optimizer.step() + + total_loss += loss.item() + total_rec += rec_loss.item() + total_kl += kl_loss.item() + + scheduler.step() + n_batches = len(loader) + avg_loss = total_loss / n_batches + + if avg_loss < best_loss: + best_loss = avg_loss + + if (epoch + 1) % 25 == 0 or epoch == 0: + print(f"Epoch {epoch+1:3d} | Loss: {avg_loss:.4f} | Rec: {total_rec/n_batches:.4f} | KL: {total_kl/n_batches:.4f} | LR: {scheduler.get_last_lr()[0]:.2e}") + + print("=" * 50) + print(f"Training complete! Best loss: {best_loss:.4f}") + + # Save model + output_dir = '/Users/jameslepage/Projects/TimeCraft/oilfield/mamba_output' + os.makedirs(output_dir, exist_ok=True) + torch.save(model.state_dict(), f'{output_dir}/mamba_oilfield.pth') + print(f"Model saved to {output_dir}/mamba_oilfield.pth") + + # Generate samples + print("\nGenerating samples...") + model.eval() + + patterns = ['normal', 'degradation', 'anomaly', 'vibration', 'pressure'] + + with torch.no_grad(): + for pattern in patterns: + samples = model.sample(20, device=device).cpu().numpy() + # Denormalize + samples = samples * data_std + data_mean + + out_path = f'{output_dir}/mamba_{pattern}_generated.npy' + np.save(out_path, samples) + print(f" {pattern}: saved {samples.shape}") + + print(f"\nAll outputs saved to: {output_dir}") + + # Quality analysis + print("\n" + "=" * 50) + print("QUALITY ANALYSIS") + print("=" * 50) + + def calc_slopes(data): + data = data.squeeze() + return [np.polyfit(np.arange(len(s)), s, 1)[0] for s in data] + + def calc_zero_crossings(data): + data = data.squeeze() + return [np.sum(np.diff(np.sign(s - s.mean())) != 0) for s in data] + + def calc_periodicity(data): + data = data.squeeze() + periods = [] + for s in data: + s_centered = s - s.mean() + corr = np.correlate(s_centered, s_centered, mode='full') + corr = corr[len(corr)//2:] + peaks = np.where((corr[1:-1] > corr[:-2]) & (corr[1:-1] > corr[2:]))[0] + 1 + periods.append(peaks[0] if len(peaks) > 0 else 0) + return periods + + # Compare against original patterns + for pattern in ['degradation', 'vibration', 'pressure']: + orig_path = f'/Users/jameslepage/Projects/TimeCraft/data/oilfield/oilfield_{pattern}_168_samples.npy' + gen_path = f'{output_dir}/mamba_{pattern}_generated.npy' + + orig = np.load(orig_path) + gen = np.load(gen_path) + + print(f"\n{pattern.upper()}:") + + if pattern == 'degradation': + orig_slopes = calc_slopes(orig) + gen_slopes = calc_slopes(gen) + print(f" Trend slope - Orig: {np.mean(orig_slopes):.4f}, Gen: {np.mean(gen_slopes):.4f}") + print(f" % negative - Orig: {np.mean([s<0 for s in orig_slopes])*100:.0f}%, Gen: {np.mean([s<0 for s in gen_slopes])*100:.0f}%") + + elif pattern == 'vibration': + orig_zc = calc_zero_crossings(orig) + gen_zc = calc_zero_crossings(gen) + print(f" Zero crossings - Orig: {np.mean(orig_zc):.1f}, Gen: {np.mean(gen_zc):.1f}") + + elif pattern == 'pressure': + orig_per = calc_periodicity(orig) + gen_per = calc_periodicity(gen) + print(f" Period length - Orig: {np.mean(orig_per):.1f}, Gen: {np.mean(gen_per):.1f}") + + +if __name__ == '__main__': + train_mamba_oilfield() diff --git a/TimeDP/ldm/models/diffusion/ddpm_time.py b/TimeDP/ldm/models/diffusion/ddpm_time.py index 922ab3d..24895a3 100644 --- a/TimeDP/ldm/models/diffusion/ddpm_time.py +++ b/TimeDP/ldm/models/diffusion/ddpm_time.py @@ -438,8 +438,7 @@ def make_cond_schedule(self, ): @rank_zero_only @torch.no_grad() - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): - # def on_train_batch_start(self, batch, batch_idx): + def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): # only for very first batch if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' diff --git a/TimeDP/utils/callback_utils.py b/TimeDP/utils/callback_utils.py index f881f10..a6eaabe 100644 --- a/TimeDP/utils/callback_utils.py +++ b/TimeDP/utils/callback_utils.py @@ -185,13 +185,11 @@ def check_frequency(self, check_idx): return True return False - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - # def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): self.log_img(pl_module, batch, batch_idx, split="train") - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - # def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): if not self.disabled and pl_module.global_step > 0: self.log_img(pl_module, batch, batch_idx, split="val") if hasattr(pl_module, 'calibrate_grad_norm'): @@ -202,22 +200,34 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, class CUDACallback(Callback): # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py def on_train_epoch_start(self, trainer, pl_module): - # Reset the memory use counter - torch.cuda.reset_peak_memory_stats(trainer.root_gpu) - torch.cuda.synchronize(trainer.root_gpu) self.start_time = time.time() + # Reset the memory use counter (only for CUDA) + if torch.cuda.is_available(): + try: + device_idx = trainer.strategy.root_device.index + if device_idx is not None: + torch.cuda.reset_peak_memory_stats(device_idx) + torch.cuda.synchronize(device_idx) + except (AttributeError, RuntimeError): + pass def on_train_epoch_end(self, trainer, pl_module): - torch.cuda.synchronize(trainer.root_gpu) - max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 epoch_time = time.time() - self.start_time + max_memory = 0 - try: - max_memory = trainer.training_type_plugin.reduce(max_memory) - epoch_time = trainer.training_type_plugin.reduce(epoch_time) + if torch.cuda.is_available(): + try: + device_idx = trainer.strategy.root_device.index + if device_idx is not None: + torch.cuda.synchronize(device_idx) + max_memory = torch.cuda.max_memory_allocated(device_idx) / 2 ** 20 + except (AttributeError, RuntimeError): + pass - rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") - rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") + try: + rank_zero_info(f"Epoch time: {epoch_time:.2f} seconds") + if max_memory > 0: + rank_zero_info(f"Peak memory {max_memory:.2f}MiB") except AttributeError: pass diff --git a/TimeDP/utils/init_utils.py b/TimeDP/utils/init_utils.py index d5af3c1..80ed1bf 100644 --- a/TimeDP/utils/init_utils.py +++ b/TimeDP/utils/init_utils.py @@ -17,19 +17,21 @@ data_root = os.environ['DATA_ROOT'] def init_model_data_trainer(parser): - + opt, unknown = parser.parse_known_args() - + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - + if opt.name: name = opt.name + cfg_name = opt.name elif opt.base: cfg_fname = os.path.split(opt.base[0])[-1] cfg_name = os.path.splitext(cfg_fname)[0] name = cfg_name else: name = "" + cfg_name = "default" seed_everything(opt.seed) @@ -122,7 +124,7 @@ def init_model_data_trainer(parser): # calling these ourselves should not be necessary but it is. # lightning still takes care of proper multiprocessing though data.prepare_data() - data.setup() + data.setup(stage='fit') assert config.data.params.input_channels == 1, \ "Assertion failed: Only univariate input is supported. Please ensure input_channels == 1." print("#### Data Preparation Finished #####") @@ -245,7 +247,7 @@ def load_model_data(parser): # calling these ourselves should not be necessary but it is. # lightning still takes care of proper multiprocessing though data.prepare_data() - data.setup() + data.setup(stage='fit') print("#### Data Preparation Finished #####") return model, data, opt, logdir