diff --git a/cortex/_lib/viz.py b/cortex/_lib/viz.py index b080e47..a0282b5 100644 --- a/cortex/_lib/viz.py +++ b/cortex/_lib/viz.py @@ -389,15 +389,15 @@ def save_heatmap(X, out_file=None, caption='', title='', image_id=0): def save_scatter(points, out_file=None, labels=None, caption='', title='', image_id=0): + names = data.DATA_HANDLER.get_label_names() if labels is not None: Y = (labels + 1.5).astype(int) + Y = Y - min(Y) + 1 + if len(names) != max(Y): + names = [names[y] for y in Y] else: Y = None - - names = data.DATA_HANDLER.get_label_names() - Y = Y - min(Y) + 1 - if len(names) != max(Y): - names = ['{}'.format(i + 1) for i in range(max(Y))] + names = None visualizer.scatter( X=points,