-
Notifications
You must be signed in to change notification settings - Fork 8
Open
Description
Using "cuda" devices for examples/MNIST_fishleg_CNN.py script (by commenting out line 31) gives an error in fishleg update_aux method about mismatch data type as follows:
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
I pressumed this is because the aux_dataloader is initiated to load data onto cpu and if the model was initiated on a gpu device, we get a mismatch in data.
This could be fix by perhaps adding an additional line of code to move data to specified device.
Original code:
aux_loader = torch.utils.data.DataLoader(
train_dataset, shuffle=True, batch_size=batch_size
)
New code:
aux_loader = torch.utils.data.DataLoader(
train_dataset, shuffle=True, batch_size=batch_size,
collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x))
)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels