From 8c4210cb2f3ce7ddab0b0a4d01a1e6ee3e2ee17c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Ekstr=C3=B6m=20Kelvinius?= <25795024+filipekstrm@users.noreply.github.com> Date: Mon, 1 Sep 2025 15:20:39 +0200 Subject: [PATCH 1/4] Attempting improved training speed with more efficient dataloading --- wyckoff_generation/common/args_and_config.py | 18 +++++++++++++ wyckoff_generation/datasets/dataset.py | 28 +++++++++++++++++--- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/wyckoff_generation/common/args_and_config.py b/wyckoff_generation/common/args_and_config.py index 6210ecc..4ed1739 100644 --- a/wyckoff_generation/common/args_and_config.py +++ b/wyckoff_generation/common/args_and_config.py @@ -17,6 +17,7 @@ "l2_reg": 0.0, "epochs": 1000, "val_interval": 1, + "num_workers": 0, # D3PM "t_max": 1000, "num_elements": 100, @@ -111,6 +112,23 @@ def get_parser(): help="Evaluation interval, number of gradient steps", ) + parser.add_argument( + "--num_workers", + type=int, + default=default_args_dict["num_workers"], + help="num_workers for dataloader", + ) + + parser.add_argument( + "--pin_memory_false", action="store_true", help="Use pin_memory=False" + ) + + parser.add_argument( + "--persistent_workers_false", + action="store_true", + help="Use persistent_workers=False", + ) + parser.add_argument( "--backbone", type=str.lower, help="What type of backbone model" ) diff --git a/wyckoff_generation/datasets/dataset.py b/wyckoff_generation/datasets/dataset.py index 05063dc..9a3d843 100644 --- a/wyckoff_generation/datasets/dataset.py +++ b/wyckoff_generation/datasets/dataset.py @@ -358,9 +358,21 @@ def get(self, idx: int): data.x = torch.chunk(data.x, len(data.wyckoff_set), 0)[ torch.randint(len(data.wyckoff_set), (1,)) ] - data.x_0_dof = data.x[data.zero_dof, 0] - data.x_inf_dof = data.x[~data.zero_dof, 1 : (self.num_elements + 1)] - return data + + return Data( + x_0_dof=data.x[data.zero_dof, 0], + x_inf_dof=data.x[~data.zero_dof, 1 : (self.num_elements + 1)], + edge_index=data.edge_index, + space_group=data.space_group, + multiplicities=data.multiplicities, + num_pos=data.num_pos, + num_nodes=data.num_pos, + zero_dof=data.zero_dof, + num_0_dof=data.num_0_dof, + num_inf_dof=data.num_inf_dof, + wyckoff_pos_idx=data.wyckoff_pos_idx, + degrees_of_freedom=data.degrees_of_freedom, + ) @classmethod def get_dataloaders(cls, config): @@ -371,7 +383,15 @@ def get_dataloaders(cls, config): ) loaders.append( DataLoader( - dataset, batch_size=config["batch_size"], shuffle=split == "train" + dataset, + batch_size=config["batch_size"], + shuffle=split == "train", + num_workers=config["num_workers"], + pin_memory=~config["pin_memory_false"], + persistent_workers=( + ~config["persistent_workers_false"] + and config["num_workers"] > 0 + ), ) ) return loaders From 73e5537fd9646e34c6b0c0596022ac5a384cb8c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Ekstr=C3=B6m=20Kelvinius?= <25795024+filipekstrm@users.noreply.github.com> Date: Mon, 1 Sep 2025 15:38:30 +0200 Subject: [PATCH 2/4] Changed how arguments for pin_memory and persisten_workers work. Now they are called exactly that --- wyckoff_generation/common/args_and_config.py | 9 +++++---- wyckoff_generation/datasets/dataset.py | 5 ++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/wyckoff_generation/common/args_and_config.py b/wyckoff_generation/common/args_and_config.py index 4ed1739..391755e 100644 --- a/wyckoff_generation/common/args_and_config.py +++ b/wyckoff_generation/common/args_and_config.py @@ -120,13 +120,14 @@ def get_parser(): ) parser.add_argument( - "--pin_memory_false", action="store_true", help="Use pin_memory=False" + "--pin_memory", type=bool, default=True, help="pin_memory passed to dataloader" ) parser.add_argument( - "--persistent_workers_false", - action="store_true", - help="Use persistent_workers=False", + "--persistent_workers", + type=bool, + default=True, + help="persistent_workers passed to dataloader in case num_workers>0", ) parser.add_argument( diff --git a/wyckoff_generation/datasets/dataset.py b/wyckoff_generation/datasets/dataset.py index 9a3d843..a715630 100644 --- a/wyckoff_generation/datasets/dataset.py +++ b/wyckoff_generation/datasets/dataset.py @@ -387,10 +387,9 @@ def get_dataloaders(cls, config): batch_size=config["batch_size"], shuffle=split == "train", num_workers=config["num_workers"], - pin_memory=~config["pin_memory_false"], + pin_memory=config["pin_memory"], persistent_workers=( - ~config["persistent_workers_false"] - and config["num_workers"] > 0 + config["persistent_workers"] and config["num_workers"] > 0 ), ) ) From 2b1008998506bd55578c44e165974d6ffd3f765b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Ekstr=C3=B6m=20Kelvinius?= <25795024+filipekstrm@users.noreply.github.com> Date: Mon, 1 Sep 2025 15:45:02 +0200 Subject: [PATCH 3/4] Add num_workers to example in README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7e27e0f..6625407 100644 --- a/README.md +++ b/README.md @@ -13,11 +13,11 @@ See [INSTALL.md](INSTALL.md) for instructions on how to install required package ### Train To train a WyckoffDiff model on WBM, a minimal example is ``` -python main.py --mode train_d3pm --d3pm_transition [uniform/marginal/zeros_init] --logger [none/model_only/local_only/tensorboard/wandb] +python main.py --mode train_d3pm --d3pm_transition [uniform/marginal/zeros_init] --logger [none/model_only/local_only/tensorboard/wandb] --num_workers [NUM_WORKERS] ``` Warning: using logger ```none``` will not save any checkpoints (or anything else), but can be used for, e.g., debugging. -This command will use the default values for all other parameters, which are the ones used in the paper. +This command will use the default values for all other parameters, which are the ones used in the paper. **Note: It is not strictly necessary to set ```num_workers```, and if not it will default to 0. However, in our experience, increasing it can substantially speed up training** ### Generate To generate new data, a minimal example is From 81ce66628aeb7e5770727f4aee5e863f6694263e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Ekstr=C3=B6m=20Kelvinius?= <25795024+filipekstrm@users.noreply.github.com> Date: Thu, 4 Sep 2025 10:04:51 +0200 Subject: [PATCH 4/4] pin_memory and persisten_workers defaults from default_args_dict --- wyckoff_generation/common/args_and_config.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/wyckoff_generation/common/args_and_config.py b/wyckoff_generation/common/args_and_config.py index 391755e..0ba5284 100644 --- a/wyckoff_generation/common/args_and_config.py +++ b/wyckoff_generation/common/args_and_config.py @@ -18,6 +18,8 @@ "epochs": 1000, "val_interval": 1, "num_workers": 0, + "pin_memory": True, + "persistent_workers": True, # D3PM "t_max": 1000, "num_elements": 100, @@ -120,13 +122,16 @@ def get_parser(): ) parser.add_argument( - "--pin_memory", type=bool, default=True, help="pin_memory passed to dataloader" + "--pin_memory", + type=bool, + default=default_args_dict["pin_memory"], + help="pin_memory passed to dataloader", ) parser.add_argument( "--persistent_workers", type=bool, - default=True, + default=default_args_dict["persistent_workers"], help="persistent_workers passed to dataloader in case num_workers>0", )