Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions src/aiida_shell/calculations/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aiida.common.datastructures import CalcInfo, CodeInfo, FileCopyOperation
from aiida.common.folders import Folder
from aiida.engine import CalcJob, CalcJobProcessSpec
from aiida.orm import Data, Dict, FolderData, List, RemoteData, SinglefileData, to_aiida_type
from aiida.orm import Computer, Data, Dict, FolderData, List, RemoteData, SinglefileData, to_aiida_type
from aiida.parsers import Parser

from aiida_shell.data import EntryPointData, PickledData
Expand Down Expand Up @@ -281,9 +281,11 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
inputs = {}

nodes = inputs.get('nodes', {})
computer = inputs['code'].computer
filenames = (inputs.get('filenames', None) or Dict()).get_dict()
arguments = (inputs.get('arguments', None) or List()).get_list()
outputs = (inputs.get('outputs', None) or List()).get_list()
use_symlinks = inputs['metadata']['options']['use_symlinks']
filename_stdin = inputs['metadata']['options'].get('filename_stdin', None)
filename_stdout = inputs['metadata']['options'].get('output_filename', None)
default_retrieved_temporary = list(self.DEFAULT_RETRIEVED_TEMPORARY)
Expand All @@ -300,7 +302,10 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
if filename_stdin and filename_stdin in processed_arguments:
processed_arguments.remove(filename_stdin)

remote_copy_list, remote_symlink_list = self.handle_remote_data_nodes(inputs)
remote_data_nodes = {key: node for key, node in nodes.items() if isinstance(node, RemoteData)}
remote_copy_list, remote_symlink_list = self.handle_remote_data_nodes(
remote_data_nodes, filenames, computer, use_symlinks
)

code_info = CodeInfo()
code_info.code_uuid = inputs['code'].uuid
Expand Down Expand Up @@ -329,16 +334,22 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
return calc_info

@staticmethod
def handle_remote_data_nodes(inputs: dict[str, Data]) -> tuple[list[t.Any], list[t.Any]]:
"""Handle a ``RemoteData`` that was passed in the ``nodes`` input.
def handle_remote_data_nodes(
remote_data_nodes: dict[str, RemoteData], filenames: dict[str, str], computer: Computer, use_symlinks: bool
) -> tuple[list[t.Any], list[t.Any]]:
"""Handle all ``RemoteData`` nodes that were passed in the ``nodes`` input.

:param inputs: The inputs dictionary.
:param remote_data_nodes: The ``RemoteData`` input nodes.
:param filenames: A dictionary of explicit filenames to use for the ``nodes`` to be written to ``dirpath``.
:returns: A tuple of two lists, the ``remote_copy_list`` and the ``remote_symlink_list``.
"""
use_symlinks: bool = inputs['metadata']['options']['use_symlinks'] # type: ignore[index]
computer_uuid = inputs['code'].computer.uuid # type: ignore[union-attr]
remote_nodes = [node for node in inputs.get('nodes', {}).values() if isinstance(node, RemoteData)]
instructions = [(computer_uuid, f'{node.get_remote_path()}/*', '.') for node in remote_nodes]
instructions = []

for key, node in remote_data_nodes.items():
if key in filenames:
instructions.append((computer.uuid, node.get_remote_path(), filenames[key]))
else:
instructions.append((computer.uuid, f'{node.get_remote_path()}/*', '.'))

if use_symlinks:
return [], instructions
Expand Down Expand Up @@ -407,7 +418,10 @@ def process_arguments_and_nodes(
self.write_folder_data(node, dirpath, filename)
argument_interpolated = argument.format(**{placeholder: filename or placeholder})
elif isinstance(node, RemoteData):
self.handle_remote_data(node)
# Only the placeholder needs to be formatted. The content of the remote data itself is handled by the
# engine through the instructions created in ``handle_remote_data_nodes``.
filename = prepared_filenames[placeholder]
argument_interpolated = argument.format(**{placeholder: filename or placeholder})
else:
argument_interpolated = argument.format(**{placeholder: str(node.value)})

Expand Down Expand Up @@ -465,6 +479,8 @@ def prepare_filenames(self, nodes: dict[str, SinglefileData], filenames: dict[st
raise RuntimeError(
f'node `{key}` contains the file `{f}` which overlaps with a reserved output filename.'
)
elif isinstance(node, RemoteData):
filename = filenames.get(key, None)
else:
continue

Expand Down
35 changes: 34 additions & 1 deletion tests/calculations/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_nodes_folder_data(generate_calc_job, generate_code, tmp_path):

@pytest.mark.parametrize('use_symlinks', (True, False))
def test_nodes_remote_data(generate_calc_job, generate_code, tmp_path, aiida_localhost, use_symlinks):
"""Test the ``nodes`` input with ``RemoteData`` nodes ."""
"""Test the ``nodes`` input with ``RemoteData`` nodes."""
inputs = {
'code': generate_code(),
'arguments': [],
Expand All @@ -107,6 +107,39 @@ def test_nodes_remote_data(generate_calc_job, generate_code, tmp_path, aiida_loc
assert sorted(calc_info.remote_copy_list) == [(aiida_localhost.uuid, str(tmp_path / '*'), '.')]


def test_nodes_remote_data_filename(generate_calc_job, generate_code, tmp_path, aiida_localhost):
"""Test the ``nodes`` and ``filenames`` inputs with ``RemoteData`` nodes."""
remote_path_a = tmp_path / 'remote_a'
remote_path_b = tmp_path / 'remote_b'
remote_path_a.mkdir()
remote_path_b.mkdir()
(remote_path_a / 'file_a.txt').write_text('content a')
(remote_path_b / 'file_b.txt').write_text('content b')
remote_data_a = RemoteData(remote_path=str(remote_path_a.absolute()), computer=aiida_localhost)
remote_data_b = RemoteData(remote_path=str(remote_path_b.absolute()), computer=aiida_localhost)

inputs = {
'code': generate_code(),
'arguments': ['{remote_a}'],
'nodes': {
'remote_a': remote_data_a,
'remote_b': remote_data_b,
},
'filenames': {'remote_a': 'target_remote'},
}
dirpath, calc_info = generate_calc_job('core.shell', inputs)

code_info = calc_info.codes_info[0]
assert code_info.cmdline_params == ['target_remote']

assert calc_info.remote_symlink_list == []
assert sorted(calc_info.remote_copy_list) == [
(aiida_localhost.uuid, str(remote_path_a), 'target_remote'),
(aiida_localhost.uuid, str(remote_path_b / '*'), '.'),
]
assert sorted(p.name for p in dirpath.iterdir()) == []


def test_nodes_base_types(generate_calc_job, generate_code):
"""Test the ``nodes`` input with ``BaseType`` nodes ."""
inputs = {
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def factory(entry_point_name='core.shell', store_provenance=False, filepath_retr


@pytest.fixture
def generate_calc_job(tmp_path):
def generate_calc_job(tmp_path_factory):
"""Create a :class:`aiida.engine.CalcJob` instance with the given inputs.

The fixture will call ``prepare_for_submission`` and return a tuple of the temporary folder that was passed to it,
Expand All @@ -81,6 +81,7 @@ def factory(
which ensures that all input files are written, including those by the scheduler plugin, such as the
submission script.
"""
tmp_path = tmp_path_factory.mktemp('calc_job_submit_dir')
manager = get_manager()
runner = manager.get_runner()

Expand Down
22 changes: 22 additions & 0 deletions tests/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,28 @@ def test_nodes_remote_data(tmp_path, aiida_localhost, use_symlinks):
assert (dirpath_working / 'filled' / 'file_b.txt').read_text() == 'content b'


def test_nodes_remote_data_filename(tmp_path_factory, aiida_localhost):
"""Test copying contents of a ``RemoteData`` to specific subdirectory."""
dirpath_remote = tmp_path_factory.mktemp('remote')
dirpath_source = dirpath_remote / 'source'
dirpath_source.mkdir()
(dirpath_source / 'file.txt').touch()
remote_data = RemoteData(remote_path=str(dirpath_remote), computer=aiida_localhost)

results, node = launch_shell_job(
'echo',
arguments=['{remote}'],
nodes={'remote': remote_data},
filenames={'remote': 'sub_directory'},
)
assert node.is_finished_ok
assert results['stdout'].get_content().strip() == 'sub_directory'
dirpath_working = pathlib.Path(node.outputs.remote_folder.get_remote_path())
assert (dirpath_working / 'sub_directory').is_dir()
assert (dirpath_working / 'sub_directory' / 'source').is_dir()
assert (dirpath_working / 'sub_directory' / 'source' / 'file.txt').is_file()


def test_nodes_base_types():
"""Test a shellfunction that specifies positional CLI arguments that are interpolated by the ``kwargs``."""
nodes = {
Expand Down
Loading