diff --git a/jraph/__init__.py b/jraph/__init__.py index 9aeea9d..66b00ce 100644 --- a/jraph/__init__.py +++ b/jraph/__init__.py @@ -70,7 +70,8 @@ from jraph._src.utils import unpad_with_graphs from jraph._src.utils import with_zero_out_padding_outputs from jraph._src.utils import zero_out_padding - +from jraph._src.utils import get_node_permuted_graph +from jraph._src.utils import get_edge_permuted_graph __version__ = "0.0.6.dev0" @@ -93,7 +94,8 @@ "partition_softmax", "concatenated_args", "get_fully_connected_graph", "dynamically_batch", "with_zero_out_padding_outputs", "zero_out_padding", - "sparse_matrix_to_graphs_tuple") + "sparse_matrix_to_graphs_tuple", + "get_node_permuted_graph", "get_edge_permuted_graph") # _________________________________________ # / Please don't use symbols in `_src` they \ diff --git a/jraph/_src/utils.py b/jraph/_src/utils.py index 90a7ed4..fcd6ea6 100644 --- a/jraph/_src/utils.py +++ b/jraph/_src/utils.py @@ -18,6 +18,7 @@ import jax from jax import lax +from jax import random import jax.numpy as jnp import jax.tree_util as tree from jraph._src import graph as gn_graph @@ -787,6 +788,139 @@ def _get_mask(padding_length, full_length): return jnp.arange(full_length, dtype=jnp.int32) < valid_length +def _get_valid_permutation(rng_key:jnp.array, + n_elements:jnp.array, + ): + """Helper function to create individual permutations of elements. + For example, this works with nodes (n_elements = graph.n_nodes) + and edges (n_elements = graph.n_edges). The result is a permutation, + but only elements in the subgraphs are permuted. This leaves batched + and padded graphs intact. + + TODO(02/20/24)InnocentBug, at the moment, I don't know how to make this jittable. + """ + node_keys = random.split(rng_key, int(jnp.sum(n_elements))) + permutation = [] + for i in range(len(n_elements)): + # Permutation of length and idx of the local element + local_permutation = random.permutation(node_keys[i], n_elements[i]) + # Adjust permutation index with the element index offset + adjusted_permutation = local_permutation + jnp.sum(n_elements[:i]) + # Stack the global permutation + permutation += [adjusted_permutation] + + return jnp.concatenate(permutation).astype(int) + + +def get_node_permuted_graph(graph: gn_graph.GraphsTuple, + rng_key: Optional[jnp.array] = None, + permutation: Optional[jnp.array] = None, # with integer dtype + return_permutation: Optional[bool] = False, + ) -> gn_graph.GraphsTuple: + """Permutes the order of nodes in the graph. + + Args: + graph: ``GraphsTuple`` graph to be permuted it can be batched and/or padded. + rng_key: random key to obtain permutations. If rng_key is specified a random + permutation is computed, this random permutation permutes nodes only + inside individual batched (padded) graphs, so they can still be unbatched, + or unpadded as usual. Either `rng_key`, or `permutation` has to be specified. + permutation: an array with permutation of nodes. This gives explicit control over the + permutation, however it also comes with the risk unpadding or unbatching + no longer works as expected. + A safe permutation array looks like this + [ permutation(0, n), permutation(n, n+m), ...], where n is the number of nodes + in the first graph, and m the number nodes in the second graph etc. + return_permutation: boolean to indicate if the applied permutation sequence should be returned. + + Returns: + A copy of the original graph, but with permuted nodes, senders, and receivers. + Raises: Runtime error if rng_key and permutation are specified. + """ + + # If the graph, doesn't have nodes specified, we return the original. + if graph.nodes is None: + return graph + + if rng_key is not None and permutation is not None: + raise RuntimeError("Either specify rng_key or permutation, not both.") + + if rng_key is not None: + permutation = _get_valid_permutation(rng_key, graph.n_node) + + # A bunch of checks, that make sure the permutation is actually valid. + assert int(jnp.sum(graph.n_node)) == int(len(graph.nodes)) + assert int(jnp.max(permutation)) + 1 == int(len(graph.nodes)) + assert len(jnp.unique(permutation)) == len(graph.nodes) + + inverse_permutation = jnp.argsort(permutation) + + # Perfrom the actual permutation of the nodes. + permuted_graph = gn_graph.GraphsTuple(nodes = graph.nodes[permutation], + edges = graph.edges, + receivers = inverse_permutation[graph.receivers.astype(int)], + senders = inverse_permutation[graph.senders.astype(int)], + globals = graph.globals, + n_node = graph.n_node, + n_edge = graph.n_edge,) + if return_permutation: + return permuted_graph, permutation + return permuted_graph + + +def get_edge_permuted_graph(graph: gn_graph.GraphsTuple, + rng_key: Optional[jnp.array] = None, + permutation: Optional[jnp.array] = None, # with integer dtype + return_permutation: bool = False, + ) -> gn_graph.GraphsTuple: + """Permutes the order of edges in the graph. + + Args: + graph: ``GraphsTuple`` graph to be permuted it can be batched and/or padded. + rng_key: random key to obtain permutations. If rng_key is specified a random + permutation is computed, this random permutation permutes edges only + inside individual batched (padded) graphs, so they can still be unbatched, + or unpadded as usual. Either `rng_key`, or `permutation` has to be specified. + permutation: an array with permutation for the edges. This gives explicit control over the + permutation, however it also comes with the risk unpadding or unbatching + no longer works as expected. + A safe permutation array looks like this + [ permutation(0, n), permutation(n, n+m), ...], where n is the number of edges + in the first graph, and m the number edges in the second graph etc. + return_permutation: boolean to indicate if the applied permutation sequence should be returned. + Returns: + A copy of the original graph, but with permuted edges, senders, and receivers. + Raises: Runtime error if rng_key and permutation are specified. + """ + + # If the graph, doesn't have edges specified, we return the original. + if graph.edges is None: + return graph + + if rng_key is not None and permutation is not None: + raise RuntimeError("Either specify rng_key or permutation, not both.") + + if rng_key is not None: + permutation = _get_valid_permutation(rng_key, graph.n_edge) + + # A bunch of checks, that make sure the permutation is actually valid. + assert int(jnp.sum(graph.n_edge)) == int(len(graph.edges)) # Since nodes are present, this should add up + assert int(jnp.max(permutation))+1 == int(len(graph.edges)) + assert int(len(jnp.unique(permutation))) == int(len(graph.edges)) + + # Perfrom the actual permutation of the nodes. + permuted_graph = gn_graph.GraphsTuple(nodes = graph.nodes, + edges = graph.edges[permutation], + receivers = graph.receivers[permutation], + senders = graph.senders[permutation], + globals = graph.globals, + n_node = graph.n_node, + n_edge = graph.n_edge,) + if return_permutation: + return permuted_graph, permutation + return permuted_graph + + def concatenated_args( update: Optional[Callable[..., ArrayTree]] = None, *, diff --git a/jraph/_src/utils_test.py b/jraph/_src/utils_test.py index 6016161..140da2b 100644 --- a/jraph/_src/utils_test.py +++ b/jraph/_src/utils_test.py @@ -20,6 +20,7 @@ from absl.testing import parameterized import jax from jax.lib import xla_bridge +from jax import random import jax.numpy as jnp import jax.tree_util as tree from jraph._src import graph @@ -30,10 +31,14 @@ def _get_random_graph(max_n_graph=10, include_node_features=True, include_edge_features=True, - include_globals=True): + include_globals=True, + min_nodes_per_graph=0, + max_nodes_per_graph=10, + min_edges_per_graph=0, + max_edges_per_graph=20): n_graph = np.random.randint(1, max_n_graph + 1) - n_node = np.random.randint(0, 10, n_graph) - n_edge = np.random.randint(0, 20, n_graph) + n_node = np.random.randint(min_nodes_per_graph, max_nodes_per_graph, n_graph) + n_edge = np.random.randint(min_edges_per_graph, max_edges_per_graph, n_graph) # We cannot have any edges if there are no nodes. n_edge[n_node == 0] = 0 @@ -897,6 +902,82 @@ def test_fully_connected_graph_order_edges(self, add_self_edges): np.testing.assert_array_equal(graph_batch.receivers, [0, 0, 1, 1, 2, 2]) + def test_permute_nodes(self): + # Create a ranomdly batched graph + graph_a = _get_random_graph(max_n_graph=1, + min_nodes_per_graph=25, + max_nodes_per_graph=50, + min_edges_per_graph=50, + max_edges_per_graph=75, + include_node_features=True, + include_edge_features=True) + graph_b = _get_random_graph(max_n_graph=1, + min_nodes_per_graph=25, + max_nodes_per_graph=75, + min_edges_per_graph=50, + max_edges_per_graph=125, + include_node_features=True, + include_edge_features=True) + + key = random.PRNGKey(0) + + batched_ab = utils.batch([graph_a, graph_b]) + + # Apply a node permutation + key, subkey = random.split(key) + batched_node_mutated_ab, node_mutation = utils.get_node_permuted_graph(batched_ab, rng_key=subkey, return_permutation=True) + # We can use argsort to invert a permutation + inverted_node_mutation = jnp.argsort(node_mutation) + + node_mutated_a, node_mutated_b = utils.unbatch(batched_node_mutated_ab) + + # After permutation the graphs shouldn't be equal. + np.testing.assert_raises(AssertionError, + lambda :jax.tree_util.tree_map(np.testing.assert_allclose, + graph_a, + node_mutated_a)) + np.testing.assert_raises(AssertionError, + lambda :jax.tree_util.tree_map(np.testing.assert_allclose, + graph_b, + node_mutated_b)) + # But if we take the receivers, and senders look up of node features, they are the same + np.testing.assert_allclose(graph_a.nodes[graph_a.receivers], node_mutated_a.nodes[node_mutated_a.receivers]) + np.testing.assert_allclose(graph_a.nodes[graph_a.senders], node_mutated_a.nodes[node_mutated_a.senders]) + np.testing.assert_allclose(graph_b.nodes[graph_b.receivers], node_mutated_b.nodes[node_mutated_b.receivers]) + np.testing.assert_allclose(graph_b.nodes[graph_b.senders], node_mutated_b.nodes[node_mutated_b.senders]) + + + # Apply an edge permutation + key, subkey = random.split(key) + batched_edge_and_node_mutated_ab, edge_mutation = utils.get_edge_permuted_graph(batched_node_mutated_ab, rng_key=subkey, return_permutation=True) + inverted_edge_mutation = jnp.argsort(edge_mutation) + + edge_mutated_a, edge_mutated_b = utils.unbatch(batched_edge_and_node_mutated_ab) + + # After permutation the graphs shouldn't be equal. + # Here we test the the edge mutated once against the node mutated once + np.testing.assert_raises(AssertionError, + lambda :jax.tree_util.tree_map(np.testing.assert_allclose, + edge_mutated_a, + node_mutated_a)) + np.testing.assert_raises(AssertionError, + lambda :jax.tree_util.tree_map(np.testing.assert_allclose, + edge_mutated_b, + node_mutated_b)) + + + # Now we invert both permutations, (the order doesn't matter) and recover the original graphs + + invert_node_graph = utils.get_node_permuted_graph(batched_edge_and_node_mutated_ab, permutation=inverted_node_mutation) + invert_edge_graph = utils.get_edge_permuted_graph(invert_node_graph, permutation=inverted_edge_mutation) + + recover_a, recover_b = utils.unbatch(invert_edge_graph) + jax.tree_util.tree_map(np.testing.assert_allclose, graph_a, recover_a) + jax.tree_util.tree_map(np.testing.assert_allclose, graph_b, recover_b) + + + + class ConcatenatedArgsWrapperTest(parameterized.TestCase): @parameterized.parameters(