Skip to content

chore: add audio m5 speechcommands#3092

Closed
jaiakash wants to merge 1 commit intokubeflow:masterfrom
jaiakash:add-audio-examples
Closed

chore: add audio m5 speechcommands#3092
jaiakash wants to merge 1 commit intokubeflow:masterfrom
jaiakash:add-audio-examples

Conversation

@jaiakash
Copy link
Member

@jaiakash jaiakash commented Jan 13, 2026

What this PR does / why we need it:
This PR add audio classification example to trainer repo. It trains an audio classification model using the M5 Network architecture on the Google Speech Commands dataset with PyTorch and Kubeflow Trainer.

On below system specs, it took around 3 mins to run, we can also include it in our E2E test coverage.

VM.GPU.A10.1(GPU: 1xA10) 15OCPU 24GB Memory 240GB Storage

Which issue(s) this PR fixes
Related #2040

Checklist:

  • Docs included if any changes are user facing

Signed-off-by: Akash Jaiswal <akashjaiswal3846@gmail.com>
Copilot AI review requested due to automatic review settings January 13, 2026 04:50
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@google-oss-prow
Copy link

[APPROVALNOTIFIER] This PR is NOT APPROVED

This pull-request has been approved by:
Once this PR has been reviewed and has the lgtm label, please assign tenzen-y for approval. For more information see the Kubernetes Code Review Process.

The full list of commands accepted by this bot can be found here.

Details Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

@jaiakash jaiakash changed the title add(examples): audio m5 speechcommands chore): add audio m5 speechcommands Jan 13, 2026
@jaiakash jaiakash changed the title chore): add audio m5 speechcommands chore: add audio m5 speechcommands Jan 13, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an audio classification example demonstrating how to train an M5 Network model on the Google Speech Commands dataset using PyTorch and Kubeflow Trainer. The example shows distributed training capabilities and scaling from local execution to Kubernetes clusters.

Changes:

  • Added a comprehensive Jupyter notebook example for audio classification with M5 Network architecture
  • Implemented distributed training setup with PyTorch DDP on Speech Commands dataset
  • Integrated Kubeflow Trainer SDK for job submission, monitoring, and management

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

" batch_size=batch_size,\n",
" sampler=sampler,\n",
" collate_fn=collate_fn,\n",
" num_workers=1 if device_type == \"cuda\" else 0\n",
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The num_workers is set to 1 for cuda and 0 for cpu on line 227. However, using num_workers=1 with distributed training can cause issues with data loading performance and may not fully utilize multi-core CPUs. Consider using a higher value like num_workers=4 or making it configurable as a parameter for better performance, especially when using multiple workers per node.

Suggested change
" num_workers=1 if device_type == \"cuda\" else 0\n",
" num_workers=4 if device_type == \"cuda\" else 2\n",

Copilot uses AI. Check for mistakes.
"\n",
" # Initialize model with DDP\n",
" model = M5(n_input=waveform.shape[0], n_output=len(labels)).to(device)\n",
" model = nn.parallel.DistributedDataParallel(model)\n",
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model is wrapped with DistributedDataParallel without specifying the device_ids parameter. While this works, for GPU training it's recommended to explicitly pass device_ids=[local_rank] and output_device=local_rank to ensure proper device placement and avoid potential issues with multi-GPU setups.

Suggested change
" model = nn.parallel.DistributedDataParallel(model)\n",
" if device.type == \"cuda\":\n",
" # Determine local_rank for this process and bind model to the correct GPU\n",
" if torch.distributed.is_initialized() and torch.cuda.is_available() and torch.cuda.device_count() > 0:\n",
" local_rank = torch.distributed.get_rank() % torch.cuda.device_count()\n",
" else:\n",
" local_rank = torch.cuda.current_device()\n",
" model = nn.parallel.DistributedDataParallel(\n",
" model,\n",
" device_ids=[local_rank],\n",
" output_device=local_rank,\n",
" )\n",
" else:\n",
" model = nn.parallel.DistributedDataParallel(model)\n",

Copilot uses AI. Check for mistakes.
Comment on lines +249 to +250
" output = model(data)\n",
" loss = F.nll_loss(output.squeeze(), target)\n",
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output shape from the model forward pass appears inconsistent with the loss calculation. The forward method returns F.log_softmax(self.classifier(x), dim=2) where x has been permuted to shape (batch_size, 1, 2*n_channel). This means the output has shape (batch_size, 1, n_output). The loss function then calls output.squeeze() to get shape (batch_size, n_output), but this adds unnecessary complexity. The dim parameter in log_softmax should be dim=-1 or the architecture should be simplified to avoid the extra dimension.

Suggested change
" output = model(data)\n",
" loss = F.nll_loss(output.squeeze(), target)\n",
" output = model(data).squeeze(1)\n",
" loss = F.nll_loss(output, target)\n",

Copilot uses AI. Check for mistakes.
Comment on lines +193 to +194
" device_type, backend = (\n",
" (\"cuda\", \"nccl\") if torch.cuda.is_available() else (\"cpu\", \"gloo\")\n",
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output logs show "Using Device: cpu, Backend: gloo" (lines 440-469) even though the job configuration requests a GPU with "nvidia.com/gpu": 1 on line 361. This suggests the GPU is not being properly detected or utilized. Verify that the Docker image includes proper CUDA support and that torch.cuda.is_available() returns True in the GPU environment, or update the example to clarify CPU-only behavior.

Copilot uses AI. Check for mistakes.
Comment on lines +177 to +178
" x = F.avg_pool1d(x, x.shape[-1]).permute(0, 2, 1)\n",
" return F.log_softmax(self.classifier(x), dim=2)\n",
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The forward method returns log probabilities with shape (batch_size, 1, n_output) after permute, but the training loop applies nll_loss after squeeze on line 250. This architecture design means x.shape[-1] is always 1 after avg_pool1d since the pooled dimension becomes 1. The permute(0, 2, 1) operation and squeeze() work, but this architecture seems unnecessarily complex. Consider simplifying to return shape (batch_size, n_output) directly by using x = F.avg_pool1d(x, x.shape[-1]).squeeze(-1) followed by self.classifier(x), and removing the permute and the need for squeeze in the loss calculation.

Suggested change
" x = F.avg_pool1d(x, x.shape[-1]).permute(0, 2, 1)\n",
" return F.log_softmax(self.classifier(x), dim=2)\n",
" x = F.avg_pool1d(x, x.shape[-1]).squeeze(-1)\n",
" return F.log_softmax(self.classifier(x), dim=1)\n",

Copilot uses AI. Check for mistakes.
Comment on lines +181 to +190
" def collate_fn(batch):\n",
" tensors, targets = [], []\n",
" for waveform, _, label, *_ in batch:\n",
" tensors += [waveform.t()]\n",
" targets += [torch.tensor(labels.index(label))]\n",
" # Pad to same length\n",
" tensors = torch.nn.utils.rnn.pad_sequence(\n",
" tensors, batch_first=True, padding_value=0.\n",
" ).permute(0, 2, 1)\n",
" return tensors, torch.stack(targets)\n",
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The collate function is defined inside train_m5_speechcommands but it needs access to the 'labels' variable which is only created later on line 213. This will cause the DataLoader to fail when it tries to use collate_fn. The function definition order should be reconsidered - either move the collate_fn definition after labels is created, or restructure to pass labels as a closure variable or use a class-based approach.

Copilot uses AI. Check for mistakes.
Comment on lines +199 to +203
" dist.init_process_group(backend=backend)\n",
" print(\n",
" f\"Distributed Training - WORLD_SIZE: {dist.get_world_size()}, \"\n",
" f\"RANK: {dist.get_rank()}, LOCAL_RANK: {local_rank}\"\n",
" )\n",
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logs show WORLD_SIZE of 30 processes being spawned (lines 470-499), which seems excessive for a single node configuration. The job is configured with num_nodes=1 on line 355, but torch distributed appears to be creating 30 worker processes. This could be intentional for data parallelism, but it's not clearly documented and may lead to resource contention. Consider adding documentation about why 30 processes are used or making this configurable.

Copilot uses AI. Check for mistakes.
@coveralls
Copy link

Pull Request Test Coverage Report for Build 20945115994

Details

  • 0 of 0 changed or added relevant lines in 0 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage remained the same at 51.435%

Totals Coverage Status
Change from base Build 20862290829: 0.0%
Covered Lines: 1237
Relevant Lines: 2405

💛 - Coveralls

@jaiakash
Copy link
Member Author

Closing this as #3063 already has raised PR for audio example.

@jaiakash jaiakash closed this Jan 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants