Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 41 additions & 130 deletions beneuro_pose_estimation/anipose/aniposeTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import matplotlib as plt

import beneuro_pose_estimation.sleap.sleapTools as sleapTools
import beneuro_pose_estimation.tools as tools
from beneuro_pose_estimation import params

if not logging.getLogger().hasHandlers():
Expand Down Expand Up @@ -693,9 +694,7 @@ def run_pose_estimation(

# get calibration file - toml file saved in calibration/calibration.toml
if session.split("_")[1] == "2025": # TODO: set the condition so that sessions after 3rd of february 2025 use this
recent_calib_folder = config.REMOTE_PATH/ "raw"/ "pose-estimation"/ "calibration-videos"/ "camera_calibration_2025_03_12_11_45" / "Recording_2025-03-12T114830"
calib_file_path = config.calibration/"calibration_2025_03_12_11_45.toml"
# get_calib_file(calib_videos_dir=recent_calib_folder, calib_save_path=calib_file_path)
else:
calib_file_path = get_most_recent_calib(session)

Expand Down Expand Up @@ -726,83 +725,28 @@ def run_pose_estimation(
# Save the updated CSV
combined_data.to_csv(combined_csv, index=False)
logging.info(f"Angles computed and combined CSV saved at {combined_csv}.")
logging.info(f"Pose estimation completed for {session}.")

def create_test_videos(session, cameras=params.default_cameras, duration_seconds=10, fps=100,
force_new=False, start_frame=None):
"""
Creates short test videos and corresponding metadata for each camera.
"""
animal = session.split("_")[0]
n_frames = duration_seconds * fps

# Create output directory for test videos
test_dir = config.LOCAL_PATH /"raw" / animal / session / "pose-estimation"/ "tests"


cameras_dir = test_dir / f"{session}_cameras"
cameras_dir.mkdir(parents=True, exist_ok=True)
for camera in cameras:
try:
# Output video path
output_video = cameras_dir / f"{params.camera_name_mapping.get(camera, camera)}.avi"

# Skip if video exists and force_new is False
if output_video.exists() and not force_new:
logger.info(f"Test video already exists for {camera}, skipping: {output_video}")
continue

# Input video path
input_video = (
config.recordings
/ animal
/ session
/ f"{session}_cameras"
/ f"{params.camera_name_mapping.get(camera, camera)}.avi"
)

if not input_video.exists():
logger.warning(f"Input video not found: {input_video}")
continue


# Create video
cap = cv2.VideoCapture(str(input_video))
if not cap.isOpened():
logger.error(f"Could not open video: {input_video}")
continue

# Get video properties
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

# Create video writer
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(str(output_video), fourcc, fps, (width, height))

# Set start frame if specified
if start_frame is not None:
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)

# Read and write frames
frame_count = 0
while frame_count < n_frames:
ret, frame = cap.read()
if not ret:
break
out.write(frame)
frame_count += 1

# Release resources
cap.release()
out.release()

logger.info(f"Created test video for {camera}: {output_video}")

if labels_fname.exists():
labels_fname.unlink()
logger.info(f"Deleted intermediate CSV: {labels_fname.name}")
if angles_csv.exists():
angles_csv.unlink()
logger.info(f"Deleted intermediate CSV: {angles_csv.name}")
except Exception as e:
logger.error(f"Error processing camera {camera}: {e}")

return test_dir
logger.error(f"Error deleting intermediate CSVs: {e}")


triangulation_files = list(predictions_dir.glob("**/*triangulation*.h5"))
if triangulation_files:
for tri_file in triangulation_files:
try:
tri_file.unlink()
logger.info(f"Deleted: {tri_file}")
except Exception as e:
logger.error(f"Error deleting {tri_file}: {e}")
logging.info(f"Pose estimation completed for {session}.")



def run_pose_test(session, test_name = None, cameras=params.default_cameras, force_new_videos=False,
start_frame=None, duration_seconds=10):
Expand All @@ -812,7 +756,7 @@ def run_pose_test(session, test_name = None, cameras=params.default_cameras, for
try:
# 1. Create test videos
logger.info("Creating test videos...")
tests_dir = create_test_videos(session, cameras, duration_seconds,
tests_dir = tools.create_test_videos(session, cameras, duration_seconds,
force_new=force_new_videos, start_frame=start_frame)
test_dir = tests_dir / test_name
if test_name is None:
Expand All @@ -827,10 +771,8 @@ def run_pose_test(session, test_name = None, cameras=params.default_cameras, for



if session.split("_")[1] == "2025": # TODO: set the condition so that sessions after 3rd of february 2025 use this
recent_calib_folder = config.REMOTE_PATH/ "raw"/ "pose-estimation"/ "calibration-videos"/ "camera_calibration_2025_03_12_11_45" / "Recording_2025-03-12T114830"
if session.split("_")[1] == "2025" and session.split("_")[2] != "01": # TODO: set the condition so that sessions after 3rd of february 2025 use this
calib_file_path = config.calibration/"calibration_2025_03_12_11_45.toml"
# get_calib_file(calib_videos_dir=recent_calib_folder, calib_save_path=calib_file_path)
else:
calib_file_path = get_most_recent_calib(session)

Expand All @@ -848,8 +790,8 @@ def run_pose_test(session, test_name = None, cameras=params.default_cameras, for
config_path = create_config_file(config_path)
config_angles = toml.load(config_path)
angles_csv = test_dir/f"{session}_angles.csv"
# labels_data = pd.read_csv(labels_xfname)
# logging.debug(labels_data.columns)


compute_angles(config_angles, labels_fname, angles_csv)

pose_data = pd.read_csv(labels_fname)
Expand All @@ -861,57 +803,26 @@ def run_pose_test(session, test_name = None, cameras=params.default_cameras, for
# Save the updated CSV
combined_data.to_csv(combined_csv, index=False)
logging.info(f"Angles computed and combined CSV saved at {combined_csv}.")
logging.info(f"Pose estimation completed for {session}.")



except Exception as e:
logger.error(f"Error in pose test for {session}: {e}")
raise


def cleanup_intermediate_files(session: str, include_slp: bool = False):
"""
Clean up intermediate files for a session with interactive prompts.

Args:
session: Session name to clean up
include_slp: Whether to include .slp files in cleanup
"""
animal = session.split("_")[0]
project_dir = config.predictions3D / animal / session / "pose-estimation"

if not project_dir.exists():
logger.error(f"Project directory not found: {project_dir}")
return
try:
if labels_fname.exists():
labels_fname.unlink()
logger.info(f"Deleted intermediate CSV: {labels_fname.name}")
if angles_csv.exists():
angles_csv.unlink()
logger.info(f"Deleted intermediate CSV: {angles_csv.name}")
except Exception as e:
logger.error(f"Error deleting intermediate CSVs: {e}")

# 1. Clean up 2D prediction .slp files (only if include_slp is True)
if include_slp:
slp_files = list(project_dir.glob("**/*.slp"))
if slp_files:
response = input(f"\nFound {len(slp_files)} .slp files. Do you want to delete them? (y/n): ").lower()
if response == 'y':
for slp_file in slp_files:
try:
slp_file.unlink()
logger.info(f"Deleted: {slp_file}")
except Exception as e:
logger.error(f"Error deleting {slp_file}: {e}")
else:
logger.info("Skipping .slp files deletion")

# 2. Clean up triangulation files
triangulation_files = list(project_dir.glob("**/*triangulation*.h5"))
if triangulation_files:
response = input(f"\nFound {len(triangulation_files)} triangulation files. Do you want to delete them? (y/n): ").lower()
if response == 'y':
triangulation_files = list(tests_dir.glob("**/*triangulation*.h5"))
if triangulation_files:
for tri_file in triangulation_files:
try:
tri_file.unlink()
logger.info(f"Deleted: {tri_file}")
except Exception as e:
logger.error(f"Error deleting {tri_file}: {e}")
else:
logger.info("Skipping triangulation files deletion")

logger.info("Cleanup completed")
logging.info(f"Pose estimation completed for {session}.")

except Exception as e:
logger.error(f"Error in pose test for {session}: {e}")
raise
28 changes: 19 additions & 9 deletions beneuro_pose_estimation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,25 +113,35 @@ def create_training_project(
@app.command()
def cleanup(
session: str = typer.Argument(..., help="Session name to clean up intermediate files for."),
slp: bool = typer.Option(
False,
"--slp", "-s",
help="Also clean up .slp files (2D prediction files)."
)

):
"""
Clean up intermediate files for a session.
By default, only asks about cleaning up triangulation files.
Use --slp flag to also clean up 2D prediction .slp files.
"""
from beneuro_pose_estimation.anipose.aniposeTools import cleanup_intermediate_files
from beneuro_pose_estimation.tools import cleanup_intermediate_files

cleanup_intermediate_files(session, include_slp=slp)
cleanup_intermediate_files(session)

return



@app.command()
def model_up(
test_folder_name: str = typer.Argument(
...,
help="Name of the test folder under config.models/<camera>/ to copy to remote_models"
)
):
"""
Recursively copy a model test folder from your local models path to the remote_models path,
prompting if it already exists.
"""
from beneuro_pose_estimation.tools import copy_model_to_remote
copy_model_to_remote(test_folder_name)

@app.command()
def eval_report(
session_name: str = typer.Argument(..., help="Session name to evaluate"),
Expand Down Expand Up @@ -175,7 +185,7 @@ def pose_test(
start_frame: Optional[int] = typer.Option(
None,
"--start-frame", "-s",
help="Frame number to start from. If not specified, uses first 100 frames."
help="Frame number to start from. If not specified, uses frame 0."
),
duration: Optional[int] = typer.Option(
10,
Expand Down Expand Up @@ -220,7 +230,7 @@ def train(
from beneuro_pose_estimation.sleap.sleapTools import train_models
cams = cameras or params.default_cameras
# train_models is the function you already wrote
train_models(cams)
train_models(cam, custom_labels = custom_labels)

# =================================== Updating ==========================================

Expand Down
13 changes: 6 additions & 7 deletions beneuro_pose_estimation/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,24 +79,23 @@ def assign_paths(self):
self.recordings_remote = self.REMOTE_PATH / "raw"
self.annotation_party = self.REMOTE_PATH / "processed" / "AnnotationParty"
self.annotations = self.annotation_party / "annotations"
# self.models = (
# self.REMOTE_PATH / "raw" / "pose-estimation" / "models" / "h1_new_setup"
# )

self.models_local = self.LOCAL_PATH / "raw" / "pose-estimation" / "models"
self.models_remote = self.annotation_party / "models"
self.custom_models = self.models_local

self.models = self.models_remote
self.skeleton_path = (
self.REPO_PATH / "beneuro_pose_estimation" / "sleap" / "skeleton.json"
)
self.recordings = self.annotation_party # self.LOCAL_PATH / "raw"
self.recordings = self.LOCAL_PATH / "raw"
# self.recordings = self.annotation_party
self.predictions2D = self.LOCAL_PATH / "raw" # change to 'processed'?
self.training = self.annotation_party / "training"
self.training_config = self.REPO_PATH / "beneuro_pose_estimation"/ "training_config.json"
self.predictions3D = self.LOCAL_PATH / "raw" # change to 'processed'
# self.predictions3D = self.LOCAL_PATH / "raw"
self.predictions3D = self.LOCAL_PATH / "raw" # change to 'processed'?
self.calibration_videos = self.REMOTE_PATH / "raw" / "calibration_videos"
self.calibration = self.LOCAL_PATH / "raw"/ "pose_estimation"/ "calibration_config"
self.calibration = self.LOCAL_PATH / "raw"/ "pose-estimation"/ "calibration_config"
self.angles_config = self.REPO_PATH / "beneuro_pose_estimation"
return

Expand Down
Loading
Loading