-
Notifications
You must be signed in to change notification settings - Fork 7
Dataset Module
Module location: mmai25_hackathon/dataset.py
This module is the crux of the hackathon: it defines simple, extensible base classes that make all your modality loaders play nicely together. Use these to build consistent, composable datasets, plug in flexible sampling, and batch graph and non‑graph data with a single DataLoader.
- BaseDataset: Minimal abstract dataset with clear extension points
- BaseDataLoader: Drop‑in DataLoader for tensors and PyTorch Geometric graphs
- BaseSampler: Template for custom (multi)modal sampling strategies
Together, these are the core interfaces teams can target so modules remain interoperable.
- Required: implement
__len__and__getitem__. - Optional: override
extra_repr()for nicerrepr. - Optional: implement
__add__(self, other)to compose modalities (e.g., align on shared IDs and return merged samples). - Optional:
@classmethod prepare_data(...)for downloading, preprocessing, and split creation.
Design tip: Return a single sample as a plain Python dict keyed by modality/component, e.g., {"image": ..., "text": ..., "label": ...}. For graphs, return a PyG Data or a dict that includes one.
Subclass of torch_geometric.data.DataLoader, so it handles:
- Pure tensor batches (acts like
torch.utils.data.DataLoader). - PyG graphs (
torch_geometric.data.Data), including mini‑batches.
You can use it directly for both graph and non‑graph datasets.
Subclass of torch.utils.data.Sampler. Extend to implement custom sampling, e.g.,
- Balanced sampling across labels or sites
- Coordinated sampling across modalities before collation
Note: These are illustrative patterns — adapt or replace freely.
from typing import Dict, Any, Optional, Sequence
import pandas as pd
from mmai25_hackathon.dataset import BaseDataset, BaseDataLoader
class TabularDataset(BaseDataset):
"""
Generic tabular dataset returning dict-shaped samples. This keeps interfaces
consistent across teams and works with default PyTorch/PyG collation.
"""
def __init__(
self,
frame: pd.DataFrame,
feature_cols: Optional[Sequence[str]] = None,
label_col: Optional[str] = None,
id_col: Optional[str] = None,
):
# Store frame and optional column config
self.frame = frame.reset_index(drop=True)
self.label_col = label_col
self.id_col = id_col
# If no feature selection is given, use all columns except label
if feature_cols is None:
exclude = {label_col} if label_col else set()
self.feature_cols = [c for c in self.frame.columns if c not in exclude]
else:
self.feature_cols = list(feature_cols)
def __len__(self) -> int:
# Number of rows = number of samples
return len(self.frame)
def __getitem__(self, idx: int) -> Dict[str, Any]:
# Return a dict with features (+ optional label/id). Default collation
# will batch dicts into dicts of tensors/lists when possible.
row = self.frame.iloc[idx]
sample: Dict[str, Any] = {"features": {c: row[c] for c in self.feature_cols}}
if self.label_col is not None:
sample["label"] = row[self.label_col]
if self.id_col is not None:
sample["id"] = row[self.id_col]
return sample
# Usage (illustrative): build dataset and iterate in mini-batches
# ds = TabularDataset(df, feature_cols=["age", "hr", "bmi"], label_col="Y", id_col="subject_id")
# loader = BaseDataLoader(ds, batch_size=16, shuffle=True)
# for batch in loader: # batch is a dict with keys: features, (label), (id)
# ...import numpy as np
from mmai25_hackathon.dataset import BaseSampler, BaseDataLoader
class StratifiedSampler(BaseSampler):
"""
Tiny example of label-balanced sampling by index.
- Yields indices in an order that forms roughly 50/50 class-balanced
contiguous groups of size `batch_size`.
- Create your DataLoader with the same `batch_size` and pass this sampler.
- DataLoader will ignore `shuffle` when `sampler` is provided.
- This is illustrative, not production-ready (no handling of >2 classes,
class imbalance beyond clipping, or epoch-to-epoch reshuffling strategy).
"""
def __init__(self, labels, batch_size):
self.labels = np.asarray(labels)
self.batch_size = int(batch_size)
def __iter__(self):
# Split indices by label (binary example)
idx_pos = np.where(self.labels == 1)[0]
idx_neg = np.where(self.labels == 0)[0]
# Use the smaller class to form balanced chunks
half = max(1, self.batch_size // 2)
m = min(len(idx_pos), len(idx_neg))
for i in range(0, m, half):
# Build a balanced group of indices, then yield them sequentially.
# DataLoader will take the next `batch_size` yielded indices as a batch,
# making each batch approximately balanced.
batch = np.concatenate([idx_pos[i:i+half], idx_neg[i:i+half]])
np.random.shuffle(batch) # shuffle within the group
yield from batch.tolist()
# Optionally, you could yield remaining majority-class indices here if desired.
def __len__(self):
# Return the dataset length; DataLoader uses this for epoch sizing.
return len(self.labels)
# Usage (illustrative): create a DataLoader with this sampler
# sampler = StratifiedSampler(labels=labels_array, batch_size=16)
# loader = BaseDataLoader(ds, batch_size=16, sampler=sampler)
# for batch in loader: # batches are roughly class-balanced
# ...Modules under mmai25_hackathon/load_data/ read raw files and return tidy tables or objects (e.g., DataFrames, paths, graphs). The Dataset Module sits on top: use those outputs to implement datasets that expose a clean, unified interface to your training code. This keeps teams free to innovate on modality specifics while staying compatible at the boundaries.
- Prefer dict‑shaped samples for multimodal data; keep keys stable (
"image","echo","ecg","text","label"). - Align modalities by a common ID and document how missing entries are handled (drop vs. impute vs. sparse batches).
- Keep
__getitem__fast; push heavy work to a preprocessing step orprepare_data. - Use
BaseDataLoadereven for non‑graph data; it’s compatible and future‑proof if you add graphs later. - Optional: if you choose to implement
__add__, keeping it small and merging on shared keys works well.
- Must‑haves: subclass
BaseDatasetand implement__len__and__getitem__. - Creative integration: how you combine modalities is up to you. The examples (
__add__, combiners, samplers) are optional patterns, not requirements. - Show scalability/flexibility your way: e.g., composable datasets, simple APIs, or concise docs/examples — pick what best showcases your idea.