diff --git a/src/deepgraphpose/dataset.py b/src/deepgraphpose/dataset.py index eea8c1c..d039a3c 100644 --- a/src/deepgraphpose/dataset.py +++ b/src/deepgraphpose/dataset.py @@ -596,7 +596,7 @@ def _compute_targets(self): dataset = create_dataset(dlc_config) nt = len(self.idxs['vis']['train']) # number of training frames # assert nt >= 1 - nj = max([dat_.joints[0].shape[0] for dat_ in dataset.data]) + nj = self.nj #Taiga edit 8/31/21: untested. max([dat_.joints[0].shape[0] for dat_ in dataset.data]) stride = dlc_config['stride'] def extract_frame_num(img_path):