diff --git a/src/aiida_shell/calculations/shell.py b/src/aiida_shell/calculations/shell.py index 078d200..da62e7d 100644 --- a/src/aiida_shell/calculations/shell.py +++ b/src/aiida_shell/calculations/shell.py @@ -10,7 +10,7 @@ from aiida.common.datastructures import CalcInfo, CodeInfo, FileCopyOperation from aiida.common.folders import Folder from aiida.engine import CalcJob, CalcJobProcessSpec -from aiida.orm import Data, Dict, FolderData, List, RemoteData, SinglefileData, to_aiida_type +from aiida.orm import Computer, Data, Dict, FolderData, List, RemoteData, SinglefileData, to_aiida_type from aiida.parsers import Parser from aiida_shell.data import EntryPointData, PickledData @@ -281,9 +281,11 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: inputs = {} nodes = inputs.get('nodes', {}) + computer = inputs['code'].computer filenames = (inputs.get('filenames', None) or Dict()).get_dict() arguments = (inputs.get('arguments', None) or List()).get_list() outputs = (inputs.get('outputs', None) or List()).get_list() + use_symlinks = inputs['metadata']['options']['use_symlinks'] filename_stdin = inputs['metadata']['options'].get('filename_stdin', None) filename_stdout = inputs['metadata']['options'].get('output_filename', None) default_retrieved_temporary = list(self.DEFAULT_RETRIEVED_TEMPORARY) @@ -300,7 +302,10 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: if filename_stdin and filename_stdin in processed_arguments: processed_arguments.remove(filename_stdin) - remote_copy_list, remote_symlink_list = self.handle_remote_data_nodes(inputs) + remote_data_nodes = {key: node for key, node in nodes.items() if isinstance(node, RemoteData)} + remote_copy_list, remote_symlink_list = self.handle_remote_data_nodes( + remote_data_nodes, filenames, computer, use_symlinks + ) code_info = CodeInfo() code_info.code_uuid = inputs['code'].uuid @@ -329,16 +334,22 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: return calc_info @staticmethod - def handle_remote_data_nodes(inputs: dict[str, Data]) -> tuple[list[t.Any], list[t.Any]]: - """Handle a ``RemoteData`` that was passed in the ``nodes`` input. + def handle_remote_data_nodes( + remote_data_nodes: dict[str, RemoteData], filenames: dict[str, str], computer: Computer, use_symlinks: bool + ) -> tuple[list[t.Any], list[t.Any]]: + """Handle all ``RemoteData`` nodes that were passed in the ``nodes`` input. - :param inputs: The inputs dictionary. + :param remote_data_nodes: The ``RemoteData`` input nodes. + :param filenames: A dictionary of explicit filenames to use for the ``nodes`` to be written to ``dirpath``. :returns: A tuple of two lists, the ``remote_copy_list`` and the ``remote_symlink_list``. """ - use_symlinks: bool = inputs['metadata']['options']['use_symlinks'] # type: ignore[index] - computer_uuid = inputs['code'].computer.uuid # type: ignore[union-attr] - remote_nodes = [node for node in inputs.get('nodes', {}).values() if isinstance(node, RemoteData)] - instructions = [(computer_uuid, f'{node.get_remote_path()}/*', '.') for node in remote_nodes] + instructions = [] + + for key, node in remote_data_nodes.items(): + if key in filenames: + instructions.append((computer.uuid, node.get_remote_path(), filenames[key])) + else: + instructions.append((computer.uuid, f'{node.get_remote_path()}/*', '.')) if use_symlinks: return [], instructions @@ -407,7 +418,10 @@ def process_arguments_and_nodes( self.write_folder_data(node, dirpath, filename) argument_interpolated = argument.format(**{placeholder: filename or placeholder}) elif isinstance(node, RemoteData): - self.handle_remote_data(node) + # Only the placeholder needs to be formatted. The content of the remote data itself is handled by the + # engine through the instructions created in ``handle_remote_data_nodes``. + filename = prepared_filenames[placeholder] + argument_interpolated = argument.format(**{placeholder: filename or placeholder}) else: argument_interpolated = argument.format(**{placeholder: str(node.value)}) @@ -465,6 +479,8 @@ def prepare_filenames(self, nodes: dict[str, SinglefileData], filenames: dict[st raise RuntimeError( f'node `{key}` contains the file `{f}` which overlaps with a reserved output filename.' ) + elif isinstance(node, RemoteData): + filename = filenames.get(key, None) else: continue diff --git a/tests/calculations/test_shell.py b/tests/calculations/test_shell.py index c5c88f1..c1b6531 100644 --- a/tests/calculations/test_shell.py +++ b/tests/calculations/test_shell.py @@ -88,7 +88,7 @@ def test_nodes_folder_data(generate_calc_job, generate_code, tmp_path): @pytest.mark.parametrize('use_symlinks', (True, False)) def test_nodes_remote_data(generate_calc_job, generate_code, tmp_path, aiida_localhost, use_symlinks): - """Test the ``nodes`` input with ``RemoteData`` nodes .""" + """Test the ``nodes`` input with ``RemoteData`` nodes.""" inputs = { 'code': generate_code(), 'arguments': [], @@ -107,6 +107,39 @@ def test_nodes_remote_data(generate_calc_job, generate_code, tmp_path, aiida_loc assert sorted(calc_info.remote_copy_list) == [(aiida_localhost.uuid, str(tmp_path / '*'), '.')] +def test_nodes_remote_data_filename(generate_calc_job, generate_code, tmp_path, aiida_localhost): + """Test the ``nodes`` and ``filenames`` inputs with ``RemoteData`` nodes.""" + remote_path_a = tmp_path / 'remote_a' + remote_path_b = tmp_path / 'remote_b' + remote_path_a.mkdir() + remote_path_b.mkdir() + (remote_path_a / 'file_a.txt').write_text('content a') + (remote_path_b / 'file_b.txt').write_text('content b') + remote_data_a = RemoteData(remote_path=str(remote_path_a.absolute()), computer=aiida_localhost) + remote_data_b = RemoteData(remote_path=str(remote_path_b.absolute()), computer=aiida_localhost) + + inputs = { + 'code': generate_code(), + 'arguments': ['{remote_a}'], + 'nodes': { + 'remote_a': remote_data_a, + 'remote_b': remote_data_b, + }, + 'filenames': {'remote_a': 'target_remote'}, + } + dirpath, calc_info = generate_calc_job('core.shell', inputs) + + code_info = calc_info.codes_info[0] + assert code_info.cmdline_params == ['target_remote'] + + assert calc_info.remote_symlink_list == [] + assert sorted(calc_info.remote_copy_list) == [ + (aiida_localhost.uuid, str(remote_path_a), 'target_remote'), + (aiida_localhost.uuid, str(remote_path_b / '*'), '.'), + ] + assert sorted(p.name for p in dirpath.iterdir()) == [] + + def test_nodes_base_types(generate_calc_job, generate_code): """Test the ``nodes`` input with ``BaseType`` nodes .""" inputs = { diff --git a/tests/conftest.py b/tests/conftest.py index 76190db..46a8ac0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,7 +58,7 @@ def factory(entry_point_name='core.shell', store_provenance=False, filepath_retr @pytest.fixture -def generate_calc_job(tmp_path): +def generate_calc_job(tmp_path_factory): """Create a :class:`aiida.engine.CalcJob` instance with the given inputs. The fixture will call ``prepare_for_submission`` and return a tuple of the temporary folder that was passed to it, @@ -81,6 +81,7 @@ def factory( which ensures that all input files are written, including those by the scheduler plugin, such as the submission script. """ + tmp_path = tmp_path_factory.mktemp('calc_job_submit_dir') manager = get_manager() runner = manager.get_runner() diff --git a/tests/test_launch.py b/tests/test_launch.py index 5a1edb2..dfbe236 100644 --- a/tests/test_launch.py +++ b/tests/test_launch.py @@ -164,6 +164,28 @@ def test_nodes_remote_data(tmp_path, aiida_localhost, use_symlinks): assert (dirpath_working / 'filled' / 'file_b.txt').read_text() == 'content b' +def test_nodes_remote_data_filename(tmp_path_factory, aiida_localhost): + """Test copying contents of a ``RemoteData`` to specific subdirectory.""" + dirpath_remote = tmp_path_factory.mktemp('remote') + dirpath_source = dirpath_remote / 'source' + dirpath_source.mkdir() + (dirpath_source / 'file.txt').touch() + remote_data = RemoteData(remote_path=str(dirpath_remote), computer=aiida_localhost) + + results, node = launch_shell_job( + 'echo', + arguments=['{remote}'], + nodes={'remote': remote_data}, + filenames={'remote': 'sub_directory'}, + ) + assert node.is_finished_ok + assert results['stdout'].get_content().strip() == 'sub_directory' + dirpath_working = pathlib.Path(node.outputs.remote_folder.get_remote_path()) + assert (dirpath_working / 'sub_directory').is_dir() + assert (dirpath_working / 'sub_directory' / 'source').is_dir() + assert (dirpath_working / 'sub_directory' / 'source' / 'file.txt').is_file() + + def test_nodes_base_types(): """Test a shellfunction that specifies positional CLI arguments that are interpolated by the ``kwargs``.""" nodes = {