diff --git a/.github/workflows/on-release-main.yml b/.github/workflows/on-release-main.yml index 3fec054..25bec42 100644 --- a/.github/workflows/on-release-main.yml +++ b/.github/workflows/on-release-main.yml @@ -3,7 +3,6 @@ name: release-main on: release: types: [published] - branches: [main] jobs: set-version: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 61656a4..79f1369 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -65,9 +65,11 @@ Please note this documentation assumes you already have `uv` and `Git` installed Then, install and activate the environment with: ```bash - uv sync + uv sync --all-extras ``` + Using `--all-extras` ensures that all optional dependencies, including those for generating test data (`pymeshup`), are installed. + 4. Install pre-commit to run linters/formatters at commit time: ```bash diff --git a/Justfile b/Justfile index 2deb7b5..240485f 100644 --- a/Justfile +++ b/Justfile @@ -122,7 +122,7 @@ generate-ship-rotation: @uv run python examples/defraction_box.py --output-dir examples --file-base boxship --only-base; exit 0 # Run fleetmaster examples -fleetmaster-all: fleetmaster-full fleetmaster-half +fleetmaster-all: fleetmaster-full fleetmaster-half fleetmaster-rotation fleetmaster-full: generate-box-mesh-full @fleetmaster -v run --settings-file examples/settings_full.yml --lid; exit 0 fleetmaster-half: generate-box-mesh-half @@ -130,6 +130,9 @@ fleetmaster-half: generate-box-mesh-half fleetmaster-rotation: generate-ship-rotation @fleetmaster -v run --settings-file examples/settings_rotations.yml; exit 0 +fitting-example: + @uv run python examples/fitting_example.py + # clean examples directory clean-examples: clean-examples-stl clean-examples-hdf5 # clean examples stl files diff --git a/ci.yml b/ci.yml new file mode 100644 index 0000000..e69de29 diff --git a/examples/defraction_box.py b/examples/defraction_box.py index be6c3ac..c8f6ec7 100644 --- a/examples/defraction_box.py +++ b/examples/defraction_box.py @@ -14,7 +14,10 @@ import argparse from pathlib import Path +import numpy as np +import trimesh from pymeshup import Box +from trimesh.transformations import compose_matrix # Constants for the box dimensions BOX_LENGTH = 10 @@ -25,7 +28,13 @@ FILE_BASE = "defraction_box" -def main(grid_symmetry: bool, output_dir: Path, file_base: str, only_base: bool = False): +def main( + grid_symmetry: bool, + output_dir: Path, + file_base: str, + only_base: bool = False, + generate_fitting_meshes: bool = False, +): """ Generates STL meshes for a defraction box based on specified parameters. @@ -35,6 +44,7 @@ def main(grid_symmetry: bool, output_dir: Path, file_base: str, only_base: bool output_dir (Path): The directory where the generated STL files will be saved. file_base (str): The base name for the generated STL files. only_base (bool): If True, only the base mesh will be generated. + generate_fitting_meshes (bool): If True, generates specific STL files for the fitting example. """ if grid_symmetry: print(f"Grid symmetry on with file base {file_base}") @@ -67,13 +77,68 @@ def main(grid_symmetry: bool, output_dir: Path, file_base: str, only_base: bool return for draft in DRAFTS: - box_draft = box_base.move(z=-draft) + # Start from the original buoy and move it, similar to how box_base was created. + box_draft = box_buoy.move(x=-half_length, z=-draft) box_draft = box_draft.cut_at_waterline() box_draft_mesh = box_draft.regrid(pct=REGRID_PERCENTAGE) box_draft_filename = output_dir / f"{file_base}_{draft}m.stl" print(f"Saving draft mesh {box_draft_filename}") box_draft_mesh.save(str(box_draft_filename)) + if generate_fitting_meshes: + generate_fitting_stl_files(output_dir, file_base) + + +def generate_fitting_stl_files(output_dir: Path, file_base: str): + """ + Generates specific rotated and translated STL meshes required by settings_rotations.yml + for the fitting example. These meshes are based on the full 'boxship.stl' and then transformed. + """ + base_stl_path = output_dir / f"{file_base}.stl" + if not base_stl_path.exists(): + print( + f"Error: Base mesh '{base_stl_path}' not found. " + "Please ensure it's generated first (e.g., by running with --only-base)." + ) + return + + print(f"\n--- Generating fitting example STL files based on '{base_stl_path.name}' ---") + loaded_mesh = trimesh.load(base_stl_path) + + # trimesh.load can return a Scene object or None. We need a single Trimesh object. + if isinstance(loaded_mesh, trimesh.Scene): + # Combine all geometries in the scene into a single mesh + base_mesh_untransformed = loaded_mesh.dump(concatenate=True) + else: + base_mesh_untransformed = loaded_mesh + + if not isinstance(base_mesh_untransformed, trimesh.Trimesh) or base_mesh_untransformed.is_empty: + print(f"Error: Failed to load a valid mesh from '{base_stl_path}'. Aborting fitting mesh generation.") + return + + # The `translation` in the original settings_rotations.yml is the desired position of the mesh's geometric center + # relative to the database origin. We bake this transformation directly into the STL. + fitting_cases = [ + ("boxship_t_1_r_00_00_00.stl", -1.0, 0.0, 0.0, 0.0), + ("boxship_t_2_r_00_00_00.stl", -2.0, 0.0, 0.0, 0.0), + ("boxship_t_1_r_45_00_00.stl", -1.0, 45.0, 0.0, 0.0), + ("boxship_t_1_r_00_10_00.stl", -1.0, 0.0, 10.0, 0.0), + ("boxship_t_1_r_20_20_00.stl", -1.0, 20.0, 20.0, 0.0), + ] + + for filename, target_z_rel_db_origin, roll_deg, pitch_deg, yaw_deg in fitting_cases: + # The absolute translation is the target position relative to the database origin. + translation_vec = [0.0, 0.0, target_z_rel_db_origin] + + transform_matrix = compose_matrix(angles=np.radians([roll_deg, pitch_deg, yaw_deg]), translate=translation_vec) + + transformed_mesh = base_mesh_untransformed.copy() + transformed_mesh.apply_transform(transform_matrix) + + output_path = output_dir / filename + transformed_mesh.export(str(output_path)) + print(f"Generated: {output_path.name}") + if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -97,7 +162,16 @@ def main(grid_symmetry: bool, output_dir: Path, file_base: str, only_base: bool action="store_true", help="Only generate the base mesh.", ) + parser.add_argument( + "--generate-fitting-meshes", + action="store_true", + help="Generate specific STL files for the fitting example based on settings_rotations.yml.", + ) args = parser.parse_args() main( - grid_symmetry=args.grid_symmetry, output_dir=args.output_dir, file_base=args.file_base, only_base=args.only_base + grid_symmetry=args.grid_symmetry, + output_dir=args.output_dir, + file_base=args.file_base, + only_base=args.only_base, + generate_fitting_meshes=args.generate_fitting_meshes, ) diff --git a/examples/fitting_example.py b/examples/fitting_example.py index 8ef839e..0840e7a 100644 --- a/examples/fitting_example.py +++ b/examples/fitting_example.py @@ -15,6 +15,38 @@ logger = logging.getLogger(__name__) +def _run_and_print_test_case( + case_number: int, + description: str, + hdf5_path: Path, + target_translation: list[float], + target_rotation: list[float], + water_level: float, + expected_match: str, + note: str = "", +): + """Runs a single fitting test case and prints the results.""" + print(f"\n\n--- Running Test Case {case_number}: {description} ---") + logger.info(f"Searching for best match for translation={target_translation}, rotation={target_rotation}...\n") + + best_match, distance = find_best_matching_mesh( + hdf5_path=hdf5_path, + target_translation=target_translation, + target_rotation=target_rotation, + water_level=water_level, + ) + + print(f"\n--- Result for Test Case {case_number} ---") + if best_match: + print(f"✅ Best match found: '{best_match}'") + print(f" - Minimized Chamfer Distance: {distance:.6f}") + print(f" - Expected match: '{expected_match}'") + if note: + print(f" - Note: {note}") + else: + print("❌ No match found.") + + def run_fitting_example(): """Runs the fitting example. @@ -32,85 +64,76 @@ def run_fitting_example(): # the meshes in the database were generated (wetted surface). water_level = 0.0 - draft = 2.0 - # --- Test Case 1: A transformation that should perfectly match an existing mesh --- - # We are looking for a mesh that corresponds to a Z-translation of -1.0, - # a roll of 20 degrees, and a pitch of 20 degrees. - # The database contains 'boxship_t_1_r_20_20_00.stl' with these exact parameters. - print("\n--- Running Test Case 1: Exact Match ---") - target_translation_1 = [0.0, 0.0, -draft] - target_rotation_1 = [20.0, 20.0, 0.0] # [roll, pitch, yaw] - - logger.info(f"Searching for best match for translation={target_translation_1}, rotation={target_rotation_1}...\n") - - best_match_1, distance_1 = find_best_matching_mesh( + _run_and_print_test_case( + case_number=1, + description="Exact Match Draft 1 meter", hdf5_path=hdf5_path, - target_translation=target_translation_1, - target_rotation=target_rotation_1, + target_translation=[0.0, 0.0, -1.0], + target_rotation=[20.0, 20.0, 0.0], water_level=water_level, + expected_match="boxship_t_1_r_20_20_00", ) - print("\n--- Result for Test Case 1 ---") - if best_match_1: - print(f"✅ Best match found: '{best_match_1}'") - print(f" - Minimized Chamfer Distance: {distance_1:.6f}") - print(" - Expected match: 'boxship_t_1_r_20_20_00'") - else: - print("❌ No match found.") - # --- Test Case 2: A transformation with irrelevant translations and rotations --- - # This case has the same core properties (Z-trans, X/Y-rot) as Case 1, - # but with added X/Y translation and a Z rotation (yaw). - # The optimization algorithm should ignore these and still find the same best match. - print("\n\n--- Running Test Case 2: Match with Noise ---") - target_translation_2 = [2.5, -4.2, -draft] # Added dx, dy - target_rotation_2 = [20.0, 20.0, 15.0] # Added yaw - - logger.info(f"Searching for best match for translation={target_translation_2}, rotation={target_rotation_2}...\n") - - best_match_2, distance_2 = find_best_matching_mesh( + _run_and_print_test_case( + case_number=2, + description="Match with Noise draft 1.0", hdf5_path=hdf5_path, - target_translation=target_translation_2, - target_rotation=target_rotation_2, + target_translation=[2.5, -4.2, -1.1], # Added dx, dy and dz + target_rotation=[20.0, 20.0, 15.0], # Added yaw water_level=water_level, + expected_match="boxship_t_1_r_20_20_00", + note="The distance should be very close to the distance in Case 1.", ) - print("\n--- Result for Test Case 2 ---") - if best_match_2: - print(f"✅ Best match found: '{best_match_2}'") - print(f" - Minimized Chamfer Distance: {distance_2:.6f}") - print(" - Expected match: 'boxship_t_1_r_20_20_00'") - print(" - Note: The distance should be very close to the distance in Case 1.") - else: - print("❌ No match found.") - - # --- Test Case 3: A transformation with irrelevant translations and rotations --- - # This case has the same core properties (Z-trans, X/Y-rot) as Case 1, - # but with added X/Y translation and a Z rotation (yaw). - # this time, difference valeus for the z-translation and x,y rotation are assumed. - # the expected result should give a larger distance - print("\n\n--- Running Test Case 3: Match with Noise ---") - target_translation_3 = [2.5, -4.2, -draft * 1.2] # Added dx, dy AND dz - target_rotation_3 = [23.0, 19.0, 15.0] # Added yaw AND roll and pitch + # --- Test Case 2: A transformation with irrelevant translations and rotations --- + _run_and_print_test_case( + case_number=3, + description="Different Match with Noise draft 1.0", + hdf5_path=hdf5_path, + target_translation=[2.5, -4.2, -1.1], # Added dx, dy AND dz + target_rotation=[23.0, 19.0, 15.0], # Added yaw AND roll and pitch + water_level=water_level, + expected_match="boxship_t_1_r_00_00_00", + note="The distance should be larger than both case 1 and case 2.", + ) - logger.info(f"Searching for best match for translation={target_translation_3}, rotation={target_rotation_3}...\n") + # --- Test Case 4: A transformation with different core properties --- + _run_and_print_test_case( + case_number=4, + description="Exact Match for draft 2.0", + hdf5_path=hdf5_path, + target_translation=[0.0, -0.0, -2], + target_rotation=[0.0, 0.0, 0.0], + water_level=water_level, + expected_match="boxship_t_2_r_00_00_00", + note="The distance should be zero.", + ) - best_match_3, distance_3 = find_best_matching_mesh( + # --- Test Case 5: A transformation with different core properties --- + _run_and_print_test_case( + case_number=5, + description="Exact Match for draft 2.0 with deviation in xy plane and yaw", hdf5_path=hdf5_path, - target_translation=target_translation_3, - target_rotation=target_rotation_3, + target_translation=[10.0, -20.0, -2], + target_rotation=[0.0, 0.0, 15.0], water_level=water_level, + expected_match="boxship_t_2_r_00_00_00", + note="The distance should be zero.", ) - print("\n--- Result for Test Case 3 ---") - if best_match_2: - print(f"✅ Best match found: '{best_match_3}'") - print(f" - Minimized Chamfer Distance: {distance_3:.6f}") - print(" - Expected match: 'boxship_t_1_r_20_20_00'") - print(" - Note: The distance should be larger than both case 1 and case 2.") - else: - print("❌ No match found.") + # --- Test Case 6: A transformation with different core properties --- + _run_and_print_test_case( + case_number=6, + description="Match for draft 2.0 with noise", + hdf5_path=hdf5_path, + target_translation=[10.0, -20.0, -2.2], + target_rotation=[4.0, -1.0, 15.0], + water_level=water_level, + expected_match="boxship_t_2_r_00_00_00", + note="The distance should be larger than zero.", + ) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 9f2758e..7867a2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -188,6 +188,9 @@ skip_empty = true [tool.coverage.run] branch = true source = ["src"] +omit = [ + "tests/*", +] # ======================================================================================== diff --git a/src/fleetmaster/commands/run.py b/src/fleetmaster/commands/run.py index 0a66562..eed9610 100644 --- a/src/fleetmaster/commands/run.py +++ b/src/fleetmaster/commands/run.py @@ -179,26 +179,30 @@ def _resolve_paths_in_config(config: dict[str, Any], settings_dir: Path) -> None """Resolves relative paths for 'base_mesh' and 'stl_files' in the config.""" # Resolve base_mesh path if config.get("base_mesh") and not Path(config["base_mesh"]).is_absolute(): - config["base_mesh"] = str(settings_dir / config["base_mesh"]) + config["base_mesh"] = str((settings_dir / config["base_mesh"]).resolve()) # Resolve stl_files paths - if not config.get("stl_files"): + if "stl_files" not in config: return - resolved_stl_files: list[str | dict[str, Any]] = [] + resolved_files: list[Any] = [] for item in config["stl_files"]: - path_str = item if isinstance(item, str) else item.get("file") - - if path_str and not Path(path_str).is_absolute(): - new_path = str(settings_dir / path_str) - if isinstance(item, str): - resolved_stl_files.append(new_path) - elif isinstance(item, dict): - item["file"] = new_path - resolved_stl_files.append(item) + if isinstance(item, str): + if not Path(item).is_absolute(): + resolved_files.append(str((settings_dir / item).resolve())) + else: + resolved_files.append(item) + elif isinstance(item, dict) and "file" in item: + if not Path(item["file"]).is_absolute(): + item["file"] = str((settings_dir / item["file"]).resolve()) + resolved_files.append(item) + elif isinstance(item, MeshConfig): + if not Path(item.file).is_absolute(): + item.file = str((settings_dir / item.file).resolve()) + resolved_files.append(item) else: - resolved_stl_files.append(item) - config["stl_files"] = resolved_stl_files + resolved_files.append(item) + config["stl_files"] = resolved_files def _load_config(settings_file: str | None, cli_args: dict[str, Any]) -> dict[str, Any]: @@ -241,16 +245,7 @@ def _load_and_validate_settings( base_mesh_path = config.get("base_mesh") if "stl_files" in config: - # Convert dicts to MeshConfig objects to satisfy mypy - new_stl_files: list[str | MeshConfig] = [] - for item in config["stl_files"]: - if isinstance(item, dict): - new_stl_files.append(MeshConfig(**item)) - else: - new_stl_files.append(str(item)) # it's a string - config["stl_files"] = new_stl_files - - all_files_in_config = [item if isinstance(item, str) else item.file for item in config["stl_files"]] + all_files_in_config = [item if isinstance(item, str) else item["file"] for item in config["stl_files"]] if not base_mesh_path and all_files_in_config: base_mesh_path = all_files_in_config[0] diff --git a/src/fleetmaster/core/engine.py b/src/fleetmaster/core/engine.py index 3b551b8..9fe7443 100644 --- a/src/fleetmaster/core/engine.py +++ b/src/fleetmaster/core/engine.py @@ -2,6 +2,7 @@ import logging import tempfile from dataclasses import dataclass +from itertools import product from pathlib import Path from typing import Any @@ -136,18 +137,17 @@ def _prepare_trimesh_geometry(stl_file: str, mesh_config: MeshConfig | None = No Returns: A trimesh.Trimesh object representing the transformed geometry. """ - transformed_mesh = trimesh.load_mesh(stl_file) + mesh = trimesh.load_mesh(stl_file) if mesh_config is None: - return transformed_mesh + return mesh - transformed_mesh = _apply_mesh_translation_and_rotation( - mesh=transformed_mesh, + return _apply_mesh_translation_and_rotation( + mesh=mesh, translation_vector=mesh_config.translation, rotation_vector_deg=mesh_config.rotation, cog=mesh_config.cog, ) - return transformed_mesh def _apply_mesh_translation_and_rotation( @@ -157,14 +157,8 @@ def _apply_mesh_translation_and_rotation( cog: npt.NDArray[np.float64] | list | None = None, ) -> trimesh.Trimesh: """Apply a translation and rotation to a mesh object.""" - if translation_vector is not None and isinstance(translation_vector, list): - translation_vector = np.array(translation_vector) - else: - translation_vector = np.zeros(3) - if rotation_vector_deg is not None and isinstance(rotation_vector_deg, list): - rotation_vector_deg = np.array(rotation_vector_deg) - else: - rotation_vector_deg = np.zeros(3) + translation_vector = np.asarray(translation_vector) if translation_vector is not None else np.zeros(3) + rotation_vector_deg = np.asarray(rotation_vector_deg) if rotation_vector_deg is not None else np.zeros(3) has_translation = np.any(translation_vector != 0) has_rotation = np.any(rotation_vector_deg != 0) @@ -173,7 +167,7 @@ def _apply_mesh_translation_and_rotation( return mesh # Start with an identity matrix (no transformation) - # The affine matrix is definets as: + # The affine matrix is defined as: # [ R R R T ] # [ R R R T ] # [ R R R T ] @@ -185,8 +179,7 @@ def _apply_mesh_translation_and_rotation( if has_rotation: # Determine the point of rotation if cog is not None: - if isinstance(cog, list): - rotation_point = np.array(cog) + rotation_point = np.asarray(cog) logger.debug(f"Using specified COG {rotation_point} as rotation point.") else: rotation_point = mesh.center_mass @@ -269,6 +262,10 @@ def _prepare_capytaine_body( boat = cpt.FloatingBody(mesh=hull_mesh, lid_mesh=lid_mesh, center_of_mass=cog) boat.keep_immersed_part(free_surface=water_level) + # Check for empty mesh after keep_immersed_part + if boat.mesh.vertices.size == 0 or boat.mesh.faces.size == 0: + logger.warning("Resulting mesh is empty after keep_immersed_part. Check if water_level is above the mesh.") + # Important: do this step after keep_immersed_part in order to keep the body constent with the cut mesh boat.add_all_rigid_body_dofs() @@ -375,6 +372,10 @@ def add_mesh_to_database( Args: mesh_to_add: The trimesh object of the mesh to be added. """ + if not isinstance(mesh_to_add, trimesh.Trimesh) or mesh_to_add.is_empty: + logger.warning(f"Attempted to add an empty or invalid mesh named '{mesh_name}' to the database. Skipping.") + return + mesh_group_path = f"{MESH_GROUP_NAME}/{mesh_name}" new_stl_content, new_hash = _get_mesh_hash(mesh_to_add) @@ -404,6 +405,72 @@ def _generate_case_group_name(mesh_name: str, water_depth: float, water_level: f return f"{mesh_name}_wd_{wd}_wl_{wl}_fs_{fs}" +def _load_or_generate_mesh(mesh_name: str, mesh_config: MeshConfig, settings: SimulationSettings) -> trimesh.Trimesh: + """ + Load a mesh from an STL file and apply transformations, or generate it if it doesn't exist. + + - If the STL file specified in `mesh_config.file` exists, it's loaded, and the transformations + (translation, rotation) from the `mesh_config` are applied. + - If the file does not exist, this function attempts to generate it by taking the `settings.base_mesh`, + applying the transformations from `mesh_config`, and saving the result to the path specified + in `mesh_config.file`. + """ + target_stl_path = Path(mesh_config.file) + + if target_stl_path.exists(): + logger.info(f"Found existing STL file: '{target_stl_path}'. Loading and applying transformations.") + # Load the existing STL and apply its specific transformations. + return _prepare_trimesh_geometry(stl_file=str(target_stl_path), mesh_config=mesh_config) + + # If the STL file does not exist, generate it from the base mesh. + logger.info(f"STL file not found at '{target_stl_path}'. Attempting to generate from base mesh.") + source_file_path = settings.base_mesh + if not source_file_path or not Path(source_file_path).exists(): + err_msg = ( + f"Cannot generate mesh '{mesh_name}'. The source file '{target_stl_path}' does not exist, " + f"and no valid 'base_mesh' ('{source_file_path}') is configured to generate it from." + ) + raise FileNotFoundError(err_msg) + + # Load the base STL, apply the specified transformations. + generated_mesh = _prepare_trimesh_geometry(str(source_file_path), mesh_config) + + # Save the newly generated, transformed mesh to the target path for future runs and inspection. + logger.info(f"Saving newly generated mesh to: {target_stl_path}") + target_stl_path.parent.mkdir(parents=True, exist_ok=True) + generated_mesh.export(target_stl_path) + + return generated_mesh + + +def _obtain_mesh( + mesh_name: str, mesh_config: MeshConfig, settings: SimulationSettings, output_file: Path +) -> trimesh.Trimesh: + """ + Obtains a mesh for processing, prioritizing the database cache. + + 1. If `overwrite_meshes` is False, it first attempts to load the mesh from the HDF5 database. + 2. If the mesh is not found in the database, or if `overwrite_meshes` is True, it falls back + to loading or generating the mesh from an STL file via `_load_or_generate_mesh`. + """ + # 1. Prioritize loading from the HDF5 database if overwrite_meshes is False + if not settings.overwrite_meshes: + try: + if existing_meshes := load_meshes_from_hdf5(output_file, [mesh_name]): + logger.info(f"Found existing mesh '{mesh_name}' in the database. Using it directly.") + return existing_meshes[0] + except FileNotFoundError: + # The HDF5 file doesn't exist yet, so no meshes can exist. This is expected on the first run. + pass + else: # This means overwrite_meshes is True + logger.info( + f"'overwrite_meshes' is True. Mesh '{mesh_name}' will be regenerated from its STL file and updated in the database." + ) + + # 2. If not in DB or if overwriting, load/generate from STL. + return _load_or_generate_mesh(mesh_name, mesh_config, settings) + + def _process_single_stl( mesh_config: MeshConfig, settings: SimulationSettings, @@ -414,52 +481,20 @@ def _process_single_stl( """ Checks if a mesh exists in the database. If so, uses it. If not, generates it, saves it, and then uses it for the simulation pipeline. - """ - mesh_name = mesh_name_override or Path(mesh_config.file).stem - final_mesh_to_process: trimesh.Trimesh | None = None - # --- Workflow to determine the mesh to process --- - if not settings.overwrite_meshes: - # 1. Prioritize loading from the HDF5 database if it already exists, only if we dont want - # to overwrite the meshes - try: - existing_meshes = load_meshes_from_hdf5(output_file, [mesh_name]) - if existing_meshes: - final_mesh_to_process = existing_meshes[0] - logger.info(f"Found existing mesh '{mesh_name}' in the database. Using it directly.") - except FileNotFoundError: - # The HDF5 file doesn't exist yet, so no meshes can exist. This is expected on the first run. - pass + Mesh selection priority: + - If a mesh exists in the database and overwrite_meshes is False, the database mesh is used. + - If overwrite_meshes is True, the mesh is regenerated from the STL file and replaces the database mesh. + - If no mesh exists in the database, the mesh is generated from the STL file and saved to the database. - # 2. If not in DB, check if a pre-translated STL file exists. - if final_mesh_to_process is None: - target_stl_path = Path(mesh_config.file) - if target_stl_path.exists(): - logger.info( - f"Mesh '{mesh_name}' not in DB, but found STL file: '{target_stl_path}'. Loading and adding to DB." - ) - # Load the existing, presumably pre-translated, STL file. - final_mesh_to_process = _prepare_trimesh_geometry(stl_file=str(target_stl_path)) - else: - # 3. If neither DB entry nor STL file exists, generate the mesh. - logger.info(f"Mesh '{mesh_name}' not found in DB or as STL file. Attempting to generate it.") - # Use the global base_mesh as the source for generation. - source_file_path = settings.base_mesh - if not source_file_path or not Path(source_file_path).exists(): - err_msg = ( - f"Cannot generate mesh '{mesh_name}'. The source file '{target_stl_path}' does not exist, " - f"and no valid 'base_mesh' ('{source_file_path}') is configured to generate it from." - ) - raise FileNotFoundError(err_msg) + This ensures that the database mesh is preferred unless the user explicitly requests to overwrite meshes. + """ + mesh_name = mesh_name_override or Path(mesh_config.file).stem - # Load the base STL and apply the specified translation. - translated_mesh = _prepare_trimesh_geometry(str(source_file_path), mesh_config) - # Save the newly generated, translated mesh to a separate STL file for inspection. - logger.info(f"Saving newly generated translated mesh to: {target_stl_path}") - translated_mesh.export(target_stl_path) - final_mesh_to_process = translated_mesh + # Obtain the mesh, either from the database or by loading/generating it. + final_mesh_to_process = _obtain_mesh(mesh_name, mesh_config, settings, output_file) - # 4. Run the complete processing pipeline with the determined mesh. + # Run the complete processing pipeline with the determined mesh. engine_mesh = EngineMesh(name=mesh_name, mesh=final_mesh_to_process, config=mesh_config) _run_pipeline_for_mesh(engine_mesh, settings, output_file, origin_translation) @@ -475,24 +510,27 @@ def _log_pipeline_parameters( forwards_speeds: list[float], ) -> None: """Logs all relevant parameters for a pipeline run for better traceability.""" - fmt_str = "%-40s: %s" - logger.info(fmt_str % ("Base STL file", engine_mesh.config.file)) - logger.info(fmt_str % ("Base STL vertices", engine_mesh.mesh.vertices.shape)) - logger.info(fmt_str % ("Output file", output_file)) - logger.info(fmt_str % ("Grid symmetry", settings.grid_symmetry)) - logger.info(fmt_str % ("Use lid", settings.lid)) - logger.info(fmt_str % ("Add COG ", settings.add_center_of_mass)) - logger.info(fmt_str % ("Direction(s) [rad]", wave_directions_rad)) - logger.info(fmt_str % ("Wave period(s) [s]", wave_periods)) - logger.info(fmt_str % ("Water depth(s) [m]", water_depths)) - logger.info(fmt_str % ("Water level(s) [m]", water_levels)) - logger.info(fmt_str % ("Translation X", engine_mesh.config.translation[0])) - logger.info(fmt_str % ("Translation Y", engine_mesh.config.translation[1])) - logger.info(fmt_str % ("Translation Z", engine_mesh.config.translation[2])) - logger.info(fmt_str % ("Rotation Roll [deg]", engine_mesh.config.rotation[0])) - logger.info(fmt_str % ("Rotation Pitch [deg]", engine_mesh.config.rotation[1])) - logger.info(fmt_str % ("Rotation Yaw [deg]", engine_mesh.config.rotation[2])) - logger.info(fmt_str % ("Forward speed(s) [m/s]", forwards_speeds)) + params = { + "Base STL file": engine_mesh.config.file, + "Base STL vertices": engine_mesh.mesh.vertices.shape, + "Output file": output_file, + "Grid symmetry": settings.grid_symmetry, + "Use lid": settings.lid, + "Add COG": settings.add_center_of_mass, + "Direction(s) [rad]": wave_directions_rad, + "Wave period(s) [s]": wave_periods, + "Water depth(s) [m]": water_depths, + "Water level(s) [m]": water_levels, + "Translation X": engine_mesh.config.translation[0], + "Translation Y": engine_mesh.config.translation[1], + "Translation Z": engine_mesh.config.translation[2], + "Rotation Roll [deg]": engine_mesh.config.rotation[0], + "Rotation Pitch [deg]": engine_mesh.config.rotation[1], + "Rotation Yaw [deg]": engine_mesh.config.rotation[2], + "Forward speed(s) [m/s]": forwards_speeds, + } + for key, val in params.items(): + logger.info(f"{key:<40}: {val}") def _run_pipeline_for_mesh( @@ -573,6 +611,7 @@ def _process_and_save_single_case( # Calculate the transformation matrix for this specific case relative to the global origin transformation_matrix = None if origin_translation is not None: + origin_translation = np.asarray(origin_translation) # The transformation is the translation from the global origin to the mesh's COG for this case. # Note: boat.center_of_mass is the COG used for calculation, not necessarily the geometric center. translation_vector = boat.center_of_mass - origin_translation @@ -637,36 +676,38 @@ def process_all_cases_for_one_stl( all_datasets = [] - for water_level in water_levels: - for water_depth in water_depths: - for forward_speed in forwards_speeds: - case_params = { - "omegas": wave_frequencies, - "wave_directions": wave_directions, - "water_level": water_level, - "water_depth": water_depth, - "forward_speed": forward_speed, - "update_cases": update_cases, - "combine_cases": combine_cases, - } - result_db = _process_and_save_single_case( - boat, engine_mesh.name, case_params, output_file, origin_translation - ) - if combine_cases and result_db is not None: - all_datasets.append(result_db) - - if combine_cases and all_datasets: - logger.info("Combining all calculated cases into a single multi-dimensional dataset.") - combined_dataset = xr.combine_by_coords(all_datasets, combine_attrs="drop_conflicts") - combined_group_name = f"{engine_mesh.name}_multi_dim" - - logger.info(f"Writing combined dataset to group '{combined_group_name}' in HDF5 file: {output_file}") - with h5py.File(output_file, "a") as f: - if combined_group_name in f: - del f[combined_group_name] - combined_dataset.to_netcdf(output_file, mode="a", group=combined_group_name, engine="h5netcdf") - with h5py.File(output_file, "a") as f: - f[combined_group_name].attrs["stl_mesh_name"] = engine_mesh.name + for water_level, water_depth, forward_speed in product(water_levels, water_depths, forwards_speeds): + case_params = { + "omegas": wave_frequencies, + "wave_directions": wave_directions, + "water_level": water_level, + "water_depth": water_depth, + "forward_speed": forward_speed, + "update_cases": update_cases, + "combine_cases": combine_cases, + } + result_db = _process_and_save_single_case(boat, engine_mesh.name, case_params, output_file, origin_translation) + if combine_cases and result_db is not None: + all_datasets.append(result_db) + + if combine_cases: + if all_datasets: + logger.info("Combining all calculated cases into a single multi-dimensional dataset.") + combined_dataset = xr.combine_by_coords(all_datasets, combine_attrs="drop_conflicts") + combined_group_name = f"{engine_mesh.name}_multi_dim" + + logger.info(f"Writing combined dataset to group '{combined_group_name}' in HDF5 file: {output_file}") + with h5py.File(output_file, "a") as f: + if combined_group_name in f: + del f[combined_group_name] + combined_dataset.to_netcdf(output_file, mode="a", group=combined_group_name, engine="h5netcdf") + with h5py.File(output_file, "a") as f: + f[combined_group_name].attrs["stl_mesh_name"] = engine_mesh.name + else: + logger.warning( + "The 'combine_cases' option is enabled, but no datasets were generated to combine. " + "This can happen if all cases were already present in the output file and 'update_cases' was false." + ) logger.debug(f"Successfully wrote all data for mesh '{engine_mesh.name}' to HDF5.") diff --git a/src/fleetmaster/core/exceptions.py b/src/fleetmaster/core/exceptions.py index 69ec184..d7e61fd 100644 --- a/src/fleetmaster/core/exceptions.py +++ b/src/fleetmaster/core/exceptions.py @@ -31,6 +31,13 @@ def __init__(self, message: str = "Periods must be larger than 0.") -> None: super().__init__(message) +class InvalidVectorLength(SimulationConfigurationError): + """Raised when a vector has an invalid length.""" + + def __init__(self, message: str = "Invalid vector length") -> None: + super().__init__(message) + + class HDF5AttributeError(ValueError): """Raised when a required attribute is missing from an HDF5 file.""" diff --git a/src/fleetmaster/core/fitting.py b/src/fleetmaster/core/fitting.py index c2e8a3f..85da138 100644 --- a/src/fleetmaster/core/fitting.py +++ b/src/fleetmaster/core/fitting.py @@ -94,9 +94,19 @@ def _find_best_fit_for_candidates( distances[name] = np.inf continue - # Combine transformations: - # - XY translation from candidate, Z translation from target. - # - XY rotation (roll, pitch) from target, Z rotation (yaw) from candidate. + # The goal of the fitting is to find a mesh from the database that best matches + # the target's submerged shape, which is primarily determined by Z-translation (draft) + # and X/Y-rotations (roll, pitch). The database contains meshes with varying roll and pitch, + # but typically constant XY translation and Z-rotation (yaw). + # + # To find the best match, we create a hybrid transformation that respects these assumptions: + # - We use the target's Z-translation (draft) because that's a key property we're matching. + # - We use the target's roll and pitch for the same reason. + # - We take the candidate's XY-translation and yaw, because these are considered irrelevant + # for the shape matching and are constant in the database generation process. + # + # This allows us to transform the base mesh into a shape that is directly comparable + # with the candidate's wetted surface. new_translation = [ candidate_translation[0], candidate_translation[1], diff --git a/src/fleetmaster/core/settings.py b/src/fleetmaster/core/settings.py index 86648dc..8fa14da 100644 --- a/src/fleetmaster/core/settings.py +++ b/src/fleetmaster/core/settings.py @@ -1,7 +1,10 @@ +from typing import Any + import numpy as np from pydantic import BaseModel, Field, field_validator, model_validator from fleetmaster.core.exceptions import ( + InvalidVectorLength, LidAndSymmetryEnabledError, NegativeForwardSpeedError, NonPositivePeriodError, @@ -14,6 +17,9 @@ class MeshConfig(BaseModel): """Configuration for a single mesh, including its path and transformation.""" file: str + name: str | None = Field( + default=None, description="An optional name for the mesh. If not provided, it's derived from the file name." + ) translation: list[float] = Field(default_factory=lambda: [0.0, 0.0, 0.0]) rotation: list[float] = Field( default_factory=lambda: [0.0, 0.0, 0.0], description="Rotation [roll, pitch, yaw] in degrees." @@ -28,6 +34,20 @@ class MeshConfig(BaseModel): default=None, description="Mesh-specific wave directions in degrees. Overrides global settings." ) + @field_validator("translation", "rotation") + def check_vector_length(cls, v: list[float]) -> list[float]: + if len(v) != 3: + msg = "Translation and rotation must a of length 3" + raise InvalidVectorLength(msg) + return v + + @field_validator("cog") + def check_cog_length(cls, v: list[float] | None) -> list[float] | None: + if v is not None and len(v) != 3: + msg = "Cog must be a list of 3 floats or None" + raise InvalidVectorLength(msg) + return v + class SimulationSettings(BaseModel): """Defines all possible settings for a simulation. @@ -43,14 +63,11 @@ class SimulationSettings(BaseModel): default=None, description="A point [x, y, z] in the local coordinate system of the base_mesh that defines the world origin.", ) - stl_files: list[str | MeshConfig] = Field(description="A list of STL mesh files or mesh configurations.") + stl_files: list[MeshConfig] = Field(description="A list of STL mesh files or mesh configurations.") output_directory: str | None = Field(default=None, description="Directory to save the output files.") output_hdf5_file: str = Field(default="results.hdf5", description="Path to the HDF5 output file.") wave_periods: float | list[float] = Field(default=[5.0, 10.0, 15.0, 20.0]) wave_directions: float | list[float] = Field(default=[0.0, 45.0, 90.0, 135.0, 180.0]) - translation_x: float = Field(default=0.0, description="Translation in X-direction to apply to the mesh.") - translation_y: float = Field(default=0.0, description="Translation in Y-direction to apply to the mesh.") - translation_z: float = Field(default=0.0, description="Translation in Z-direction to apply to the mesh.") forward_speed: float | list[float] = 0.0 lid: bool = False add_center_of_mass: bool = False @@ -66,6 +83,12 @@ class SimulationSettings(BaseModel): default=False, description="Combine all calculated cases for a single STL into one multi-dimensional dataset." ) + @field_validator("stl_files", mode="before") + def normalize_stl_files(cls, v: Any) -> Any: + if not isinstance(v, list): + return v + return [MeshConfig(file=item) if isinstance(item, str) else item for item in v] + # field validator checks the value of one specific field inmediately @field_validator("forward_speed") def speed_must_be_non_negative(cls, v: float | list[float]) -> float | list[float]: diff --git a/tests/test_engine.py b/tests/test_engine.py index de16e02..aefe29d 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -66,6 +66,23 @@ def test_setup_output_file_no_overwrite(tmp_path, mock_settings): assert output_file.exists() # Should NOT have been deleted +def test_setup_output_file_with_dict_as_stl(tmp_path): + """Test _setup_output_file with stl_files as a list of dicts.""" + settings = MagicMock(spec=SimulationSettings) + settings.stl_files = [{"file": str(tmp_path / "test.stl")}] + settings.output_directory = None + settings.output_hdf5_file = "results.hdf5" + settings.overwrite_meshes = False + + # Act + result_path = _setup_output_file(settings) + + # Assert + # If no output dir is specified, it should be the parent of the first STL + assert result_path == tmp_path / "results.hdf5" + assert result_path.parent.exists() + + @patch("fleetmaster.core.engine.cpt") @patch("fleetmaster.core.engine.tempfile") def test_prepare_capytaine_body(mock_tempfile, mock_cpt, tmp_path: Path): @@ -230,6 +247,15 @@ def test_run_simulation_batch_standard(mock_setup, mock_process, mock_prepare, m mock_mesh = MagicMock(spec=trimesh.Trimesh) mock_mesh.export.return_value = b"dummy stl content" + + # --- FIX HIER TOEGEVOEGD --- + # Configureer de mock om waarden terug te geven die de productiecode verwacht + mock_mesh.moment_inertia = np.eye(3) + mock_mesh.volume = 1.0 # Dit voorkomt de TypeError: MagicMock > float + mock_mesh.center_mass = [0, 0, 0] + mock_mesh.bounding_box.extents = [1, 1, 1] + # --- EINDE FIX --- + mock_prepare.return_value = mock_mesh run_simulation_batch(mock_settings) @@ -246,6 +272,15 @@ def test_run_simulation_batch_drafts(mock_setup, mock_process, mock_prepare, moc mock_setup.return_value = tmp_path / "output.hdf5" mock_mesh = MagicMock(spec=trimesh.Trimesh) mock_mesh.export.return_value = b"dummy stl content" + + # --- FIX HIER TOEGEVOEGD --- + # Configureer de mock om waarden terug te geven die de productiecode verwacht + mock_mesh.moment_inertia = np.eye(3) + mock_mesh.volume = 1.0 # Dit voorkomt de TypeError: MagicMock > float + mock_mesh.center_mass = [0, 0, 0] + mock_mesh.bounding_box.extents = [1, 1, 1] + # --- EINDE FIX --- + mock_prepare.return_value = mock_mesh mock_settings.stl_files = [MeshConfig(file="base_mesh.stl", translation=[0, 0, 5])] @@ -281,3 +316,185 @@ def test_run_simulation_batch_drafts_wrong_stl_count(mock_prepare, mock_settings mock_settings.stl_files = [MeshConfig(file="file1.stl"), MeshConfig(file="file2.stl")] with pytest.raises(ValueError, match="exactly one base STL file must be provided"): run_simulation_batch(mock_settings) + + +@patch("h5py.File") +def test_add_mesh_to_database_overwrite_true_different_content(mock_h5py_file, caplog): + """Test that with overwrite=True, a different mesh with the same name is replaced.""" + # Arrange + mock_file = MagicMock() + mock_h5py_file.return_value.__enter__.return_value = mock_file + + mock_group = MagicMock() + mock_group.name = "/meshes/test_mesh" + mock_file.__contains__.return_value = True + mock_file.__getitem__.return_value = mock_group + mock_group.attrs.get.return_value = "old_hash" + + output_file = Path("dummy.h5") + mesh_config = MeshConfig(file="test.stl", name="test_mesh") + new_mesh = trimesh.creation.box() + + # Act + add_mesh_to_database(output_file, new_mesh, "test_mesh", overwrite=True, mesh_config=mesh_config) + + # Assert + mock_file.__delitem__.assert_called_once_with("meshes/test_mesh") + assert "Overwriting existing mesh" in caplog.text + mock_file.create_group.assert_called_with("meshes/test_mesh") + + +@patch("fleetmaster.core.engine._process_single_stl") +@patch("fleetmaster.core.engine._setup_output_file") +def test_run_simulation_batch_with_mesh_config_dict(mock_setup, mock_process, tmp_path): + """Test run_simulation_batch with stl_files as a list of dicts.""" + settings_dict = { + "stl_files": [{"file": "box.stl", "translation": [1, 2, 3]}], + "output_directory": str(tmp_path), + "output_hdf5_file": "results.h5", + "overwrite_meshes": True, + "drafts": [], + "base_mesh": None, + "base_origin": None, + "wave_periods": [10.0], + "wave_directions": [0.0], + "water_depth": 100.0, + "water_level": 0.0, + "forward_speed": 0.0, + "grid_symmetry": False, + "lid": False, + "add_center_of_mass": False, + "update_cases": False, + "combine_cases": False, + } + + stl_file = tmp_path / "box.stl" + stl_file.touch() + settings_dict["stl_files"][0]["file"] = str(stl_file) + mock_setup.return_value = tmp_path / "results.h5" + + settings = SimulationSettings(**settings_dict) + + run_simulation_batch(settings) + + mock_process.assert_called_once() + called_mesh_config = mock_process.call_args[0][0] + assert isinstance(called_mesh_config, MeshConfig) + assert called_mesh_config.file == str(stl_file) + assert called_mesh_config.translation == [1, 2, 3] + + +@patch("fleetmaster.core.engine._process_single_stl") +@patch("h5py.File") +@patch("fleetmaster.core.engine._setup_output_file") +@patch("fleetmaster.core.engine._prepare_trimesh_geometry") +def test_run_simulation_batch_with_base_origin(mock_prepare, mock_setup, mock_h5py, mock_process_stl, tmp_path): + """Test run_simulation_batch with base_origin set.""" + output_dir = tmp_path / "output" + output_dir.mkdir() + stl_path = tmp_path / "test.stl" + stl_path.touch() + mock_setup.return_value = output_dir / "results.h5" + + mock_prepare.return_value = trimesh.creation.box() + + settings = SimulationSettings( + stl_files=[{"file": str(stl_path)}], + base_mesh=str(stl_path), + base_origin=[1.0, 2.0, 3.0], + output_directory=str(output_dir), + wave_periods=[10.0], + wave_directions=[0.0], + water_depth=100.0, + water_level=0.0, + forward_speed=0.0, + grid_symmetry=False, + lid=False, + add_center_of_mass=False, + update_cases=False, + combine_cases=False, + ) + + mock_file = MagicMock() + mock_h5py.return_value.__enter__.return_value = mock_file + + run_simulation_batch(settings) + + # Check that origin_translation is passed correctly + mock_process_stl.assert_called_once() + kwargs = mock_process_stl.call_args.kwargs + assert "origin_translation" in kwargs + np.testing.assert_array_equal(kwargs["origin_translation"], np.array([1.0, 2.0, 3.0])) + + # Check that base_origin is saved to HDF5 file attributes + mock_file.attrs.__setitem__.assert_any_call("base_origin", [1.0, 2.0, 3.0]) + + +@patch("h5py.File") +def test_add_mesh_to_database_with_meshconfig_attrs(mock_h5py_file, tmp_path): + """Test that MeshConfig attributes are saved to the HDF5 group.""" + # Arrange + mock_file = MagicMock() + mock_group = MagicMock() + mock_h5py_file.return_value.__enter__.return_value = mock_file + mock_file.create_group.return_value = mock_group + + output_file = tmp_path / "db.h5" + mesh_config = MeshConfig( + file="test.stl", + name="test_mesh", + translation=[1, 1, 1], + rotation=[2, 2, 2], + cog=[3, 3, 3], + ) + mock_mesh = MagicMock(spec=trimesh.Trimesh) + mock_mesh.export.return_value = b"content" + mock_mesh.moment_inertia = np.eye(3) + mock_mesh.volume = 1.0 + mock_mesh.center_mass = [0, 0, 0] + mock_mesh.bounding_box.extents = [1, 1, 1] + mock_mesh.is_empty = False # Ensure the mock mesh is not considered empty + + # Act + add_mesh_to_database(output_file, mock_mesh, "test_mesh", mesh_config=mesh_config) + + # Assert + mock_file.create_group.assert_called_once_with("meshes/test_mesh") + mock_group.attrs.__setitem__.assert_any_call("translation", [1, 1, 1]) + mock_group.attrs.__setitem__.assert_any_call("rotation", [2, 2, 2]) + mock_group.attrs.__setitem__.assert_any_call("cog", [3, 3, 3]) + + +@patch("fleetmaster.core.engine._load_or_generate_mesh") +@patch("fleetmaster.core.engine.load_meshes_from_hdf5") +@patch("fleetmaster.core.engine._run_pipeline_for_mesh") +def test_process_single_stl_from_db(mock_run_pipeline, mock_load_meshes, mock_load_or_generate, tmp_path): + """Test _process_single_stl when the mesh is loaded from the database.""" + # Arrange + settings = MagicMock(spec=SimulationSettings) + settings.overwrite_meshes = False + + # The name of the mesh is derived from the file name stem + mesh_config = MeshConfig(file="dummy_name.stl") + output_file = tmp_path / "db.h5" + + mock_mesh_from_db = MagicMock(spec=trimesh.Trimesh) + mock_load_meshes.return_value = [mock_mesh_from_db] + + # Act + _process_single_stl(mesh_config, settings, output_file) + + # Assert + # Should attempt to load from DB first + mock_load_meshes.assert_called_once_with(output_file, ["dummy_name"]) + + # Should NOT try to load or generate a new mesh because it was found in the DB + mock_load_or_generate.assert_not_called() + + # Should run the pipeline with the mesh from the DB + mock_run_pipeline.assert_called_once() + engine_mesh_arg = mock_run_pipeline.call_args[0][0] + assert isinstance(engine_mesh_arg, EngineMesh) + assert engine_mesh_arg.mesh == mock_mesh_from_db + assert engine_mesh_arg.name == "dummy_name" + assert engine_mesh_arg.config == mesh_config diff --git a/tests/test_fitting.py b/tests/test_fitting.py new file mode 100644 index 0000000..1e8b534 --- /dev/null +++ b/tests/test_fitting.py @@ -0,0 +1,111 @@ +"""Unit tests for the mesh fitting functionality.""" + +import logging +from pathlib import Path + +import pytest + +from fleetmaster.core.fitting import find_best_matching_mesh + +# Configure basic logging +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="module") +def hdf5_path() -> Path: + """ + Provides the path to the HDF5 database file and skips tests if it doesn't exist. + + This fixture ensures that tests depending on the pre-generated database + are only run when the file is available. + """ + # The HDF5 file is expected to be in the 'examples' directory, relative to the project root. + # Assuming tests are run from the project root. + path = Path("examples/boxship.hdf5") + if not path.exists(): + pytest.skip( + f"Database file not found at: {path.resolve()}. " + "Run 'uv run python examples/defraction_box.py --output-dir examples --file-base boxship --generate-fitting-meshes' " + "followed by 'fleetmaster -v run --settings-file examples/settings_rotations.yml' to generate it." + ) + return path + + +# Define test cases using pytest.mark.parametrize +# Each tuple represents a test case: +# (case_description, target_translation, target_rotation, water_level, expected_match, expected_distance_check) +# The `expected_distance_check` is a lambda function to validate the distance. +TEST_CASES = [ + ( + "Case 1: Exact Match Draft 1 meter", + [0.0, 0.0, -1.0], + [20.0, 20.0, 0.0], + 0.0, + "boxship_t_1_r_20_20_00", + lambda dist: dist < 0.41, # Exact matches are not zero due to regridding, so we use a small threshold. + ), + ( + "Case 2: Match with irrelevant translation/rotation noise (draft 1.0)", + [2.5, -4.2, -1.0], # dx, dy noise + [20.0, 20.0, 15.0], # yaw noise + 0.0, + "boxship_t_1_r_20_20_00", + lambda dist: dist < 0.41, # Distance should still be very small as the shape is identical. + ), + ( + "Case 3: Different match due to significant rotation deviation (draft 1.0)", + [2.5, -4.2, -1.1], + [23.0, 19.0, 15.0], # Deviations in roll and pitch + 0.0, + "boxship_t_1_r_20_20_00", # This is still the closest match, even with noise + lambda dist: 0.41 < dist < 0.5, # Distance should be clearly non-zero and larger than the threshold. + ), + ( + "Case 4: Exact Match for draft 2.0", + [0.0, 0.0, -2.0], + [0.0, 0.0, 0.0], + 0.0, + "boxship_t_2_r_00_00_00", + lambda dist: dist < 0.1, + ), + ( + "Case 5: Exact Match for draft 2.0 with irrelevant xy-plane and yaw deviation", + [10.0, -20.0, -2.0], + [0.0, 0.0, 15.0], + 0.0, + "boxship_t_2_r_00_00_00", + lambda dist: dist < 0.1, + ), + ( + "Case 6: Match for draft 2.0 with noise in all axes", + [10.0, -20.0, -2.2], + [4.0, -1.0, 15.0], + 0.0, + "boxship_t_2_r_00_00_00", + lambda dist: 0.1 < dist < 0.2, # Distance should be clearly non-zero and larger than the threshold. + ), +] + + +@pytest.mark.parametrize( + "description, target_translation, target_rotation, water_level, expected_match, distance_check", + TEST_CASES, + ids=[case[0] for case in TEST_CASES], +) +def test_find_best_matching_mesh( + hdf5_path: Path, + description: str, + target_translation: list[float], + target_rotation: list[float], + water_level: float, + expected_match: str, + distance_check, +): + """Tests the find_best_matching_mesh function with various scenarios.""" + logger.info(f"Running test: {description}") + best_match, distance = find_best_matching_mesh(hdf5_path, target_translation, target_rotation, water_level) + + assert best_match is not None, "A best match should have been found." + assert best_match == expected_match + assert distance_check(distance), f"Distance check failed for {description}. Got distance: {distance}"