diff --git a/src/deepxtrace/diagnose.py b/src/deepxtrace/diagnose.py index e75051b..a9b8c55 100644 --- a/src/deepxtrace/diagnose.py +++ b/src/deepxtrace/diagnose.py @@ -117,6 +117,7 @@ class Diagnose: DEEPEP_DIAGNOSE_THRESHOLD_COL: determine threshold for abnormal columns. Default 3.0. DEEPEP_DIAGNOSE_THRESHOLD_ROW: determine threshold for abnormal rows. Default 3.0. DEEPEP_DIAGNOSE_THRESHOLD_POINT: determine threshold for abnormal individual points. Default 5.0. + DEEPEP_DIAGNOSE_EXCLUDING_ZEROS: controls whether excluding zeros in diagnose_matrix. Default 0. """ @@ -161,6 +162,8 @@ def __init__( os.getenv( "DEEPEP_DIAGNOSE_THRESHOLD_POINT", 5.0)) + self.excluding_zeros = int( + os.getenv("DEEPEP_DIAGNOSE_EXCLUDING_ZEROS", 0)) # Initialize the diagnose self.group = group @@ -306,7 +309,7 @@ def _setup_logger_internal( @staticmethod def diagnose_matrix( mat, thres_col=3.0, thres_row=3.0, thres_point=5.0, - suppress_points_in_strong_rowscols=True + suppress_points_in_strong_rowscols=True, excluding_zeros=0 ): """ Detect abnormal columns, rows, and individual points in a 2D wait-time matrix. @@ -344,8 +347,17 @@ def diagnose_matrix( ] # 3. Check for abnormal single points - # z_all = (mat - mat.mean()) / (mat.std() + 1e-8) - z_all = mat / (mat.mean() + 1e-8) + if excluding_zeros == 0: + # z_all = (mat - mat.mean()) / (mat.std() + 1e-8) + z_all = mat / (mat.mean() + 1e-8) + elif excluding_zeros == 1: + nonzero_values = mat[mat != 0] + if len(nonzero_values) > 0: + mean_val = nonzero_values.mean() + z_all = mat / (mean_val + 1e-8) + else: + z_all = mat + # Get all positions with z-score > threshold abnormal_points = [ [i, j, mat[i, j], z_all[i, j]] @@ -362,6 +374,7 @@ def diagnose_matrix( [i, j, v, z] for [i, j, v, z] in abnormal_points if i not in strong_rows and j not in strong_cols ] + # 4. Return for automatic processing return { "abnormal_cols": abnormal_cols, @@ -436,7 +449,7 @@ def _gather_diagnose_stats_internal( stats_arr = torch.stack(self.gather_tensor, dim=0).numpy() for i, name in enumerate(["Dispatch", "Combine"]): res = Diagnose.diagnose_matrix( - stats_arr[:, i, :], thres_col=self.thres_col, thres_row=self.thres_row, thres_point=self.thres_point) + stats_arr[:, i, :], thres_col=self.thres_col, thres_row=self.thres_row, thres_point=self.thres_point, excluding_zeros=self.excluding_zeros) results.append(res) self.logger.info( f"[Diagnose] InstanceID: {self.instance_id} EPSize: {self.group_size}, diagnose: {res}, {name} Wait Recv Cost Per Token Matrix[src_rank, dst_rank]") diff --git a/tests/test_diagnose.py b/tests/test_diagnose.py index c8c4ea9..4cf2ff7 100644 --- a/tests/test_diagnose.py +++ b/tests/test_diagnose.py @@ -70,6 +70,41 @@ def setUp(self): [15, 17, 12, 18, 13, 13, 15, 14], ]) + self.mc2_layered = np.array([ + [169, 537, 530, 294, 173, 128, 139, 140, + 40, 0, 0, 0, 0, 0, 0, 0], + [1617, 196, 207, 170, 187, 151, 887, 174, + 0, 34, 0, 0, 0, 0, 0, 0], + [1626, 210, 194, 186, 174, 162, 864, 160, + 0, 0, 31, 0, 0, 0, 0, 0], + [1635, 324, 341, 186, 178, 153, 866, 169, + 0, 0, 0, 34, 0, 0, 0, 0], + [1635, 543, 534, 302, 176, 125, 847, 140, + 0, 0, 0, 0, 33, 0, 0, 0], + [1712, 681, 671, 401, 232, 102, 877, 132, + 0, 0, 0, 0, 0, 37, 0, 0], + [997, 656, 643, 382, 235, 172, 107, 146, 0, + 0, 0, 0, 0, 0, 42, 0], + [1918, 941, 931, 652, 448, 314, 1064, 199, + 0, 0, 0, 0, 0, 0, 0, 42], + [1480, 0, 0, 0, 0, 0, 0, 0, 167, 239, 343, + 154, 148, 150, 155, 143], + [0, 46, 0, 0, 0, 0, 0, 0, 1599, 169, 237, + 156, 149, 146, 860, 140], + [0, 0, 48, 0, 0, 0, 0, 0, 1610, 161, 168, + 159, 150, 161, 846, 145], + [0, 0, 0, 41, 0, 0, 0, 0, 1687, 320, 452, + 82, 139, 166, 875, 136], + [0, 0, 0, 0, 42, 0, 0, 0, 1802, 481, 616, + 242, 168, 214, 918, 166], + [0, 0, 0, 0, 0, 35, 0, 0, 1746, 417, 559, + 226, 171, 185, 903, 151], + [0, 0, 0, 0, 0, 0, 738, 0, 1011, 393, 529, + 171, 150, 162, 176, 154], + [0, 0, 0, 0, 0, 0, 0, 36, 1866, 555, 693, + 325, 211, 222, 965, 180] + ]) + def test_diagnose_row(self): res = ds.Diagnose.diagnose_matrix(self.abnormal_row) self.assertEqual( @@ -105,6 +140,31 @@ def test_diagnose_point(self): 'abnormal_cols': [], 'abnormal_rows': [], 'abnormal_points': [ [3, 4, 125, 7.279344854723584]]}) + def test_mc2_layered(self): + res = ds.Diagnose.diagnose_matrix( + mat=self.mc2_layered, excluding_zeros=0) + self.assertEqual( + res, { + 'abnormal_cols': [ + [ + 0, 799.3125, 3.2102414457222475]], 'abnormal_rows': [], 'abnormal_points': [ + [ + 9, 8, 1599, 6.421988986422549], [ + 10, 8, 1610, 6.466167772445468], [ + 11, 8, 1687, 6.775419274605904], [ + 12, 8, 1802, 7.237288401209152], [ + 13, 8, 1746, 7.012378217819744], [ + 15, 8, 1866, 7.494328610797046]]}) + + res = ds.Diagnose.diagnose_matrix( + mat=self.mc2_layered, excluding_zeros=1) + self.assertEqual(res, + {'abnormal_cols': [[0, + 799.3125, + 3.2102414457222475]], + 'abnormal_rows': [], + 'abnormal_points': []}) + if __name__ == '__main__': unittest.main()