The indicator is always 1.00 ''' batch_threshold = torch.index_select(cls_thresholds, 0, hard_label) indicator = max_prob > batch_threshold '''