Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cosmotech/coal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
# etc., to any person is prohibited unless it has been previously and
# specifically authorized by written means by Cosmo Tech.

__version__ = "2.0.0"
__version__ = "2.1.0-rc1"
49 changes: 34 additions & 15 deletions cosmotech/coal/cosmotech_api/apis/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,30 +35,49 @@ def __init__(
LOGGER.debug(T("coal.cosmotech_api.initialization.dataset_api_initialized"))

def download_dataset(self, dataset_id) -> Dataset:
LOGGER.debug(f"Downloading dataset {dataset_id}")
dataset = self.get_dataset(
organization_id=self.configuration.cosmotech.organization_id,
workspace_id=self.configuration.cosmotech.workspace_id,
dataset_id=dataset_id,
)
# send dataset files under dataset id folder
destination = Path(self.configuration.cosmotech.dataset_absolute_path) / dataset_id
for part in dataset.parts:
self._download_part(dataset_id, part, destination)
return dataset

dataset_dir = self.configuration.cosmotech.dataset_absolute_path
dataset_dir_path = Path(dataset_dir) / dataset_id
def download_parameter(self, dataset_id) -> Dataset:
LOGGER.debug(f"Downloading dataset {dataset_id}")
dataset = self.get_dataset(
organization_id=self.configuration.cosmotech.organization_id,
workspace_id=self.configuration.cosmotech.workspace_id,
dataset_id=dataset_id,
)
# send parameters file under parameters_name folder
destination = Path(self.configuration.cosmotech.parameters_absolute_path) / dataset_id
for part in dataset.parts:
part_file_path = dataset_dir_path / part.source_name
part_file_path.parent.mkdir(parents=True, exist_ok=True)
data_part = self.download_dataset_part(
organization_id=self.configuration.cosmotech.organization_id,
workspace_id=self.configuration.cosmotech.workspace_id,
dataset_id=dataset_id,
dataset_part_id=part.id,
)
with open(part_file_path, "wb") as binary_file:
binary_file.write(data_part)
LOGGER.debug(
T("coal.services.dataset.part_downloaded").format(part_name=part.source_name, file_path=part_file_path)
)
part_dst = destination / part.name
self._download_part(dataset_id, part, part_dst)
return dataset

def _download_part(self, dataset_id, dataset_part, destination):
part_file_path = destination / dataset_part.source_name
part_file_path.parent.mkdir(parents=True, exist_ok=True)
data_part = self.download_dataset_part(
organization_id=self.configuration.cosmotech.organization_id,
workspace_id=self.configuration.cosmotech.workspace_id,
dataset_id=dataset_id,
dataset_part_id=dataset_part.id,
)
with open(part_file_path, "wb") as binary_file:
binary_file.write(data_part)
LOGGER.debug(
T("coal.services.dataset.part_downloaded").format(
part_name=dataset_part.source_name, file_path=part_file_path
)
)

@staticmethod
def path_to_parts(_path, part_type) -> list[tuple[str, Path, DatasetPartTypeEnum]]:
if (_path := Path(_path)).is_dir():
Expand Down
42 changes: 23 additions & 19 deletions cosmotech/coal/cosmotech_api/apis/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,43 +29,47 @@ def __init__(

def get_runner_metadata(
self,
organization_id: str,
workspace_id: str,
runner_id: str,
runner_id: Optional[str] = None,
include: Optional[list[str]] = None,
exclude: Optional[list[str]] = None,
) -> dict[str, Any]:
runner = self.get_runner(organization_id, workspace_id, runner_id)
runner = self.get_runner(
self.configuration.cosmotech.organization_id,
self.configuration.cosmotech.workspace_id,
runner_id or self.configuration.cosmotech.runner_id,
)

return runner.model_dump(by_alias=True, exclude_none=True, include=include, exclude=exclude, mode="json")

def download_runner_data(
self,
organization_id: str,
workspace_id: str,
runner_id: str,
parameter_folder: str,
dataset_folder: Optional[str] = None,
download_datasets: Optional[str] = None,
):
LOGGER.info(T("coal.cosmotech_api.runner.starting_download"))

# Get runner data
runner_data = self.get_runner(organization_id, workspace_id, runner_id)
runner = self.get_runner(
self.configuration.cosmotech.organization_id,
self.configuration.cosmotech.workspace_id,
self.configuration.cosmotech.runner_id,
)

# Skip if no parameters found
if not runner_data.parameters_values:
if not runner.parameters_values:
LOGGER.warning(T("coal.cosmotech_api.runner.no_parameters"))
else:
LOGGER.info(T("coal.cosmotech_api.runner.loaded_data"))
parameters = Parameters(runner_data)
parameters.write_parameters_to_json(parameter_folder)
parameters = Parameters(runner)
parameters.write_parameters_to_json(self.configuration.cosmotech.parameters_absolute_path)

# Download datasets if requested
if dataset_folder:
datasets_ids = runner_data.datasets.bases
if runner.datasets.parameter:
ds_api = DatasetApi(self.configuration)
ds_api.download_parameter(runner.datasets.parameter)

if datasets_ids:
LOGGER.info(T("coal.cosmotech_api.runner.downloading_datasets").format(count=len(datasets_ids)))
# Download datasets if requested
if download_datasets:
LOGGER.info(T("coal.cosmotech_api.runner.downloading_datasets").format(count=len(runner.datasets.bases)))
if runner.datasets.bases:
ds_api = DatasetApi(self.configuration)
for dataset_id in datasets_ids:
for dataset_id in runner.datasets.bases:
ds_api.download_dataset(dataset_id)
5 changes: 1 addition & 4 deletions cosmotech/coal/postgresql/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def send_runner_metadata_to_postgresql(
# Get runner metadata
_runner_api = RunnerApi(configuration)
runner = _runner_api.get_runner_metadata(
configuration.cosmotech.organization_id,
configuration.cosmotech.workspace_id,
configuration.cosmotech.runner_id,
)

Expand All @@ -66,6 +64,7 @@ def send_runner_metadata_to_postgresql(
DO
UPDATE SET name = EXCLUDED.name, last_csm_run_id = EXCLUDED.last_csm_run_id;
"""
LOGGER.debug(runner)
curs.execute(
sql_upsert,
(
Expand Down Expand Up @@ -97,8 +96,6 @@ def remove_runner_metadata_from_postgresql(
# Get runner metadata
_runner_api = RunnerApi(configuration)
runner = _runner_api.get_runner_metadata(
configuration.cosmotech.organization_id,
configuration.cosmotech.workspace_id,
configuration.cosmotech.runner_id,
)

Expand Down
10 changes: 2 additions & 8 deletions cosmotech/coal/postgresql/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,12 @@ def dump_store_to_postgresql_from_conf(
_s_time = perf_counter()
target_table_name = f"{_psql.table_prefix}{table_name}"
LOGGER.info(T("coal.services.database.table_entry").format(table=target_table_name))
if fk_id:
_s.execute_query(
f"""
ALTER TABLE {table_name}
ADD csm_run_id TEXT NOT NULL
DEFAULT ('{fk_id}')
"""
)
data = _s.get_table(table_name)
if not len(data):
LOGGER.info(T("coal.services.database.no_rows"))
continue
if fk_id:
data = data.append_column("csm_run_id", [[fk_id] * data.num_rows])
_dl_time = perf_counter()
rows = _psql.send_pyarrow_table_to_postgresql(
data,
Expand Down
1 change: 0 additions & 1 deletion cosmotech/coal/store/output/postgres_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
)
from cosmotech.coal.postgresql.store import dump_store_to_postgresql_from_conf
from cosmotech.coal.store.output.channel_interface import ChannelInterface
from cosmotech.coal.utils.configuration import Configuration, Dotdict


class PostgresChannel(ChannelInterface):
Expand Down
4 changes: 2 additions & 2 deletions cosmotech/coal/utils/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __getattr__(self, key):
_r = _r.__getattr__(_p)
return _r
except (KeyError, AttributeError):
LOGGER.warning("dotdict Ref {_v} doesn't exist")
LOGGER.warning(f"dotdict Ref {_v} doesn't exist")
raise ReferenceKeyError(_v)
return _v

Expand Down Expand Up @@ -190,7 +190,7 @@ def safe_get(self, key, default=None):
_r = _r.__getattr__(_k)
return _r
except (KeyError, AttributeError) as err:
LOGGER.warning(err)
LOGGER.debug(f"{err} not found; returning {default}")
return default


Expand Down
18 changes: 10 additions & 8 deletions cosmotech/csm_data/commands/api/run_load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from cosmotech.orchestrator.utils.translate import T

from cosmotech.coal.utils.configuration import Configuration
from cosmotech.csm_data.utils.click import click
from cosmotech.csm_data.utils.decorators import require_env, translate_help, web_help

Expand Down Expand Up @@ -66,15 +67,16 @@ def run_load_data(
# Import the function at the start of the command
from cosmotech.coal.cosmotech_api.apis.runner import RunnerApi

_r = RunnerApi()
_configuration = Configuration()
_configuration.cosmotech.organization_id = organization_id
_configuration.cosmotech.workspace_id = workspace_id
_configuration.cosmotech.runner_id = runner_id
_configuration.cosmotech.parameters_absolute_path = parameters_absolute_path
_configuration.cosmotech.dataset_absolute_path = dataset_absolute_path

return _r.download_runner_data(
organization_id=organization_id,
workspace_id=workspace_id,
runner_id=runner_id,
parameter_folder=parameters_absolute_path,
dataset_folder=dataset_absolute_path,
)
_r = RunnerApi(_configuration)

return _r.download_runner_data(download_datasets=True)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion cosmotech/csm_data/commands/store/load_csv_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
@click.option(
"--csv-folder",
envvar="CSM_DATASET_ABSOLUTE_PATH",
envvar="CSM_OUTPUT_ABSOLUTE_PATH",
help=T("csm_data.commands.store.load_csv_folder.parameters.csv_folder"),
metavar="PATH",
type=str,
Expand Down
2 changes: 1 addition & 1 deletion cosmotech/csm_data/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def wrap_function(func):
@wraps(func)
def f(*args, **kwargs):
if envvar not in os.environ:
raise EnvironmentError(T("coal.errors.environment.missing_var").format(envvar=envvar))
raise EnvironmentError(T("coal.common.errors.missing_var").format(envvar=envvar))
return func(*args, **kwargs)

f.__doc__ = "\n".join([f.__doc__ or "", f"Requires env var `{envvar:<15}` *{envvar_desc}* "])
Expand Down
56 changes: 50 additions & 6 deletions tests/unit/coal/test_cosmotech_api/test_apis/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pathlib import Path
from unittest.mock import MagicMock, mock_open, patch

import pytest
from cosmotech_api.models.dataset import Dataset
from cosmotech_api.models.dataset_part_type_enum import DatasetPartTypeEnum

Expand Down Expand Up @@ -251,6 +250,7 @@ def test_upload_dataset_with_tags(self, mock_cosmotech_config, mock_api_client):
mock_configuration_instance = MagicMock()
mock_cosmotech_config.return_value = mock_configuration_instance

# Mock created dataset
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "new-dataset-123"

Expand Down Expand Up @@ -359,7 +359,7 @@ def test_upload_dataset_parts_new_parts(self, mock_cosmotech_config, mock_api_cl
file1 = Path(tmpdir) / "file1.csv"
file1.write_text("data1")

result = api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1)])
api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1)])

assert api.create_dataset_part.called
assert api.get_dataset.call_count == 2 # Called at start and end
Expand Down Expand Up @@ -394,7 +394,7 @@ def test_upload_dataset_parts_skip_existing(self, mock_cosmotech_config, mock_ap
file1 = Path(tmpdir) / "file1.csv"
file1.write_text("data1")

result = api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1)])
api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1)])

# Part should be skipped, not created
api.create_dataset_part.assert_not_called()
Expand Down Expand Up @@ -430,7 +430,7 @@ def test_upload_dataset_parts_replace_existing(self, mock_cosmotech_config, mock
file1 = Path(tmpdir) / "file1.csv"
file1.write_text("data1")

result = api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1)], replace_existing=True)
api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1)], replace_existing=True)

# Part should be deleted and then created
api.delete_dataset_part.assert_called_once()
Expand Down Expand Up @@ -470,7 +470,7 @@ def test_upload_dataset_parts_mixed(self, mock_cosmotech_config, mock_api_client
file2 = Path(tmpdir) / "file2.csv"
file2.write_text("data2")

result = api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1), str(file2)])
api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1), str(file2)])

# Only the new file should be created
assert api.create_dataset_part.call_count == 1
Expand Down Expand Up @@ -502,10 +502,54 @@ def test_upload_dataset_parts_with_db_type(self, mock_cosmotech_config, mock_api
db_file = Path(tmpdir) / "data.db"
db_file.write_text("database content")

result = api.upload_dataset_parts("existing-dataset-123", as_db=[str(db_file)])
api.upload_dataset_parts("existing-dataset-123", as_db=[str(db_file)])

assert api.create_dataset_part.called
# Verify the part request has DB type
call_args = api.create_dataset_part.call_args
part_request = call_args.kwargs.get("dataset_part_create_request")
assert part_request.type == DatasetPartTypeEnum.DB

@patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True)
@patch("cosmotech_api.ApiClient")
@patch("cosmotech_api.Configuration")
def test_update_dataset_mixed_files(self, mock_cosmotech_config, mock_api_client):
"""Test uploading a dataset with both file and database types."""
mock_config = MagicMock()
mock_config.cosmotech.organization_id = "org-123"
mock_config.cosmotech.workspace_id = "ws-456"

mock_client_instance = MagicMock()
mock_api_client.return_value = mock_client_instance
mock_configuration_instance = MagicMock()
mock_cosmotech_config.return_value = mock_configuration_instance

# Mock created dataset
mock_dataset = MagicMock(spec=Dataset)
mock_dataset.id = "new-dataset-123"

api = DatasetApi(configuration=mock_config)
api.create_dataset_part = MagicMock()

with tempfile.TemporaryDirectory() as tmpdir:
csv_file = Path(tmpdir) / "data.csv"
csv_file.write_text("csv content")
db_file = Path(tmpdir) / "data.db"
db_file.write_text("database content")

api.upload_dataset_parts("Test Dataset", as_files=[str(csv_file)], as_db=[str(db_file)])

args_list = api.create_dataset_part.call_args_list
assert len(args_list) == 2
# check first call used to create csv part
dpcr = args_list[0].kwargs.get("dataset_part_create_request")
assert dpcr.name == "data.csv"
assert dpcr.source_name == "data.csv"
assert dpcr.description == "data.csv"
assert dpcr.type == DatasetPartTypeEnum.FILE
# check second call used to create db part
dpcr = args_list[1].kwargs.get("dataset_part_create_request")
assert dpcr.name == "data.db"
assert dpcr.source_name == "data.db"
assert dpcr.description == "data.db"
assert dpcr.type == DatasetPartTypeEnum.DB
Loading
Loading