Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 58 additions & 32 deletions sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,23 @@ def iou(bb_test,bb_gt):
+ (bb_gt[2]-bb_gt[0])*(bb_gt[3]-bb_gt[1]) - wh)
return(o)

def colinearity(det,hist):
'''
det - current detection
hist - last 2 mean detections
'''
dims = det[2:4] - det[:2]
diag = np.sqrt(sum(dims**2))
a = det[:2] + dims/2 - hist[-2]
b = hist[-1] - hist[-2]
len1 = np.sqrt(sum(a*a))
len2 = np.sqrt(sum(b*b))
ratio = len2/float(len1)
maxdist = diag*(min(dims)/max(dims)+1)
maxval = b.dot(b)
a *= ratio
return a.dot(b)/float(maxval) if maxval and maxdist > len1 else 0

def convert_bbox_to_z(bbox):
"""
Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
Expand Down Expand Up @@ -100,8 +117,9 @@ def __init__(self,bbox):
self.hits = 0
self.hit_streak = 0
self.age = 0
self.cthist = [self.kf.x[:2].ravel()]

def update(self,bbox):
def update(self, bbox, n):
"""
Updates the state vector with observed bbox.
"""
Expand All @@ -110,6 +128,8 @@ def update(self,bbox):
self.hits += 1
self.hit_streak += 1
self.kf.update(convert_bbox_to_z(bbox))
self.cthist.append(bbox[:2] + (bbox[2:4] - bbox[:2]) / 2)
self.cthist = self.cthist[-n:]

def predict(self):
"""
Expand All @@ -121,6 +141,7 @@ def predict(self):
self.age += 1
if(self.time_since_update>0):
self.hit_streak = 0
self.kf.P *= 1.2 # we may be lost, increase uncertainty and responsiveness
self.time_since_update += 1
self.history.append(convert_x_to_bbox(self.kf.x))
return self.history[-1]
Expand All @@ -131,49 +152,39 @@ def get_state(self):
"""
return convert_x_to_bbox(self.kf.x)

def associate_detections_to_trackers(detections,trackers,iou_threshold = 0.3):
def associate_detections_to_trackers(detections, trackers, cost_fn = iou, threshold = 0.33):
"""
Assigns detections to tracked object (both represented as bounding boxes)

Returns 3 lists of matches, unmatched_detections and unmatched_trackers
"""
if(len(trackers)==0):
return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0,5),dtype=int)
iou_matrix = np.zeros((len(detections),len(trackers)),dtype=np.float32)
lendet = len(detections)
lentrk = len(trackers)

if(lentrk==0):
return np.empty((0,2),dtype=int), np.arange(lendet), np.array([],dtype=int)
cost_matrix = np.zeros((lendet,lentrk),dtype=np.float32)

for d,det in enumerate(detections):
for t,trk in enumerate(trackers):
iou_matrix[d,t] = iou(det,trk)
matched_indices = linear_assignment(-iou_matrix)
cost_matrix[d,t] = cost_fn(det,trk)
cost_matrix[cost_matrix < threshold] = 0.
matched_indices = linear_assignment(-cost_matrix)

costs = cost_matrix[tuple(matched_indices.T)] # select values from cost matrix by matched indices
matches = matched_indices[np.where(costs)[0]] # remove zero values from matches
unmatched_detections = np.where(np.in1d(range(lendet), matches[:,0], invert=True))[0]
unmatched_trackers = np.where(np.in1d(range(lentrk), matches[:,1], invert=True))[0]

unmatched_detections = []
for d,det in enumerate(detections):
if(d not in matched_indices[:,0]):
unmatched_detections.append(d)
unmatched_trackers = []
for t,trk in enumerate(trackers):
if(t not in matched_indices[:,1]):
unmatched_trackers.append(t)

#filter out matched with low IOU
matches = []
for m in matched_indices:
if(iou_matrix[m[0],m[1]]<iou_threshold):
unmatched_detections.append(m[0])
unmatched_trackers.append(m[1])
else:
matches.append(m.reshape(1,2))
if(len(matches)==0):
matches = np.empty((0,2),dtype=int)
else:
matches = np.concatenate(matches,axis=0)

return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
return matches, unmatched_detections, unmatched_trackers



class Sort(object):
def __init__(self,max_age=1,min_hits=3):
def __init__(self,max_age=10,min_hits=0):
"""
Sets key parameters for SORT
"""
Expand All @@ -182,10 +193,11 @@ def __init__(self,max_age=1,min_hits=3):
self.trackers = []
self.frame_count = 0

def update(self,dets):
def update(self, dets, cnum = 3):
"""
Params:
dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...]
cnum - number of center positions to average
Requires: this method must be called once for each frame even with empty detections.
Returns the a similar array, where the last column is the object ID.

Expand All @@ -194,6 +206,7 @@ def update(self,dets):
self.frame_count += 1
#get predicted locations from existing trackers.
trks = np.zeros((len(self.trackers),5))
ctmean = []
to_del = []
ret = []
for t,trk in enumerate(trks):
Expand All @@ -206,11 +219,24 @@ def update(self,dets):
self.trackers.pop(t)
matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets,trks)

for t in unmatched_trks:
cnt = np.array(self.trackers[t].cthist)
cnt = np.array([np.convolve(cnt[:,i], np.ones((cnum,))/cnum, mode='valid') for i in (0,1)]).T
if cnt.shape[0] == 1: # fix same len
cnt = np.concatenate((cnt,cnt),axis=0)
ctmean.append(cnt)

rematch, new_dets, lost_trks = associate_detections_to_trackers(dets[unmatched_dets],ctmean,colinearity,0.6)
rematch = np.array([unmatched_dets[rematch[:,0]], unmatched_trks[rematch[:,1]]]).T
matched = np.concatenate((matched, rematch.reshape(-1,2)))
unmatched_dets = unmatched_dets[new_dets]
unmatched_trks = unmatched_trks[lost_trks]

#update matched trackers with assigned detections
for t,trk in enumerate(self.trackers):
if(t not in unmatched_trks):
d = matched[np.where(matched[:,1]==t)[0],0]
trk.update(dets[d,:][0])
trk.update(dets[d,:][0], cnum+1)

#create and initialise new trackers for unmatched detections
for i in unmatched_dets:
Expand All @@ -219,8 +245,8 @@ def update(self,dets):
i = len(self.trackers)
for trk in reversed(self.trackers):
d = trk.get_state()[0]
if((trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits)):
ret.append(np.concatenate((d,[trk.id+1])).reshape(1,-1)) # +1 as MOT benchmark requires positive
if((trk.time_since_update < self.max_age) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits)):
ret.append(np.concatenate((d,[trk.id+1],[trk.time_since_update])).reshape(1,-1)) # +1 as MOT benchmark requires positive
i -= 1
#remove dead tracklet
if(trk.time_since_update > self.max_age):
Expand Down