diff --git a/requirements/framework.txt b/requirements/framework.txt index 874ae2bb46..5fe0c918c5 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -34,7 +34,7 @@ sortedcontainers>=1.5.9 tensorboard tiktoken tqdm -transformers>=4.33,<4.58 +transformers>=4.33 transformers_stream_generator trl>=0.15,<0.25 uvicorn diff --git a/swift/callbacks/perf_log.py b/swift/callbacks/perf_log.py index 6e0ea64c24..9c71f3967b 100644 --- a/swift/callbacks/perf_log.py +++ b/swift/callbacks/perf_log.py @@ -5,7 +5,7 @@ import torch from transformers import TrainerControl, TrainerState -from swift.utils import empty_cache, get_logger +from swift.utils import empty_cache, get_current_device, get_device_count, get_env_args, get_logger from .base import TrainerCallback if TYPE_CHECKING: @@ -43,7 +43,6 @@ def __init__(self, args: 'TrainingArguments', trainer: 'Trainer'): self.step_start_time = None def on_init_end(self, args: 'TrainingArguments', state: TrainerState, control: TrainerControl, **kwargs): - from swift.utils import get_current_device, get_device_count, get_env_args # Top priority. Specify by ENV tflops = get_env_args('DEVICE_TFLOPS', int, None) diff --git a/swift/template/register.py b/swift/template/register.py index 27e4fce332..9e58983f06 100644 --- a/swift/template/register.py +++ b/swift/template/register.py @@ -25,7 +25,7 @@ def _read_args_json_template_type(model_dir): return from swift.arguments import BaseArguments args = BaseArguments.from_pretrained(model_dir) - return args.template_type + return args.template def get_template_meta(model_info: 'ModelInfo',