From 3f64992ac3ec183741b729bf481f06178a1b4ef7 Mon Sep 17 00:00:00 2001 From: Pegerto Fernandez Date: Tue, 11 Mar 2025 14:23:52 +0000 Subject: [PATCH] eigen_bt_calculation --- src/graphpro/annotations.py | 25 +++++++++++++++++++-- src/graphpro/util/energy.py | 27 +++++++++++++++++++++++ test/graphpro/annnotations/energy_test.py | 9 ++++++-- 3 files changed, 57 insertions(+), 4 deletions(-) diff --git a/src/graphpro/annotations.py b/src/graphpro/annotations.py index 65b4091..a3c31df 100644 --- a/src/graphpro/annotations.py +++ b/src/graphpro/annotations.py @@ -12,7 +12,7 @@ from graphpro.util.dssp import compute_dssp, DSSP_CLASS from graphpro.util.polarity import POLARITY_CLASSES, residue_polarity from graphpro.util.conservation import ConservationScoreClient -from graphpro.util.energy import compute_bt_potential +from graphpro.util.energy import compute_bt_potential, compute_eigen_centrality class NodeTargetBinaryAttribute(NodeTarget): """ Binary target, creates a binary one_hot encoding of the property @@ -232,4 +232,25 @@ def generate(self, G, atom_group): def encode(self, G: Graph) -> torch.tensor: scores = [G.node_attr(n)[self.attr_name] if self.attr_name in G.node_attr(n) else 0 for n in G.nodes()] - return F.normalize(torch.tensor([scores], dtype=torch.float).T, dim=(0,1)) \ No newline at end of file + return F.normalize(torch.tensor([scores], dtype=torch.float).T, dim=(0,1)) + +class BTEigenCentrality(NodeAnnotation): + """ Computes the residue energy contribution based on BT potential to the graph + centrality. + """ + def __init__(self, attr_name: str = 'bt_eigen_centrality', chain: str = None): + """ Attribute name + """ + self.attr_name = attr_name + self.chain = chain + + def generate(self, G, atom_group): + res_ids, eigen_potential = compute_eigen_centrality(atom_group, self.chain) + for i,resid in enumerate(res_ids): + node_id = G.get_node_by_resid(resid) + G.node_attr_add(node_id, self.attr_name, eigen_potential[i]) + + def encode(self, G: Graph) -> torch.tensor: + scores = [G.node_attr(n)[self.attr_name] if self.attr_name in G.node_attr(n) else 0 for n in G.nodes()] + return F.normalize(torch.tensor([scores], dtype=torch.float).T, dim=(0,1)) + \ No newline at end of file diff --git a/src/graphpro/util/energy.py b/src/graphpro/util/energy.py index ca45bf2..3eaa4ae 100644 --- a/src/graphpro/util/energy.py +++ b/src/graphpro/util/energy.py @@ -65,4 +65,31 @@ def compute_bt_potential(atom_group, chain, cutoff=6, epsilon=1): potential[j,i] = energy eigen_value, _ = LA.eig(potential) + return res_ids, eigen_value + + +def compute_eigen_centrality(atom_group, chain, cutoff=6, epsilon=1): + from scipy.spatial import distance + import networkx as nx + + ca_position = atom_group.c_alphas_positions(chain) + residues = atom_group.c_alphas_residues(chain) + dist = distance.squareform(distance.pdist(ca_position)) + potential = np.zeros((len(dist), len(dist))) + res_ids = [res['resid'] for res in residues] + + for i in range(len(dist)): + for j in range(i + 1, len(dist)): + resname_i = residues[i]['resname'] + resname_j = residues[j]['resname'] + + V_ij = bt_potential(resname_i, resname_j) + r_ij = dist[i,j] + + if r_ij < cutoff: + energy = np.exp(-V_ij) + potential[i,j] = energy + potential[j,i] = energy + G = nx.from_numpy_array(potential) + eigen_value = nx.eigenvector_centrality_numpy(G) return res_ids, eigen_value \ No newline at end of file diff --git a/test/graphpro/annnotations/energy_test.py b/test/graphpro/annnotations/energy_test.py index 7747c83..72f02a4 100644 --- a/test/graphpro/annnotations/energy_test.py +++ b/test/graphpro/annnotations/energy_test.py @@ -3,7 +3,7 @@ from graphpro import md_analisys from graphpro.graphgen import ContactMap -from graphpro.annotations import BTPotential +from graphpro.annotations import BTPotential, BTEigenCentrality from MDAnalysis.tests.datafiles import PDB_small @@ -19,4 +19,9 @@ def test_gnm_encoding(): data = G.to_data(node_encoders=[BTPotential()]) assert data.x.size() == (214, 1) - assert data.x.dtype == torch.float \ No newline at end of file + assert data.x.dtype == torch.float + +def test_bt_potential_calculation(): + G = md_analisys(u1).generate(ContactMap(cutoff=6), [BTEigenCentrality()]) + assert len(G.nodes()) == 214 + assert round(G.node_attr(90)['bt_eigen_centrality'],5) == 0.00022