Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
"python.testing.unittestEnabled": false,
"python.testing.pytestArgs": [
"tests"
]
],
"editor.formatOnSave": true
}
3 changes: 1 addition & 2 deletions src/cdm_data_loader_utils/audit/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pyspark.sql import functions as sf

from cdm_data_loader_utils.audit.schema import (
AUDIT_SCHEMA,

Check failure on line 8 in src/cdm_data_loader_utils/audit/checkpoint.py

View workflow job for this annotation

GitHub Actions / Run code lint checks

Ruff (F401)

src/cdm_data_loader_utils/audit/checkpoint.py:8:5: F401 `cdm_data_loader_utils.audit.schema.AUDIT_SCHEMA` imported but unused
CHECKPOINT,
LAST_ENTRY_ID,
PIPELINE,
Expand Down Expand Up @@ -50,11 +50,10 @@
sf.lit(last_entry_id).alias(LAST_ENTRY_ID),
sf.current_timestamp().alias(UPDATED),
)
updates = spark.createDataFrame(df.rdd, schema=AUDIT_SCHEMA[CHECKPOINT])

(
delta.alias("t")
.merge(updates.alias("s"), current_run_expr())
.merge(df.alias("s"), current_run_expr())
.whenMatchedUpdate(set={val: f"s.{val}" for val in [STATUS, RECORDS_PROCESSED, LAST_ENTRY_ID, UPDATED]})
.whenNotMatchedInsertAll()
.execute()
Expand Down
3 changes: 1 addition & 2 deletions src/cdm_data_loader_utils/audit/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pyspark.sql import functions as sf

from cdm_data_loader_utils.audit.schema import (
AUDIT_SCHEMA,

Check failure on line 8 in src/cdm_data_loader_utils/audit/metrics.py

View workflow job for this annotation

GitHub Actions / Run code lint checks

Ruff (F401)

src/cdm_data_loader_utils/audit/metrics.py:8:5: F401 `cdm_data_loader_utils.audit.schema.AUDIT_SCHEMA` imported but unused
METRICS,
N_INVALID,
N_READ,
Expand Down Expand Up @@ -75,7 +75,6 @@
sf.lit(metrics.validation_errors).alias(VALIDATION_ERRORS),
sf.current_timestamp().alias(UPDATED),
)
updates = spark.createDataFrame(df.rdd, schema=AUDIT_SCHEMA[METRICS])

target = DeltaTable.forName(
spark,
Expand All @@ -85,7 +84,7 @@
(
target.alias("t")
.merge(
updates.alias("s"),
df.alias("s"),
current_run_expr(),
)
.whenMatchedUpdate(set={k: f"s.{k}" for k in [N_READ, N_VALID, N_INVALID, VALIDATION_ERRORS, UPDATED]})
Expand Down
5 changes: 1 addition & 4 deletions src/cdm_data_loader_utils/audit/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pyspark.sql import functions as sf

from cdm_data_loader_utils.audit.schema import (
AUDIT_SCHEMA,

Check failure on line 8 in src/cdm_data_loader_utils/audit/run.py

View workflow job for this annotation

GitHub Actions / Run code lint checks

Ruff (F401)

src/cdm_data_loader_utils/audit/run.py:8:5: F401 `cdm_data_loader_utils.audit.schema.AUDIT_SCHEMA` imported but unused
END_TIME,
ERROR,
PIPELINE,
Expand Down Expand Up @@ -48,10 +48,7 @@
sf.lit(sf.lit(None).cast("timestamp")).alias(END_TIME),
sf.lit(None).cast("string").alias(ERROR),
)

spark.createDataFrame(df.rdd, schema=AUDIT_SCHEMA[RUN]).write.format("delta").mode("append").saveAsTable(
f"{run.namespace}.{RUN}"
)
df.write.format("delta").mode("append").saveAsTable(f"{run.namespace}.{RUN}")


def complete_run(spark: SparkSession, run: PipelineRun, records_processed: int) -> None:
Expand Down
2 changes: 2 additions & 0 deletions src/cdm_data_loader_utils/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@

INVALID_DATA_FIELD_NAME = "__invalid_data__"
D = "delta"

CDM_LAKE_S3 = "s3a://cdm-lake"
17 changes: 8 additions & 9 deletions src/cdm_data_loader_utils/parsers/uniprot/idmapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pyspark.sql import functions as sf
from pyspark.sql.types import StringType, StructField

from cdm_data_loader_utils.core.constants import INVALID_DATA_FIELD_NAME
from cdm_data_loader_utils.core.constants import CDM_LAKE_S3, INVALID_DATA_FIELD_NAME
from cdm_data_loader_utils.core.pipeline_run import PipelineRun
from cdm_data_loader_utils.readers.dsv import read
from cdm_data_loader_utils.utils.cdm_logger import get_cdm_logger
Expand Down Expand Up @@ -117,7 +117,7 @@ def read_and_write(spark: SparkSession, pipeline_run: PipelineRun, id_mapping_ts
@click.option(
"--source",
required=True,
help="Full path to the source directory containing ID mapping file(s). S3 buckets should include the full path with s3a:// prefix; otherwise, the path will be assumed to be local.",
help="Full path to the source directory containing ID mapping file(s). Files are assumed to be in the CDM s3 minio bucket, and the s3a://cdm-lake prefix may be omitted.",
)
@click.option(
"--namespace",
Expand All @@ -141,17 +141,16 @@ def cli(source: str, namespace: str, tenant_name: str | None) -> None:
:type tenant_name: str | None
"""
(spark, delta_ns) = set_up_workspace(APP_NAME, namespace, tenant_name)
bucket_list = []
if "://" in source and source.startswith("s3a://"):
# we're golden
bucket_list = list_remote_dir_contents(source)
# TODO: other locations

# TODO: other locations / local files?
bucket_list = list_remote_dir_contents(source.removeprefix("s3a://cdm-lake/"))
for file in bucket_list:
# file names are in the 'Key' value
# 'tenant-general-warehouse/kbase/datasets/uniprot/id_mapping/id_mapping_part_001.tsv.gz'
pipeline_run = PipelineRun(str(uuid4()), APP_NAME, file["Key"], delta_ns)
read_and_write(spark, pipeline_run, file["Key"])
file_path = f"{CDM_LAKE_S3}/{file['Key']}"
pipeline_run = PipelineRun(str(uuid4()), APP_NAME, file_path, delta_ns)
logger.info("Reading in mappings from %s", file_path)
read_and_write(spark, pipeline_run, file_path)


if __name__ == "__main__":
Expand Down
25 changes: 24 additions & 1 deletion src/cdm_data_loader_utils/utils/cdm_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
MAX_LOG_FILE_SIZE = 2**30 # 1 GiB
MAX_LOG_BACKUPS = 5

__LOGGER = None

# TODO: adopt logging config, set just once
LOGGING_CONFIG = {
Expand Down Expand Up @@ -46,6 +47,28 @@

def get_cdm_logger(
logger_name: str | None = None, log_level: str | None = None, log_dir: str | None = None
) -> logging.Logger:
"""Retrieve the logger, initialising it if necessary.

If the logger name is not set, the default name "cdm_data_loader" will be used.

:param logger_name: name for the logger, defaults to None
:type logger_name: str | None, optional
:param log_level: logger level, defaults to None
:type log_level: str | None, optional
:param log_dir: directory to save log files to, optional. If no directory is specified, logs will just be emitted to the console.
:type log_dir: str | None
:return: initialised logger
:rtype: logging.Logger
"""
global __LOGGER
if not __LOGGER:
__LOGGER = init_logger(logger_name, log_level, log_dir)
return __LOGGER


def init_logger(
logger_name: str | None = None, log_level: str | None = None, log_dir: str | None = None
) -> logging.Logger:
"""Initialise the logger for the module.

Expand All @@ -62,6 +85,7 @@ def get_cdm_logger(
"""
if not logger_name:
logger_name = DEFAULT_LOGGER_NAME

# Always get the same logger by name
logger = logging.getLogger(logger_name)

Expand Down Expand Up @@ -89,7 +113,6 @@ def get_cdm_logger(
LOG_FILENAME, maxBytes=MAX_LOG_FILE_SIZE, backupCount=MAX_LOG_BACKUPS
)
logger.addHandler(file_handler)

return logger


Expand Down
3 changes: 2 additions & 1 deletion src/cdm_data_loader_utils/utils/minio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import tqdm
from berdl_notebook_utils.berdl_settings import get_settings

from cdm_data_loader_utils.core.constants import CDM_LAKE_S3
from cdm_data_loader_utils.utils.cdm_logger import get_cdm_logger

S3_BUCKET = "cdm-lake"
S3_BUCKET = CDM_LAKE_S3.removeprefix("s3a://")

# Get credentials from environment variables (automatically set in JupyterHub)

Expand Down
2 changes: 1 addition & 1 deletion src/cdm_data_loader_utils/utils/spark_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def write_delta(spark: SparkSession, sdf: DataFrame, delta_ns: str, table: str,
raise ValueError(msg)

db_table = f"{delta_ns}.{table}"
if sdf is None or not isinstance(sdf, DataFrame) or sdf.rdd.isEmpty():
if sdf is None or not isinstance(sdf, DataFrame) or sdf.isEmpty():
logger.warning("No data to write to %s", db_table)
return

Expand Down
Loading