Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/cactus/cactus_progressive_config.xml
Original file line number Diff line number Diff line change
Expand Up @@ -515,12 +515,16 @@
<!-- ancestor_quality_fraction: min fraction of children of ancestor in candidateSet in order for the ancestor to be an outgroup candidate -->
<!-- max_num_outgroups: maximum number of outgroups per alignment job-->
<!-- extra_chrom_outgroups: if max_num_outgroups is not sufficent to satisfy sex chromosome requirements of an ancestor while still using nearest species, allow up to this many extra (when -1 set to the number of unique sex chromosomes for the ancestor). -->
<!-- topological: if 1, sort outgroup candidates by (LCA_height, distance) instead of just distance. This prefers topologically closer outgroups over those with short branch lengths. -->
<!-- overlap_penalty: scale factor for branches already covered by an outgroup. After selecting an outgroup, all branches on the path from the ancestor to that outgroup are scaled by this factor for subsequent outgroup selection. Values > 1 penalize shared paths (e.g., 2.0 doubles the effective distance). 0 means disabled. -->
<outgroup
strategy="greedyLeavesPreference"
threshold="0"
ancestor_quality_fraction="0.75"
max_num_outgroups="2"
extra_chrom_outgroups="-1"
topological="1"
overlap_penalty="2"
/>

<!-- default_internal_node_prefix: internal nodes of the tree are labeled with this prefix then a BFS order number -->
Expand Down
179 changes: 166 additions & 13 deletions src/cactus/progressive/outgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self):
self.dag = None
self.dm = None
self.dmDirected = None
self.lcaMap = None
self.root = None
self.ogMap = None
self.mcTree = None
Expand All @@ -58,6 +59,8 @@ def importTree(self, mcTree, rootId = None):
# undirected distance matrix
self.dm = dict(NX.algorithms.shortest_paths.weighted.\
all_pairs_dijkstra_path_length(graph))
# LCA map: (node1, node2) -> lowest common ancestor
self.lcaMap = dict(NX.tree_all_pairs_lowest_common_ancestor(self.dag, root=self.root))
# the results: map tree events to otugroups
self.ogMap = defaultdict(list)

Expand Down Expand Up @@ -101,6 +104,56 @@ def onSamePath(self, source, sink):
return True
return False

def getPath(self, source, sink, graph=None):
"""Get the path (list of nodes) between source and sink in the undirected tree."""
if graph is None:
graph = NX.Graph(self.dag)
try:
return NX.shortest_path(graph, source, sink)
except NX.NetworkXNoPath:
return []

def getPathEdges(self, source, sink, graph=None):
"""Get the edges on the path between source and sink.
Returns list of (node1, node2) tuples representing edges."""
path = self.getPath(source, sink, graph)
if len(path) < 2:
return []
edges = []
for i in range(len(path) - 1):
edges.append((path[i], path[i+1]))
return edges

def computeAdjustedDistance(self, source, sink, edge_scales, graph=None):
"""Compute distance between source and sink, accounting for scaled edges.

edge_scales: dict mapping frozenset({node1, node2}) -> scale factor
Edges not in edge_scales have scale factor 1.0.
"""
if graph is None:
graph = NX.Graph(self.dag)
path = self.getPath(source, sink, graph)
if len(path) < 2:
return self.dm[source].get(sink, float('inf'))

total_dist = 0.0
for i in range(len(path) - 1):
n1, n2 = path[i], path[i+1]
# Get original edge weight
edge_data = graph.get_edge_data(n1, n2)
if edge_data and 'weight' in edge_data:
weight = edge_data['weight']
else:
# Fallback: compute from distance matrix
weight = abs(self.dm[n1].get(n2, 0) if n2 in self.dm.get(n1, {}) else 0)

# Apply scale factor if this edge has been scaled
edge_key = frozenset({n1, n2})
scale = edge_scales.get(edge_key, 1.0)
total_dist += weight * scale

return total_dist

# fill up a dictionary of node id -> height in tree where
# leaves have height = 0
def heightTable(self):
Expand All @@ -120,7 +173,7 @@ def rec(node):
htable[node] = max([htable[i] for i in children]) + 1
rec(self.root)
return htable

# check the candidate using the set and and fraction
def inCandidateSet(self, node, candidateChildFrac):
if self.candidateMap is None or len(self.candidateMap) == 0:
Expand Down Expand Up @@ -289,25 +342,49 @@ def refine_og_chroms(self, node_chroms, max_outgroups, extra_chrom_outgroups):
# maxNumOutgroups : max number of outgroups to put in each entry of self.ogMap
# extraChromOutgroups : number of extra outgroups that can be added to attemp
# to satisfy chromosomes.
# topological : if True, sort outgroup candidates by (LCA_height, distance)
# instead of just distance. This prefers topologically closer outgroups
# (those sharing a more recent common ancestor) over those that happen
# to have short branch lengths.
# overlapPenalty : scale factor for branches already covered by an outgroup.
# After selecting an outgroup, all branches on the path from source to
# that outgroup are scaled by this factor for subsequent outgroup
# computations. This encourages subsequent outgroups to come from
# different parts of the tree. Default 0 means disabled.
def greedy(self, threshold = None, candidateSet = None,
candidateChildFrac = 2., maxNumOutgroups = 1,
extraChromOutgroups = -1):
extraChromOutgroups = -1, topological = False,
overlapPenalty = 0.0):
# compute height table early (needed for topological sorting and threshold)
htable = self.heightTable()

# sort the (undirected) distance map
orderedPairs = []
for source, sinks in list(self.dm.items()):
for sink, dist in list(sinks.items()):
if source != self.root and sink != self.root:
orderedPairs.append((dist, (source, sink)))
orderedPairs.sort(key = lambda x: x[0])

if topological:
# Sort by (LCA_height, distance) tuple to prefer topologically
# closer outgroups over those with short branch lengths.
# Lower LCA height means a more recent common ancestor.
def topo_key(x):
dist, (source, sink) = x
# lcaMap may have either (source, sink) or (sink, source) as key
lca = self.lcaMap.get((source, sink), self.lcaMap.get((sink, source)))
return (htable[lca], dist)
orderedPairs.sort(key=topo_key)
else:
orderedPairs.sort(key = lambda x: x[0])

finished = set()
self.candidateMap = dict()
if candidateSet is not None:
assert isinstance(candidateSet, set)
for candidate in candidateSet:
self.candidateMap[candidate] = True

htable = self.heightTable()

# convert the input (leaf) chroms to id-space
node_chroms = defaultdict(set)
if self.chrom_map:
Expand Down Expand Up @@ -343,11 +420,60 @@ def greedy(self, threshold = None, candidateSet = None,
source = candidate[1][0]
ordered_pairs_by_source[source].append(candidate)

# track scaled edges for overlap penalty feature
# key: source node id, value: dict of frozenset({n1, n2}) -> cumulative scale factor
edge_scales_by_source = defaultdict(dict)

# cache undirected graph for path computations (will be updated as dag changes)
undirected_graph = NX.Graph(self.dag)

def update_edge_scales(source, sink, scale_factor):
"""Scale all edges on the path from source to sink by scale_factor."""
path_edges = self.getPathEdges(source, sink, undirected_graph)
for n1, n2 in path_edges:
edge_key = frozenset({n1, n2})
current_scale = edge_scales_by_source[source].get(edge_key, 1.0)
edge_scales_by_source[source][edge_key] = current_scale * scale_factor

def get_adjusted_dist(source, sink):
"""Get distance from source to sink, adjusted for overlap penalty."""
if source not in edge_scales_by_source or not edge_scales_by_source[source]:
return self.dm[source].get(sink, float('inf'))
return self.computeAdjustedDistance(source, sink, edge_scales_by_source[source], undirected_graph)

def get_adjusted_topo_key(source, sink, original_dist):
"""Get topological sort key using adjusted distance."""
lca = self.lcaMap.get((source, sink), self.lcaMap.get((sink, source)))
adj_dist = get_adjusted_dist(source, sink)
return (htable[lca], adj_dist)

# visit the tree bottom up
for node in self.mcTree.postOrderTraversal():
# visit the candidates in order of increasing distance
orderedPairs = ordered_pairs_by_source[node]
for candidate in orderedPairs:

# When overlap penalty is enabled, re-sort remaining candidates after each
# outgroup is added. Track whether we need to re-sort.
needs_resort = False

def resort_by_adjusted_distance(pairs, source):
"""Re-sort candidate pairs by adjusted distance for overlap penalty."""
if topological:
return sorted(pairs, key=lambda x: get_adjusted_topo_key(source, x[1][1], x[0]))
else:
return sorted(pairs, key=lambda x: get_adjusted_dist(source, x[1][1]))

candidate_idx = 0
while candidate_idx < len(orderedPairs):
# Re-sort remaining candidates only when flagged (after adding an outgroup)
source = orderedPairs[candidate_idx][1][0]
if needs_resort and overlapPenalty > 0:
remaining = orderedPairs[candidate_idx:]
orderedPairs = orderedPairs[:candidate_idx] + resort_by_adjusted_distance(remaining, source)
needs_resort = False

candidate = orderedPairs[candidate_idx]
candidate_idx += 1
# source is the ancestor we're trying to find outgroups for
source = candidate[1][0]
# sink it the candidate outgroup
Expand Down Expand Up @@ -400,19 +526,33 @@ def greedy(self, threshold = None, candidateSet = None,
existingOutgroupDist = dict(self.ogMap[sourceName])
assert existingOutgroupDist[sinkName] == dist
continue

self.ogMap[sourceName].append((sinkName, dist))
# Update edge scales for overlap penalty and flag for re-sort
if overlapPenalty > 0:
update_edge_scales(source, sink, overlapPenalty)
needs_resort = True
source_satisfied[source] = self.check_chrom_satisfied(source, node_chroms)
if len(self.ogMap[sourceName]) >= maxNumOutgroups and source_satisfied[source]:
finished.add(source)
else:
self.dag.remove_edge(source, sink)

# Since we could be adding to the ogMap instead of creating
# it, sort the outgroups by distance again. Sorting the
# outgroups is critical for the multiple-outgroups code to
# work well.
for node, outgroups in list(self.ogMap.items()):
self.ogMap[node] = sorted(outgroups, key=lambda x: x[1])
# it, sort the outgroups again. Sorting the outgroups is critical
# for the multiple-outgroups code to work well.
if topological:
# Sort by (LCA_height, distance) to maintain topological preference
def topo_sort_key(source_name, og_name, og_dist):
source_id = self.mcTree.getNodeId(source_name)
sink_id = self.mcTree.getNodeId(og_name)
lca = self.lcaMap.get((source_id, sink_id), self.lcaMap.get((sink_id, source_id)))
return (htable[lca], og_dist)
for node, outgroups in list(self.ogMap.items()):
self.ogMap[node] = sorted(outgroups, key=lambda x: topo_sort_key(node, x[0], x[1]))
else:
for node, outgroups in list(self.ogMap.items()):
self.ogMap[node] = sorted(outgroups, key=lambda x: x[1])

# the chromosome specification logic trumps the maximum number of outgroups
# so we do a second pass to reconcile them (greedily) as best as possible
Expand All @@ -434,9 +574,20 @@ def main():
help="Maximum number of outgroups to provide if necessitated by --chromInfo", type=int, default=-1)
parser.add_option("--chromInfo", dest="chromInfo",
help="File mapping genomes to sex chromosome lists")
parser.add_option("--topological", dest="topological", action="store_true",
default=False, help="Sort outgroup candidates by (LCA_height, distance) "
"instead of just distance. This prefers topologically closer outgroups "
"(sharing a more recent common ancestor) over those with short branch lengths.")
parser.add_option("--overlapPenalty", dest="overlapPenalty", type='float',
default=0.0, help="Scale factor for branches already covered by an outgroup. "
"After selecting an outgroup, all branches on the path from the ancestor to "
"that outgroup are scaled by this factor for subsequent outgroup selection. "
"This encourages subsequent outgroups to come from different parts of the tree. "
"Values > 1 penalize shared paths (e.g., 2.0 doubles the effective distance). "
"(default: 0, disabled)")
parser.add_option("--configFile", dest="configFile",
help="Specify cactus configuration file",
default=os.path.join(cactusRootPath(), "cactus_progressive_config.xml"))
default=os.path.join(cactusRootPath(), "cactus_progressive_config.xml"))
options, args = parser.parse_args()

options.binariesMode = 'local'
Expand Down Expand Up @@ -466,7 +617,9 @@ def main():
outgroup.greedy(threshold=options.threshold, candidateSet=candidates,
candidateChildFrac=1.1,
maxNumOutgroups=options.maxNumOutgroups,
extraChromOutgroups=options.extraChromOutgroups)
extraChromOutgroups=options.extraChromOutgroups,
topological=options.topological,
overlapPenalty=options.overlapPenalty)

try:
NX.drawing.nx_agraph.write_dot(outgroup.dag, args[1])
Expand Down
12 changes: 9 additions & 3 deletions src/cactus/progressive/progressive_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,26 @@ def compute_outgroups(mc_tree, config_wrapper, outgroup_candidates = set(), root
candidateSet=outgroup_candidates,
candidateChildFrac=config_wrapper.getOutgroupAncestorQualityFraction(),
maxNumOutgroups=config_wrapper.getMaxNumOutgroups(),
extraChromOutgroups=config_wrapper.getExtraChromOutgroups())
extraChromOutgroups=config_wrapper.getExtraChromOutgroups(),
topological=config_wrapper.getOutgroupTopological(),
overlapPenalty=config_wrapper.getOutgroupOverlapPenalty())
if leaves_only:
# use all leaves as outgroups, unless outgroup candidates are given
outgroup.greedy(threshold=config_wrapper.getOutgroupThreshold(),
candidateSet=set([mc_tree.getName(n) for n in mc_tree.getLeaves()]),
candidateChildFrac=2.0,
maxNumOutgroups=config_wrapper.getMaxNumOutgroups(),
extraChromOutgroups=config_wrapper.getExtraChromOutgroups())
extraChromOutgroups=config_wrapper.getExtraChromOutgroups(),
topological=config_wrapper.getOutgroupTopological(),
overlapPenalty=config_wrapper.getOutgroupOverlapPenalty())
elif config_wrapper.getOutgroupStrategy() != 'none':
outgroup.greedy(threshold=config_wrapper.getOutgroupThreshold(),
candidateSet=None,
candidateChildFrac=config_wrapper.getOutgroupAncestorQualityFraction(),
maxNumOutgroups=config_wrapper.getMaxNumOutgroups(),
extraChromOutgroups=config_wrapper.getExtraChromOutgroups())
extraChromOutgroups=config_wrapper.getExtraChromOutgroups(),
topological=config_wrapper.getOutgroupTopological(),
overlapPenalty=config_wrapper.getOutgroupOverlapPenalty())

if not include_dists:
for k, v in outgroup.ogMap.items():
Expand Down
18 changes: 18 additions & 0 deletions src/cactus/shared/configWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,24 @@ def getExtraChromOutgroups(self):
extraChromOutgroups = int(ogElem.attrib["extra_chrom_outgroups"])
return extraChromOutgroups

def getOutgroupTopological(self):
ogElem = self.getOutgroupElem()
topological = False
if (ogElem is not None and\
"strategy" in ogElem.attrib and\
"topological" in ogElem.attrib):
topological = ogElem.attrib["topological"] == "1"
return topological

def getOutgroupOverlapPenalty(self):
ogElem = self.getOutgroupElem()
overlapPenalty = 0.0
if (ogElem is not None and\
"strategy" in ogElem.attrib and\
"overlap_penalty" in ogElem.attrib):
overlapPenalty = float(ogElem.attrib["overlap_penalty"])
return overlapPenalty

def getDefaultInternalNodePrefix(self):
decompElem = self.getDecompositionElem()
prefix = self.defaultInternalNodePrefix
Expand Down