diff --git a/src/cactus/cactus_progressive_config.xml b/src/cactus/cactus_progressive_config.xml index f1e8d5694..0745ea6ab 100644 --- a/src/cactus/cactus_progressive_config.xml +++ b/src/cactus/cactus_progressive_config.xml @@ -515,12 +515,16 @@ + + diff --git a/src/cactus/progressive/outgroup.py b/src/cactus/progressive/outgroup.py index 32a199d83..c25d77937 100644 --- a/src/cactus/progressive/outgroup.py +++ b/src/cactus/progressive/outgroup.py @@ -32,6 +32,7 @@ def __init__(self): self.dag = None self.dm = None self.dmDirected = None + self.lcaMap = None self.root = None self.ogMap = None self.mcTree = None @@ -58,6 +59,8 @@ def importTree(self, mcTree, rootId = None): # undirected distance matrix self.dm = dict(NX.algorithms.shortest_paths.weighted.\ all_pairs_dijkstra_path_length(graph)) + # LCA map: (node1, node2) -> lowest common ancestor + self.lcaMap = dict(NX.tree_all_pairs_lowest_common_ancestor(self.dag, root=self.root)) # the results: map tree events to otugroups self.ogMap = defaultdict(list) @@ -101,6 +104,56 @@ def onSamePath(self, source, sink): return True return False + def getPath(self, source, sink, graph=None): + """Get the path (list of nodes) between source and sink in the undirected tree.""" + if graph is None: + graph = NX.Graph(self.dag) + try: + return NX.shortest_path(graph, source, sink) + except NX.NetworkXNoPath: + return [] + + def getPathEdges(self, source, sink, graph=None): + """Get the edges on the path between source and sink. + Returns list of (node1, node2) tuples representing edges.""" + path = self.getPath(source, sink, graph) + if len(path) < 2: + return [] + edges = [] + for i in range(len(path) - 1): + edges.append((path[i], path[i+1])) + return edges + + def computeAdjustedDistance(self, source, sink, edge_scales, graph=None): + """Compute distance between source and sink, accounting for scaled edges. + + edge_scales: dict mapping frozenset({node1, node2}) -> scale factor + Edges not in edge_scales have scale factor 1.0. + """ + if graph is None: + graph = NX.Graph(self.dag) + path = self.getPath(source, sink, graph) + if len(path) < 2: + return self.dm[source].get(sink, float('inf')) + + total_dist = 0.0 + for i in range(len(path) - 1): + n1, n2 = path[i], path[i+1] + # Get original edge weight + edge_data = graph.get_edge_data(n1, n2) + if edge_data and 'weight' in edge_data: + weight = edge_data['weight'] + else: + # Fallback: compute from distance matrix + weight = abs(self.dm[n1].get(n2, 0) if n2 in self.dm.get(n1, {}) else 0) + + # Apply scale factor if this edge has been scaled + edge_key = frozenset({n1, n2}) + scale = edge_scales.get(edge_key, 1.0) + total_dist += weight * scale + + return total_dist + # fill up a dictionary of node id -> height in tree where # leaves have height = 0 def heightTable(self): @@ -120,7 +173,7 @@ def rec(node): htable[node] = max([htable[i] for i in children]) + 1 rec(self.root) return htable - + # check the candidate using the set and and fraction def inCandidateSet(self, node, candidateChildFrac): if self.candidateMap is None or len(self.candidateMap) == 0: @@ -289,16 +342,42 @@ def refine_og_chroms(self, node_chroms, max_outgroups, extra_chrom_outgroups): # maxNumOutgroups : max number of outgroups to put in each entry of self.ogMap # extraChromOutgroups : number of extra outgroups that can be added to attemp # to satisfy chromosomes. + # topological : if True, sort outgroup candidates by (LCA_height, distance) + # instead of just distance. This prefers topologically closer outgroups + # (those sharing a more recent common ancestor) over those that happen + # to have short branch lengths. + # overlapPenalty : scale factor for branches already covered by an outgroup. + # After selecting an outgroup, all branches on the path from source to + # that outgroup are scaled by this factor for subsequent outgroup + # computations. This encourages subsequent outgroups to come from + # different parts of the tree. Default 0 means disabled. def greedy(self, threshold = None, candidateSet = None, candidateChildFrac = 2., maxNumOutgroups = 1, - extraChromOutgroups = -1): + extraChromOutgroups = -1, topological = False, + overlapPenalty = 0.0): + # compute height table early (needed for topological sorting and threshold) + htable = self.heightTable() + # sort the (undirected) distance map orderedPairs = [] for source, sinks in list(self.dm.items()): for sink, dist in list(sinks.items()): if source != self.root and sink != self.root: orderedPairs.append((dist, (source, sink))) - orderedPairs.sort(key = lambda x: x[0]) + + if topological: + # Sort by (LCA_height, distance) tuple to prefer topologically + # closer outgroups over those with short branch lengths. + # Lower LCA height means a more recent common ancestor. + def topo_key(x): + dist, (source, sink) = x + # lcaMap may have either (source, sink) or (sink, source) as key + lca = self.lcaMap.get((source, sink), self.lcaMap.get((sink, source))) + return (htable[lca], dist) + orderedPairs.sort(key=topo_key) + else: + orderedPairs.sort(key = lambda x: x[0]) + finished = set() self.candidateMap = dict() if candidateSet is not None: @@ -306,8 +385,6 @@ def greedy(self, threshold = None, candidateSet = None, for candidate in candidateSet: self.candidateMap[candidate] = True - htable = self.heightTable() - # convert the input (leaf) chroms to id-space node_chroms = defaultdict(set) if self.chrom_map: @@ -343,11 +420,60 @@ def greedy(self, threshold = None, candidateSet = None, source = candidate[1][0] ordered_pairs_by_source[source].append(candidate) + # track scaled edges for overlap penalty feature + # key: source node id, value: dict of frozenset({n1, n2}) -> cumulative scale factor + edge_scales_by_source = defaultdict(dict) + + # cache undirected graph for path computations (will be updated as dag changes) + undirected_graph = NX.Graph(self.dag) + + def update_edge_scales(source, sink, scale_factor): + """Scale all edges on the path from source to sink by scale_factor.""" + path_edges = self.getPathEdges(source, sink, undirected_graph) + for n1, n2 in path_edges: + edge_key = frozenset({n1, n2}) + current_scale = edge_scales_by_source[source].get(edge_key, 1.0) + edge_scales_by_source[source][edge_key] = current_scale * scale_factor + + def get_adjusted_dist(source, sink): + """Get distance from source to sink, adjusted for overlap penalty.""" + if source not in edge_scales_by_source or not edge_scales_by_source[source]: + return self.dm[source].get(sink, float('inf')) + return self.computeAdjustedDistance(source, sink, edge_scales_by_source[source], undirected_graph) + + def get_adjusted_topo_key(source, sink, original_dist): + """Get topological sort key using adjusted distance.""" + lca = self.lcaMap.get((source, sink), self.lcaMap.get((sink, source))) + adj_dist = get_adjusted_dist(source, sink) + return (htable[lca], adj_dist) + # visit the tree bottom up for node in self.mcTree.postOrderTraversal(): # visit the candidates in order of increasing distance orderedPairs = ordered_pairs_by_source[node] - for candidate in orderedPairs: + + # When overlap penalty is enabled, re-sort remaining candidates after each + # outgroup is added. Track whether we need to re-sort. + needs_resort = False + + def resort_by_adjusted_distance(pairs, source): + """Re-sort candidate pairs by adjusted distance for overlap penalty.""" + if topological: + return sorted(pairs, key=lambda x: get_adjusted_topo_key(source, x[1][1], x[0])) + else: + return sorted(pairs, key=lambda x: get_adjusted_dist(source, x[1][1])) + + candidate_idx = 0 + while candidate_idx < len(orderedPairs): + # Re-sort remaining candidates only when flagged (after adding an outgroup) + source = orderedPairs[candidate_idx][1][0] + if needs_resort and overlapPenalty > 0: + remaining = orderedPairs[candidate_idx:] + orderedPairs = orderedPairs[:candidate_idx] + resort_by_adjusted_distance(remaining, source) + needs_resort = False + + candidate = orderedPairs[candidate_idx] + candidate_idx += 1 # source is the ancestor we're trying to find outgroups for source = candidate[1][0] # sink it the candidate outgroup @@ -400,7 +526,12 @@ def greedy(self, threshold = None, candidateSet = None, existingOutgroupDist = dict(self.ogMap[sourceName]) assert existingOutgroupDist[sinkName] == dist continue + self.ogMap[sourceName].append((sinkName, dist)) + # Update edge scales for overlap penalty and flag for re-sort + if overlapPenalty > 0: + update_edge_scales(source, sink, overlapPenalty) + needs_resort = True source_satisfied[source] = self.check_chrom_satisfied(source, node_chroms) if len(self.ogMap[sourceName]) >= maxNumOutgroups and source_satisfied[source]: finished.add(source) @@ -408,11 +539,20 @@ def greedy(self, threshold = None, candidateSet = None, self.dag.remove_edge(source, sink) # Since we could be adding to the ogMap instead of creating - # it, sort the outgroups by distance again. Sorting the - # outgroups is critical for the multiple-outgroups code to - # work well. - for node, outgroups in list(self.ogMap.items()): - self.ogMap[node] = sorted(outgroups, key=lambda x: x[1]) + # it, sort the outgroups again. Sorting the outgroups is critical + # for the multiple-outgroups code to work well. + if topological: + # Sort by (LCA_height, distance) to maintain topological preference + def topo_sort_key(source_name, og_name, og_dist): + source_id = self.mcTree.getNodeId(source_name) + sink_id = self.mcTree.getNodeId(og_name) + lca = self.lcaMap.get((source_id, sink_id), self.lcaMap.get((sink_id, source_id))) + return (htable[lca], og_dist) + for node, outgroups in list(self.ogMap.items()): + self.ogMap[node] = sorted(outgroups, key=lambda x: topo_sort_key(node, x[0], x[1])) + else: + for node, outgroups in list(self.ogMap.items()): + self.ogMap[node] = sorted(outgroups, key=lambda x: x[1]) # the chromosome specification logic trumps the maximum number of outgroups # so we do a second pass to reconcile them (greedily) as best as possible @@ -434,9 +574,20 @@ def main(): help="Maximum number of outgroups to provide if necessitated by --chromInfo", type=int, default=-1) parser.add_option("--chromInfo", dest="chromInfo", help="File mapping genomes to sex chromosome lists") + parser.add_option("--topological", dest="topological", action="store_true", + default=False, help="Sort outgroup candidates by (LCA_height, distance) " + "instead of just distance. This prefers topologically closer outgroups " + "(sharing a more recent common ancestor) over those with short branch lengths.") + parser.add_option("--overlapPenalty", dest="overlapPenalty", type='float', + default=0.0, help="Scale factor for branches already covered by an outgroup. " + "After selecting an outgroup, all branches on the path from the ancestor to " + "that outgroup are scaled by this factor for subsequent outgroup selection. " + "This encourages subsequent outgroups to come from different parts of the tree. " + "Values > 1 penalize shared paths (e.g., 2.0 doubles the effective distance). " + "(default: 0, disabled)") parser.add_option("--configFile", dest="configFile", help="Specify cactus configuration file", - default=os.path.join(cactusRootPath(), "cactus_progressive_config.xml")) + default=os.path.join(cactusRootPath(), "cactus_progressive_config.xml")) options, args = parser.parse_args() options.binariesMode = 'local' @@ -466,7 +617,9 @@ def main(): outgroup.greedy(threshold=options.threshold, candidateSet=candidates, candidateChildFrac=1.1, maxNumOutgroups=options.maxNumOutgroups, - extraChromOutgroups=options.extraChromOutgroups) + extraChromOutgroups=options.extraChromOutgroups, + topological=options.topological, + overlapPenalty=options.overlapPenalty) try: NX.drawing.nx_agraph.write_dot(outgroup.dag, args[1]) diff --git a/src/cactus/progressive/progressive_decomposition.py b/src/cactus/progressive/progressive_decomposition.py index 19f085766..381fc7fea 100644 --- a/src/cactus/progressive/progressive_decomposition.py +++ b/src/cactus/progressive/progressive_decomposition.py @@ -77,20 +77,26 @@ def compute_outgroups(mc_tree, config_wrapper, outgroup_candidates = set(), root candidateSet=outgroup_candidates, candidateChildFrac=config_wrapper.getOutgroupAncestorQualityFraction(), maxNumOutgroups=config_wrapper.getMaxNumOutgroups(), - extraChromOutgroups=config_wrapper.getExtraChromOutgroups()) + extraChromOutgroups=config_wrapper.getExtraChromOutgroups(), + topological=config_wrapper.getOutgroupTopological(), + overlapPenalty=config_wrapper.getOutgroupOverlapPenalty()) if leaves_only: # use all leaves as outgroups, unless outgroup candidates are given outgroup.greedy(threshold=config_wrapper.getOutgroupThreshold(), candidateSet=set([mc_tree.getName(n) for n in mc_tree.getLeaves()]), candidateChildFrac=2.0, maxNumOutgroups=config_wrapper.getMaxNumOutgroups(), - extraChromOutgroups=config_wrapper.getExtraChromOutgroups()) + extraChromOutgroups=config_wrapper.getExtraChromOutgroups(), + topological=config_wrapper.getOutgroupTopological(), + overlapPenalty=config_wrapper.getOutgroupOverlapPenalty()) elif config_wrapper.getOutgroupStrategy() != 'none': outgroup.greedy(threshold=config_wrapper.getOutgroupThreshold(), candidateSet=None, candidateChildFrac=config_wrapper.getOutgroupAncestorQualityFraction(), maxNumOutgroups=config_wrapper.getMaxNumOutgroups(), - extraChromOutgroups=config_wrapper.getExtraChromOutgroups()) + extraChromOutgroups=config_wrapper.getExtraChromOutgroups(), + topological=config_wrapper.getOutgroupTopological(), + overlapPenalty=config_wrapper.getOutgroupOverlapPenalty()) if not include_dists: for k, v in outgroup.ogMap.items(): diff --git a/src/cactus/shared/configWrapper.py b/src/cactus/shared/configWrapper.py index 3207cca21..0af7c0d57 100644 --- a/src/cactus/shared/configWrapper.py +++ b/src/cactus/shared/configWrapper.py @@ -108,6 +108,24 @@ def getExtraChromOutgroups(self): extraChromOutgroups = int(ogElem.attrib["extra_chrom_outgroups"]) return extraChromOutgroups + def getOutgroupTopological(self): + ogElem = self.getOutgroupElem() + topological = False + if (ogElem is not None and\ + "strategy" in ogElem.attrib and\ + "topological" in ogElem.attrib): + topological = ogElem.attrib["topological"] == "1" + return topological + + def getOutgroupOverlapPenalty(self): + ogElem = self.getOutgroupElem() + overlapPenalty = 0.0 + if (ogElem is not None and\ + "strategy" in ogElem.attrib and\ + "overlap_penalty" in ogElem.attrib): + overlapPenalty = float(ogElem.attrib["overlap_penalty"]) + return overlapPenalty + def getDefaultInternalNodePrefix(self): decompElem = self.getDecompositionElem() prefix = self.defaultInternalNodePrefix