diff --git a/cosmotech/coal/__init__.py b/cosmotech/coal/__init__.py index ad7576a3..13062d36 100644 --- a/cosmotech/coal/__init__.py +++ b/cosmotech/coal/__init__.py @@ -5,4 +5,4 @@ # etc., to any person is prohibited unless it has been previously and # specifically authorized by written means by Cosmo Tech. -__version__ = "2.0.0" +__version__ = "2.1.0-rc1" diff --git a/cosmotech/coal/cosmotech_api/apis/dataset.py b/cosmotech/coal/cosmotech_api/apis/dataset.py index e0cade46..a137b2f4 100644 --- a/cosmotech/coal/cosmotech_api/apis/dataset.py +++ b/cosmotech/coal/cosmotech_api/apis/dataset.py @@ -35,30 +35,49 @@ def __init__( LOGGER.debug(T("coal.cosmotech_api.initialization.dataset_api_initialized")) def download_dataset(self, dataset_id) -> Dataset: + LOGGER.debug(f"Downloading dataset {dataset_id}") dataset = self.get_dataset( organization_id=self.configuration.cosmotech.organization_id, workspace_id=self.configuration.cosmotech.workspace_id, dataset_id=dataset_id, ) + # send dataset files under dataset id folder + destination = Path(self.configuration.cosmotech.dataset_absolute_path) / dataset_id + for part in dataset.parts: + self._download_part(dataset_id, part, destination) + return dataset - dataset_dir = self.configuration.cosmotech.dataset_absolute_path - dataset_dir_path = Path(dataset_dir) / dataset_id + def download_parameter(self, dataset_id) -> Dataset: + LOGGER.debug(f"Downloading dataset {dataset_id}") + dataset = self.get_dataset( + organization_id=self.configuration.cosmotech.organization_id, + workspace_id=self.configuration.cosmotech.workspace_id, + dataset_id=dataset_id, + ) + # send parameters file under parameters_name folder + destination = Path(self.configuration.cosmotech.parameters_absolute_path) / dataset_id for part in dataset.parts: - part_file_path = dataset_dir_path / part.source_name - part_file_path.parent.mkdir(parents=True, exist_ok=True) - data_part = self.download_dataset_part( - organization_id=self.configuration.cosmotech.organization_id, - workspace_id=self.configuration.cosmotech.workspace_id, - dataset_id=dataset_id, - dataset_part_id=part.id, - ) - with open(part_file_path, "wb") as binary_file: - binary_file.write(data_part) - LOGGER.debug( - T("coal.services.dataset.part_downloaded").format(part_name=part.source_name, file_path=part_file_path) - ) + part_dst = destination / part.name + self._download_part(dataset_id, part, part_dst) return dataset + def _download_part(self, dataset_id, dataset_part, destination): + part_file_path = destination / dataset_part.source_name + part_file_path.parent.mkdir(parents=True, exist_ok=True) + data_part = self.download_dataset_part( + organization_id=self.configuration.cosmotech.organization_id, + workspace_id=self.configuration.cosmotech.workspace_id, + dataset_id=dataset_id, + dataset_part_id=dataset_part.id, + ) + with open(part_file_path, "wb") as binary_file: + binary_file.write(data_part) + LOGGER.debug( + T("coal.services.dataset.part_downloaded").format( + part_name=dataset_part.source_name, file_path=part_file_path + ) + ) + @staticmethod def path_to_parts(_path, part_type) -> list[tuple[str, Path, DatasetPartTypeEnum]]: if (_path := Path(_path)).is_dir(): diff --git a/cosmotech/coal/cosmotech_api/apis/runner.py b/cosmotech/coal/cosmotech_api/apis/runner.py index 51c175a1..1ea62c7d 100644 --- a/cosmotech/coal/cosmotech_api/apis/runner.py +++ b/cosmotech/coal/cosmotech_api/apis/runner.py @@ -29,43 +29,47 @@ def __init__( def get_runner_metadata( self, - organization_id: str, - workspace_id: str, - runner_id: str, + runner_id: Optional[str] = None, include: Optional[list[str]] = None, exclude: Optional[list[str]] = None, ) -> dict[str, Any]: - runner = self.get_runner(organization_id, workspace_id, runner_id) + runner = self.get_runner( + self.configuration.cosmotech.organization_id, + self.configuration.cosmotech.workspace_id, + runner_id or self.configuration.cosmotech.runner_id, + ) return runner.model_dump(by_alias=True, exclude_none=True, include=include, exclude=exclude, mode="json") def download_runner_data( self, - organization_id: str, - workspace_id: str, - runner_id: str, - parameter_folder: str, - dataset_folder: Optional[str] = None, + download_datasets: Optional[str] = None, ): LOGGER.info(T("coal.cosmotech_api.runner.starting_download")) # Get runner data - runner_data = self.get_runner(organization_id, workspace_id, runner_id) + runner = self.get_runner( + self.configuration.cosmotech.organization_id, + self.configuration.cosmotech.workspace_id, + self.configuration.cosmotech.runner_id, + ) # Skip if no parameters found - if not runner_data.parameters_values: + if not runner.parameters_values: LOGGER.warning(T("coal.cosmotech_api.runner.no_parameters")) else: LOGGER.info(T("coal.cosmotech_api.runner.loaded_data")) - parameters = Parameters(runner_data) - parameters.write_parameters_to_json(parameter_folder) + parameters = Parameters(runner) + parameters.write_parameters_to_json(self.configuration.cosmotech.parameters_absolute_path) - # Download datasets if requested - if dataset_folder: - datasets_ids = runner_data.datasets.bases + if runner.datasets.parameter: + ds_api = DatasetApi(self.configuration) + ds_api.download_parameter(runner.datasets.parameter) - if datasets_ids: - LOGGER.info(T("coal.cosmotech_api.runner.downloading_datasets").format(count=len(datasets_ids))) + # Download datasets if requested + if download_datasets: + LOGGER.info(T("coal.cosmotech_api.runner.downloading_datasets").format(count=len(runner.datasets.bases))) + if runner.datasets.bases: ds_api = DatasetApi(self.configuration) - for dataset_id in datasets_ids: + for dataset_id in runner.datasets.bases: ds_api.download_dataset(dataset_id) diff --git a/cosmotech/coal/postgresql/runner.py b/cosmotech/coal/postgresql/runner.py index 2cd8f2ad..683628bb 100644 --- a/cosmotech/coal/postgresql/runner.py +++ b/cosmotech/coal/postgresql/runner.py @@ -38,8 +38,6 @@ def send_runner_metadata_to_postgresql( # Get runner metadata _runner_api = RunnerApi(configuration) runner = _runner_api.get_runner_metadata( - configuration.cosmotech.organization_id, - configuration.cosmotech.workspace_id, configuration.cosmotech.runner_id, ) @@ -66,6 +64,7 @@ def send_runner_metadata_to_postgresql( DO UPDATE SET name = EXCLUDED.name, last_csm_run_id = EXCLUDED.last_csm_run_id; """ + LOGGER.debug(runner) curs.execute( sql_upsert, ( @@ -97,8 +96,6 @@ def remove_runner_metadata_from_postgresql( # Get runner metadata _runner_api = RunnerApi(configuration) runner = _runner_api.get_runner_metadata( - configuration.cosmotech.organization_id, - configuration.cosmotech.workspace_id, configuration.cosmotech.runner_id, ) diff --git a/cosmotech/coal/postgresql/store.py b/cosmotech/coal/postgresql/store.py index 14167ec9..06425e96 100644 --- a/cosmotech/coal/postgresql/store.py +++ b/cosmotech/coal/postgresql/store.py @@ -101,18 +101,12 @@ def dump_store_to_postgresql_from_conf( _s_time = perf_counter() target_table_name = f"{_psql.table_prefix}{table_name}" LOGGER.info(T("coal.services.database.table_entry").format(table=target_table_name)) - if fk_id: - _s.execute_query( - f""" - ALTER TABLE {table_name} - ADD csm_run_id TEXT NOT NULL - DEFAULT ('{fk_id}') - """ - ) data = _s.get_table(table_name) if not len(data): LOGGER.info(T("coal.services.database.no_rows")) continue + if fk_id: + data = data.append_column("csm_run_id", [[fk_id] * data.num_rows]) _dl_time = perf_counter() rows = _psql.send_pyarrow_table_to_postgresql( data, diff --git a/cosmotech/coal/store/output/postgres_channel.py b/cosmotech/coal/store/output/postgres_channel.py index 4298f774..22894b80 100644 --- a/cosmotech/coal/store/output/postgres_channel.py +++ b/cosmotech/coal/store/output/postgres_channel.py @@ -6,7 +6,6 @@ ) from cosmotech.coal.postgresql.store import dump_store_to_postgresql_from_conf from cosmotech.coal.store.output.channel_interface import ChannelInterface -from cosmotech.coal.utils.configuration import Configuration, Dotdict class PostgresChannel(ChannelInterface): diff --git a/cosmotech/coal/utils/configuration.py b/cosmotech/coal/utils/configuration.py index 496be5ef..1b0e98d6 100644 --- a/cosmotech/coal/utils/configuration.py +++ b/cosmotech/coal/utils/configuration.py @@ -26,7 +26,7 @@ def __getattr__(self, key): _r = _r.__getattr__(_p) return _r except (KeyError, AttributeError): - LOGGER.warning("dotdict Ref {_v} doesn't exist") + LOGGER.warning(f"dotdict Ref {_v} doesn't exist") raise ReferenceKeyError(_v) return _v @@ -190,7 +190,7 @@ def safe_get(self, key, default=None): _r = _r.__getattr__(_k) return _r except (KeyError, AttributeError) as err: - LOGGER.warning(err) + LOGGER.debug(f"{err} not found; returning {default}") return default diff --git a/cosmotech/csm_data/commands/api/run_load_data.py b/cosmotech/csm_data/commands/api/run_load_data.py index 1fa48212..9a7cdb3a 100644 --- a/cosmotech/csm_data/commands/api/run_load_data.py +++ b/cosmotech/csm_data/commands/api/run_load_data.py @@ -7,6 +7,7 @@ from cosmotech.orchestrator.utils.translate import T +from cosmotech.coal.utils.configuration import Configuration from cosmotech.csm_data.utils.click import click from cosmotech.csm_data.utils.decorators import require_env, translate_help, web_help @@ -66,15 +67,16 @@ def run_load_data( # Import the function at the start of the command from cosmotech.coal.cosmotech_api.apis.runner import RunnerApi - _r = RunnerApi() + _configuration = Configuration() + _configuration.cosmotech.organization_id = organization_id + _configuration.cosmotech.workspace_id = workspace_id + _configuration.cosmotech.runner_id = runner_id + _configuration.cosmotech.parameters_absolute_path = parameters_absolute_path + _configuration.cosmotech.dataset_absolute_path = dataset_absolute_path - return _r.download_runner_data( - organization_id=organization_id, - workspace_id=workspace_id, - runner_id=runner_id, - parameter_folder=parameters_absolute_path, - dataset_folder=dataset_absolute_path, - ) + _r = RunnerApi(_configuration) + + return _r.download_runner_data(download_datasets=True) if __name__ == "__main__": diff --git a/cosmotech/csm_data/commands/store/load_csv_folder.py b/cosmotech/csm_data/commands/store/load_csv_folder.py index 593b9c73..93d2bf18 100644 --- a/cosmotech/csm_data/commands/store/load_csv_folder.py +++ b/cosmotech/csm_data/commands/store/load_csv_folder.py @@ -24,7 +24,7 @@ ) @click.option( "--csv-folder", - envvar="CSM_DATASET_ABSOLUTE_PATH", + envvar="CSM_OUTPUT_ABSOLUTE_PATH", help=T("csm_data.commands.store.load_csv_folder.parameters.csv_folder"), metavar="PATH", type=str, diff --git a/cosmotech/csm_data/utils/decorators.py b/cosmotech/csm_data/utils/decorators.py index b16c74b6..1c65c161 100644 --- a/cosmotech/csm_data/utils/decorators.py +++ b/cosmotech/csm_data/utils/decorators.py @@ -35,7 +35,7 @@ def wrap_function(func): @wraps(func) def f(*args, **kwargs): if envvar not in os.environ: - raise EnvironmentError(T("coal.errors.environment.missing_var").format(envvar=envvar)) + raise EnvironmentError(T("coal.common.errors.missing_var").format(envvar=envvar)) return func(*args, **kwargs) f.__doc__ = "\n".join([f.__doc__ or "", f"Requires env var `{envvar:<15}` *{envvar_desc}* "]) diff --git a/tests/unit/coal/test_cosmotech_api/test_apis/test_dataset.py b/tests/unit/coal/test_cosmotech_api/test_apis/test_dataset.py index 191adf0d..8ac14f48 100644 --- a/tests/unit/coal/test_cosmotech_api/test_apis/test_dataset.py +++ b/tests/unit/coal/test_cosmotech_api/test_apis/test_dataset.py @@ -10,7 +10,6 @@ from pathlib import Path from unittest.mock import MagicMock, mock_open, patch -import pytest from cosmotech_api.models.dataset import Dataset from cosmotech_api.models.dataset_part_type_enum import DatasetPartTypeEnum @@ -251,6 +250,7 @@ def test_upload_dataset_with_tags(self, mock_cosmotech_config, mock_api_client): mock_configuration_instance = MagicMock() mock_cosmotech_config.return_value = mock_configuration_instance + # Mock created dataset mock_dataset = MagicMock(spec=Dataset) mock_dataset.id = "new-dataset-123" @@ -359,7 +359,7 @@ def test_upload_dataset_parts_new_parts(self, mock_cosmotech_config, mock_api_cl file1 = Path(tmpdir) / "file1.csv" file1.write_text("data1") - result = api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1)]) + api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1)]) assert api.create_dataset_part.called assert api.get_dataset.call_count == 2 # Called at start and end @@ -394,7 +394,7 @@ def test_upload_dataset_parts_skip_existing(self, mock_cosmotech_config, mock_ap file1 = Path(tmpdir) / "file1.csv" file1.write_text("data1") - result = api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1)]) + api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1)]) # Part should be skipped, not created api.create_dataset_part.assert_not_called() @@ -430,7 +430,7 @@ def test_upload_dataset_parts_replace_existing(self, mock_cosmotech_config, mock file1 = Path(tmpdir) / "file1.csv" file1.write_text("data1") - result = api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1)], replace_existing=True) + api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1)], replace_existing=True) # Part should be deleted and then created api.delete_dataset_part.assert_called_once() @@ -470,7 +470,7 @@ def test_upload_dataset_parts_mixed(self, mock_cosmotech_config, mock_api_client file2 = Path(tmpdir) / "file2.csv" file2.write_text("data2") - result = api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1), str(file2)]) + api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1), str(file2)]) # Only the new file should be created assert api.create_dataset_part.call_count == 1 @@ -502,10 +502,54 @@ def test_upload_dataset_parts_with_db_type(self, mock_cosmotech_config, mock_api db_file = Path(tmpdir) / "data.db" db_file.write_text("database content") - result = api.upload_dataset_parts("existing-dataset-123", as_db=[str(db_file)]) + api.upload_dataset_parts("existing-dataset-123", as_db=[str(db_file)]) assert api.create_dataset_part.called # Verify the part request has DB type call_args = api.create_dataset_part.call_args part_request = call_args.kwargs.get("dataset_part_create_request") assert part_request.type == DatasetPartTypeEnum.DB + + @patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True) + @patch("cosmotech_api.ApiClient") + @patch("cosmotech_api.Configuration") + def test_update_dataset_mixed_files(self, mock_cosmotech_config, mock_api_client): + """Test uploading a dataset with both file and database types.""" + mock_config = MagicMock() + mock_config.cosmotech.organization_id = "org-123" + mock_config.cosmotech.workspace_id = "ws-456" + + mock_client_instance = MagicMock() + mock_api_client.return_value = mock_client_instance + mock_configuration_instance = MagicMock() + mock_cosmotech_config.return_value = mock_configuration_instance + + # Mock created dataset + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "new-dataset-123" + + api = DatasetApi(configuration=mock_config) + api.create_dataset_part = MagicMock() + + with tempfile.TemporaryDirectory() as tmpdir: + csv_file = Path(tmpdir) / "data.csv" + csv_file.write_text("csv content") + db_file = Path(tmpdir) / "data.db" + db_file.write_text("database content") + + api.upload_dataset_parts("Test Dataset", as_files=[str(csv_file)], as_db=[str(db_file)]) + + args_list = api.create_dataset_part.call_args_list + assert len(args_list) == 2 + # check first call used to create csv part + dpcr = args_list[0].kwargs.get("dataset_part_create_request") + assert dpcr.name == "data.csv" + assert dpcr.source_name == "data.csv" + assert dpcr.description == "data.csv" + assert dpcr.type == DatasetPartTypeEnum.FILE + # check second call used to create db part + dpcr = args_list[1].kwargs.get("dataset_part_create_request") + assert dpcr.name == "data.db" + assert dpcr.source_name == "data.db" + assert dpcr.description == "data.db" + assert dpcr.type == DatasetPartTypeEnum.DB diff --git a/tests/unit/coal/test_cosmotech_api/test_apis/test_runner.py b/tests/unit/coal/test_cosmotech_api/test_apis/test_runner.py index 314d5f39..6a5f26c1 100644 --- a/tests/unit/coal/test_cosmotech_api/test_apis/test_runner.py +++ b/tests/unit/coal/test_cosmotech_api/test_apis/test_runner.py @@ -14,35 +14,42 @@ from cosmotech.coal.utils.configuration import Configuration +@pytest.fixture +def base_runner_config(): + return Configuration( + { + "cosmotech": { + "organization_id": "org-123", + "workspace_id": "ws-456", + "runner_id": "runner-789", + "parameters_absolute_path": "/tmp/params", + "datasets_absolute_path": "/tmp/datasets", + } + } + ) + + class TestRunnerApi: """Tests for the RunnerApi class.""" @patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True) @patch("cosmotech_api.ApiClient") - @patch("cosmotech_api.Configuration") - def test_runner_api_initialization(self, mock_cosmotech_config, mock_api_client): + def test_runner_api_initialization(self, mock_api_client, base_runner_config): """Test RunnerApi initialization.""" - mock_config = MagicMock(spec=Configuration) mock_client_instance = MagicMock() mock_api_client.return_value = mock_client_instance - mock_configuration_instance = MagicMock() - mock_cosmotech_config.return_value = mock_configuration_instance - api = RunnerApi(configuration=mock_config) + api = RunnerApi(configuration=base_runner_config) assert api.api_client == mock_client_instance - assert api.configuration == mock_config + assert api.configuration == base_runner_config @patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True) @patch("cosmotech_api.ApiClient") - @patch("cosmotech_api.Configuration") - def test_get_runner_metadata(self, mock_cosmotech_config, mock_api_client): + def test_get_runner_metadata_default_config(self, mock_api_client, base_runner_config): """Test getting runner metadata.""" - mock_config = MagicMock(spec=Configuration) mock_client_instance = MagicMock() mock_api_client.return_value = mock_client_instance - mock_configuration_instance = MagicMock() - mock_cosmotech_config.return_value = mock_configuration_instance # Mock runner object mock_runner = MagicMock() @@ -52,10 +59,10 @@ def test_get_runner_metadata(self, mock_cosmotech_config, mock_api_client): "description": "Test Description", } - api = RunnerApi(configuration=mock_config) + api = RunnerApi(configuration=base_runner_config) api.get_runner = MagicMock(return_value=mock_runner) - result = api.get_runner_metadata("org-123", "ws-456", "runner-789") + result = api.get_runner_metadata() assert result == {"id": "runner-123", "name": "Test Runner", "description": "Test Description"} api.get_runner.assert_called_once_with("org-123", "ws-456", "runner-789") @@ -65,23 +72,45 @@ def test_get_runner_metadata(self, mock_cosmotech_config, mock_api_client): @patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True) @patch("cosmotech_api.ApiClient") - @patch("cosmotech_api.Configuration") - def test_get_runner_metadata_with_include(self, mock_cosmotech_config, mock_api_client): + def test_get_runner_metadata(self, mock_api_client, base_runner_config): + """Test getting runner metadata.""" + mock_client_instance = MagicMock() + mock_api_client.return_value = mock_client_instance + + # Mock runner object + mock_runner = MagicMock() + mock_runner.model_dump.return_value = { + "id": "runner-123", + "name": "Test Runner", + "description": "Test Description", + } + + api = RunnerApi(configuration=base_runner_config) + api.get_runner = MagicMock(return_value=mock_runner) + + result = api.get_runner_metadata("runner-1000") + + assert result == {"id": "runner-123", "name": "Test Runner", "description": "Test Description"} + api.get_runner.assert_called_once_with("org-123", "ws-456", "runner-1000") + mock_runner.model_dump.assert_called_once_with( + by_alias=True, exclude_none=True, include=None, exclude=None, mode="json" + ) + + @patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True) + @patch("cosmotech_api.ApiClient") + def test_get_runner_metadata_with_include(self, mock_api_client, base_runner_config): """Test getting runner metadata with include filter.""" - mock_config = MagicMock(spec=Configuration) mock_client_instance = MagicMock() mock_api_client.return_value = mock_client_instance - mock_configuration_instance = MagicMock() - mock_cosmotech_config.return_value = mock_configuration_instance # Mock runner object mock_runner = MagicMock() mock_runner.model_dump.return_value = {"id": "runner-123", "name": "Test Runner"} - api = RunnerApi(configuration=mock_config) + api = RunnerApi(configuration=base_runner_config) api.get_runner = MagicMock(return_value=mock_runner) - result = api.get_runner_metadata("org-123", "ws-456", "runner-789", include=["id", "name"]) + result = api.get_runner_metadata(include=["id", "name"]) assert result == {"id": "runner-123", "name": "Test Runner"} mock_runner.model_dump.assert_called_once_with( @@ -90,23 +119,19 @@ def test_get_runner_metadata_with_include(self, mock_cosmotech_config, mock_api_ @patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True) @patch("cosmotech_api.ApiClient") - @patch("cosmotech_api.Configuration") - def test_get_runner_metadata_with_exclude(self, mock_cosmotech_config, mock_api_client): + def test_get_runner_metadata_with_exclude(self, mock_api_client, base_runner_config): """Test getting runner metadata with exclude filter.""" - mock_config = MagicMock(spec=Configuration) mock_client_instance = MagicMock() mock_api_client.return_value = mock_client_instance - mock_configuration_instance = MagicMock() - mock_cosmotech_config.return_value = mock_configuration_instance # Mock runner object mock_runner = MagicMock() mock_runner.model_dump.return_value = {"id": "runner-123", "name": "Test Runner"} - api = RunnerApi(configuration=mock_config) + api = RunnerApi(configuration=base_runner_config) api.get_runner = MagicMock(return_value=mock_runner) - result = api.get_runner_metadata("org-123", "ws-456", "runner-789", exclude=["description"]) + result = api.get_runner_metadata(exclude=["description"]) assert result == {"id": "runner-123", "name": "Test Runner"} mock_runner.model_dump.assert_called_once_with( @@ -115,15 +140,11 @@ def test_get_runner_metadata_with_exclude(self, mock_cosmotech_config, mock_api_ @patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True) @patch("cosmotech_api.ApiClient") - @patch("cosmotech_api.Configuration") @patch("cosmotech.coal.cosmotech_api.apis.runner.Parameters") - def test_download_runner_data_with_parameters(self, mock_parameters_class, mock_cosmotech_config, mock_api_client): + def test_download_runner_data_with_parameters(self, mock_parameters_class, mock_api_client, base_runner_config): """Test downloading runner data with parameters.""" - mock_config = MagicMock(spec=Configuration) mock_client_instance = MagicMock() mock_api_client.return_value = mock_client_instance - mock_configuration_instance = MagicMock() - mock_cosmotech_config.return_value = mock_configuration_instance # Mock runner data with parameters mock_runner_data = MagicMock() @@ -131,16 +152,17 @@ def test_download_runner_data_with_parameters(self, mock_parameters_class, mock_ param1.parameter_id = "param1" param1.value = "value1" mock_runner_data.parameters_values = [param1] + mock_runner_data.datasets.parameter = "d-123" mock_runner_data.datasets.bases = [] # Mock Parameters instance mock_parameters_instance = MagicMock() mock_parameters_class.return_value = mock_parameters_instance - api = RunnerApi(configuration=mock_config) + api = RunnerApi(configuration=base_runner_config) api.get_runner = MagicMock(return_value=mock_runner_data) - api.download_runner_data("org-123", "ws-456", "runner-789", "/tmp/params") + api.download_runner_data() api.get_runner.assert_called_once_with("org-123", "ws-456", "runner-789") mock_parameters_class.assert_called_once_with(mock_runner_data) @@ -148,39 +170,32 @@ def test_download_runner_data_with_parameters(self, mock_parameters_class, mock_ @patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True) @patch("cosmotech_api.ApiClient") - @patch("cosmotech_api.Configuration") - def test_download_runner_data_no_parameters(self, mock_cosmotech_config, mock_api_client): + def test_download_runner_data_no_parameters(self, mock_api_client, base_runner_config): """Test downloading runner data without parameters.""" - mock_config = MagicMock(spec=Configuration) mock_client_instance = MagicMock() mock_api_client.return_value = mock_client_instance - mock_configuration_instance = MagicMock() - mock_cosmotech_config.return_value = mock_configuration_instance # Mock runner data without parameters mock_runner_data = MagicMock() mock_runner_data.parameters_values = None mock_runner_data.datasets.bases = [] + mock_runner_data.datasets.parameter = "d-123" - api = RunnerApi(configuration=mock_config) + api = RunnerApi(configuration=base_runner_config) api.get_runner = MagicMock(return_value=mock_runner_data) # Should not raise an exception - api.download_runner_data("org-123", "ws-456", "runner-789", "/tmp/params") + api.download_runner_data() api.get_runner.assert_called_once_with("org-123", "ws-456", "runner-789") @patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True) @patch("cosmotech_api.ApiClient") - @patch("cosmotech_api.Configuration") @patch("cosmotech.coal.cosmotech_api.apis.runner.DatasetApi") - def test_download_runner_data_with_datasets(self, mock_dataset_api_class, mock_cosmotech_config, mock_api_client): + def test_download_runner_data_with_datasets(self, mock_dataset_api_class, mock_api_client, base_runner_config): """Test downloading runner data with datasets.""" - mock_config = MagicMock(spec=Configuration) mock_client_instance = MagicMock() mock_api_client.return_value = mock_client_instance - mock_configuration_instance = MagicMock() - mock_cosmotech_config.return_value = mock_configuration_instance # Mock runner data with datasets mock_runner_data = MagicMock() @@ -191,37 +206,34 @@ def test_download_runner_data_with_datasets(self, mock_dataset_api_class, mock_c mock_dataset_api_instance = MagicMock() mock_dataset_api_class.return_value = mock_dataset_api_instance - api = RunnerApi(configuration=mock_config) + api = RunnerApi(configuration=base_runner_config) api.get_runner = MagicMock(return_value=mock_runner_data) - api.download_runner_data("org-123", "ws-456", "runner-789", "/tmp/params", "/tmp/datasets") + api.download_runner_data("/tmp/datasets") api.get_runner.assert_called_once_with("org-123", "ws-456", "runner-789") - mock_dataset_api_class.assert_called_once_with(mock_config) + mock_dataset_api_class.assert_called_with(base_runner_config) assert mock_dataset_api_instance.download_dataset.call_count == 2 mock_dataset_api_instance.download_dataset.assert_any_call("dataset-1") mock_dataset_api_instance.download_dataset.assert_any_call("dataset-2") @patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True) @patch("cosmotech_api.ApiClient") - @patch("cosmotech_api.Configuration") - def test_download_runner_data_no_datasets_without_folder(self, mock_cosmotech_config, mock_api_client): - """Test downloading runner data without datasets when dataset_folder is None.""" - mock_config = MagicMock(spec=Configuration) + def test_download_runner_data_no_datasets_without_folder(self, mock_api_client, base_runner_config): + """Test downloading runner data without datasets when download_datasets is None.""" mock_client_instance = MagicMock() mock_api_client.return_value = mock_client_instance - mock_configuration_instance = MagicMock() - mock_cosmotech_config.return_value = mock_configuration_instance # Mock runner data with datasets mock_runner_data = MagicMock() mock_runner_data.parameters_values = None mock_runner_data.datasets.bases = ["dataset-1"] + mock_runner_data.datasets.parameter = "d-123" - api = RunnerApi(configuration=mock_config) + api = RunnerApi(configuration=base_runner_config) api.get_runner = MagicMock(return_value=mock_runner_data) - # Should not try to download datasets when dataset_folder is None - api.download_runner_data("org-123", "ws-456", "runner-789", "/tmp/params", dataset_folder=None) + # Should not try to download datasets when download_datasets is None + api.download_runner_data(download_datasets=None) api.get_runner.assert_called_once_with("org-123", "ws-456", "runner-789") diff --git a/tests/unit/coal/test_postgresql/test_postgresql_runner.py b/tests/unit/coal/test_postgresql/test_postgresql_runner.py index 051ef445..d50ff91e 100644 --- a/tests/unit/coal/test_postgresql/test_postgresql_runner.py +++ b/tests/unit/coal/test_postgresql/test_postgresql_runner.py @@ -66,9 +66,7 @@ def test_send_runner_metadata_to_postgresql(self, mock_connect, mock_postgres_ut mock_runner_api_class.assert_called_once_with(mock_configuration) # Verify get_runner_metadata was called with correct parameters - mock_runner_api_instance.get_runner_metadata.assert_called_once_with( - "test-org", "test-workspace", "test-runner-id" - ) + mock_runner_api_instance.get_runner_metadata.assert_called_once_with("test-runner-id") # Verify PostgreSQL connection was established mock_connect.assert_called_once_with("postgresql://user:password@localhost:5432/testdb", autocommit=True) @@ -147,9 +145,7 @@ def test_remove_runner_metadata_to_postgresql(self, mock_connect, mock_postgres_ mock_runner_api_class.assert_called_once_with(mock_configuration) # Verify get_runner_metadata was called with correct parameters - mock_runner_api_instance.get_runner_metadata.assert_called_once_with( - "test-org", "test-workspace", "test-runner-id" - ) + mock_runner_api_instance.get_runner_metadata.assert_called_once_with("test-runner-id") # Check that PostgreSQL connection was established mock_connect.assert_called_once_with("postgresql://user:password@localhost:5432/testdb", autocommit=True)