Skip to content
This repository was archived by the owner on May 21, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions jraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@
from jraph._src.utils import unpad_with_graphs
from jraph._src.utils import with_zero_out_padding_outputs
from jraph._src.utils import zero_out_padding

from jraph._src.utils import get_node_permuted_graph
from jraph._src.utils import get_edge_permuted_graph

__version__ = "0.0.6.dev0"

Expand All @@ -93,7 +94,8 @@
"partition_softmax", "concatenated_args",
"get_fully_connected_graph", "dynamically_batch",
"with_zero_out_padding_outputs", "zero_out_padding",
"sparse_matrix_to_graphs_tuple")
"sparse_matrix_to_graphs_tuple",
"get_node_permuted_graph", "get_edge_permuted_graph")

# _________________________________________
# / Please don't use symbols in `_src` they \
Expand Down
134 changes: 134 additions & 0 deletions jraph/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import jax
from jax import lax
from jax import random
import jax.numpy as jnp
import jax.tree_util as tree
from jraph._src import graph as gn_graph
Expand Down Expand Up @@ -787,6 +788,139 @@ def _get_mask(padding_length, full_length):
return jnp.arange(full_length, dtype=jnp.int32) < valid_length


def _get_valid_permutation(rng_key:jnp.array,
n_elements:jnp.array,
):
"""Helper function to create individual permutations of elements.
For example, this works with nodes (n_elements = graph.n_nodes)
and edges (n_elements = graph.n_edges). The result is a permutation,
but only elements in the subgraphs are permuted. This leaves batched
and padded graphs intact.

TODO(02/20/24)InnocentBug, at the moment, I don't know how to make this jittable.
"""
node_keys = random.split(rng_key, int(jnp.sum(n_elements)))
permutation = []
for i in range(len(n_elements)):
# Permutation of length and idx of the local element
local_permutation = random.permutation(node_keys[i], n_elements[i])
# Adjust permutation index with the element index offset
adjusted_permutation = local_permutation + jnp.sum(n_elements[:i])
# Stack the global permutation
permutation += [adjusted_permutation]

return jnp.concatenate(permutation).astype(int)


def get_node_permuted_graph(graph: gn_graph.GraphsTuple,
rng_key: Optional[jnp.array] = None,
permutation: Optional[jnp.array] = None, # with integer dtype
return_permutation: Optional[bool] = False,
) -> gn_graph.GraphsTuple:
"""Permutes the order of nodes in the graph.

Args:
graph: ``GraphsTuple`` graph to be permuted it can be batched and/or padded.
rng_key: random key to obtain permutations. If rng_key is specified a random
permutation is computed, this random permutation permutes nodes only
inside individual batched (padded) graphs, so they can still be unbatched,
or unpadded as usual. Either `rng_key`, or `permutation` has to be specified.
permutation: an array with permutation of nodes. This gives explicit control over the
permutation, however it also comes with the risk unpadding or unbatching
no longer works as expected.
A safe permutation array looks like this
[ permutation(0, n), permutation(n, n+m), ...], where n is the number of nodes
in the first graph, and m the number nodes in the second graph etc.
return_permutation: boolean to indicate if the applied permutation sequence should be returned.

Returns:
A copy of the original graph, but with permuted nodes, senders, and receivers.
Raises: Runtime error if rng_key and permutation are specified.
"""

# If the graph, doesn't have nodes specified, we return the original.
if graph.nodes is None:
return graph

if rng_key is not None and permutation is not None:
raise RuntimeError("Either specify rng_key or permutation, not both.")

if rng_key is not None:
permutation = _get_valid_permutation(rng_key, graph.n_node)

# A bunch of checks, that make sure the permutation is actually valid.
assert int(jnp.sum(graph.n_node)) == int(len(graph.nodes))
assert int(jnp.max(permutation)) + 1 == int(len(graph.nodes))
assert len(jnp.unique(permutation)) == len(graph.nodes)

inverse_permutation = jnp.argsort(permutation)

# Perfrom the actual permutation of the nodes.
permuted_graph = gn_graph.GraphsTuple(nodes = graph.nodes[permutation],
edges = graph.edges,
receivers = inverse_permutation[graph.receivers.astype(int)],
senders = inverse_permutation[graph.senders.astype(int)],
globals = graph.globals,
n_node = graph.n_node,
n_edge = graph.n_edge,)
if return_permutation:
return permuted_graph, permutation
return permuted_graph


def get_edge_permuted_graph(graph: gn_graph.GraphsTuple,
rng_key: Optional[jnp.array] = None,
permutation: Optional[jnp.array] = None, # with integer dtype
return_permutation: bool = False,
) -> gn_graph.GraphsTuple:
"""Permutes the order of edges in the graph.

Args:
graph: ``GraphsTuple`` graph to be permuted it can be batched and/or padded.
rng_key: random key to obtain permutations. If rng_key is specified a random
permutation is computed, this random permutation permutes edges only
inside individual batched (padded) graphs, so they can still be unbatched,
or unpadded as usual. Either `rng_key`, or `permutation` has to be specified.
permutation: an array with permutation for the edges. This gives explicit control over the
permutation, however it also comes with the risk unpadding or unbatching
no longer works as expected.
A safe permutation array looks like this
[ permutation(0, n), permutation(n, n+m), ...], where n is the number of edges
in the first graph, and m the number edges in the second graph etc.
return_permutation: boolean to indicate if the applied permutation sequence should be returned.
Returns:
A copy of the original graph, but with permuted edges, senders, and receivers.
Raises: Runtime error if rng_key and permutation are specified.
"""

# If the graph, doesn't have edges specified, we return the original.
if graph.edges is None:
return graph

if rng_key is not None and permutation is not None:
raise RuntimeError("Either specify rng_key or permutation, not both.")

if rng_key is not None:
permutation = _get_valid_permutation(rng_key, graph.n_edge)

# A bunch of checks, that make sure the permutation is actually valid.
assert int(jnp.sum(graph.n_edge)) == int(len(graph.edges)) # Since nodes are present, this should add up
assert int(jnp.max(permutation))+1 == int(len(graph.edges))
assert int(len(jnp.unique(permutation))) == int(len(graph.edges))

# Perfrom the actual permutation of the nodes.
permuted_graph = gn_graph.GraphsTuple(nodes = graph.nodes,
edges = graph.edges[permutation],
receivers = graph.receivers[permutation],
senders = graph.senders[permutation],
globals = graph.globals,
n_node = graph.n_node,
n_edge = graph.n_edge,)
if return_permutation:
return permuted_graph, permutation
return permuted_graph


def concatenated_args(
update: Optional[Callable[..., ArrayTree]] = None,
*,
Expand Down
87 changes: 84 additions & 3 deletions jraph/_src/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from absl.testing import parameterized
import jax
from jax.lib import xla_bridge
from jax import random
import jax.numpy as jnp
import jax.tree_util as tree
from jraph._src import graph
Expand All @@ -30,10 +31,14 @@
def _get_random_graph(max_n_graph=10,
include_node_features=True,
include_edge_features=True,
include_globals=True):
include_globals=True,
min_nodes_per_graph=0,
max_nodes_per_graph=10,
min_edges_per_graph=0,
max_edges_per_graph=20):
n_graph = np.random.randint(1, max_n_graph + 1)
n_node = np.random.randint(0, 10, n_graph)
n_edge = np.random.randint(0, 20, n_graph)
n_node = np.random.randint(min_nodes_per_graph, max_nodes_per_graph, n_graph)
n_edge = np.random.randint(min_edges_per_graph, max_edges_per_graph, n_graph)
# We cannot have any edges if there are no nodes.
n_edge[n_node == 0] = 0

Expand Down Expand Up @@ -897,6 +902,82 @@ def test_fully_connected_graph_order_edges(self, add_self_edges):
np.testing.assert_array_equal(graph_batch.receivers, [0, 0, 1, 1, 2, 2])


def test_permute_nodes(self):
# Create a ranomdly batched graph
graph_a = _get_random_graph(max_n_graph=1,
min_nodes_per_graph=25,
max_nodes_per_graph=50,
min_edges_per_graph=50,
max_edges_per_graph=75,
include_node_features=True,
include_edge_features=True)
graph_b = _get_random_graph(max_n_graph=1,
min_nodes_per_graph=25,
max_nodes_per_graph=75,
min_edges_per_graph=50,
max_edges_per_graph=125,
include_node_features=True,
include_edge_features=True)

key = random.PRNGKey(0)

batched_ab = utils.batch([graph_a, graph_b])

# Apply a node permutation
key, subkey = random.split(key)
batched_node_mutated_ab, node_mutation = utils.get_node_permuted_graph(batched_ab, rng_key=subkey, return_permutation=True)
# We can use argsort to invert a permutation
inverted_node_mutation = jnp.argsort(node_mutation)

node_mutated_a, node_mutated_b = utils.unbatch(batched_node_mutated_ab)

# After permutation the graphs shouldn't be equal.
np.testing.assert_raises(AssertionError,
lambda :jax.tree_util.tree_map(np.testing.assert_allclose,
graph_a,
node_mutated_a))
np.testing.assert_raises(AssertionError,
lambda :jax.tree_util.tree_map(np.testing.assert_allclose,
graph_b,
node_mutated_b))
# But if we take the receivers, and senders look up of node features, they are the same
np.testing.assert_allclose(graph_a.nodes[graph_a.receivers], node_mutated_a.nodes[node_mutated_a.receivers])
np.testing.assert_allclose(graph_a.nodes[graph_a.senders], node_mutated_a.nodes[node_mutated_a.senders])
np.testing.assert_allclose(graph_b.nodes[graph_b.receivers], node_mutated_b.nodes[node_mutated_b.receivers])
np.testing.assert_allclose(graph_b.nodes[graph_b.senders], node_mutated_b.nodes[node_mutated_b.senders])


# Apply an edge permutation
key, subkey = random.split(key)
batched_edge_and_node_mutated_ab, edge_mutation = utils.get_edge_permuted_graph(batched_node_mutated_ab, rng_key=subkey, return_permutation=True)
inverted_edge_mutation = jnp.argsort(edge_mutation)

edge_mutated_a, edge_mutated_b = utils.unbatch(batched_edge_and_node_mutated_ab)

# After permutation the graphs shouldn't be equal.
# Here we test the the edge mutated once against the node mutated once
np.testing.assert_raises(AssertionError,
lambda :jax.tree_util.tree_map(np.testing.assert_allclose,
edge_mutated_a,
node_mutated_a))
np.testing.assert_raises(AssertionError,
lambda :jax.tree_util.tree_map(np.testing.assert_allclose,
edge_mutated_b,
node_mutated_b))


# Now we invert both permutations, (the order doesn't matter) and recover the original graphs

invert_node_graph = utils.get_node_permuted_graph(batched_edge_and_node_mutated_ab, permutation=inverted_node_mutation)
invert_edge_graph = utils.get_edge_permuted_graph(invert_node_graph, permutation=inverted_edge_mutation)

recover_a, recover_b = utils.unbatch(invert_edge_graph)
jax.tree_util.tree_map(np.testing.assert_allclose, graph_a, recover_a)
jax.tree_util.tree_map(np.testing.assert_allclose, graph_b, recover_b)




class ConcatenatedArgsWrapperTest(parameterized.TestCase):

@parameterized.parameters(
Expand Down