-
Notifications
You must be signed in to change notification settings - Fork 67
Open
Description
Hi,
It is about the piece of code shared_dict.
Reproduced here with minor changes to track the memory consumption:
import resource
from multiprocessing import Manager
import torch
from torch.utils.data import Dataset, DataLoader
def print_mem():
m = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024
print("Peak memory (MB):", m)
class MyDataset(Dataset):
def __init__(self, shared_dict, length):
self.shared_dict = shared_dict
self.length = length
def __getitem__(self, index):
if index not in self.shared_dict:
print('Adding data')
self.shared_dict[index] = torch.tensor([float(index)]*100000)
return self.shared_dict[index]
def __len__(self):
return self.length
# Init
manager = Manager()
shared_dict = manager.dict()
dataset = MyDataset(shared_dict, length=100)
loader = DataLoader(
dataset,
batch_size=10,
num_workers=0,
shuffle=False,
pin_memory=False
)
# First loop will add data to the shared_dict
print_mem()
for x in loader:
pass # print(x)
# The second loop will just get the data
print_mem()
for x in loader:
pass # print(x)
# The third loop will just get the data too
print_mem()
for x in loader:
pass # print(x)
print_mem()
To make things easier to understand, the number of workers is set to 0.
If I run this code, the memory keeps increasing after each loop.
If I use the standard dict() instead of manager.dict(), memory does not increase after the first loop, as expected.
Would you know why manager.dict() leads to a memory increase?
Thanks for the help.
Metadata
Metadata
Assignees
Labels
No labels