diff --git a/python/fate_flow/components/components/upload.py b/python/fate_flow/components/components/upload.py index bd02a366b..ae723c994 100644 --- a/python/fate_flow/components/components/upload.py +++ b/python/fate_flow/components/components/upload.py @@ -25,7 +25,7 @@ from fate_flow.entity.types import JsonMetricArtifactType, EngineType from fate_flow.manager.outputs.data import DatasetManager from fate_flow.runtime.system_settings import STANDALONE_DATA_HOME, ENGINES -from fate_flow.utils.file_utils import get_fate_flow_directory +from fate_flow.utils.file_utils import get_fate_flow_directory, transform_local_file, file_delete from fate_flow.utils.io_utils import URI @@ -135,6 +135,7 @@ def run(self, parameters: UploadParam, outputs: IOMeta.OutputMeta = None, job_id self.parameters = parameters logging.info(self.parameters.to_dict()) storage_address = self.parameters.storage_address + parameters.file, is_cache = transform_local_file(parameters.file) if not os.path.isabs(parameters.file): parameters.file = os.path.join( get_fate_flow_directory(), parameters.file @@ -208,6 +209,8 @@ def run(self, parameters: UploadParam, outputs: IOMeta.OutputMeta = None, job_id logging.info("table name: {}, table namespace: {}".format(name, namespace)) if outputs: self.save_outputs(job_id, outputs, data_table_count) + if is_cache: + file_delete(parameters.file) return {"name": name, "namespace": namespace, "count": data_table_count, "data_meta": self.data_meta} def save_data_table(self, job_id): diff --git a/python/fate_flow/utils/file_utils.py b/python/fate_flow/utils/file_utils.py index 413ed3b2e..c9a9f97b3 100644 --- a/python/fate_flow/utils/file_utils.py +++ b/python/fate_flow/utils/file_utils.py @@ -16,8 +16,12 @@ import json import os +import uuid +from pathlib import Path +import pandas as pd from ruamel import yaml +from sqlalchemy import create_engine from fate_flow.runtime.env import is_in_virtualenv @@ -74,6 +78,50 @@ def get_fate_flow_directory(*args): return fate_flow_dir +def transform_local_file(file): + """ + Args: + file (str): + values like : + mysql://user:password@host_ip:host_port/db/table + file:///path/to/local_file.csv + /path/to/local_file.csv + """ + def _find_positions(s): + last_at_index = s.rfind('@') + first_colon_index = s.find(':') + last_slash_index = s.rfind('/') + second_last_slash_index = s.rfind('/', 0, last_slash_index) + return last_at_index, first_colon_index, second_last_slash_index + + if file.startswith('mysql://'): + db_info_str = file[8:] + db_info = db_info_str.split('/') + table_name = db_info[-1] + dbname = db_info[-2] + last_at, first_colon, second_last_slash = _find_positions(db_info_str) + username = db_info_str[0:first_colon] + password = db_info_str[first_colon + 1: last_at] + host = db_info_str[last_at + 1:second_last_slash] + database_url = f"mysql+pymysql://{username}:{password}@{host}/{dbname}" + engine = create_engine(database_url) + df = pd.read_sql_table(table_name, con=engine) + file = f"/tmp/data_{uuid.uuid4()}.csv" + df.to_csv(file, index=False) + return file, True + + elif file.startswith('file://'): + return file[7:], False + + else: + return file, False + + +def file_delete(file): + file_path = Path(file) + file_path.unlink(missing_ok=True) + + def load_yaml_conf(conf_path): if not os.path.isabs(conf_path): conf_path = os.path.join(get_fate_flow_directory(), conf_path) diff --git a/python/requirements-flow.txt b/python/requirements-flow.txt index c8dfefe6b..7e3ed0096 100644 --- a/python/requirements-flow.txt +++ b/python/requirements-flow.txt @@ -23,4 +23,5 @@ shortuuid cos-python-sdk-v5==1.9.27 typing-extensions==4.8.0 boto3 +sqlalchemy==2.0.31 pyarrow==15.0.1