Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ See [INSTALL.md](INSTALL.md) for instructions on how to install required package
### Train
To train a WyckoffDiff model on WBM, a minimal example is
```
python main.py --mode train_d3pm --d3pm_transition [uniform/marginal/zeros_init] --logger [none/model_only/local_only/tensorboard/wandb]
python main.py --mode train_d3pm --d3pm_transition [uniform/marginal/zeros_init] --logger [none/model_only/local_only/tensorboard/wandb] --num_workers [NUM_WORKERS]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
python main.py --mode train_d3pm --d3pm_transition [uniform/marginal/zeros_init] --logger [none/model_only/local_only/tensorboard/wandb] --num_workers [NUM_WORKERS]
python main.py --mode train_d3pm --d3pm_transition [uniform/marginal/zeros_init] --logger [none/model_only/local_only/tensorboard/wandb] --num_workers [NUM_WORKERS] --persistent_workers [True/False] --pin_memory [True/False]

```
Warning: using logger ```none``` will not save any checkpoints (or anything else), but can be used for, e.g., debugging.

This command will use the default values for all other parameters, which are the ones used in the paper.
This command will use the default values for all other parameters, which are the ones used in the paper. **Note: It is not strictly necessary to set ```num_workers```, and if not it will default to 0. However, in our experience, increasing it can substantially speed up training**
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This command will use the default values for all other parameters, which are the ones used in the paper. **Note: It is not strictly necessary to set ```num_workers```, and if not it will default to 0. However, in our experience, increasing it can substantially speed up training**
This command will use the default values for all other parameters, which are the ones used in the paper. **Note: It is not strictly necessary to set ```num_workers```, ``persistent_workers``, and ``pin_memory``. However, in our experience, increasing ``num_workers``, and setting ``persistent_workers=True``, and ``pin_memory=True`` can substantially speed up training.** Optimum ``num_workers`` value depends on your system, we have used the maximum suggested value from PyTorch warning.


### Generate
To generate new data, a minimal example is
Expand Down
24 changes: 24 additions & 0 deletions wyckoff_generation/common/args_and_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
"l2_reg": 0.0,
"epochs": 1000,
"val_interval": 1,
"num_workers": 0,
"pin_memory": True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"pin_memory": True,
"pin_memory": False,

"persistent_workers": True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"persistent_workers": True,
"persistent_workers": False,

# D3PM
"t_max": 1000,
"num_elements": 100,
Expand Down Expand Up @@ -111,6 +114,27 @@ def get_parser():
help="Evaluation interval, number of gradient steps",
)

parser.add_argument(
"--num_workers",
type=int,
default=default_args_dict["num_workers"],
help="num_workers for dataloader",
)

parser.add_argument(
"--pin_memory",
type=bool,
default=default_args_dict["pin_memory"],
help="pin_memory passed to dataloader",
)

parser.add_argument(
"--persistent_workers",
type=bool,
default=default_args_dict["persistent_workers"],
help="persistent_workers passed to dataloader in case num_workers>0",
)

parser.add_argument(
"--backbone", type=str.lower, help="What type of backbone model"
)
Expand Down
27 changes: 23 additions & 4 deletions wyckoff_generation/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,21 @@ def get(self, idx: int):
data.x = torch.chunk(data.x, len(data.wyckoff_set), 0)[
torch.randint(len(data.wyckoff_set), (1,))
]
data.x_0_dof = data.x[data.zero_dof, 0]
data.x_inf_dof = data.x[~data.zero_dof, 1 : (self.num_elements + 1)]
return data

return Data(
x_0_dof=data.x[data.zero_dof, 0],
x_inf_dof=data.x[~data.zero_dof, 1 : (self.num_elements + 1)],
edge_index=data.edge_index,
space_group=data.space_group,
multiplicities=data.multiplicities,
num_pos=data.num_pos,
num_nodes=data.num_pos,
zero_dof=data.zero_dof,
num_0_dof=data.num_0_dof,
num_inf_dof=data.num_inf_dof,
wyckoff_pos_idx=data.wyckoff_pos_idx,
degrees_of_freedom=data.degrees_of_freedom,
)

@classmethod
def get_dataloaders(cls, config):
Expand All @@ -371,7 +383,14 @@ def get_dataloaders(cls, config):
)
loaders.append(
DataLoader(
dataset, batch_size=config["batch_size"], shuffle=split == "train"
dataset,
batch_size=config["batch_size"],
shuffle=split == "train",
num_workers=config["num_workers"],
pin_memory=config["pin_memory"],
persistent_workers=(
config["persistent_workers"] and config["num_workers"] > 0
),
)
)
return loaders
Expand Down