Skip to content

Error when running example MNIST_fishleg_CNN script on GPU #33

@WeiShengL

Description

@WeiShengL

Using "cuda" devices for examples/MNIST_fishleg_CNN.py script (by commenting out line 31) gives an error in fishleg update_aux method about mismatch data type as follows:

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

I pressumed this is because the aux_dataloader is initiated to load data onto cpu and if the model was initiated on a gpu device, we get a mismatch in data.

This could be fix by perhaps adding an additional line of code to move data to specified device.
Original code:

aux_loader = torch.utils.data.DataLoader(
    train_dataset, shuffle=True, batch_size=batch_size
)

New code:

aux_loader = torch.utils.data.DataLoader(
    train_dataset, shuffle=True, batch_size=batch_size, 
    collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x))
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions