-
Notifications
You must be signed in to change notification settings - Fork 4
Description
In order to label new videos with the discovered prototypes, I am trying to do this:
Recommended Approach: Train a LISBET Classifier on Prototypes
Prepare a Labeled Dataset with Prototype Annotations:
Using the example Python snippet to patch the MY NEW DLC-labeled dataset (in directory: A1Suppression_0-100) with prototype labels (that are in directory: prototypes), I have tried to adapt the suggested python code:
import numpy as np
import pandas as pd
import xarray as xr
from lisbet.datasets import dump_records, load_records
def extract_labels(csv_path):
df = pd.read_csv(csv_path, index_col=0)
# Rows that already have at least one positive label
covered = df.eq(1).any(axis=1)
# Create / update the fallback class
df["Other"] = (~covered).astype(int)
# Keep only the first 1 in every row
first_mask = df.eq(1).cumsum(axis=1).eq(1)
# Apply the mask – everything that isn’t the first 1 becomes 0
df &= first_mask
return df.values
def patch_dataset():
records = load_records(
data_format="DLC",
data_path="A1suppression_0-100",
data_scale="None",
data_filter="train",
)["main_records"]
patched_records = []
for key, data in records:
posetracks = data["posetracks"].unstack("features")
labels = extract_labels(f"prototypes\{key}\machineAnnotation_hmmbest_6_32.csv")
assert labels.shape[0] == posetracks.sizes["time"]
# Convert to xarray Dataset
annotations = xr.Dataset(
data_vars=dict(
label=(
["time", "behaviors", "annotators"],
labels[..., np.newaxis],
)
),
coords=dict(
time=posetracks.time,
behaviors=[f"motif_{motif_id}" for motif_id in range(labels.shape[1])],
annotators=["LISBET"],
),
attrs=dict(
source_software=posetracks.source_software,
ds_type="annotations",
fps=posetracks.fps,
time_unit=posetracks.time_unit,
),
)
patched_record = (
key,
{"posetracks": posetracks, "annotations": annotations},
)
patched_records.append(patched_record)
dump_records("datasets\proto_A1suppression_0-100", patched_records)
if name == "main":
patch_dataset()
The first error that comes up is: cannot import name 'dump_records' from 'lisbet.datasets'
But I am sure more will follow - because I am not confident I have modified the Python code correctly
Once this is accomplished, I think the next steps are:
to train the classifier:
betman train_model ^
--run_id=proto_classifier ^
--data_format=DLC ^
--data_scale="1x1" ^
--data_filter=train ^
--learning_rate=1e-4 ^
--epochs=10 ^
--load_backbone_weights=models\A1supp01-embedder\weights\weights_last.pt ^
--freeze_backbone_weights ^
--save_history ^
-v ^
dataset\proto_A1supp
and, finally, annotate the new data:
betman annotate_behavior ^
--data_format=DLC ^
--data_scale="None" ^
--data_filter=test ^
-v ^
A1suppression\task1_classic_classification ^
models\proto_classifier\model_config.yml ^
models\proto_classifier\weights\weights_last.pt
There is a lot here I don't understand. Any help would be appreciated.