diff --git a/beneuro_pose_estimation/anipose/aniposeTools.py b/beneuro_pose_estimation/anipose/aniposeTools.py index 6df6404..819f451 100644 --- a/beneuro_pose_estimation/anipose/aniposeTools.py +++ b/beneuro_pose_estimation/anipose/aniposeTools.py @@ -16,6 +16,7 @@ import matplotlib as plt import beneuro_pose_estimation.sleap.sleapTools as sleapTools +import beneuro_pose_estimation.tools as tools from beneuro_pose_estimation import params if not logging.getLogger().hasHandlers(): @@ -693,9 +694,7 @@ def run_pose_estimation( # get calibration file - toml file saved in calibration/calibration.toml if session.split("_")[1] == "2025": # TODO: set the condition so that sessions after 3rd of february 2025 use this - recent_calib_folder = config.REMOTE_PATH/ "raw"/ "pose-estimation"/ "calibration-videos"/ "camera_calibration_2025_03_12_11_45" / "Recording_2025-03-12T114830" calib_file_path = config.calibration/"calibration_2025_03_12_11_45.toml" - # get_calib_file(calib_videos_dir=recent_calib_folder, calib_save_path=calib_file_path) else: calib_file_path = get_most_recent_calib(session) @@ -726,83 +725,28 @@ def run_pose_estimation( # Save the updated CSV combined_data.to_csv(combined_csv, index=False) logging.info(f"Angles computed and combined CSV saved at {combined_csv}.") - logging.info(f"Pose estimation completed for {session}.") - -def create_test_videos(session, cameras=params.default_cameras, duration_seconds=10, fps=100, - force_new=False, start_frame=None): - """ - Creates short test videos and corresponding metadata for each camera. - """ - animal = session.split("_")[0] - n_frames = duration_seconds * fps - - # Create output directory for test videos - test_dir = config.LOCAL_PATH /"raw" / animal / session / "pose-estimation"/ "tests" - - - cameras_dir = test_dir / f"{session}_cameras" - cameras_dir.mkdir(parents=True, exist_ok=True) - for camera in cameras: try: - # Output video path - output_video = cameras_dir / f"{params.camera_name_mapping.get(camera, camera)}.avi" - - # Skip if video exists and force_new is False - if output_video.exists() and not force_new: - logger.info(f"Test video already exists for {camera}, skipping: {output_video}") - continue - - # Input video path - input_video = ( - config.recordings - / animal - / session - / f"{session}_cameras" - / f"{params.camera_name_mapping.get(camera, camera)}.avi" - ) - - if not input_video.exists(): - logger.warning(f"Input video not found: {input_video}") - continue - - - # Create video - cap = cv2.VideoCapture(str(input_video)) - if not cap.isOpened(): - logger.error(f"Could not open video: {input_video}") - continue - - # Get video properties - width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - - # Create video writer - fourcc = cv2.VideoWriter_fourcc(*'XVID') - out = cv2.VideoWriter(str(output_video), fourcc, fps, (width, height)) - - # Set start frame if specified - if start_frame is not None: - cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) - - # Read and write frames - frame_count = 0 - while frame_count < n_frames: - ret, frame = cap.read() - if not ret: - break - out.write(frame) - frame_count += 1 - - # Release resources - cap.release() - out.release() - - logger.info(f"Created test video for {camera}: {output_video}") - + if labels_fname.exists(): + labels_fname.unlink() + logger.info(f"Deleted intermediate CSV: {labels_fname.name}") + if angles_csv.exists(): + angles_csv.unlink() + logger.info(f"Deleted intermediate CSV: {angles_csv.name}") except Exception as e: - logger.error(f"Error processing camera {camera}: {e}") - - return test_dir + logger.error(f"Error deleting intermediate CSVs: {e}") + + + triangulation_files = list(predictions_dir.glob("**/*triangulation*.h5")) + if triangulation_files: + for tri_file in triangulation_files: + try: + tri_file.unlink() + logger.info(f"Deleted: {tri_file}") + except Exception as e: + logger.error(f"Error deleting {tri_file}: {e}") + logging.info(f"Pose estimation completed for {session}.") + + def run_pose_test(session, test_name = None, cameras=params.default_cameras, force_new_videos=False, start_frame=None, duration_seconds=10): @@ -812,7 +756,7 @@ def run_pose_test(session, test_name = None, cameras=params.default_cameras, for try: # 1. Create test videos logger.info("Creating test videos...") - tests_dir = create_test_videos(session, cameras, duration_seconds, + tests_dir = tools.create_test_videos(session, cameras, duration_seconds, force_new=force_new_videos, start_frame=start_frame) test_dir = tests_dir / test_name if test_name is None: @@ -827,10 +771,8 @@ def run_pose_test(session, test_name = None, cameras=params.default_cameras, for - if session.split("_")[1] == "2025": # TODO: set the condition so that sessions after 3rd of february 2025 use this - recent_calib_folder = config.REMOTE_PATH/ "raw"/ "pose-estimation"/ "calibration-videos"/ "camera_calibration_2025_03_12_11_45" / "Recording_2025-03-12T114830" + if session.split("_")[1] == "2025" and session.split("_")[2] != "01": # TODO: set the condition so that sessions after 3rd of february 2025 use this calib_file_path = config.calibration/"calibration_2025_03_12_11_45.toml" - # get_calib_file(calib_videos_dir=recent_calib_folder, calib_save_path=calib_file_path) else: calib_file_path = get_most_recent_calib(session) @@ -848,8 +790,8 @@ def run_pose_test(session, test_name = None, cameras=params.default_cameras, for config_path = create_config_file(config_path) config_angles = toml.load(config_path) angles_csv = test_dir/f"{session}_angles.csv" - # labels_data = pd.read_csv(labels_xfname) - # logging.debug(labels_data.columns) + + compute_angles(config_angles, labels_fname, angles_csv) pose_data = pd.read_csv(labels_fname) @@ -861,57 +803,26 @@ def run_pose_test(session, test_name = None, cameras=params.default_cameras, for # Save the updated CSV combined_data.to_csv(combined_csv, index=False) logging.info(f"Angles computed and combined CSV saved at {combined_csv}.") - logging.info(f"Pose estimation completed for {session}.") - - - - except Exception as e: - logger.error(f"Error in pose test for {session}: {e}") - raise - - -def cleanup_intermediate_files(session: str, include_slp: bool = False): - """ - Clean up intermediate files for a session with interactive prompts. - - Args: - session: Session name to clean up - include_slp: Whether to include .slp files in cleanup - """ - animal = session.split("_")[0] - project_dir = config.predictions3D / animal / session / "pose-estimation" - - if not project_dir.exists(): - logger.error(f"Project directory not found: {project_dir}") - return + try: + if labels_fname.exists(): + labels_fname.unlink() + logger.info(f"Deleted intermediate CSV: {labels_fname.name}") + if angles_csv.exists(): + angles_csv.unlink() + logger.info(f"Deleted intermediate CSV: {angles_csv.name}") + except Exception as e: + logger.error(f"Error deleting intermediate CSVs: {e}") - # 1. Clean up 2D prediction .slp files (only if include_slp is True) - if include_slp: - slp_files = list(project_dir.glob("**/*.slp")) - if slp_files: - response = input(f"\nFound {len(slp_files)} .slp files. Do you want to delete them? (y/n): ").lower() - if response == 'y': - for slp_file in slp_files: - try: - slp_file.unlink() - logger.info(f"Deleted: {slp_file}") - except Exception as e: - logger.error(f"Error deleting {slp_file}: {e}") - else: - logger.info("Skipping .slp files deletion") - - # 2. Clean up triangulation files - triangulation_files = list(project_dir.glob("**/*triangulation*.h5")) - if triangulation_files: - response = input(f"\nFound {len(triangulation_files)} triangulation files. Do you want to delete them? (y/n): ").lower() - if response == 'y': + triangulation_files = list(tests_dir.glob("**/*triangulation*.h5")) + if triangulation_files: for tri_file in triangulation_files: try: tri_file.unlink() logger.info(f"Deleted: {tri_file}") except Exception as e: logger.error(f"Error deleting {tri_file}: {e}") - else: - logger.info("Skipping triangulation files deletion") - - logger.info("Cleanup completed") \ No newline at end of file + logging.info(f"Pose estimation completed for {session}.") + + except Exception as e: + logger.error(f"Error in pose test for {session}: {e}") + raise diff --git a/beneuro_pose_estimation/cli.py b/beneuro_pose_estimation/cli.py index 210ba23..3a5880d 100644 --- a/beneuro_pose_estimation/cli.py +++ b/beneuro_pose_estimation/cli.py @@ -113,25 +113,35 @@ def create_training_project( @app.command() def cleanup( session: str = typer.Argument(..., help="Session name to clean up intermediate files for."), - slp: bool = typer.Option( - False, - "--slp", "-s", - help="Also clean up .slp files (2D prediction files)." - ) + ): """ Clean up intermediate files for a session. By default, only asks about cleaning up triangulation files. Use --slp flag to also clean up 2D prediction .slp files. """ - from beneuro_pose_estimation.anipose.aniposeTools import cleanup_intermediate_files + from beneuro_pose_estimation.tools import cleanup_intermediate_files - cleanup_intermediate_files(session, include_slp=slp) + cleanup_intermediate_files(session) return +@app.command() +def model_up( + test_folder_name: str = typer.Argument( + ..., + help="Name of the test folder under config.models// to copy to remote_models" + ) +): + """ + Recursively copy a model test folder from your local models path to the remote_models path, + prompting if it already exists. + """ + from beneuro_pose_estimation.tools import copy_model_to_remote + copy_model_to_remote(test_folder_name) + @app.command() def eval_report( session_name: str = typer.Argument(..., help="Session name to evaluate"), @@ -175,7 +185,7 @@ def pose_test( start_frame: Optional[int] = typer.Option( None, "--start-frame", "-s", - help="Frame number to start from. If not specified, uses first 100 frames." + help="Frame number to start from. If not specified, uses frame 0." ), duration: Optional[int] = typer.Option( 10, @@ -220,7 +230,7 @@ def train( from beneuro_pose_estimation.sleap.sleapTools import train_models cams = cameras or params.default_cameras # train_models is the function you already wrote - train_models(cams) + train_models(cam, custom_labels = custom_labels) # =================================== Updating ========================================== diff --git a/beneuro_pose_estimation/config.py b/beneuro_pose_estimation/config.py index fbc791e..07b6a59 100644 --- a/beneuro_pose_estimation/config.py +++ b/beneuro_pose_estimation/config.py @@ -79,24 +79,23 @@ def assign_paths(self): self.recordings_remote = self.REMOTE_PATH / "raw" self.annotation_party = self.REMOTE_PATH / "processed" / "AnnotationParty" self.annotations = self.annotation_party / "annotations" - # self.models = ( - # self.REMOTE_PATH / "raw" / "pose-estimation" / "models" / "h1_new_setup" - # ) + self.models_local = self.LOCAL_PATH / "raw" / "pose-estimation" / "models" self.models_remote = self.annotation_party / "models" self.custom_models = self.models_local + self.models = self.models_remote self.skeleton_path = ( self.REPO_PATH / "beneuro_pose_estimation" / "sleap" / "skeleton.json" ) - self.recordings = self.annotation_party # self.LOCAL_PATH / "raw" + self.recordings = self.LOCAL_PATH / "raw" + # self.recordings = self.annotation_party self.predictions2D = self.LOCAL_PATH / "raw" # change to 'processed'? self.training = self.annotation_party / "training" self.training_config = self.REPO_PATH / "beneuro_pose_estimation"/ "training_config.json" - self.predictions3D = self.LOCAL_PATH / "raw" # change to 'processed' - # self.predictions3D = self.LOCAL_PATH / "raw" + self.predictions3D = self.LOCAL_PATH / "raw" # change to 'processed'? self.calibration_videos = self.REMOTE_PATH / "raw" / "calibration_videos" - self.calibration = self.LOCAL_PATH / "raw"/ "pose_estimation"/ "calibration_config" + self.calibration = self.LOCAL_PATH / "raw"/ "pose-estimation"/ "calibration_config" self.angles_config = self.REPO_PATH / "beneuro_pose_estimation" return diff --git a/beneuro_pose_estimation/evaluation.py b/beneuro_pose_estimation/evaluation.py index eeb35f7..cb33987 100644 --- a/beneuro_pose_estimation/evaluation.py +++ b/beneuro_pose_estimation/evaluation.py @@ -113,111 +113,32 @@ def load_2d_predictions_old(session_dir, cameras=params.default_cameras): logger.info(f"Final stacked shape: {stacked.shape}") return stacked -def calculate_reprojection_errors(session_name, test_dir, cameras=params.default_cameras): +def _print_error_stats(errors: np.ndarray, cameras: list): """ - Calculate reprojection errors between 3D predictions and 2D data. - Returns dict with per_camera, per_keypoint, and overall stats. + Helper to print mean per keypoint, per camera, and overall stats. """ - logger.info(f"Analyzing session: {session_name}") + # Stack: (n_cams, n_frames, n_kp) + mean_per_kp = np.nanmean(errors, axis=(0, 1)) # (n_kp,) + mean_per_cam = np.nanmean(errors, axis=(1, 2)) # (n_cams,) + flat = errors.flatten() + mean_all = float(np.nanmean(flat)) + median_all = float(np.nanmedian(flat)) + std_all = float(np.nanstd(flat)) - # 1) Load 2D predictions - preds2d = load_2d_predictions(session_name, test_dir, cameras) - if preds2d is None: - return None - - # 2) Load raw 3D predictions - points3d_raw = load_3d_predictions(session_name, test_dir) - if points3d_raw is None: - return None - - # Handle both 3-D and 4-D shapes - if points3d_raw.ndim == 3: - # (n_frames, n_kp, 3) → add track dim - n_frames, n_kp, _ = points3d_raw.shape - points3d_4d = points3d_raw[:, None, :, :] # → (n_frames,1,n_kp,3) - elif points3d_raw.ndim == 4: - n_frames, n_tracks, n_kp, _ = points3d_raw.shape - points3d_4d = points3d_raw - else: - raise ValueError(f"Unexpected points3d shape {points3d_raw.shape}") - # logger.info(f"points3d_4d shape: {points3d_4d.shape}") - - # We’ll compare using the first track: - points3d = points3d_4d[:, 0, :, :] # (n_frames, n_kp, 3) - - # 3) Load calibration - if session_name.split("_")[1] == "2025": - calib_file = config.calibration / "calibration_2025_03_12_11_45.toml" - else: - calib_file = aniposeTools.get_most_recent_calib(session_name) - if not calib_file or not calib_file.exists(): - logger.error(f"Missing calibration for {session_name}") - return None - - # 4) Reproject all tracks back into each camera - reproj_file = test_dir / f"{session_name}_reprojections.h5" - slap.reproject( - p3d=points3d_4d, # 4-D array as required - calib=str(calib_file), - frames=(0, n_frames), - fname=str(reproj_file), - ) - - # 5) Load per-camera reprojections - reproj = {} - with h5py.File(reproj_file, "r") as f: - for cam in cameras: - arr = f[cam][:] - # Squeeze any extra singleton dims beyond the last two - while arr.ndim > 3 and 1 in arr.shape[:-2]: - arr = arr.squeeze(arr.shape.index(1)) - # If shape is (n_kp, n_frames, 2), transpose - if arr.shape[0] == n_kp and arr.shape[1] == n_frames: - arr = arr.transpose(1, 0, 2) - assert arr.shape == (n_frames, n_kp, 2), f"Bad reproj shape for {cam}: {arr.shape}" - reproj[cam] = arr - # logger.info(f"Reproj[{cam}] shape: {arr.shape}") - - # 6) Compute per-camera errors - errors = {} - for i, cam in enumerate(cameras): - # preds2d[i]: (n_tracks,2,n_kp,n_frames) - pred2d = preds2d[i, 0] # (2,n_kp,n_frames) - pred_xy = np.transpose(pred2d, (2, 1, 0)) # → (n_frames,n_kp,2) - err = np.linalg.norm(reproj[cam] - pred_xy, axis=-1) # (n_frames,n_kp) - errors[cam] = err - # logger.info(f"errors[{cam}] shape: {err.shape}") - - # 7) Stack and summarize - all_err = np.stack(list(errors.values()), axis=0) # (n_cams,n_frames,n_kp) - mean_per_kp = np.nanmean(all_err, axis=(0, 1)) # (n_kp,) - mean_per_cam= np.nanmean(all_err, axis=(1, 2)) # (n_cams,) - flat = all_err.flatten() - mean_all = float(np.nanmean(flat)) - median_all = float(np.nanmedian(flat)) - std_all = float(np.nanstd(flat)) - - # 8) Log results print("Mean reprojection error per keypoint (px):") for idx, kp in enumerate(params.body_parts): print(f" {kp}: {mean_per_kp[idx]:.3f}") - print("Mean reprojection error per camera (px):") + print("\nMean reprojection error per camera (px):") for idx, cam in enumerate(cameras): print(f" {cam}: {mean_per_cam[idx]:.3f}") - print("Overall error stats (px):") + print("\nOverall error stats (px):") print(f" Mean : {mean_all:.3f}") print(f" Median : {median_all:.3f}") print(f" Std : {std_all:.3f}") - return { - "per_camera": mean_per_cam, - "per_keypoint": mean_per_kp, - "overall": {"mean": mean_all, "median": median_all, "std": std_all}, - } - -def get_reprojection_errors_array(session_name, test_dir, cameras=params.default_cameras): +def get_reprojection_errors(session_name, test_dir, cameras=params.default_cameras, print_stats = False): """ Returns a NumPy array of shape (n_cameras, n_frames, n_keypoints) containing reprojection errors for each camera, frame, and keypoint. @@ -226,50 +147,66 @@ def get_reprojection_errors_array(session_name, test_dir, cameras=params.default preds2d = load_2d_predictions(session_name, test_dir, cameras) points3d = load_3d_predictions(session_name, test_dir) - # Ensure a 4D array for triangulation: (n_frames, n_tracks, n_kp, 3) - if points3d.ndim == 3: - points3d_4d = points3d[:, None, :, :] - else: - points3d_4d = points3d - n_frames = points3d_4d.shape[0] + reproj_file = test_dir / f"{session_name}_reproj.h5" + errors_file = test_dir / f"{session_name}_reproj_errors.h5" - # Load calibration - if session_name.split("_")[1] == "2025": - calib_file = config.calibration / "calibration_2025_03_12_11_45.toml" + # 1) If errors already saved, load & return + if errors_file.exists(): + with h5py.File(errors_file, "r") as f_err: + # assume dataset "errors" of shape (n_cameras, n_frames, n_kp) + errors = f_err["errors"][:] else: - calib_file = aniposeTools.get_most_recent_calib(session_name) - - # Triangulate (reproject) - reproj_file = Path(test_dir) / f"{session_name}_reproj_temp.h5" - slap.reproject( - p3d=points3d_4d, - calib=str(calib_file), - frames=(0, n_frames), - fname=str(reproj_file), - ) - - # Load reprojections - reproj = [] - with h5py.File(reproj_file, "r") as f: - for cam in cameras: - arr = f[cam][:] - # Squeeze any extra singleton dims before last two - while arr.ndim > 3 and 1 in arr.shape[:-2]: - arr = arr.squeeze(axis=arr.shape.index(1)) - # If dims are (n_kp, n_frames, 2), transpose - if arr.shape[0] == len(params.body_parts) and arr.shape[1] == n_frames: - arr = arr.transpose(1, 0, 2) - reproj.append(arr) - - # Compute errors - errors = [] - for i, cam in enumerate(cameras): - pred2d = preds2d[i, 0] # (2, n_kp, n_frames) - pred_xy = np.transpose(pred2d, (2, 1, 0)) # (n_frames, n_kp, 2) - err = np.linalg.norm(reproj[i] - pred_xy, axis=-1) # (n_frames, n_kp) - errors.append(err) - - return np.stack(errors, axis=0) # (n_cameras, n_frames, n_kp) + # 2) Load data + preds2d = load_2d_predictions(session_name, test_dir, cameras) + points3d = load_3d_predictions(session_name, test_dir) + if points3d.ndim == 3: + points3d = points3d[:, None, :, :] + n_frames = points3d.shape[0] + + # 3) Select calibration + parts = session_name.split("_") + if parts[1] == "2025" and parts[2] != "01": + calib_file = config.calibration / "calibration_2025_03_12_11_45.toml" + else: + calib_file = aniposeTools.get_most_recent_calib(session_name) + + # 4) Ensure reprojections exist + if not reproj_file.exists(): + slap.reproject( + p3d=points3d, + calib=str(calib_file), + frames=(0, n_frames), + fname=str(reproj_file), + ) + + # 5) Load reprojections + reproj = [] + with h5py.File(reproj_file, "r") as f_reproj: + for cam in cameras: + arr = f_reproj[cam][:] + # squeeze extra singleton dims + while arr.ndim > 3 and 1 in arr.shape[:-2]: + arr = arr.squeeze(axis=arr.shape.index(1)) + # transpose if needed + if arr.shape[0] == len(params.body_parts) and arr.shape[1] == n_frames: + arr = arr.transpose(1, 0, 2) + reproj.append(arr) + + # 6) Compute reprojection errors + errors = [] + for i, cam in enumerate(cameras): + pred2d = preds2d[i, 0] # (2, n_kp, n_frames) + pred_xy = np.transpose(pred2d, (2, 1, 0)) # (n_frames, n_kp, 2) + err = np.linalg.norm(reproj[i] - pred_xy, axis=-1) + errors.append(err) + errors = np.stack(errors, axis=0) # (n_cameras, n_frames, n_kp) + + # 7) Save errors for next time + with h5py.File(errors_file, "w") as f_err: + f_err.create_dataset("errors", data=errors, compression="gzip") + if print_stats: + _print_error_stats(errors, cameras) + return errors import math @@ -288,7 +225,7 @@ def plot_reprojection_error_histograms(session_name, test_dir, bins=50): tuple: (fig_cam, fig_kp) Matplotlib Figure objects. """ # Load errors - all_errors = get_reprojection_errors_array(session_name, test_dir) + all_errors = get_reprojection_errors(session_name, test_dir) cameras = params.default_cameras keypoints = params.body_parts @@ -352,7 +289,7 @@ def plot_reprojection_error_per_camera(session_name, test_dir, bins=50): Overlay reprojection‐error histograms for each camera on one plot. """ # Get the errors array: shape (n_cameras, n_frames, n_keypoints) - errors = get_reprojection_errors_array(session_name, test_dir) + errors = get_reprojection_errors(session_name, test_dir) plt.figure(figsize=(8, 6)) for i, cam in enumerate(params.default_cameras): @@ -384,7 +321,7 @@ def plot_reprojection_error_per_keypoint(session_name, test_dir, bins=50): with a distinct color per keypoint. """ # Get the errors array: shape (n_cameras, n_frames, n_keypoints) - errors = get_reprojection_errors_array(session_name, test_dir) # (cams, frames, kps) + errors = get_reprojection_errors(session_name, test_dir) # (cams, frames, kps) n_kp = len(params.body_parts) # Choose a qualitative colormap with enough distinct colors @@ -437,7 +374,7 @@ def plot_keypoint_errors_by_camera(session_name, test_dir, camera, bins=50): matplotlib.figure.Figure: Figure with subplots per keypoint. """ # Get full error array: (n_cameras, n_frames, n_keypoints) - all_errors = get_reprojection_errors_array(session_name, test_dir) + all_errors = get_reprojection_errors(session_name, test_dir) cameras = params.default_cameras if camera not in cameras: @@ -935,8 +872,9 @@ def update(frame): anim.save(str(video_path), writer="ffmpeg", fps=fps) return anim + def plot_reprojection_errors(session_name, test_dir, bins=50): - all_errors = get_reprojection_errors_array(session_name, test_dir) + all_errors = get_reprojection_errors(session_name, test_dir) flat_errors = all_errors.flatten() flat_errors = flat_errors[~np.isnan(flat_errors)] @@ -947,6 +885,7 @@ def plot_reprojection_errors(session_name, test_dir, bins=50): plt.xlabel("Error (pixels)") plt.ylabel("Count") plt.show() + def create_3d_animation(points3d, session_dir, output_dir=None, fps=30, start_frame=None, end_frame=None): """ Create a simplified 3D animation of pose estimation results. @@ -1399,7 +1338,8 @@ def compute_keypoint_missing_frame_stats(csv_filepath, body_parts=None, sep=",") csv_filepath = Path(csv_filepath) df = pd.read_csv(csv_filepath, sep=sep) df.columns = df.columns.str.strip() - + n_frames = len(df) + print(f"Total frames: {n_frames}") if body_parts is None: body_parts = params.body_parts @@ -1904,24 +1844,6 @@ def plot_angle_histograms(csv_path, frame_start=0, frame_end=None, bins=50): plt.tight_layout() plt.show() -def plot_angles_timeseries(csv_path, frame_start=0, frame_end=None): - """ - Overlay all `_angle` time series over fnum (frame number). - """ - df = _load_angle_df(csv_path, frame_start, frame_end) - angle_cols = [c for c in df.columns if c.endswith("_angle")] - if "fnum" not in df.columns: - raise KeyError("CSV must contain 'fnum' column for frame numbers.") - plt.figure(figsize=(10,6)) - for col in angle_cols: - plt.plot(df["fnum"], df[col], label=col) - plt.title("Joint Angles over Time") - plt.xlabel("Frame number") - plt.ylabel("Angle (°)") - plt.legend(bbox_to_anchor=(1.02,1), loc="upper left") - plt.tight_layout() - plt.show() - def plot_bodypart_autocorr_spectrum( csv_path, @@ -2246,4 +2168,69 @@ def print_angle_stats_table(csv_path, frame_start=0, frame_end=None): """ stats_df = build_angle_stats_table(csv_path, frame_start, frame_end) # Print as Markdown for clearer console display - print(stats_df.to_markdown(index=False)) \ No newline at end of file + print(stats_df.to_markdown(index=False)) + + +def plot_angles( + csv_path, + fields=None, + frame_start=0, + frame_end=None, + figsize=(20, 5), + save_path=None +): + """ + Plot specified joint angles (or all '*_angle' columns) over a frame range from a CSV file. + + Args: + csv_path: Path to the CSV containing angle columns. + fields: List of column names to plot. If None, auto-selects all columns ending in '_angle'. + frame_start: First frame index (inclusive). + frame_end: One-past-last frame index. Defaults to end of CSV. + figsize: Figure size tuple for matplotlib. + save_path: If provided, save the figure to this path. + """ + path = Path(csv_path) + df = pd.read_csv(path) + + # Auto-select angle fields if not provided + if fields is None: + fields = [c for c in df.columns if c.endswith("_angle")] + if not fields: + raise ValueError("No columns ending with '_angle' found in CSV.") + + # Validate requested fields + missing = [f for f in fields if f not in df.columns] + if missing: + raise ValueError(f"Requested fields not in CSV: {missing}") + + # Determine frame_end + total = len(df) + if frame_end is None or frame_end > total: + frame_end = total + + df_range = df.iloc[frame_start:frame_end] + + # Prepare colors (at least as many as fields) + cmap = plt.get_cmap("tab20") + colors = [cmap(i % 20) for i in range(len(fields))] + + # Plot + plt.figure(figsize=figsize) + for idx, field in enumerate(fields): + plt.plot( + df_range.index, + df_range[field], + color=colors[idx], + label=field + ) + + plt.xlabel("Frame") + plt.ylabel("Angle (degrees)") + plt.legend(fontsize="x-small", ncol=2) + plt.tight_layout() + + if save_path: + plt.savefig(save_path) + + plt.show() \ No newline at end of file diff --git a/beneuro_pose_estimation/params.py b/beneuro_pose_estimation/params.py index c9d7052..98bf684 100644 --- a/beneuro_pose_estimation/params.py +++ b/beneuro_pose_estimation/params.py @@ -134,13 +134,13 @@ "scale_smooth": 5, "scale_length": 4, "scale_length_weak": 1, - "reproj_error_threshold": 15, + "reproj_error_threshold": 5, "reproj_loss": "soft_l1", "n_deriv_smooth": 2, - "ransac": True + "ransac": False } -# "soft_l1" "l2" -# ransac + + frame_window = 1000 diff --git a/beneuro_pose_estimation/sleap/sleapTools.py b/beneuro_pose_estimation/sleap/sleapTools.py index 1505ad9..288aa7d 100644 --- a/beneuro_pose_estimation/sleap/sleapTools.py +++ b/beneuro_pose_estimation/sleap/sleapTools.py @@ -587,7 +587,7 @@ def train_models_old(cameras=params.default_cameras, sessions=None): logging.info("All training has been executed.") -def train_models(cameras=params.default_cameras): +def train_models(cameras=params.default_cameras, custom_labels = False): """ TBD - create config file with training parameters; check if config file exists, if not create it using the parameters in params @@ -618,7 +618,10 @@ def train_models(cameras=params.default_cameras): # Run sleap-train command logging.info(f"Training model for {camera}...") - command = ["sleap-train", config_file, labels_file] + if custom_labels: + command = ["sleap-train", str(config_file)] + else: + command = ["sleap-train", str(config_file), str(labels_file)] result = subprocess.run(command, cwd=str(config_file.parent)) if result.returncode == 0: @@ -753,24 +756,22 @@ def get_2Dpredictions( / f"{params.camera_name_mapping.get(camera, camera)}.avi" ) if custom_model_name is not None: - model_dir = config.custom_models / camera / f"{camera}_{test_name}" + + model_dir = config.custom_models / camera / f"{camera}_{custom_model_name}" + if not model_dir.exists(): logging.info( f"Custom Model directory for {camera} does not exist, looking for general model." ) - model_dir = config.models / camera - if not model_dir.exists(): - logging.info( - f"Model directory for {camera} does not exist, skipping." - ) - continue + model_dir = config.models / camera + else: model_dir = config.models / camera - if not model_dir.exists(): - logging.info( - f"Model directory for {camera} does not exist, skipping." - ) - continue + if not model_dir.exists(): + logging.info( + f"Model directory for {camera} does not exist, skipping." + ) + continue model_path = model_dir / "training_config.json" logging.info( diff --git a/beneuro_pose_estimation/tools.py b/beneuro_pose_estimation/tools.py new file mode 100644 index 0000000..d7221e0 --- /dev/null +++ b/beneuro_pose_estimation/tools.py @@ -0,0 +1,180 @@ +from pathlib import Path +import logging +import shutil +if not logging.getLogger().hasHandlers(): + logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" + ) + +from beneuro_pose_estimation import params, set_logging +from beneuro_pose_estimation.config import _load_config +import cv2 +config = _load_config() + +logger = set_logging(__name__) + +def copy_model_to_remote(test_folder_name: str): + """ + Copy a model test folder from local config.custom_models to remote config.training + + """ + src_root = config.custom_models + + # Find the folder under one of the camera subdirectories + src_dir = None + for cam in params.default_cameras: + candidate = src_root / cam / test_folder_name + if candidate.is_dir(): + src_dir = candidate + camera = cam + break + if src_dir is None: + raise FileNotFoundError(f"'{test_folder_name}' not found under any camera in {src_root}") + + # Build destination path: //models/ + dest_dir = config.training / camera / "models" / test_folder_name + + if dest_dir.exists(): + resp = input(f"Remote folder '{dest_dir}' already exists. Overwrite? (y/N): ").strip().lower() + if resp != "y": + logging.info("Aborted. Existing remote model not overwritten.") + return + shutil.rmtree(dest_dir) + logger.info(f"Deleted existing remote folder: {dest_dir}") + + # Perform recursive copy + shutil.copytree(src_dir, dest_dir) + logging.info(f"Copied '{src_dir}' → '{dest_dir}'.") + +def cleanup_intermediate_files(session: str): + """ + Clean up intermediate files for a session with interactive prompts. + + 1) Delete any '*triangulation*.h5' files under project_dir (including subfolders). + Uses shutil.os.remove for files. + 2) Prompt to delete the entire 'tests' directory under project_dir, uses shutil.rmtree. + """ + animal = session.split("_")[0] + project_dir = config.predictions3D / animal / session / "pose-estimation" + + if not project_dir.exists(): + logger.error(f"Project directory not found: {project_dir}") + return + + # 1) Clean up triangulation files (searching recursively) + triangulation_files = list(project_dir.glob("**/*triangulation*.h5")) + if triangulation_files: + resp = input(f"\nFound {len(triangulation_files)} triangulation file(s). Delete them? (y/N): ").strip().lower() + if resp == "y": + for fpath in triangulation_files: + try: + # shutil.os.remove is the same as os.remove + shutil.os.remove(fpath) + logger.info(f"Deleted triangulation file: {fpath}") + except Exception as e: + logger.error(f"Error deleting {fpath}: {e}") + else: + logger.info("Skipped deleting triangulation files.") + else: + logger.info("No triangulation files found.") + + # 2) Clean up 'tests' directory + tests_dir = project_dir / "tests" + if tests_dir.is_dir(): + subdirs = [p for p in tests_dir.iterdir() if p.is_dir()] + if subdirs: + print(f"\nFound {len(subdirs)} test folder(s) under '{tests_dir}':") + for sd in subdirs: + print(f" • {sd.name}") + resp = input("Delete the entire 'tests' directory and its contents? (y/N): ").strip().lower() + if resp == "y": + try: + shutil.rmtree(tests_dir) + logger.info(f"Deleted 'tests' directory: {tests_dir}") + except Exception as e: + logger.error(f"Failed to delete tests directory {tests_dir}: {e}") + else: + logger.info("Skipped deleting 'tests' directory.") + else: + logger.info(f"'tests' directory exists but has no subfolders: {tests_dir}") + else: + logger.info("No 'tests' directory found.") + + logger.info("Cleanup completed.") + + +def create_test_videos(session, cameras=params.default_cameras, duration_seconds=10, fps=100, + force_new=False, start_frame=None): + """ + Creates short test videos and corresponding metadata for each camera. + """ + animal = session.split("_")[0] + n_frames = duration_seconds * fps + + # Create output directory for test videos + test_dir = config.LOCAL_PATH /"raw" / animal / session / "pose-estimation"/ "tests" + + + cameras_dir = test_dir / f"{session}_cameras" + cameras_dir.mkdir(parents=True, exist_ok=True) + for camera in cameras: + try: + # Output video path + output_video = cameras_dir / f"{params.camera_name_mapping.get(camera, camera)}.avi" + + # Skip if video exists and force_new is False + if output_video.exists() and not force_new: + logger.info(f"Test video already exists for {camera}, skipping: {output_video}") + continue + + # Input video path + input_video = ( + config.recordings + / animal + / session + / f"{session}_cameras" + / f"{params.camera_name_mapping.get(camera, camera)}.avi" + ) + + if not input_video.exists(): + logger.warning(f"Input video not found: {input_video}") + continue + + + # Create video + cap = cv2.VideoCapture(str(input_video)) + if not cap.isOpened(): + logger.error(f"Could not open video: {input_video}") + continue + + # Get video properties + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + # Create video writer + fourcc = cv2.VideoWriter_fourcc(*'XVID') + out = cv2.VideoWriter(str(output_video), fourcc, fps, (width, height)) + + # Set start frame if specified + if start_frame is not None: + cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) + + # Read and write frames + frame_count = 0 + while frame_count < n_frames: + ret, frame = cap.read() + if not ret: + break + out.write(frame) + frame_count += 1 + + # Release resources + cap.release() + out.release() + + logger.info(f"Created test video for {camera}: {output_video}") + + except Exception as e: + logger.error(f"Error processing camera {camera}: {e}") + + return test_dir \ No newline at end of file