diff --git a/src/cellflow/model/_cellflow.py b/src/cellflow/model/_cellflow.py index a535ddf7..60362d63 100644 --- a/src/cellflow/model/_cellflow.py +++ b/src/cellflow/model/_cellflow.py @@ -520,6 +520,7 @@ def train( callbacks: Sequence[BaseCallback] = [], monitor_metrics: Sequence[str] = [], out_of_core_dataloading: bool = False, + seed: int = 0, ) -> None: """Train the model. @@ -547,6 +548,8 @@ def train( out_of_core_dataloading If :obj:`True`, use out-of-core dataloading. Uses the :class:`cellflow.data._dataloader.OOCTrainSampler` to load data that does not fit into GPU memory. + seed + Random seed, only used if `out_of_core_dataloading` is `True`. Returns ------- @@ -562,7 +565,7 @@ def train( raise ValueError("Model not initialized. Please call `prepare_model` first.") if out_of_core_dataloading: - self._dataloader = OOCTrainSampler(data=self.train_data, batch_size=batch_size) + self._dataloader = OOCTrainSampler(data=self.train_data, seed=seed, batch_size=batch_size) else: self._dataloader = TrainSampler(data=self.train_data, batch_size=batch_size) validation_loaders = {k: ValidationSampler(v) for k, v in self.validation_data.items() if k != "predict_kwargs"}