Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,7 @@ api/data/

# Test files
tests/.env
.pytest_cache
.pytest_cache

# Temp
temp/
79 changes: 46 additions & 33 deletions api/draw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def mol_to_image(
[a.SetAtomMapNum(0) for a in mol.GetAtoms()]
if show_atom_indices:
mol.UpdatePropertyCache(False)
if abbreviate and not highlight_atoms and not highlight_bonds and clear_map:
if abbreviate and not highlight_atoms and not highlight_bonds and clear_map and not show_atom_indices:
mol.UpdatePropertyCache(False)
mol = rdAbbreviations.CondenseMolAbbreviations(mol, ABBREVIATIONS)
if update:
Expand Down Expand Up @@ -494,7 +494,7 @@ def molecule_smiles_to_image(
"""
if not highlight and split:
mols = [Chem.MolFromSmarts(smi) if show_atom_indices else Chem.MolFromSmiles(smi) for smi in smiles.split(".")]
images = [mol_to_image(mol, svg=svg, transparent=transparent, abbreviate=abbreviate, reference=reference, **kwargs) for mol in mols]
images = [mol_to_image(mol, svg=svg, transparent=transparent, abbreviate=abbreviate, reference=reference, show_atom_indices=show_atom_indices, **kwargs) for mol in mols]
return combine_images_horizontally(images, transparent=transparent, return_png=return_png)

mol = Chem.MolFromSmarts(smiles) if show_atom_indices else Chem.MolFromSmiles(smiles)
Expand Down Expand Up @@ -539,60 +539,71 @@ def molecule_smiles_to_image(
print("Unable to determine atom highlights using specified reacting atoms. Drawing without highlights.")
tb.print_exc()

return mol_to_image(mol, svg=svg, transparent=transparent, return_png=return_png, abbreviate=abbreviate, reference=reference, **kwargs)
return mol_to_image(mol, svg=svg, transparent=transparent, return_png=return_png, abbreviate=abbreviate, reference=reference, show_atom_indices=show_atom_indices, **kwargs)


def determine_highlight_colors(mol, frag_map, frag_idx=None):
"""
Determine highlight colors for reactants and products based on atom map.

Adapted from RDKit MolDraw2D.DrawReaction
https://github.com/rdkit/rdkit/blob/Release_2020_09_5/Code/GraphMol/MolDraw2D/MolDraw2D.cpp#L547

If ``frag_idx`` is specified:
* The molecule is considered a reactant fragment
* All atoms in the molecule will be highlighted using the same color
* ``frag_map`` will be updated in place

Otherwise:
* The molecule is considered a product fragment
* Atoms will be colored based the contents of ``frag_map``
* ``frag_map`` will not be altered
More robust behavior:
- When frag_idx is provided (building phase): assign every mapped atom's mapno
to frag_idx in frag_map (overwriting if already present).
- When frag_idx is None (drawing phase): only highlight atoms whose mapno exists
in frag_map (skip unknown map numbers instead of raising KeyError).

Args:
mol (Chem.Mol): molecule object to analyze
frag_map (dict): mapping from atom map number to fragment index
frag_idx (int, optional): current fragment index
frag_map (dict[int, int]): mapping from atom map number -> fragment index
frag_idx (int, optional): current fragment index when building frag_map

Returns:
dict: containing highlight kwargs for ``mol_to_image``
dict: highlight kwargs for mol_to_image
"""
highlight_atoms = []
highlight_bonds = []
highlight_atom_colors = {}
highlight_bond_colors = {}

for atom in mol.GetAtoms():
mapno = atom.GetAtomMapNum()
if mapno:
idx = atom.GetIdx()
if frag_idx is not None:
frag_map[mapno] = frag_idx
highlight_atoms.append(idx)
highlight_atom_colors[idx] = HIGHLIGHT_COLORS[frag_map[mapno] % len(HIGHLIGHT_COLORS)]
for atom2 in atom.GetNeighbors():
if atom2.GetIdx() < idx and highlight_atom_colors.get(atom2.GetIdx()) == highlight_atom_colors[idx]:
bond_idx = mol.GetBondBetweenAtoms(idx, atom2.GetIdx()).GetIdx()
highlight_bonds.append(bond_idx)
highlight_bond_colors[bond_idx] = highlight_atom_colors[idx]
if not mapno:
continue # RDKit uses 0 for "no mapping"

# Build-phase: record map number -> fragment index
if frag_idx is not None:
frag_map[mapno] = frag_idx

# Draw-phase: if we don't know this map number, skip it (don't crash)
frag = frag_map.get(mapno)
if frag is None:
continue

idx = atom.GetIdx()
color = HIGHLIGHT_COLORS[frag % len(HIGHLIGHT_COLORS)]

highlight_atoms.append(idx)
highlight_atom_colors[idx] = color

# Highlight bonds between same-colored neighboring atoms
for atom2 in atom.GetNeighbors():
j = atom2.GetIdx()
if j < idx and highlight_atom_colors.get(j) == color:
bond = mol.GetBondBetweenAtoms(idx, j)
if bond is None:
continue
bond_idx = bond.GetIdx()
highlight_bonds.append(bond_idx)
highlight_bond_colors[bond_idx] = color

return {
"highlight_atoms": highlight_atoms,
"highlight_bonds": highlight_bonds,
"highlight_atom_colors": highlight_atom_colors,
"highlight_bond_colors": highlight_bond_colors,
}




def reaction_smiles_to_image(smiles, svg=True, transparent=True, return_png=True, retro=False, highlight=False, align=False, plus=True, update=True, show_atom_indices=False,**kwargs):
"""
Create image of the provided reaction SMILES string. Omits agents.
Expand Down Expand Up @@ -637,10 +648,12 @@ def reaction_smiles_to_image(smiles, svg=True, transparent=True, return_png=True
if plus and i > 0:
images.append(draw_plus(svg=svg, transparent=transparent))
smiles = Chem.MolToSmarts(mol) if show_atom_indices else Chem.MolToSmiles(mol)
if smiles.count(".") > 1:
# Only split when NOT highlighting; splitting breaks highlight atom indices
if (not highlight) and smiles.count(".") > 1:
images.append(molecule_smiles_to_image(smiles, svg=svg, transparent=transparent, **kwargs))
else:
images.append(mol_to_image(mol, svg=svg, transparent=transparent, update=update, show_atom_indices=show_atom_indices, **kwargs))
images.append(mol_to_image(mol, svg=svg, transparent=transparent, update=update,
show_atom_indices=show_atom_indices, **kwargs))

images.append(draw_arrow(retro=retro, svg=svg, transparent=transparent))

Expand Down
2 changes: 1 addition & 1 deletion api/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
dependencies:
- svgutils
- python=3.12
- rdkit=2024.3.5
- rdkit=2025.03.2
- networkx
- pip
- pip:
Expand Down
47 changes: 30 additions & 17 deletions api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ class Availability(BaseModel):
class InputFile(BaseModel):
synth_graph: Optional[SynthGraph] = None
evidence_synth_graph: Optional[SynthGraph] = None
predictive_synth_graph: Optional[SynthGraph] = None
predicted_synth_graph: Optional[SynthGraph] = None
predictive_synth_graph: Optional[SynthGraph] = None # Deprecated: Use predicted_synth_graph
routes: Optional[List[Route]] = None
availability: Optional[list[Availability]] = None

Expand Down Expand Up @@ -521,6 +522,7 @@ async def rxsmiles_to_svg_endpoint(rxsmiles: str = 'CCO.CC(=O)O>>CC(=O)OCC.O', h
<text x="10" y="50" font-size="32" fill="black">Unable to generate reaction SVG</text>
</svg>
""".strip()
logger.error(f"Error generating SVG for rxsmiles {rxsmiles}: {e}")

if base64_encode:
svg = base64.b64encode(svg.encode('utf-8')).decode('utf-8')
Expand Down Expand Up @@ -669,7 +671,7 @@ def flatten_dict(d, parent_key='', sep='_'):
return dict(items)


def convert_to_cytoscape_json(aicp_graph, synth_graph_key="synth_graph", convert_route=False, predicted_route=False, route_index=0):
def convert_to_cytoscape_json(aicp_graph, synth_graph_key="synth_graph", convert_route=False, is_predicted=False, route_index=0):
# Try to get the specified graph, with fallback logic similar to frontend
synth_graph = None

Expand All @@ -679,10 +681,12 @@ def convert_to_cytoscape_json(aicp_graph, synth_graph_key="synth_graph", convert
synth_graph = aicp_graph["synth_graph"]
elif "evidence_synth_graph" in aicp_graph and aicp_graph["evidence_synth_graph"] is not None:
synth_graph = aicp_graph["evidence_synth_graph"]
elif "predicted_synth_graph" in aicp_graph and aicp_graph["predicted_synth_graph"] is not None:
synth_graph = aicp_graph["predicted_synth_graph"]
elif "predictive_synth_graph" in aicp_graph and aicp_graph["predictive_synth_graph"] is not None:
synth_graph = aicp_graph["predictive_synth_graph"]
synth_graph = aicp_graph["predictive_synth_graph"] # Backward compatibility
else:
raise ValueError(f"No synthesis graph found. Looked for: {synth_graph_key}, synth_graph, evidence_synth_graph, predictive_synth_graph")
raise ValueError(f"No synthesis graph found. Looked for: {synth_graph_key}, synth_graph, evidence_synth_graph, predicted_synth_graph, predictive_synth_graph")

if convert_route:
routes = aicp_graph.get("routes", [])
Expand Down Expand Up @@ -724,7 +728,7 @@ def convert_to_cytoscape_json(aicp_graph, synth_graph_key="synth_graph", convert
]
aggregated_yield = "N/A"

if predicted_route:
if is_predicted:
for node in filtered_nodes:
node_type = node["data"].get("node_type", "")
if isinstance(node_type, str) and node_type.lower() == "substance":
Expand Down Expand Up @@ -758,11 +762,11 @@ def convert_to_cytoscape_json(aicp_graph, synth_graph_key="synth_graph", convert
# Generate name based on whether this is a route or full graph
if convert_route:
# Route name: Include reaction steps, index, and type
route_type = "Predicted" if predicted_route else "Evidence"
route_type = "Predicted" if is_predicted else "Evidence"
cytoscape_name = f"{target_inchikey}_SD_{reaction_steps} - Route {route_index} - {route_type}"
else:
# Full graph name: Simple format
graph_type = "Predicted Graph" if predicted_route else "Evidence Graph"
graph_type = "Predicted Graph" if is_predicted else "Evidence Graph"
cytoscape_name = f"{target_inchikey} - {graph_type}"

return {
Expand Down Expand Up @@ -820,7 +824,7 @@ def send_to_cytoscape(
layout_type: str = "hierarchical",
send_all_routes: bool = True,
synth_graph_key: str = "synth_graph",
predicted_route: bool = False,
is_predicted: bool = False,
convert_route: bool = False,
route_index: int = 0,
):
Expand All @@ -834,7 +838,7 @@ def send_to_cytoscape(
- layout_type (str, optional): Layout algorithm name (e.g., "hierarchical"). Defaults to "hierarchical".
- send_all_routes (bool, optional): If True, sends all routes as separate networks. If False, uses single route/graph mode. Defaults to True.
- synth_graph_key (str, optional): Key in the input JSON containing the synthesis graph. Defaults to "synth_graph". Auto-detects if not present.
- predicted_route (bool, optional): If True, relabels substance nodes (e.g., by InChIKey) and marks network as predicted. Only used when send_all_routes=False.
- is_predicted (bool, optional): If True, relabels substance nodes (e.g., by InChIKey) and marks network as predicted. Only used when send_all_routes=False.
- convert_route (bool, optional): If True, filters the graph to a single route. Only used when send_all_routes=False. Defaults to False.
- route_index (int, optional): Index into the 'routes' array to select a route. Only used when send_all_routes=False and convert_route=True. Defaults to 0.

Expand Down Expand Up @@ -863,7 +867,7 @@ def send_to_cytoscape(
network_json,
synth_graph_key,
convert_route,
predicted_route,
is_predicted,
route_index
)
return _send_single_network_to_cytoscape(converted_json, layout_type)
Expand All @@ -876,14 +880,16 @@ def send_to_cytoscape(
routes = network_json.get("routes", [])
results = []

# First, send all available full graphs (synth_graph, evidence_synth_graph, predictive_synth_graph)
# First, send all available full graphs (synth_graph, evidence_synth_graph, predicted_synth_graph)
graph_types = []
if "synth_graph" in network_json and network_json["synth_graph"] is not None:
graph_types.append(("synth_graph", False, "Evidence Synthesis Graph"))
if "evidence_synth_graph" in network_json and network_json["evidence_synth_graph"] is not None:
graph_types.append(("evidence_synth_graph", False, "Evidence Synthesis Graph"))
if "predictive_synth_graph" in network_json and network_json["predictive_synth_graph"] is not None:
graph_types.append(("predictive_synth_graph", True, "Predictive Synthesis Graph"))
if "predicted_synth_graph" in network_json and network_json["predicted_synth_graph"] is not None:
graph_types.append(("predicted_synth_graph", True, "Predicted Synthesis Graph"))
elif "predictive_synth_graph" in network_json and network_json["predictive_synth_graph"] is not None:
graph_types.append(("predictive_synth_graph", True, "Predicted Synthesis Graph")) # Backward compatibility

# Send each full graph
for graph_key, is_predicted, graph_name in graph_types:
Expand All @@ -894,7 +900,7 @@ def send_to_cytoscape(
network_json,
graph_key,
convert_route=False,
predicted_route=is_predicted,
is_predicted=is_predicted,
route_index=0
)

Expand Down Expand Up @@ -940,13 +946,20 @@ def send_to_cytoscape(
logger.info(f"Processing route {idx}: {route_name}")

# Use the correct graph key based on route type
route_graph_key = "predictive_synth_graph" if is_predicted else synth_graph_key
# Check predicted_synth_graph first, fallback to predictive_synth_graph for backward compatibility
if is_predicted:
if "predicted_synth_graph" in network_json and network_json["predicted_synth_graph"] is not None:
route_graph_key = "predicted_synth_graph"
else:
route_graph_key = "predictive_synth_graph" # Backward compatibility
else:
route_graph_key = synth_graph_key

converted_json = convert_to_cytoscape_json(
network_json,
route_graph_key,
convert_route=True,
predicted_route=is_predicted,
is_predicted=is_predicted,
route_index=idx
)

Expand Down Expand Up @@ -1168,7 +1181,7 @@ async def _convert_to_aicp(request: ConvertToAicpRequest) -> dict:

# Return converted graph
return {
"predictive_synth_graph": {
"predicted_synth_graph": {
"nodes": nodes,
"edges": edges
},
Expand Down
2 changes: 1 addition & 1 deletion data/hybrid_routes_example.json
Original file line number Diff line number Diff line change
Expand Up @@ -2707,7 +2707,7 @@
}
]
},
"predictive_synth_graph": {
"predicted_synth_graph": {
"nodes": [
{
"node_label": "45427382-fde3-4163-8315-468e47f1ebab",
Expand Down
Loading
Loading