Skip to content
Merged
21 changes: 17 additions & 4 deletions src/deepxtrace/diagnose.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class Diagnose:
DEEPEP_DIAGNOSE_THRESHOLD_COL: determine threshold for abnormal columns. Default 3.0.
DEEPEP_DIAGNOSE_THRESHOLD_ROW: determine threshold for abnormal rows. Default 3.0.
DEEPEP_DIAGNOSE_THRESHOLD_POINT: determine threshold for abnormal individual points. Default 5.0.
DEEPEP_DIAGNOSE_EXCLUDING_ZEROS: controls whether excluding zeros in diagnose_matrix. Default 0.

"""

Expand Down Expand Up @@ -161,6 +162,8 @@ def __init__(
os.getenv(
"DEEPEP_DIAGNOSE_THRESHOLD_POINT",
5.0))
self.excluding_zeros = int(
os.getenv("DEEPEP_DIAGNOSE_EXCLUDING_ZEROS", 0))

# Initialize the diagnose
self.group = group
Expand Down Expand Up @@ -306,7 +309,7 @@ def _setup_logger_internal(
@staticmethod
def diagnose_matrix(
mat, thres_col=3.0, thres_row=3.0, thres_point=5.0,
suppress_points_in_strong_rowscols=True
suppress_points_in_strong_rowscols=True, excluding_zeros=0
):
"""
Detect abnormal columns, rows, and individual points in a 2D wait-time matrix.
Expand Down Expand Up @@ -344,8 +347,17 @@ def diagnose_matrix(
]

# 3. Check for abnormal single points
# z_all = (mat - mat.mean()) / (mat.std() + 1e-8)
z_all = mat / (mat.mean() + 1e-8)
if excluding_zeros == 0:
# z_all = (mat - mat.mean()) / (mat.std() + 1e-8)
z_all = mat / (mat.mean() + 1e-8)
elif excluding_zeros == 1:
nonzero_values = mat[mat != 0]
if len(nonzero_values) > 0:
mean_val = nonzero_values.mean()
z_all = mat / (mean_val + 1e-8)
else:
z_all = mat

# Get all positions with z-score > threshold
abnormal_points = [
[i, j, mat[i, j], z_all[i, j]]
Expand All @@ -362,6 +374,7 @@ def diagnose_matrix(
[i, j, v, z] for [i, j, v, z] in abnormal_points
if i not in strong_rows and j not in strong_cols
]

# 4. Return for automatic processing
return {
"abnormal_cols": abnormal_cols,
Expand Down Expand Up @@ -436,7 +449,7 @@ def _gather_diagnose_stats_internal(
stats_arr = torch.stack(self.gather_tensor, dim=0).numpy()
for i, name in enumerate(["Dispatch", "Combine"]):
res = Diagnose.diagnose_matrix(
stats_arr[:, i, :], thres_col=self.thres_col, thres_row=self.thres_row, thres_point=self.thres_point)
stats_arr[:, i, :], thres_col=self.thres_col, thres_row=self.thres_row, thres_point=self.thres_point, excluding_zeros=self.excluding_zeros)
results.append(res)
self.logger.info(
f"[Diagnose] InstanceID: {self.instance_id} EPSize: {self.group_size}, diagnose: {res}, {name} Wait Recv Cost Per Token Matrix[src_rank, dst_rank]")
Expand Down
60 changes: 60 additions & 0 deletions tests/test_diagnose.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,41 @@ def setUp(self):
[15, 17, 12, 18, 13, 13, 15, 14],
])

self.mc2_layered = np.array([
[169, 537, 530, 294, 173, 128, 139, 140,
40, 0, 0, 0, 0, 0, 0, 0],
[1617, 196, 207, 170, 187, 151, 887, 174,
0, 34, 0, 0, 0, 0, 0, 0],
[1626, 210, 194, 186, 174, 162, 864, 160,
0, 0, 31, 0, 0, 0, 0, 0],
[1635, 324, 341, 186, 178, 153, 866, 169,
0, 0, 0, 34, 0, 0, 0, 0],
[1635, 543, 534, 302, 176, 125, 847, 140,
0, 0, 0, 0, 33, 0, 0, 0],
[1712, 681, 671, 401, 232, 102, 877, 132,
0, 0, 0, 0, 0, 37, 0, 0],
[997, 656, 643, 382, 235, 172, 107, 146, 0,
0, 0, 0, 0, 0, 42, 0],
[1918, 941, 931, 652, 448, 314, 1064, 199,
0, 0, 0, 0, 0, 0, 0, 42],
[1480, 0, 0, 0, 0, 0, 0, 0, 167, 239, 343,
154, 148, 150, 155, 143],
[0, 46, 0, 0, 0, 0, 0, 0, 1599, 169, 237,
156, 149, 146, 860, 140],
[0, 0, 48, 0, 0, 0, 0, 0, 1610, 161, 168,
159, 150, 161, 846, 145],
[0, 0, 0, 41, 0, 0, 0, 0, 1687, 320, 452,
82, 139, 166, 875, 136],
[0, 0, 0, 0, 42, 0, 0, 0, 1802, 481, 616,
242, 168, 214, 918, 166],
[0, 0, 0, 0, 0, 35, 0, 0, 1746, 417, 559,
226, 171, 185, 903, 151],
[0, 0, 0, 0, 0, 0, 738, 0, 1011, 393, 529,
171, 150, 162, 176, 154],
[0, 0, 0, 0, 0, 0, 0, 36, 1866, 555, 693,
325, 211, 222, 965, 180]
])

def test_diagnose_row(self):
res = ds.Diagnose.diagnose_matrix(self.abnormal_row)
self.assertEqual(
Expand Down Expand Up @@ -105,6 +140,31 @@ def test_diagnose_point(self):
'abnormal_cols': [], 'abnormal_rows': [], 'abnormal_points': [
[3, 4, 125, 7.279344854723584]]})

def test_mc2_layered(self):
res = ds.Diagnose.diagnose_matrix(
mat=self.mc2_layered, excluding_zeros=0)
self.assertEqual(
res, {
'abnormal_cols': [
[
0, 799.3125, 3.2102414457222475]], 'abnormal_rows': [], 'abnormal_points': [
[
9, 8, 1599, 6.421988986422549], [
10, 8, 1610, 6.466167772445468], [
11, 8, 1687, 6.775419274605904], [
12, 8, 1802, 7.237288401209152], [
13, 8, 1746, 7.012378217819744], [
15, 8, 1866, 7.494328610797046]]})

res = ds.Diagnose.diagnose_matrix(
mat=self.mc2_layered, excluding_zeros=1)
self.assertEqual(res,
{'abnormal_cols': [[0,
799.3125,
3.2102414457222475]],
'abnormal_rows': [],
'abnormal_points': []})


if __name__ == '__main__':
unittest.main()