diff --git a/README.md b/README.md index fbfadf5..744bf3a 100644 --- a/README.md +++ b/README.md @@ -13,11 +13,11 @@ See [INSTALL.md](INSTALL.md) for instructions on how to install required package ### Train To train a WyckoffDiff model on WBM, a minimal example is ``` -python main.py --mode train_d3pm --d3pm_transition [uniform/marginal/zeros_init] --logger [none/model_only/local_only/tensorboard/wandb] +python main.py --mode train_d3pm --d3pm_transition [uniform/marginal/zeros_init] --logger [none/model_only/local_only/tensorboard/wandb] --num_workers [NUM_WORKERS] ``` Warning: using logger ```none``` will not save any checkpoints (or anything else), but can be used for, e.g., debugging. -This command will use the default values for all other parameters, which are the ones used in the paper. +This command will use the default values for all other parameters, which are the ones used in the paper. **Note: It is not strictly necessary to set ```num_workers```, and if not it will default to 0. However, in our experience, increasing it can substantially speed up training** ### Generate To generate new data, a minimal example is diff --git a/wyckoff_generation/common/args_and_config.py b/wyckoff_generation/common/args_and_config.py index 6210ecc..0ba5284 100644 --- a/wyckoff_generation/common/args_and_config.py +++ b/wyckoff_generation/common/args_and_config.py @@ -17,6 +17,9 @@ "l2_reg": 0.0, "epochs": 1000, "val_interval": 1, + "num_workers": 0, + "pin_memory": True, + "persistent_workers": True, # D3PM "t_max": 1000, "num_elements": 100, @@ -111,6 +114,27 @@ def get_parser(): help="Evaluation interval, number of gradient steps", ) + parser.add_argument( + "--num_workers", + type=int, + default=default_args_dict["num_workers"], + help="num_workers for dataloader", + ) + + parser.add_argument( + "--pin_memory", + type=bool, + default=default_args_dict["pin_memory"], + help="pin_memory passed to dataloader", + ) + + parser.add_argument( + "--persistent_workers", + type=bool, + default=default_args_dict["persistent_workers"], + help="persistent_workers passed to dataloader in case num_workers>0", + ) + parser.add_argument( "--backbone", type=str.lower, help="What type of backbone model" ) diff --git a/wyckoff_generation/datasets/dataset.py b/wyckoff_generation/datasets/dataset.py index 05063dc..a715630 100644 --- a/wyckoff_generation/datasets/dataset.py +++ b/wyckoff_generation/datasets/dataset.py @@ -358,9 +358,21 @@ def get(self, idx: int): data.x = torch.chunk(data.x, len(data.wyckoff_set), 0)[ torch.randint(len(data.wyckoff_set), (1,)) ] - data.x_0_dof = data.x[data.zero_dof, 0] - data.x_inf_dof = data.x[~data.zero_dof, 1 : (self.num_elements + 1)] - return data + + return Data( + x_0_dof=data.x[data.zero_dof, 0], + x_inf_dof=data.x[~data.zero_dof, 1 : (self.num_elements + 1)], + edge_index=data.edge_index, + space_group=data.space_group, + multiplicities=data.multiplicities, + num_pos=data.num_pos, + num_nodes=data.num_pos, + zero_dof=data.zero_dof, + num_0_dof=data.num_0_dof, + num_inf_dof=data.num_inf_dof, + wyckoff_pos_idx=data.wyckoff_pos_idx, + degrees_of_freedom=data.degrees_of_freedom, + ) @classmethod def get_dataloaders(cls, config): @@ -371,7 +383,14 @@ def get_dataloaders(cls, config): ) loaders.append( DataLoader( - dataset, batch_size=config["batch_size"], shuffle=split == "train" + dataset, + batch_size=config["batch_size"], + shuffle=split == "train", + num_workers=config["num_workers"], + pin_memory=config["pin_memory"], + persistent_workers=( + config["persistent_workers"] and config["num_workers"] > 0 + ), ) ) return loaders