diff --git a/SimpleMLBenchmark/main.py b/SimpleMLBenchmark/main.py index 4b038bc..dd377c5 100644 --- a/SimpleMLBenchmark/main.py +++ b/SimpleMLBenchmark/main.py @@ -1,4 +1,5 @@ import torch +import torch_directml # directml support! import random from tqdm import tqdm import numpy as np @@ -25,7 +26,7 @@ total_epochs = 128 batch_size = 64 -device = grab_torch_device(args) +device = torch_directml.device() # this is a variable now, not a string track = Tracker(device) print("Batch size :", batch_size) @@ -96,4 +97,4 @@ print("Logging Complete!, Compiling Results...") print() -track.simple_print() \ No newline at end of file +track.simple_print()