Skip to content
This repository was archived by the owner on May 21, 2025. It is now read-only.
This repository was archived by the owner on May 21, 2025. It is now read-only.

How to do cross-graph attention? #35

@llan-ml

Description

@llan-ml

Hi everyone,

I'm trying to re-implement the GMN model from deepmind in Jax. This model is designed to compute the similarity score between two graphs, and needs to compute the cross-graph node-wise attention weights between two graphs. Since different graphs can have different numbers of nodes, we need to do the cross-graph attention pair-by-pair.

The original implementation is written in Tensorflow, and given a node representation matrix, we can use tf.dynamically_partition to split it into a list of node representation matrices, each of which corresponds to a graph, as follows:

def batch_block_pair_attention(data,
                               block_idx,
                               n_blocks,
                               similarity='dotproduct'):
  """Compute batched attention between pairs of blocks.

  This function partitions the batch data into blocks according to block_idx.
  For each pair of blocks, x = data[block_idx == 2i], and
  y = data[block_idx == 2i+1], we compute

  x_i attend to y_j:
  a_{i->j} = exp(sim(x_i, y_j)) / sum_j exp(sim(x_i, y_j))
  y_j attend to x_i:
  a_{j->i} = exp(sim(x_i, y_j)) / sum_i exp(sim(x_i, y_j))

  and

  attention_x = sum_j a_{i->j} y_j
  attention_y = sum_i a_{j->i} x_i.

  Args:
    data: NxD float tensor.
    block_idx: N-dim int tensor.
    n_blocks: integer.
    similarity: a string, the similarity metric.

  Returns:
    attention_output: NxD float tensor, each x_i replaced by attention_x_i.

  Raises:
    ValueError: if n_blocks is not an integer or not a multiple of 2.
  """
  if not isinstance(n_blocks, int):
    raise ValueError('n_blocks (%s) has to be an integer.' % str(n_blocks))

  if n_blocks % 2 != 0:
    raise ValueError('n_blocks (%d) must be a multiple of 2.' % n_blocks)

  sim = get_pairwise_similarity(similarity)

  results = []

  # This is probably better than doing boolean_mask for each i
  partitions = tf.dynamic_partition(data, block_idx, n_blocks)

  # It is rather complicated to allow n_blocks be a tf tensor and do this in a
  # dynamic loop, and probably unnecessary to do so.  Therefore we are
  # restricting n_blocks to be a integer constant here and using the plain for
  # loop.
  for i in range(0, n_blocks, 2):
    x = partitions[i]
    y = partitions[i + 1]
    attention_x, attention_y = compute_cross_attention(x, y, sim)
    results.append(attention_x)
    results.append(attention_y)

  results = tf.concat(results, axis=0)
  # the shape of the first dimension is lost after concat, reset it back
  results.set_shape(data.shape)
  return results

However, we do not have some functions similar to tf.dynamically_partition in Jax.

@jg8610 Do you have any advice on how to do the cross-graph attention in Jax and Jraph, or do you have some similar cases internally? Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions