From a357d06d0f6e97ec4ab2464c54efa038f8092994 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 7 Feb 2026 21:13:05 +0800 Subject: [PATCH 1/2] update time update --- swift/trainers/patcher.py | 14 ++++++++------ swift/utils/utils.py | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/swift/trainers/patcher.py b/swift/trainers/patcher.py index 7ce7ab40d6..5bdf8688f0 100644 --- a/swift/trainers/patcher.py +++ b/swift/trainers/patcher.py @@ -25,22 +25,22 @@ def get_max_reserved_memory() -> float: return sum(mems) / 1024**3 -def add_train_message(logs, state, start_time) -> None: +def add_train_message(logs, state, start_time, start_step) -> None: logs['global_step/max_steps'] = f'{state.global_step}/{state.max_steps}' train_percentage = state.global_step / state.max_steps if state.max_steps else 0. logs['percentage'] = f'{train_percentage * 100:.2f}%' elapsed = time.time() - start_time logs['elapsed_time'] = format_time(elapsed) + train_speed = elapsed / (state.global_step - start_step) if train_percentage != 0: - logs['remaining_time'] = format_time(elapsed / train_percentage - elapsed) + logs['remaining_time'] = format_time((state.max_steps - state.global_step) * train_speed) for k, v in logs.items(): if isinstance(v, float): logs[k] = round(logs[k], 8) state.max_memory = max(getattr(state, 'max_memory', 0), get_max_reserved_memory()) if state.max_memory: logs['memory(GiB)'] = round(state.max_memory, 2) - - logs['train_speed(iter/s)'] = round(state.global_step / elapsed, 6) + logs['train_speed(s/it)'] = round(train_speed, 6) class ProgressCallbackNew(ProgressCallback): @@ -48,6 +48,7 @@ class ProgressCallbackNew(ProgressCallback): def on_train_begin(self, args, state, control, **kwargs): if state.is_world_process_zero: self.training_bar = tqdm(desc='Train', total=state.max_steps, dynamic_ncols=True) + self.start_step = state.global_step self.current_step = 0 self.start_time = time.time() @@ -61,7 +62,7 @@ def on_prediction_step(self, args, state: TrainerState, control, eval_dataloader self.prediction_bar.update() def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs): - add_train_message(logs, state, self.start_time) + add_train_message(logs, state, self.start_time, self.start_step) if not is_pai_training_job() and state.is_world_process_zero: jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') append_to_jsonl(jsonl_path, logs) @@ -100,10 +101,11 @@ class PrinterCallbackNew(PrinterCallback): def on_train_begin(self, args, state, control, **kwargs): self.start_time = time.time() + self.start_step = state.global_step return super().on_train_begin(args, state, control, **kwargs) def on_log(self, args, state, control, logs=None, **kwargs): - add_train_message(logs, state, self.start_time) + add_train_message(logs, state, self.start_time, self.start_step) if not is_pai_training_job() and state.is_world_process_zero: jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') append_to_jsonl(jsonl_path, logs) diff --git a/swift/utils/utils.py b/swift/utils/utils.py index 943c65467e..17e977936e 100644 --- a/swift/utils/utils.py +++ b/swift/utils/utils.py @@ -87,7 +87,7 @@ def format_time(seconds): days = int(seconds // (24 * 3600)) hours = int((seconds % (24 * 3600)) // 3600) minutes = int((seconds % 3600) // 60) - seconds = int(seconds % 60) + seconds = round(seconds % 60, 2) if days > 0: time_str = f'{days}d {hours}h {minutes}m {seconds}s' From f2e271d7b5627698f369cc1fea395ee7aac696f7 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 7 Feb 2026 21:21:33 +0800 Subject: [PATCH 2/2] update --- swift/trainers/patcher.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/swift/trainers/patcher.py b/swift/trainers/patcher.py index 5bdf8688f0..349b71c1ca 100644 --- a/swift/trainers/patcher.py +++ b/swift/trainers/patcher.py @@ -31,7 +31,8 @@ def add_train_message(logs, state, start_time, start_step) -> None: logs['percentage'] = f'{train_percentage * 100:.2f}%' elapsed = time.time() - start_time logs['elapsed_time'] = format_time(elapsed) - train_speed = elapsed / (state.global_step - start_step) + n_steps = state.global_step - start_step + train_speed = elapsed / n_steps if n_steps > 0 else 0.0 if train_percentage != 0: logs['remaining_time'] = format_time((state.max_steps - state.global_step) * train_speed) for k, v in logs.items():