Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions dtw/dtw.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def plot(self, type="alignment", **kwargs):


def dtw(x, y=None,
xidx=None, yidx=None,
dist_method="euclidean",
step_pattern="symmetric2",
window_type=None,
Expand Down Expand Up @@ -199,6 +200,10 @@ def dtw(x, y=None,
query vector *or* local cost matrix
y :
reference vector, unused if `x` given as cost matrix
xidx:
time-indexed sequence corresponding to the query vector
yidx:
time-indexed sequence corresponding to the reference vector
dist_method :
pointwise (local) distance function to use.
step_pattern :
Expand Down Expand Up @@ -388,6 +393,13 @@ def dtw(x, y=None,
seed=precm,
win_args=window_args)
gcm = DTW(gcm) # turn into an object, use dot to access properties

if xidx is None or yidx is None:
gcm.xidx = numpy.arange(len(x))
gcm.yidx = numpy.arange(len(y))
else:
gcm.xidx = numpy.array(xidx)
gcm.yidx = numpy.array(yidx)

gcm.N = n
gcm.M = m
Expand Down
14 changes: 7 additions & 7 deletions dtw/dtwPlot.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def dtwPlotTwoWay(d, xts=None, yts=None,
ax.set_xlabel(xlab)
ax.set_ylabel(ylab)

ax.plot(xtimes, numpy.array(xts), color='k', **kwargs)
ax.plot(ytimes, numpy.array(yts) - offset, **kwargs) # Plot with offset applied
ax.plot(d.xidx, numpy.array(xts), color='k', **kwargs)
ax.plot(d.yidx, numpy.array(yts) - offset, **kwargs) # Plot with offset applied

if offset != 0:
# Create an offset axis
Expand All @@ -191,8 +191,8 @@ def dtwPlotTwoWay(d, xts=None, yts=None,

col = []
for i in idx:
col.append([(d.index1[i], xts[d.index1[i]]),
(d.index2[i], -offset + yts[d.index2[i]])])
col.append([(d.xidx[d.index1[i]], xts[d.index1[i]]),
(d.yidx[d.index2[i]], -offset + yts[d.index2[i]])])

lc = mc.LineCollection(col, linewidths=1, linestyles=":", colors=match_col)
ax.add_collection(lc)
Expand Down Expand Up @@ -285,14 +285,14 @@ def dtwPlotThreeWay(d, xts=None, yts=None,
ax = plt.subplot(gs[1])
axq = plt.subplot(gs[3])

axq.plot(nn1, xts) # query, horizontal, bottom
axq.plot(d.xidx, xts) # query, horizontal, bottom
axq.set_xlabel(xlab)

axr.plot(yts, mm1) # ref, vertical
axr.plot(yts, d.yidx) # ref, vertical
axr.invert_xaxis()
axr.set_ylabel(ylab)

ax.plot(d.index1, d.index2)
ax.plot(d.xidx[d.index1], d.yidx[d.index2])

if match_indices is None:
idx = []
Expand Down