-
Notifications
You must be signed in to change notification settings - Fork 21
Open
Description
Dear authors,
I have a question regarding
def sample_neg_protos(self, im2cluster, cluster2cluster, pos_proto_id, prot_logits, n, cluster_results):
"""
Sampling negative prototypes given pos_proto_id and layer
Args:
im2cluster: [N_bs]
pos_proto_id: [N_bs] actually im2cluster[index]
proto_dist_mat: [N_bs, N_l] used for sampling strategy.
prot_logits: [N_l, N_{l+1}] proto logits of cucrrent layer
"""
all_proto_id = [i for i in range(im2cluster.max())]
neg_proto_id = set(all_proto_id)-set(pos_proto_id.tolist())
neg_proto_id = torch.LongTensor(list(neg_proto_id)).to(pos_proto_id.device)
upper_pos_proto_id = cluster2cluster[pos_proto_id] # [N_q]
densities = cluster_results['density'][n+1] / cluster_results['density'][n+1].mean() * self.T
sampling_prob = 1 - (prot_logits / densities).softmax(-1)[neg_proto_id, :][:, upper_pos_proto_id].t()
neg_sampler = torch.distributions.bernoulli.Bernoulli(sampling_prob.clamp(0.0001, 0.999))
selected_mask = neg_sampler.sample() #[N_q, N_neg]
return selected_mask
Here I have a question about the upper_pos_proto_id, it seems that cluster2cluster stores n-1 to n 's clustering result. then the upper_pos_id should be next layer which is n+1. so seems that it should be result["cluster2cluster"][n+1][pos_id]
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels