Skip to content

Dataset Module

L. M. Riza Rizky edited this page Sep 13, 2025 · 2 revisions

Core Dataset Module (dataset.py)

Module location: mmai25_hackathon/dataset.py

This module is the crux of the hackathon: it defines simple, extensible base classes that make all your modality loaders play nicely together. Use these to build consistent, composable datasets, plug in flexible sampling, and batch graph and non‑graph data with a single DataLoader.

What it provides

  • BaseDataset: Minimal abstract dataset with clear extension points
  • BaseDataLoader: Drop‑in DataLoader for tensors and PyTorch Geometric graphs
  • BaseSampler: Template for custom (multi)modal sampling strategies

Together, these are the core interfaces teams can target so modules remain interoperable.

Class overview

BaseDataset

  • Required: implement __len__ and __getitem__.
  • Optional: override extra_repr() for nicer repr.
  • Optional: implement __add__(self, other) to compose modalities (e.g., align on shared IDs and return merged samples).
  • Optional: @classmethod prepare_data(...) for downloading, preprocessing, and split creation.

Design tip: Return a single sample as a plain Python dict keyed by modality/component, e.g., {"image": ..., "text": ..., "label": ...}. For graphs, return a PyG Data or a dict that includes one.

BaseDataLoader

Subclass of torch_geometric.data.DataLoader, so it handles:

  • Pure tensor batches (acts like torch.utils.data.DataLoader).
  • PyG graphs (torch_geometric.data.Data), including mini‑batches.

You can use it directly for both graph and non‑graph datasets.

BaseSampler

Subclass of torch.utils.data.Sampler. Extend to implement custom sampling, e.g.,

  • Balanced sampling across labels or sites
  • Coordinated sampling across modalities before collation

Quick start examples

Note: These are illustrative patterns — adapt or replace freely.

1) A simple unimodal dataset (tabular)

from typing import Dict, Any, Optional, Sequence
import pandas as pd
from mmai25_hackathon.dataset import BaseDataset, BaseDataLoader

class TabularDataset(BaseDataset):
    """
    Generic tabular dataset returning dict-shaped samples. This keeps interfaces
    consistent across teams and works with default PyTorch/PyG collation.
    """

    def __init__(
        self,
        frame: pd.DataFrame,
        feature_cols: Optional[Sequence[str]] = None,
        label_col: Optional[str] = None,
        id_col: Optional[str] = None,
    ):
        # Store frame and optional column config
        self.frame = frame.reset_index(drop=True)
        self.label_col = label_col
        self.id_col = id_col
        # If no feature selection is given, use all columns except label
        if feature_cols is None:
            exclude = {label_col} if label_col else set()
            self.feature_cols = [c for c in self.frame.columns if c not in exclude]
        else:
            self.feature_cols = list(feature_cols)

    def __len__(self) -> int:
        # Number of rows = number of samples
        return len(self.frame)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        # Return a dict with features (+ optional label/id). Default collation
        # will batch dicts into dicts of tensors/lists when possible.
        row = self.frame.iloc[idx]
        sample: Dict[str, Any] = {"features": {c: row[c] for c in self.feature_cols}}
        if self.label_col is not None:
            sample["label"] = row[self.label_col]
        if self.id_col is not None:
            sample["id"] = row[self.id_col]
        return sample

# Usage (illustrative): build dataset and iterate in mini-batches
# ds = TabularDataset(df, feature_cols=["age", "hr", "bmi"], label_col="Y", id_col="subject_id")
# loader = BaseDataLoader(ds, batch_size=16, shuffle=True)
# for batch in loader:  # batch is a dict with keys: features, (label), (id)
#     ...

2) Custom sampler (optional)

import numpy as np
from mmai25_hackathon.dataset import BaseSampler, BaseDataLoader

class StratifiedSampler(BaseSampler):
    """
    Tiny example of label-balanced sampling by index.

    - Yields indices in an order that forms roughly 50/50 class-balanced
      contiguous groups of size `batch_size`.
    - Create your DataLoader with the same `batch_size` and pass this sampler.
    - DataLoader will ignore `shuffle` when `sampler` is provided.
    - This is illustrative, not production-ready (no handling of >2 classes,
      class imbalance beyond clipping, or epoch-to-epoch reshuffling strategy).
    """

    def __init__(self, labels, batch_size):
        self.labels = np.asarray(labels)
        self.batch_size = int(batch_size)

    def __iter__(self):
        # Split indices by label (binary example)
        idx_pos = np.where(self.labels == 1)[0]
        idx_neg = np.where(self.labels == 0)[0]

        # Use the smaller class to form balanced chunks
        half = max(1, self.batch_size // 2)
        m = min(len(idx_pos), len(idx_neg))
        for i in range(0, m, half):
            # Build a balanced group of indices, then yield them sequentially.
            # DataLoader will take the next `batch_size` yielded indices as a batch,
            # making each batch approximately balanced.
            batch = np.concatenate([idx_pos[i:i+half], idx_neg[i:i+half]])
            np.random.shuffle(batch)  # shuffle within the group
            yield from batch.tolist()

        # Optionally, you could yield remaining majority-class indices here if desired.

    def __len__(self):
        # Return the dataset length; DataLoader uses this for epoch sizing.
        return len(self.labels)

# Usage (illustrative): create a DataLoader with this sampler
# sampler = StratifiedSampler(labels=labels_array, batch_size=16)
# loader = BaseDataLoader(ds, batch_size=16, sampler=sampler)
# for batch in loader:  # batches are roughly class-balanced
#     ...

How this connects to loaders in load_data/

Modules under mmai25_hackathon/load_data/ read raw files and return tidy tables or objects (e.g., DataFrames, paths, graphs). The Dataset Module sits on top: use those outputs to implement datasets that expose a clean, unified interface to your training code. This keeps teams free to innovate on modality specifics while staying compatible at the boundaries.

Tips and patterns

  • Prefer dict‑shaped samples for multimodal data; keep keys stable ("image", "echo", "ecg", "text", "label").
  • Align modalities by a common ID and document how missing entries are handled (drop vs. impute vs. sparse batches).
  • Keep __getitem__ fast; push heavy work to a preprocessing step or prepare_data.
  • Use BaseDataLoader even for non‑graph data; it’s compatible and future‑proof if you add graphs later.
  • Optional: if you choose to implement __add__, keeping it small and merging on shared keys works well.

Baseline and Creative Freedom

  • Must‑haves: subclass BaseDataset and implement __len__ and __getitem__.
  • Creative integration: how you combine modalities is up to you. The examples (__add__, combiners, samplers) are optional patterns, not requirements.
  • Show scalability/flexibility your way: e.g., composable datasets, simple APIs, or concise docs/examples — pick what best showcases your idea.

Clone this wiki locally