Training and inference on an example dataset¶
In this notebook we'll install sleap-nn, download a sample dataset, run training and inference on that dataset using the sleap-nn, and then download the predictions.
Note: If you only need to perform training and inference (and not use the full SLEAP GUI or labeling tools), you don't need to install the entire sleap package. You can just install sleap-nn, which is a lighter-weight package focused on model training and inference.
Install sleap-nn¶
!pip install -qqq "sleap-nn[torch-cpu]"
# if you have GPU (in colab, enable GPU runtime)
# !pip install -qqq "sleap-nn[torch-cuda-128]"
zsh:1: command not found: pip
Download sample training data into Colab¶
Let's download a sample dataset from the SLEAP sample datasets repository into Colab.
!apt-get install tree
!wget -O dataset.zip https://github.com/talmolab/sleap-datasets/releases/download/dm-courtship-v1/drosophila-melanogaster-courtship.zip
!mkdir dataset
!unzip dataset.zip -d dataset
!rm dataset.zip
!tree dataset
zsh:1: command not found: apt-get --2025-09-23 22:49:59-- https://github.com/talmolab/sleap-datasets/releases/download/dm-courtship-v1/drosophila-melanogaster-courtship.zip Resolving github.com (github.com)... 140.82.116.4 Connecting to github.com (github.com)|140.82.116.4|:443... connected. HTTP request sent, awaiting response... 302 Found Location: https://release-assets.githubusercontent.com/github-production-release-asset/263375180/16df8d00-94f1-11ea-98d1-6c03a2f89e1c?sp=r&sv=2018-11-09&sr=b&spr=https&se=2025-09-24T06%3A32%3A23Z&rscd=attachment%3B+filename%3Ddrosophila-melanogaster-courtship.zip&rsct=application%2Foctet-stream&skoid=96c2d410-5711-43a1-aedd-ab1947aa7ab0&sktid=398a6654-997b-47e9-b12b-9515b896b4de&skt=2025-09-24T05%3A32%3A15Z&ske=2025-09-24T06%3A32%3A23Z&sks=b&skv=2018-11-09&sig=Ij3ERUXEA5fEIAbcmukQghtno0Fl4j0%2BrI9epJ%2FH4Jw%3D&jwt=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmVsZWFzZS1hc3NldHMuZ2l0aHVidXNlcmNvbnRlbnQuY29tIiwia2V5Ijoia2V5MSIsImV4cCI6MTc1ODY5MzI5OSwibmJmIjoxNzU4NjkyOTk5LCJwYXRoIjoicmVsZWFzZWFzc2V0cHJvZHVjdGlvbi5ibG9iLmNvcmUud2luZG93cy5uZXQifQ.Z7KygNVvJl9-7hx2zZxvJ2Nk11iVAFbZHyRtvwuRNOA&response-content-disposition=attachment%3B%20filename%3Ddrosophila-melanogaster-courtship.zip&response-content-type=application%2Foctet-stream [following] --2025-09-23 22:49:59-- https://release-assets.githubusercontent.com/github-production-release-asset/263375180/16df8d00-94f1-11ea-98d1-6c03a2f89e1c?sp=r&sv=2018-11-09&sr=b&spr=https&se=2025-09-24T06%3A32%3A23Z&rscd=attachment%3B+filename%3Ddrosophila-melanogaster-courtship.zip&rsct=application%2Foctet-stream&skoid=96c2d410-5711-43a1-aedd-ab1947aa7ab0&sktid=398a6654-997b-47e9-b12b-9515b896b4de&skt=2025-09-24T05%3A32%3A15Z&ske=2025-09-24T06%3A32%3A23Z&sks=b&skv=2018-11-09&sig=Ij3ERUXEA5fEIAbcmukQghtno0Fl4j0%2BrI9epJ%2FH4Jw%3D&jwt=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmVsZWFzZS1hc3NldHMuZ2l0aHVidXNlcmNvbnRlbnQuY29tIiwia2V5Ijoia2V5MSIsImV4cCI6MTc1ODY5MzI5OSwibmJmIjoxNzU4NjkyOTk5LCJwYXRoIjoicmVsZWFzZWFzc2V0cHJvZHVjdGlvbi5ibG9iLmNvcmUud2luZG93cy5uZXQifQ.Z7KygNVvJl9-7hx2zZxvJ2Nk11iVAFbZHyRtvwuRNOA&response-content-disposition=attachment%3B%20filename%3Ddrosophila-melanogaster-courtship.zip&response-content-type=application%2Foctet-stream Resolving release-assets.githubusercontent.com (release-assets.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ... Connecting to release-assets.githubusercontent.com (release-assets.githubusercontent.com)|185.199.109.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 111973079 (107M) [application/octet-stream] Saving to: ‘dataset.zip’ dataset.zip 100%[===================>] 106.79M 39.0MB/s in 2.7s 2025-09-23 22:50:02 (39.0 MB/s) - ‘dataset.zip’ saved [111973079/111973079] Archive: dataset.zip creating: dataset/drosophila-melanogaster-courtship inflating: dataset/drosophila-melanogaster-courtship/.DS_Store creating: dataset/__MACOSX creating: dataset/__MACOSX/drosophila-melanogaster-courtship inflating: dataset/__MACOSX/drosophila-melanogaster-courtship/._.DS_Store inflating: dataset/drosophila-melanogaster-courtship/20190128_113421.mp4 inflating: dataset/__MACOSX/drosophila-melanogaster-courtship/._20190128_113421.mp4 inflating: dataset/drosophila-melanogaster-courtship/courtship_labels.slp inflating: dataset/__MACOSX/drosophila-melanogaster-courtship/._courtship_labels.slp inflating: dataset/drosophila-melanogaster-courtship/example.jpg inflating: dataset/__MACOSX/drosophila-melanogaster-courtship/._example.jpg zsh:1: command not found: tree
Train models¶
For the top-down pipeline, we'll need train two models: a centroid model and a centered-instance model.
We'll first train a model for centroids using the default training profile. The training profile determines the model architecture, the learning rate, and other parameters.
When you start training, you'll first see the training parameters and then the training and validation loss for each training epoch.
As soon as you're satisfied with the validation loss you see for an epoch during training, you're welcome to stop training by clicking the stop button. The version of the model with the lowest validation loss is saved during training, and that's what will be used for inference.
If you don't stop training, it will run for 200 epochs or until validation loss fails to improve for some number of epochs (controlled by the early_stopping fields in the training profile).
# Let's get the default config files
!wget -O baseline.centroid.yaml https://raw.githubusercontent.com/talmolab/sleap-nn/refs/heads/main/docs/sample_configs/config_centroid_unet.yaml
!wget -O baseline.centered_instance.yaml https://raw.githubusercontent.com/talmolab/sleap-nn/refs/heads/main/docs/sample_configs/config_topdown_centered_instance_unet_medium_rf.yaml
--2025-09-23 22:50:13-- https://raw.githubusercontent.com/talmolab/sleap-nn/refs/heads/main/docs/sample_configs/config_centroid_unet.yaml Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8003::154, 2606:50c0:8000::154, 2606:50c0:8001::154, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8003::154|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 3286 (3.2K) [text/plain] Saving to: ‘baseline.centroid.yaml’ baseline.centroid.y 100%[===================>] 3.21K --.-KB/s in 0s 2025-09-23 22:50:13 (33.0 MB/s) - ‘baseline.centroid.yaml’ saved [3286/3286] --2025-09-23 22:50:14-- https://raw.githubusercontent.com/talmolab/sleap-nn/refs/heads/main/docs/sample_configs/config_topdown_centered_instance_unet.yaml Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8003::154, 2606:50c0:8000::154, 2606:50c0:8001::154, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8003::154|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 3114 (3.0K) [text/plain] Saving to: ‘baseline.centered_instance.yaml’ baseline.centered_i 100%[===================>] 3.04K --.-KB/s in 0s 2025-09-23 22:50:14 (32.3 MB/s) - ‘baseline.centered_instance.yaml’ saved [3114/3114]
!sleap-nn train --config-name baseline.centroid --config-dir . "data_config.train_labels_path=[dataset/drosophila-melanogaster-courtship/courtship_labels.slp]" trainer_config.ckpt_dir="models" trainer_config.run_name="courtship.centroid" trainer_config.max_epochs=200
2025-09-23 22:51:31 | INFO | sleap_nn.cli:train:98 | Input config:
2025-09-23 22:51:31 | INFO | sleap_nn.cli:train:99 |
data_config:
train_labels_path:
- dataset/drosophila-melanogaster-courtship/courtship_labels.slp
val_labels_path: null
validation_fraction: 0.1
test_file_path: null
provider: LabelsReader
user_instances_only: true
data_pipeline_fw: torch_dataset
cache_img_path: null
use_existing_imgs: false
delete_cache_imgs_after_training: true
preprocessing:
ensure_rgb: false
ensure_grayscale: false
max_height: null
max_width: null
scale: 0.5
crop_size: null
min_crop_size: 100
use_augmentations_train: true
augmentation_config:
intensity:
uniform_noise_min: 0.0
uniform_noise_max: 1.0
uniform_noise_p: 0.0
gaussian_noise_mean: 5.0
gaussian_noise_std: 1.0
gaussian_noise_p: 0.0
contrast_min: 0.9
contrast_max: 1.1
contrast_p: 0.0
brightness_min: 0.9
brightness_max: 1.1
brightness_p: 0.0
geometric:
rotation_min: -15.0
rotation_max: 15.0
scale_min: 1.0
scale_max: 1.0
translate_width: 0.0
translate_height: 0.0
affine_p: 1.0
erase_scale_min: 0.0001
erase_scale_max: 0.01
erase_ratio_min: 1.0
erase_ratio_max: 1.0
erase_p: 0.0
mixup_lambda_min: 0.01
mixup_lambda_max: 0.05
mixup_p: 0.0
skeletons: null
model_config:
init_weights: xavier
pretrained_backbone_weights: null
pretrained_head_weights: null
backbone_config:
unet:
in_channels: 1
kernel_size: 3
filters: 16
filters_rate: 2.0
max_stride: 16
stem_stride: null
middle_block: true
up_interpolate: true
stacks: 1
convs_per_block: 2
output_stride: 2
convnext: null
swint: null
head_configs:
single_instance: null
centroid:
confmaps:
sigma: 1.5
output_stride: 2
anchor_part: null
centered_instance: null
bottomup: null
total_params: null
trainer_config:
train_data_loader:
batch_size: 4
shuffle: true
num_workers: 0
val_data_loader:
batch_size: 4
shuffle: false
num_workers: 0
model_ckpt:
save_top_k: 1
save_last: false
trainer_devices: null
trainer_device_indices: null
trainer_accelerator: auto
profiler: null
trainer_strategy: auto
enable_progress_bar: true
min_train_steps_per_epoch: 200
train_steps_per_epoch: null
visualize_preds_during_training: true
keep_viz: false
max_epochs: 200
seed: null
use_wandb: false
save_ckpt: true
ckpt_dir: models
run_name: courtship.centroid
resume_ckpt_path: null
wandb:
entity: null
project: project_name
name: run_name
save_viz_imgs_wandb: false
wandb_mode: ''
api_key: ''
prv_runid: null
group: null
optimizer_name: Adam
optimizer:
lr: 0.0001
amsgrad: false
lr_scheduler:
step_lr: null
reduce_lr_on_plateau:
threshold: 1.0e-06
threshold_mode: rel
cooldown: 3
patience: 5
factor: 0.5
min_lr: 1.0e-08
early_stopping:
min_delta: 1.0e-08
patience: 10
stop_training_on_plateau: true
online_hard_keypoint_mining:
online_mining: false
hard_to_easy_ratio: 2.0
min_hard_keypoints: 2
max_hard_keypoints: null
loss_scale: 5.0
zmq:
publish_port: null
controller_port: null
controller_polling_timeout: 10
name: ''
description: ''
sleap_nn_version: 0.0.1
filename: ''
2025-09-23 22:51:31 | INFO | sleap_nn.train:run_training:27 | Started training at: 2025-09-23 22:51:31.617057
2025-09-23 22:51:31 | INFO | sleap_nn.training.model_trainer:_setup_train_val_labels:216 | Creating train-val split...
2025-09-23 22:51:31 | INFO | sleap_nn.training.model_trainer:_setup_train_val_labels:261 | # Train Labeled frames: 134
2025-09-23 22:51:31 | INFO | sleap_nn.training.model_trainer:_setup_train_val_labels:262 | # Val Labeled frames: 14
2025-09-23 22:51:31 | INFO | sleap_nn.training.model_trainer:setup_config:512 | Setting up config...
2025-09-23 22:51:32 | INFO | sleap_nn.training.model_trainer:_verify_model_input_channels:417 | Updating backbone in_channels to 3 based on the input image channels.
2025-09-23 22:51:32 | INFO | sleap_nn.training.model_trainer:train:849 | Setting up for training...
2025-09-23 22:51:32 | INFO | sleap_nn.training.model_trainer:_setup_model_ckpt_dir:575 | Setting up model ckpt dir: `models/courtship.centroid`...
2025-09-23 22:51:32 | INFO | sleap_nn.training.model_trainer:train:864 | Setting up visualization train and val datasets...
2025-09-23 22:51:32 | INFO | sleap_nn.training.model_trainer:train:868 | Setting up Trainer...
2025-09-23 22:51:32 | INFO | sleap_nn.training.model_trainer:_setup_loggers_callbacks:647 | Setting up callbacks and loggers...
2025-09-23 22:51:32 | INFO | sleap_nn.training.model_trainer:train:897 | Trainer devices: auto
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
2025-09-23 22:51:32 | INFO | sleap_nn.training.model_trainer:train:950 | Training on 1 device(s)
2025-09-23 22:51:32 | INFO | sleap_nn.training.model_trainer:train:951 | Training on mps:0 accelerator
2025-09-23 22:51:32 | INFO | sleap_nn.training.model_trainer:train:955 | Setting up lightning module for centroid model...
2025-09-23 22:51:32 | INFO | sleap_nn.training.model_trainer:train:959 | Backbone model: UNet(
(encoders): ModuleList(
(0): Encoder(
(encoder_stack): ModuleList(
(0): SimpleConvBlock(
(blocks): Sequential(
(stack0_enc0_conv0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_enc0_act0_relu): ReLU()
(stack0_enc0_conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_enc0_act1_relu): ReLU()
)
)
(1): SimpleConvBlock(
(blocks): Sequential(
(stack0_enc1_pool): MaxPool2dWithSamePadding(kernel_size=2, stride=2, padding=same, dilation=1, ceil_mode=False)
(stack0_enc1_conv0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_enc1_act0_relu): ReLU()
(stack0_enc1_conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_enc1_act1_relu): ReLU()
)
)
(2): SimpleConvBlock(
(blocks): Sequential(
(stack0_enc2_pool): MaxPool2dWithSamePadding(kernel_size=2, stride=2, padding=same, dilation=1, ceil_mode=False)
(stack0_enc2_conv0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_enc2_act0_relu): ReLU()
(stack0_enc2_conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_enc2_act1_relu): ReLU()
)
)
(3): SimpleConvBlock(
(blocks): Sequential(
(stack0_enc3_pool): MaxPool2dWithSamePadding(kernel_size=2, stride=2, padding=same, dilation=1, ceil_mode=False)
(stack0_enc3_conv0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_enc3_act0_relu): ReLU()
(stack0_enc3_conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_enc3_act1_relu): ReLU()
)
)
(4): Sequential(
(stack0_enc4_last_pool): MaxPool2dWithSamePadding(kernel_size=2, stride=2, padding=same, dilation=1, ceil_mode=False)
)
)
)
)
(decoders): ModuleList(
(0): Decoder(
(decoder_stack): ModuleList(
(0): SimpleUpsamplingBlock(
(blocks): Sequential(
(stack0_dec0_s16_to_s8_interp_bilinear): Upsample(scale_factor=2.0, mode='bilinear')
(stack0_dec0_s16_to_s8_refine_conv0): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_dec0_s16_to_s8_refine_conv0_act_relu): ReLU()
(stack0_dec0_s16_to_s8_refine_conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_dec0_s16_to_s8_refine_conv1_act_relu): ReLU()
)
)
(1): SimpleUpsamplingBlock(
(blocks): Sequential(
(stack0_dec1_s8_to_s4_interp_bilinear): Upsample(scale_factor=2.0, mode='bilinear')
(stack0_dec1_s8_to_s4_refine_conv0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_dec1_s8_to_s4_refine_conv0_act_relu): ReLU()
(stack0_dec1_s8_to_s4_refine_conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_dec1_s8_to_s4_refine_conv1_act_relu): ReLU()
)
)
(2): SimpleUpsamplingBlock(
(blocks): Sequential(
(stack0_dec2_s4_to_s2_interp_bilinear): Upsample(scale_factor=2.0, mode='bilinear')
(stack0_dec2_s4_to_s2_refine_conv0): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_dec2_s4_to_s2_refine_conv0_act_relu): ReLU()
(stack0_dec2_s4_to_s2_refine_conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_dec2_s4_to_s2_refine_conv1_act_relu): ReLU()
)
)
)
)
)
(middle_blocks): ModuleList(
(0): SimpleConvBlock(
(blocks): Sequential(
(stack0_enc5_middle_expand_conv0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_enc5_middle_expand_act0_relu): ReLU()
)
)
(1): SimpleConvBlock(
(blocks): Sequential(
(stack0_enc6_middle_contract_conv0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_enc6_middle_contract_act0_relu): ReLU()
)
)
)
)
2025-09-23 22:51:32 | INFO | sleap_nn.training.model_trainer:train:960 | Head model: ModuleList(
(0): Sequential(
(CentroidConfmapsHead): Sequential(
(0): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1), padding=same)
(1): Identity()
)
)
)
2025-09-23 22:51:32 | INFO | sleap_nn.training.model_trainer:train:962 | Total model parameters: 1953393
2025-09-23 22:51:32 | INFO | sleap_nn.training.model_trainer:train:967 | Input image shape: torch.Size([1, 3, 512, 512])
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/torch/utils/data/dataloader.py:684: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, then device pinned memory won't be used.
warnings.warn(warn_msg)
2025-09-23 22:51:32 | INFO | sleap_nn.training.model_trainer:train:1021 | Finished trainer set up. [0.3s]
2025-09-23 22:51:32 | INFO | sleap_nn.training.model_trainer:train:1024 | Starting training loop...
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:751: Checkpoint directory /Users/divyasesh/Desktop/talmolab/sleap-core-docs/docs/notebooks/models/courtship.centroid exists and is not empty.
| Name | Type | Params | Mode
------------------------------------------------------------
0 | model | Model | 2.0 M | train
1 | centroid_inf_layer | CentroidCrop | 0 | train
------------------------------------------------------------
2.0 M Trainable params
0 Non-trainable params
2.0 M Total params
7.814 Total estimated model params size (MB)
73 Modules in train mode
0 Modules in eval mode
Sanity Checking: | | 0/? [00:00<?, ?it/s]/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
Sanity Checking DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('learning_rate', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('val_loss', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
Sanity Checking DataLoader 0: 100%|███████████████| 2/2 [00:00<00:00, 2.16it/s]/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('val_time', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
Epoch 0: 0%| | 0/200 [00:00<?, ?it/s]/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('train_loss', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
Epoch 0: 100%|███████| 200/200 [01:48<00:00, 1.84it/s, train_loss_step=8.05e-5]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 19.03it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 8.28it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:00<00:00, 4.03it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:01<00:00, 3.48it/s]
Epoch 0: 100%|█| 200/200 [01:51<00:00, 1.79it/s, train_loss_step=8.05e-5, learn/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('train_time', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
Epoch 1: 100%|█| 200/200 [01:52<00:00, 1.78it/s, train_loss_step=4.32e-5, learn
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 19.13it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 8.52it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:00<00:00, 4.86it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:01<00:00, 3.92it/s]
Epoch 2: 100%|█| 200/200 [02:00<00:00, 1.66it/s, train_loss_step=5.64e-5, learn
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 15.93it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 6.18it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:01<00:00, 2.81it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:01<00:00, 2.55it/s]
Epoch 3: 100%|█| 200/200 [02:14<00:00, 1.49it/s, train_loss_step=5.27e-5, learn
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 17.49it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 6.56it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:00<00:00, 3.59it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:01<00:00, 2.95it/s]
Epoch 4: 100%|█| 200/200 [02:24<00:00, 1.39it/s, train_loss_step=3.65e-5, learn
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 18.54it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 7.41it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:00<00:00, 3.45it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:01<00:00, 3.10it/s]
Epoch 5: 100%|█| 200/200 [02:39<00:00, 1.25it/s, train_loss_step=3.99e-5, learn
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:01, 2.85it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 2.58it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:01<00:00, 2.16it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:02<00:00, 1.99it/s]
Epoch 6: 100%|█| 200/200 [02:17<00:00, 1.46it/s, train_loss_step=1.21e-5, learn
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 17.55it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 6.96it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:00<00:00, 3.90it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:01<00:00, 3.25it/s]
Epoch 7: 100%|█| 200/200 [02:40<00:00, 1.25it/s, train_loss_step=1.24e-5, learn
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 13.14it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 4.79it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:01<00:00, 2.22it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:02<00:00, 1.66it/s]
Epoch 8: 100%|█| 200/200 [02:44<00:00, 1.22it/s, train_loss_step=1.4e-5, learni
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 12.82it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 5.41it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:00<00:00, 3.38it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:01<00:00, 2.84it/s]
Epoch 9: 100%|█| 200/200 [02:45<00:00, 1.21it/s, train_loss_step=1.51e-5, learn
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 14.43it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 6.28it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:00<00:00, 3.66it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:01<00:00, 3.13it/s]
Epoch 10: 100%|█| 200/200 [02:31<00:00, 1.32it/s, train_loss_step=7.77e-6, lear
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 14.05it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 5.19it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:00<00:00, 3.29it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:01<00:00, 2.83it/s]
Epoch 11: 100%|█| 200/200 [02:21<00:00, 1.41it/s, train_loss_step=8.04e-6, lear
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 18.16it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 6.91it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:00<00:00, 3.63it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:01<00:00, 2.85it/s]
Epoch 12: 100%|█| 200/200 [02:52<00:00, 1.16it/s, train_loss_step=6.16e-6, lear
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 14.08it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 4.81it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:00<00:00, 3.09it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:01<00:00, 2.65it/s]
Epoch 13: 100%|█| 200/200 [02:57<00:00, 1.13it/s, train_loss_step=5.98e-6, lear
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 12.53it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 4.64it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:01<00:00, 2.51it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:01<00:00, 2.22it/s]
Epoch 14: 100%|█| 200/200 [03:15<00:00, 1.02it/s, train_loss_step=7.85e-6, lear
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 4.86it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 3.18it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:01<00:00, 2.39it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:01<00:00, 2.09it/s]
Epoch 15: 100%|█| 200/200 [03:29<00:00, 0.96it/s, train_loss_step=3.18e-6, lear
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 14.51it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 5.21it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:00<00:00, 3.10it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:01<00:00, 2.56it/s]
Epoch 16: 100%|█| 200/200 [03:00<00:00, 1.11it/s, train_loss_step=2.3e-6, learn
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 11.24it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 4.70it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:01<00:00, 2.80it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:01<00:00, 2.35it/s]
Epoch 17: 100%|█| 200/200 [03:12<00:00, 1.04it/s, train_loss_step=3.18e-6, lear
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 9.95it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 2.95it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:01<00:00, 2.26it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:02<00:00, 1.95it/s]
Epoch 18: 100%|█| 200/200 [03:32<00:00, 0.94it/s, train_loss_step=2.43e-6, lear
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 7.00it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 2.31it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:01<00:00, 1.52it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:03<00:00, 1.18it/s]
Epoch 19: 100%|█| 200/200 [03:18<00:00, 1.01it/s, train_loss_step=2.23e-6, lear
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 14.64it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 5.20it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:00<00:00, 3.16it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:01<00:00, 2.63it/s]
Epoch 20: 100%|█| 200/200 [03:36<00:00, 0.92it/s, train_loss_step=2.11e-6, lear
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 4.77it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 2.97it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:01<00:00, 1.83it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:02<00:00, 1.39it/s]
Epoch 21: 100%|█| 200/200 [04:07<00:00, 0.81it/s, train_loss_step=3.97e-6, lear
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/4 [00:00<?, ?it/s]
Validation DataLoader 0: 25%|█████ | 1/4 [00:00<00:00, 14.72it/s]
Validation DataLoader 0: 50%|██████████ | 2/4 [00:00<00:00, 5.08it/s]
Validation DataLoader 0: 75%|███████████████ | 3/4 [00:01<00:00, 2.50it/s]
Validation DataLoader 0: 100%|████████████████████| 4/4 [00:01<00:00, 2.18it/s]
Epoch 22: 1%| | 2/200 [00:01<03:15, 1.01it/s, train_loss_step=1.23e-6, learni^C
2025-09-23 23:54:24 | INFO | sleap_nn.training.model_trainer:train:1037 | Finished training loop. [62.9 min]
2025-09-23 23:54:24 | INFO | sleap_nn.training.model_trainer:train:1064 | Deleting viz folder at models/courtship.centroid/viz...
Let's now train a centered-instance model.
!sleap-nn train --config-name baseline.centered_instance --config-dir . "data_config.train_labels_path=[dataset/drosophila-melanogaster-courtship/courtship_labels.slp]" trainer_config.ckpt_dir="models" trainer_config.run_name="courtship.topdown_confmaps" trainer_config.max_epochs=200
2025-09-23 23:54:35 | INFO | sleap_nn.cli:train:98 | Input config:
2025-09-23 23:54:35 | INFO | sleap_nn.cli:train:99 |
data_config:
train_labels_path:
- dataset/drosophila-melanogaster-courtship/courtship_labels.slp
val_labels_path: null
validation_fraction: 0.1
test_file_path: null
provider: LabelsReader
user_instances_only: true
data_pipeline_fw: torch_dataset
cache_img_path: null
use_existing_imgs: false
delete_cache_imgs_after_training: true
preprocessing:
ensure_rgb: false
ensure_grayscale: false
max_height: null
max_width: null
scale: 1.0
crop_size: null
min_crop_size: 100
use_augmentations_train: true
augmentation_config:
intensity:
uniform_noise_min: 0.0
uniform_noise_max: 1.0
uniform_noise_p: 0.0
gaussian_noise_mean: 5.0
gaussian_noise_std: 1.0
gaussian_noise_p: 0.0
contrast_min: 0.9
contrast_max: 1.1
contrast_p: 0.0
brightness_min: 0.9
brightness_max: 1.1
brightness_p: 0.0
geometric:
rotation_min: -15.0
rotation_max: 15.0
scale_min: 1.0
scale_max: 1.0
translate_width: 0.0
translate_height: 0.0
affine_p: 1.0
erase_scale_min: 0.0001
erase_scale_max: 0.01
erase_ratio_min: 1.0
erase_ratio_max: 1.0
erase_p: 0.0
mixup_lambda_min: 0.01
mixup_lambda_max: 0.05
mixup_p: 0.0
skeletons: null
model_config:
init_weights: xavier
pretrained_backbone_weights: null
pretrained_head_weights: null
backbone_config:
unet:
in_channels: 1
kernel_size: 3
filters: 16
filters_rate: 1.5
max_stride: 8
stem_stride: null
middle_block: true
up_interpolate: true
stacks: 1
convs_per_block: 2
output_stride: 2
convnext: null
swint: null
head_configs:
single_instance: null
centroid: null
centered_instance:
confmaps:
sigma: 1.5
output_stride: 2
anchor_part: null
part_names: null
bottomup: null
total_params: null
trainer_config:
train_data_loader:
batch_size: 4
shuffle: true
num_workers: 0
val_data_loader:
batch_size: 4
shuffle: false
num_workers: 0
model_ckpt:
save_top_k: 1
save_last: false
trainer_devices: null
trainer_device_indices: null
trainer_accelerator: auto
profiler: null
trainer_strategy: auto
enable_progress_bar: true
min_train_steps_per_epoch: 200
train_steps_per_epoch: null
visualize_preds_during_training: true
keep_viz: false
max_epochs: 200
seed: null
use_wandb: false
save_ckpt: true
ckpt_dir: models
run_name: courtship.topdown_confmaps
resume_ckpt_path: null
optimizer_name: Adam
optimizer:
lr: 0.0001
amsgrad: false
lr_scheduler:
step_lr: null
reduce_lr_on_plateau:
threshold: 1.0e-06
threshold_mode: rel
cooldown: 3
patience: 5
factor: 0.5
min_lr: 1.0e-08
early_stopping:
min_delta: 1.0e-08
patience: 10
stop_training_on_plateau: true
online_hard_keypoint_mining:
online_mining: false
hard_to_easy_ratio: 2.0
min_hard_keypoints: 2
max_hard_keypoints: null
loss_scale: 5.0
zmq:
publish_port: null
controller_port: null
controller_polling_timeout: 10
name: ''
description: ''
sleap_nn_version: 0.0.1
filename: ''
2025-09-23 23:54:35 | INFO | sleap_nn.train:run_training:27 | Started training at: 2025-09-23 23:54:35.542013
2025-09-23 23:54:35 | INFO | sleap_nn.training.model_trainer:_setup_train_val_labels:216 | Creating train-val split...
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:_setup_train_val_labels:261 | # Train Labeled frames: 134
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:_setup_train_val_labels:262 | # Val Labeled frames: 14
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:setup_config:512 | Setting up config...
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:_verify_model_input_channels:417 | Updating backbone in_channels to 3 based on the input image channels.
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:train:849 | Setting up for training...
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:_setup_model_ckpt_dir:575 | Setting up model ckpt dir: `models/courtship.topdown_confmaps`...
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:train:864 | Setting up visualization train and val datasets...
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:train:868 | Setting up Trainer...
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:_setup_loggers_callbacks:647 | Setting up callbacks and loggers...
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:train:897 | Trainer devices: auto
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:train:950 | Training on 1 device(s)
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:train:951 | Training on mps:0 accelerator
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:train:955 | Setting up lightning module for centered_instance model...
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:train:959 | Backbone model: UNet(
(encoders): ModuleList(
(0): Encoder(
(encoder_stack): ModuleList(
(0): SimpleConvBlock(
(blocks): Sequential(
(stack0_enc0_conv0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_enc0_act0_relu): ReLU()
(stack0_enc0_conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_enc0_act1_relu): ReLU()
)
)
(1): SimpleConvBlock(
(blocks): Sequential(
(stack0_enc1_pool): MaxPool2dWithSamePadding(kernel_size=2, stride=2, padding=same, dilation=1, ceil_mode=False)
(stack0_enc1_conv0): Conv2d(16, 24, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_enc1_act0_relu): ReLU()
(stack0_enc1_conv1): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_enc1_act1_relu): ReLU()
)
)
(2): SimpleConvBlock(
(blocks): Sequential(
(stack0_enc2_pool): MaxPool2dWithSamePadding(kernel_size=2, stride=2, padding=same, dilation=1, ceil_mode=False)
(stack0_enc2_conv0): Conv2d(24, 36, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_enc2_act0_relu): ReLU()
(stack0_enc2_conv1): Conv2d(36, 36, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_enc2_act1_relu): ReLU()
)
)
(3): Sequential(
(stack0_enc3_last_pool): MaxPool2dWithSamePadding(kernel_size=2, stride=2, padding=same, dilation=1, ceil_mode=False)
)
)
)
)
(decoders): ModuleList(
(0): Decoder(
(decoder_stack): ModuleList(
(0): SimpleUpsamplingBlock(
(blocks): Sequential(
(stack0_dec0_s8_to_s4_interp_bilinear): Upsample(scale_factor=2.0, mode='bilinear')
(stack0_dec0_s8_to_s4_refine_conv0): Conv2d(90, 36, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_dec0_s8_to_s4_refine_conv0_act_relu): ReLU()
(stack0_dec0_s8_to_s4_refine_conv1): Conv2d(36, 36, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_dec0_s8_to_s4_refine_conv1_act_relu): ReLU()
)
)
(1): SimpleUpsamplingBlock(
(blocks): Sequential(
(stack0_dec1_s4_to_s2_interp_bilinear): Upsample(scale_factor=2.0, mode='bilinear')
(stack0_dec1_s4_to_s2_refine_conv0): Conv2d(60, 24, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_dec1_s4_to_s2_refine_conv0_act_relu): ReLU()
(stack0_dec1_s4_to_s2_refine_conv1): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_dec1_s4_to_s2_refine_conv1_act_relu): ReLU()
)
)
)
)
)
(middle_blocks): ModuleList(
(0): SimpleConvBlock(
(blocks): Sequential(
(stack0_enc4_middle_expand_conv0): Conv2d(36, 54, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_enc4_middle_expand_act0_relu): ReLU()
)
)
(1): SimpleConvBlock(
(blocks): Sequential(
(stack0_enc5_middle_contract_conv0): Conv2d(54, 54, kernel_size=(3, 3), stride=(1, 1), padding=same)
(stack0_enc5_middle_contract_act0_relu): ReLU()
)
)
)
)
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:train:960 | Head model: ModuleList(
(0): Sequential(
(CenteredInstanceConfmapsHead): Sequential(
(0): Conv2d(24, 13, kernel_size=(1, 1), stride=(1, 1), padding=same)
(1): Identity()
)
)
)
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:train:962 | Total model parameters: 134229
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:train:967 | Input image shape: torch.Size([1, 3, 128, 128])
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/torch/utils/data/dataloader.py:684: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, then device pinned memory won't be used.
warnings.warn(warn_msg)
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:train:1021 | Finished trainer set up. [0.2s]
2025-09-23 23:54:36 | INFO | sleap_nn.training.model_trainer:train:1024 | Starting training loop...
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:751: Checkpoint directory /Users/divyasesh/Desktop/talmolab/sleap-core-docs/docs/notebooks/models/courtship.topdown_confmaps exists and is not empty.
| Name | Type | Params | Mode
-----------------------------------------------------------------------
0 | model | Model | 134 K | train
1 | instance_peaks_inf_layer | FindInstancePeaks | 0 | train
-----------------------------------------------------------------------
134 K Trainable params
0 Non-trainable params
134 K Total params
0.537 Total estimated model params size (MB)
59 Modules in train mode
0 Modules in eval mode
Sanity Checking: | | 0/? [00:00<?, ?it/s]/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
Sanity Checking DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('learning_rate', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('val_loss', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
Sanity Checking DataLoader 0: 100%|███████████████| 2/2 [00:01<00:00, 1.63it/s]/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('val_time', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
Epoch 0: 0%| | 0/200 [00:00<?, ?it/s]/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('head', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('thorax', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('abdomen', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('wingL', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('wingR', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('forelegL4', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('forelegR4', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('midlegL4', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('midlegR4', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('hindlegL4', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('hindlegR4', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('eyeL', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('eyeR', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('train_loss', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
Epoch 0: 100%|█| 200/200 [00:59<00:00, 3.38it/s, head_step=0.00135, thorax_step
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/7 [00:00<?, ?it/s]
Validation DataLoader 0: 14%|██▋ | 1/7 [00:00<00:00, 106.43it/s]
Validation DataLoader 0: 29%|█████▋ | 2/7 [00:00<00:00, 14.77it/s]
Validation DataLoader 0: 43%|████████▌ | 3/7 [00:00<00:00, 11.54it/s]
Validation DataLoader 0: 57%|███████████▍ | 4/7 [00:00<00:00, 9.82it/s]
Validation DataLoader 0: 71%|██████████████▎ | 5/7 [00:00<00:00, 9.21it/s]
Validation DataLoader 0: 86%|█████████████████▏ | 6/7 [00:00<00:00, 7.34it/s]
Validation DataLoader 0: 100%|████████████████████| 7/7 [00:00<00:00, 7.45it/s]
Epoch 0: 100%|█| 200/200 [01:01<00:00, 3.23it/s, head_step=0.00135, thorax_step/Users/divyasesh/Desktop/talmolab/sleap-core-docs/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:520: You called `self.log('train_time', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
Epoch 1: 100%|█| 200/200 [01:01<00:00, 3.24it/s, head_step=0.00111, thorax_step
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/7 [00:00<?, ?it/s]
Validation DataLoader 0: 14%|██▋ | 1/7 [00:00<00:00, 157.25it/s]
Validation DataLoader 0: 29%|█████▋ | 2/7 [00:00<00:00, 11.20it/s]
Validation DataLoader 0: 43%|████████▌ | 3/7 [00:00<00:00, 9.41it/s]
Validation DataLoader 0: 57%|███████████▍ | 4/7 [00:00<00:00, 8.66it/s]
Validation DataLoader 0: 71%|██████████████▎ | 5/7 [00:00<00:00, 8.20it/s]
Validation DataLoader 0: 86%|█████████████████▏ | 6/7 [00:00<00:00, 7.92it/s]
Validation DataLoader 0: 100%|████████████████████| 7/7 [00:00<00:00, 7.84it/s]
Epoch 2: 100%|█| 200/200 [00:56<00:00, 3.51it/s, head_step=0.000842, thorax_ste
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/7 [00:00<?, ?it/s]
Validation DataLoader 0: 14%|██▋ | 1/7 [00:00<00:00, 182.59it/s]
Validation DataLoader 0: 29%|█████▋ | 2/7 [00:00<00:00, 14.30it/s]
Validation DataLoader 0: 43%|████████▌ | 3/7 [00:00<00:00, 10.90it/s]
Validation DataLoader 0: 57%|███████████▍ | 4/7 [00:00<00:00, 9.31it/s]
Validation DataLoader 0: 71%|██████████████▎ | 5/7 [00:00<00:00, 8.85it/s]
Validation DataLoader 0: 86%|█████████████████▏ | 6/7 [00:00<00:00, 8.64it/s]
Validation DataLoader 0: 100%|████████████████████| 7/7 [00:00<00:00, 7.28it/s]
Epoch 3: 100%|█| 200/200 [00:52<00:00, 3.78it/s, head_step=0.000714, thorax_ste
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/7 [00:00<?, ?it/s]
Validation DataLoader 0: 14%|██▋ | 1/7 [00:00<00:00, 147.27it/s]
Validation DataLoader 0: 29%|█████▋ | 2/7 [00:00<00:00, 15.94it/s]
Validation DataLoader 0: 43%|████████▌ | 3/7 [00:00<00:00, 12.41it/s]
Validation DataLoader 0: 57%|███████████▍ | 4/7 [00:00<00:00, 11.09it/s]
Validation DataLoader 0: 71%|██████████████▎ | 5/7 [00:00<00:00, 10.46it/s]
Validation DataLoader 0: 86%|█████████████████▏ | 6/7 [00:00<00:00, 10.09it/s]
Validation DataLoader 0: 100%|████████████████████| 7/7 [00:00<00:00, 9.84it/s]
Epoch 4: 100%|█| 200/200 [00:54<00:00, 3.66it/s, head_step=0.000639, thorax_ste
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/7 [00:00<?, ?it/s]
Validation DataLoader 0: 14%|██▋ | 1/7 [00:00<00:00, 144.95it/s]
Validation DataLoader 0: 29%|█████▋ | 2/7 [00:00<00:00, 15.66it/s]
Validation DataLoader 0: 43%|████████▌ | 3/7 [00:00<00:00, 11.83it/s]
Validation DataLoader 0: 57%|███████████▍ | 4/7 [00:00<00:00, 10.18it/s]
Validation DataLoader 0: 71%|██████████████▎ | 5/7 [00:00<00:00, 9.67it/s]
Validation DataLoader 0: 86%|█████████████████▏ | 6/7 [00:00<00:00, 9.33it/s]
Validation DataLoader 0: 100%|████████████████████| 7/7 [00:00<00:00, 7.71it/s]
Epoch 5: 100%|█| 200/200 [00:59<00:00, 3.35it/s, head_step=0.000662, thorax_ste
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/7 [00:00<?, ?it/s]
Validation DataLoader 0: 14%|██▋ | 1/7 [00:00<00:00, 171.70it/s]
Validation DataLoader 0: 29%|█████▋ | 2/7 [00:00<00:00, 6.83it/s]
Validation DataLoader 0: 43%|████████▌ | 3/7 [00:00<00:00, 7.05it/s]
Validation DataLoader 0: 57%|███████████▍ | 4/7 [00:00<00:00, 7.06it/s]
Validation DataLoader 0: 71%|██████████████▎ | 5/7 [00:00<00:00, 6.90it/s]
Validation DataLoader 0: 86%|█████████████████▏ | 6/7 [00:00<00:00, 6.99it/s]
Validation DataLoader 0: 100%|████████████████████| 7/7 [00:00<00:00, 7.05it/s]
Epoch 6: 100%|█| 200/200 [01:02<00:00, 3.21it/s, head_step=0.000567, thorax_ste
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/7 [00:00<?, ?it/s]
Validation DataLoader 0: 14%|██▋ | 1/7 [00:00<00:00, 107.61it/s]
Validation DataLoader 0: 29%|█████▋ | 2/7 [00:00<00:00, 15.53it/s]
Validation DataLoader 0: 43%|████████▌ | 3/7 [00:00<00:00, 12.08it/s]
Validation DataLoader 0: 57%|███████████▍ | 4/7 [00:00<00:00, 10.98it/s]
Validation DataLoader 0: 71%|██████████████▎ | 5/7 [00:00<00:00, 10.32it/s]
Validation DataLoader 0: 86%|█████████████████▏ | 6/7 [00:00<00:00, 7.89it/s]
Validation DataLoader 0: 100%|████████████████████| 7/7 [00:00<00:00, 7.98it/s]
Epoch 7: 100%|█| 200/200 [00:58<00:00, 3.43it/s, head_step=0.000665, thorax_ste
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/7 [00:00<?, ?it/s]
Validation DataLoader 0: 14%|██▋ | 1/7 [00:00<00:00, 158.74it/s]
Validation DataLoader 0: 29%|█████▋ | 2/7 [00:00<00:00, 16.64it/s]
Validation DataLoader 0: 43%|████████▌ | 3/7 [00:00<00:00, 12.65it/s]
Validation DataLoader 0: 57%|███████████▍ | 4/7 [00:00<00:00, 11.25it/s]
Validation DataLoader 0: 71%|██████████████▎ | 5/7 [00:00<00:00, 10.64it/s]
Validation DataLoader 0: 86%|█████████████████▏ | 6/7 [00:00<00:00, 10.18it/s]
Validation DataLoader 0: 100%|████████████████████| 7/7 [00:00<00:00, 9.93it/s]
Epoch 8: 100%|█| 200/200 [00:59<00:00, 3.35it/s, head_step=0.000428, thorax_ste
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/7 [00:00<?, ?it/s]
Validation DataLoader 0: 14%|██▋ | 1/7 [00:00<00:00, 148.59it/s]
Validation DataLoader 0: 29%|█████▋ | 2/7 [00:00<00:00, 13.72it/s]
Validation DataLoader 0: 43%|████████▌ | 3/7 [00:00<00:00, 10.08it/s]
Validation DataLoader 0: 57%|███████████▍ | 4/7 [00:00<00:00, 8.64it/s]
Validation DataLoader 0: 71%|██████████████▎ | 5/7 [00:01<00:00, 4.16it/s]
Validation DataLoader 0: 86%|█████████████████▏ | 6/7 [00:01<00:00, 4.45it/s]
Validation DataLoader 0: 100%|████████████████████| 7/7 [00:01<00:00, 4.73it/s]
Epoch 9: 100%|█| 200/200 [01:06<00:00, 3.01it/s, head_step=0.00031, thorax_step
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/7 [00:00<?, ?it/s]
Validation DataLoader 0: 14%|██▋ | 1/7 [00:00<00:00, 133.32it/s]
Validation DataLoader 0: 29%|█████▋ | 2/7 [00:00<00:00, 13.18it/s]
Validation DataLoader 0: 43%|████████▌ | 3/7 [00:00<00:00, 10.45it/s]
Validation DataLoader 0: 57%|███████████▍ | 4/7 [00:00<00:00, 9.36it/s]
Validation DataLoader 0: 71%|██████████████▎ | 5/7 [00:00<00:00, 8.84it/s]
Validation DataLoader 0: 86%|█████████████████▏ | 6/7 [00:00<00:00, 8.51it/s]
Validation DataLoader 0: 100%|████████████████████| 7/7 [00:00<00:00, 8.32it/s]
Epoch 10: 100%|█| 200/200 [00:55<00:00, 3.59it/s, head_step=0.000533, thorax_st
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/7 [00:00<?, ?it/s]
Validation DataLoader 0: 14%|██▋ | 1/7 [00:00<00:00, 114.06it/s]
Validation DataLoader 0: 29%|█████▋ | 2/7 [00:00<00:00, 14.91it/s]
Validation DataLoader 0: 43%|████████▌ | 3/7 [00:00<00:00, 12.01it/s]
Validation DataLoader 0: 57%|███████████▍ | 4/7 [00:00<00:00, 10.68it/s]
Validation DataLoader 0: 71%|██████████████▎ | 5/7 [00:00<00:00, 9.66it/s]
Validation DataLoader 0: 86%|█████████████████▏ | 6/7 [00:00<00:00, 9.43it/s]
Validation DataLoader 0: 100%|████████████████████| 7/7 [00:00<00:00, 9.24it/s]
Epoch 11: 100%|█| 200/200 [00:57<00:00, 3.50it/s, head_step=0.000352, thorax_st
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/7 [00:00<?, ?it/s]
Validation DataLoader 0: 14%|██▋ | 1/7 [00:00<00:00, 148.13it/s]
Validation DataLoader 0: 29%|█████▋ | 2/7 [00:00<00:00, 14.29it/s]
Validation DataLoader 0: 43%|████████▌ | 3/7 [00:00<00:00, 11.61it/s]
Validation DataLoader 0: 57%|███████████▍ | 4/7 [00:00<00:00, 10.41it/s]
Validation DataLoader 0: 71%|██████████████▎ | 5/7 [00:00<00:00, 9.79it/s]
Validation DataLoader 0: 86%|█████████████████▏ | 6/7 [00:00<00:00, 9.37it/s]
Validation DataLoader 0: 100%|████████████████████| 7/7 [00:00<00:00, 9.17it/s]
Epoch 12: 100%|█| 200/200 [01:05<00:00, 3.05it/s, head_step=0.00033, thorax_ste
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/7 [00:00<?, ?it/s]
Validation DataLoader 0: 14%|██▊ | 1/7 [00:00<00:00, 98.69it/s]
Validation DataLoader 0: 29%|█████▋ | 2/7 [00:00<00:00, 10.41it/s]
Validation DataLoader 0: 43%|████████▌ | 3/7 [00:00<00:00, 8.83it/s]
Validation DataLoader 0: 57%|███████████▍ | 4/7 [00:00<00:00, 7.80it/s]
Validation DataLoader 0: 71%|██████████████▎ | 5/7 [00:00<00:00, 7.62it/s]
Validation DataLoader 0: 86%|█████████████████▏ | 6/7 [00:00<00:00, 7.45it/s]
Validation DataLoader 0: 100%|████████████████████| 7/7 [00:00<00:00, 7.42it/s]
Epoch 13: 100%|█| 200/200 [01:10<00:00, 2.84it/s, head_step=0.000241, thorax_st
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/7 [00:00<?, ?it/s]
Validation DataLoader 0: 14%|██▊ | 1/7 [00:00<00:00, 79.46it/s]
Validation DataLoader 0: 29%|█████▋ | 2/7 [00:00<00:00, 11.84it/s]
Validation DataLoader 0: 43%|████████▌ | 3/7 [00:00<00:00, 9.09it/s]
Validation DataLoader 0: 57%|███████████▍ | 4/7 [00:00<00:00, 8.09it/s]
Validation DataLoader 0: 71%|██████████████▎ | 5/7 [00:00<00:00, 7.38it/s]
Validation DataLoader 0: 86%|█████████████████▏ | 6/7 [00:00<00:00, 7.24it/s]
Validation DataLoader 0: 100%|████████████████████| 7/7 [00:00<00:00, 7.08it/s]
Epoch 14: 100%|█| 200/200 [01:04<00:00, 3.09it/s, head_step=0.000332, thorax_st
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/7 [00:00<?, ?it/s]
Validation DataLoader 0: 14%|██▋ | 1/7 [00:00<00:00, 114.03it/s]
Validation DataLoader 0: 29%|█████▋ | 2/7 [00:00<00:00, 14.10it/s]
Validation DataLoader 0: 43%|████████▌ | 3/7 [00:00<00:00, 10.68it/s]
Validation DataLoader 0: 57%|███████████▍ | 4/7 [00:00<00:00, 9.78it/s]
Validation DataLoader 0: 71%|██████████████▎ | 5/7 [00:00<00:00, 9.38it/s]
Validation DataLoader 0: 86%|█████████████████▏ | 6/7 [00:00<00:00, 9.10it/s]
Validation DataLoader 0: 100%|████████████████████| 7/7 [00:00<00:00, 8.87it/s]
Epoch 15: 100%|█| 200/200 [01:04<00:00, 3.11it/s, head_step=0.000249, thorax_st
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/7 [00:00<?, ?it/s]
Validation DataLoader 0: 14%|██▋ | 1/7 [00:00<00:00, 125.96it/s]
Validation DataLoader 0: 29%|█████▋ | 2/7 [00:00<00:00, 13.11it/s]
Validation DataLoader 0: 43%|████████▌ | 3/7 [00:00<00:00, 9.92it/s]
Validation DataLoader 0: 57%|███████████▍ | 4/7 [00:00<00:00, 8.76it/s]
Validation DataLoader 0: 71%|██████████████▎ | 5/7 [00:00<00:00, 8.16it/s]
Validation DataLoader 0: 86%|█████████████████▏ | 6/7 [00:00<00:00, 7.40it/s]
Validation DataLoader 0: 100%|████████████████████| 7/7 [00:00<00:00, 7.29it/s]
Epoch 16: 6%| | 11/200 [00:03<01:03, 2.97it/s, head_step=0.000618, thorax_ste^C
2025-09-24 00:11:14 | INFO | sleap_nn.training.model_trainer:train:1037 | Finished training loop. [16.6 min]
2025-09-24 00:11:14 | INFO | sleap_nn.training.model_trainer:train:1064 | Deleting viz folder at models/courtship.topdown_confmaps/viz...
The models (along with the profiles and ground truth data used to train and validate the model) are saved in the models/ directory:
!tree models/
models/ ├── courtship.centroid │ ├── best.ckpt │ ├── initial_config.yaml │ ├── labels_train_gt_0.slp │ ├── labels_val_gt_0.slp │ ├── training_config.yaml │ └── training_log.csv └── courtship.topdown_confmaps ├── best.ckpt ├── initial_config.yaml ├── labels_train_gt_0.slp ├── labels_val_gt_0.slp ├── training_config.yaml └── training_log.csv 3 directories, 12 files
Inference¶
Let's run inference with our trained models for centroids and centered instances.
!sleap-nn track -i "dataset/drosophila-melanogaster-courtship/20190128_113421.mp4" --frames 0-100 -m "models/courtship.centroid" -m "models/courtship.topdown_confmaps"
2025-09-24 00:12:47 | INFO | sleap_nn.predict:run_inference:319 | Started inference at: 2025-09-24 00:12:47.871203 2025-09-24 00:12:47 | INFO | sleap_nn.predict:run_inference:330 | Integral refinement is not supported with MPS device. Using CPU. 2025-09-24 00:12:47 | INFO | sleap_nn.predict:run_inference:335 | Using device: cpu Predicting... ━━━━━━━━━━━━━━━ 100% 101/101 ETA: 0:00:00 Elapsed: 0:00:17 6.0 FPS0 FPS7 FPS 2025-09-24 00:13:05 | INFO | sleap_nn.predict:run_inference:453 | Finished inference at: 2025-09-24 00:13:05.531684 2025-09-24 00:13:05 | INFO | sleap_nn.predict:run_inference:454 | Total runtime: 17.660489082336426 secs 2025-09-24 00:13:05 | INFO | sleap_nn.predict:run_inference:465 | Predictions output path: dataset/drosophila-melanogaster-courtship/20190128_113421.predictions.slp 2025-09-24 00:13:05 | INFO | sleap_nn.predict:run_inference:466 | Saved file at: 2025-09-24 00:13:05.594857
When inference is finished, predictions are saved in a file. Since we didn't specify a path, it will be saved as <video filename>.predictions.slp in the same directory as the video:
!tree dataset/drosophila-melanogaster-courtship
dataset/drosophila-melanogaster-courtship ├── 20190128_113421.mp4 ├── 20190128_113421.mp4.predictions.slp ├── 20190128_113421.predictions.slp ├── courtship_labels.slp └── example.jpg 1 directory, 5 files
If you're using Chrome you can download your trained models like so:
# Zip up the models directory
!zip -r trained_models.zip models/
# Download.
from google.colab import files
files.download("/content/trained_models.zip")
adding: models/ (stored 0%) adding: models/courtship.topdown_confmaps/ (stored 0%) adding: models/courtship.topdown_confmaps/labels_pr.val.slp (deflated 74%) adding: models/courtship.topdown_confmaps/metrics.val.npz (deflated 0%) adding: models/courtship.topdown_confmaps/labels_pr.train.slp (deflated 67%) adding: models/courtship.topdown_confmaps/labels_gt.val.slp (deflated 72%) adding: models/courtship.topdown_confmaps/initial_config.json (deflated 73%) adding: models/courtship.topdown_confmaps/training_log.csv (deflated 55%) adding: models/courtship.topdown_confmaps/metrics.train.npz (deflated 0%) adding: models/courtship.topdown_confmaps/labels_gt.train.slp (deflated 61%) adding: models/courtship.topdown_confmaps/best_model.h5 (deflated 8%) adding: models/courtship.topdown_confmaps/training_config.json (deflated 88%) adding: models/courtship.centroid/ (stored 0%) adding: models/courtship.centroid/labels_pr.val.slp (deflated 82%) adding: models/courtship.centroid/metrics.val.npz (deflated 1%) adding: models/courtship.centroid/labels_pr.train.slp (deflated 79%) adding: models/courtship.centroid/labels_gt.val.slp (deflated 73%) adding: models/courtship.centroid/initial_config.json (deflated 74%) adding: models/courtship.centroid/training_log.csv (deflated 57%) adding: models/courtship.centroid/metrics.train.npz (deflated 0%) adding: models/courtship.centroid/labels_gt.train.slp (deflated 61%) adding: models/courtship.centroid/best_model.h5 (deflated 7%) adding: models/courtship.centroid/training_config.json (deflated 88%)
And you can likewise download your predictions:
from google.colab import files
files.download('dataset/drosophila-melanogaster-courtship/20190128_113421.mp4.predictions.slp')
In some other browsers (Safari) you might get an error and you can instead download using the "Files" tab in the side panel (it has a folder icon). Select "Show table of contents" in the "View" menu if you don't see the side panel.