Skip to content

issue regarding sample_neg_protos() #8

@kongjianqiu0908

Description

@kongjianqiu0908

Dear authors,

I have a question regarding

 def sample_neg_protos(self, im2cluster, cluster2cluster, pos_proto_id, prot_logits, n, cluster_results):
        """
        Sampling negative prototypes given pos_proto_id and layer

        Args:
            im2cluster: [N_bs]
            pos_proto_id: [N_bs] actually im2cluster[index]
            proto_dist_mat: [N_bs, N_l] used for sampling strategy.
            prot_logits: [N_l, N_{l+1}] proto logits of cucrrent layer
        """
        all_proto_id = [i for i in range(im2cluster.max())] 
        neg_proto_id = set(all_proto_id)-set(pos_proto_id.tolist())
        neg_proto_id = torch.LongTensor(list(neg_proto_id)).to(pos_proto_id.device)
        upper_pos_proto_id = cluster2cluster[pos_proto_id] # [N_q]
        densities = cluster_results['density'][n+1] / cluster_results['density'][n+1].mean() * self.T
        sampling_prob = 1 - (prot_logits / densities).softmax(-1)[neg_proto_id, :][:, upper_pos_proto_id].t()
        neg_sampler = torch.distributions.bernoulli.Bernoulli(sampling_prob.clamp(0.0001, 0.999))
        selected_mask = neg_sampler.sample() #[N_q, N_neg]
        return selected_mask

Here I have a question about the upper_pos_proto_id, it seems that cluster2cluster stores n-1 to n 's clustering result. then the upper_pos_id should be next layer which is n+1. so seems that it should be result["cluster2cluster"][n+1][pos_id]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions