From 94145977ce79e32e41c631d7d3d5e90d571ad388 Mon Sep 17 00:00:00 2001 From: delta-cat <120516302+delta-cat@users.noreply.github.com> Date: Tue, 13 Dec 2022 15:29:57 -0500 Subject: [PATCH] add pytorch_directml and change device context --- SimpleMLBenchmark/main.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/SimpleMLBenchmark/main.py b/SimpleMLBenchmark/main.py index 4b038bc..dd377c5 100644 --- a/SimpleMLBenchmark/main.py +++ b/SimpleMLBenchmark/main.py @@ -1,4 +1,5 @@ import torch +import torch_directml # directml support! import random from tqdm import tqdm import numpy as np @@ -25,7 +26,7 @@ total_epochs = 128 batch_size = 64 -device = grab_torch_device(args) +device = torch_directml.device() # this is a variable now, not a string track = Tracker(device) print("Batch size :", batch_size) @@ -96,4 +97,4 @@ print("Logging Complete!, Compiling Results...") print() -track.simple_print() \ No newline at end of file +track.simple_print()