diff --git a/pyproject.toml b/pyproject.toml index e5ec75e..0a19547 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,9 @@ dependencies = [ "ruff>=0.14.13", ] +[project.scripts] +idmapping = "cdm_data_loader_utils.parsers.idmapping:main" + [dependency-groups] dev = [ "hypothesis>=6.148.9", diff --git a/src/cdm_data_loader_utils/parsers/uniprot/idmapping.py b/src/cdm_data_loader_utils/parsers/uniprot/idmapping.py index 0b95267..649ad5c 100644 --- a/src/cdm_data_loader_utils/parsers/uniprot/idmapping.py +++ b/src/cdm_data_loader_utils/parsers/uniprot/idmapping.py @@ -117,7 +117,7 @@ def read_and_write(spark: SparkSession, pipeline_run: PipelineRun, id_mapping_ts @click.option( "--source", required=True, - help="Full path to the source directory containing ID mapping file(s). Does not need to specify the Bucket (i.e. cdm-lake) but should specify everything else.", + help="Full path to the source directory containing ID mapping file(s). S3 buckets should include the full path with s3a:// prefix; otherwise, the path will be assumed to be local.", ) @click.option( "--namespace", @@ -130,7 +130,7 @@ def read_and_write(spark: SparkSession, pipeline_run: PipelineRun, id_mapping_ts default=None, help="Tenant warehouse to save processed data to; defaults to saving data to the user warehouse if a tenant is not specified", ) -def main(source: str, namespace: str, tenant_name: str | None) -> None: +def cli(source: str, namespace: str, tenant_name: str | None) -> None: """Run the UniProt ID Mapping importer. :param source: full path to the source directory containing ID mapping file(s) @@ -141,7 +141,13 @@ def main(source: str, namespace: str, tenant_name: str | None) -> None: :type tenant_name: str | None """ (spark, delta_ns) = set_up_workspace(APP_NAME, namespace, tenant_name) - for file in list_remote_dir_contents(source): + bucket_list = [] + if "://" in source and source.startswith("s3a://"): + # we're golden + bucket_list = list_remote_dir_contents(source) + # TODO: other locations + + for file in bucket_list: # file names are in the 'Key' value # 'tenant-general-warehouse/kbase/datasets/uniprot/id_mapping/id_mapping_part_001.tsv.gz' pipeline_run = PipelineRun(str(uuid4()), APP_NAME, file["Key"], delta_ns) @@ -149,4 +155,4 @@ def main(source: str, namespace: str, tenant_name: str | None) -> None: if __name__ == "__main__": - main() + cli()