From 8a7a192b4195f53ef57f84bd6d378e447e2abaaf Mon Sep 17 00:00:00 2001 From: kaneziki <614674490@qq.com> Date: Thu, 1 Feb 2024 20:46:34 +0800 Subject: [PATCH] add xidx and yidx --- dtw/dtw.py | 12 ++++++++++++ dtw/dtwPlot.py | 14 +++++++------- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/dtw/dtw.py b/dtw/dtw.py index 292e896..2a4376a 100644 --- a/dtw/dtw.py +++ b/dtw/dtw.py @@ -128,6 +128,7 @@ def plot(self, type="alignment", **kwargs): def dtw(x, y=None, + xidx=None, yidx=None, dist_method="euclidean", step_pattern="symmetric2", window_type=None, @@ -199,6 +200,10 @@ def dtw(x, y=None, query vector *or* local cost matrix y : reference vector, unused if `x` given as cost matrix +xidx: + time-indexed sequence corresponding to the query vector +yidx: + time-indexed sequence corresponding to the reference vector dist_method : pointwise (local) distance function to use. step_pattern : @@ -388,6 +393,13 @@ def dtw(x, y=None, seed=precm, win_args=window_args) gcm = DTW(gcm) # turn into an object, use dot to access properties + + if xidx is None or yidx is None: + gcm.xidx = numpy.arange(len(x)) + gcm.yidx = numpy.arange(len(y)) + else: + gcm.xidx = numpy.array(xidx) + gcm.yidx = numpy.array(yidx) gcm.N = n gcm.M = m diff --git a/dtw/dtwPlot.py b/dtw/dtwPlot.py index a5b19a0..1ec40aa 100644 --- a/dtw/dtwPlot.py +++ b/dtw/dtwPlot.py @@ -170,8 +170,8 @@ def dtwPlotTwoWay(d, xts=None, yts=None, ax.set_xlabel(xlab) ax.set_ylabel(ylab) - ax.plot(xtimes, numpy.array(xts), color='k', **kwargs) - ax.plot(ytimes, numpy.array(yts) - offset, **kwargs) # Plot with offset applied + ax.plot(d.xidx, numpy.array(xts), color='k', **kwargs) + ax.plot(d.yidx, numpy.array(yts) - offset, **kwargs) # Plot with offset applied if offset != 0: # Create an offset axis @@ -191,8 +191,8 @@ def dtwPlotTwoWay(d, xts=None, yts=None, col = [] for i in idx: - col.append([(d.index1[i], xts[d.index1[i]]), - (d.index2[i], -offset + yts[d.index2[i]])]) + col.append([(d.xidx[d.index1[i]], xts[d.index1[i]]), + (d.yidx[d.index2[i]], -offset + yts[d.index2[i]])]) lc = mc.LineCollection(col, linewidths=1, linestyles=":", colors=match_col) ax.add_collection(lc) @@ -285,14 +285,14 @@ def dtwPlotThreeWay(d, xts=None, yts=None, ax = plt.subplot(gs[1]) axq = plt.subplot(gs[3]) - axq.plot(nn1, xts) # query, horizontal, bottom + axq.plot(d.xidx, xts) # query, horizontal, bottom axq.set_xlabel(xlab) - axr.plot(yts, mm1) # ref, vertical + axr.plot(yts, d.yidx) # ref, vertical axr.invert_xaxis() axr.set_ylabel(ylab) - ax.plot(d.index1, d.index2) + ax.plot(d.xidx[d.index1], d.yidx[d.index2]) if match_indices is None: idx = []