AudioTask is the task type for audio processing in NeMo Curator. Each AudioTask holds a single manifest entry as a dict, matching the convention used by VideoTask and FileGroupTask.
Import
from nemo_curator.tasks import AudioTask
Class Definition
from dataclasses import dataclass
@dataclass
class AudioTask(Task[dict]):
"""Task containing a single audio manifest entry for processing.
Attributes:
task_id: Unique identifier for this task.
dataset_name: Name of the source dataset.
data: Audio manifest entry (single dict, stored as _AttrDict).
"""
task_id: str
dataset_name: str
data: dict # _AttrDict subclass — supports attribute-style access
Audio Manifest Format
Audio data follows the NeMo manifest format:
{
"audio_filepath": "/path/to/audio.wav",
"duration": 5.2,
"text": "Transcription text...",
"speaker": "speaker_001",
"metadata": {
"sample_rate": 16000,
"channels": 1
}
}
Properties
num_items
Always returns 1 — each AudioTask holds exactly one manifest entry.
@property
def num_items(self) -> int:
"""Returns 1."""
Creating AudioTask
from nemo_curator.tasks import AudioTask
# Single manifest entry
task = AudioTask(
task_id="audio_001",
dataset_name="speech_dataset",
data={
"audio_filepath": "/data/audio/sample.wav",
"duration": 5.2,
"text": "Hello world",
},
)
# Access fields via attribute or dict style
task.data["audio_filepath"] # "/data/audio/sample.wav"
task.data.audio_filepath # "/data/audio/sample.wav"
Usage in Stages
All audio stages subclass ProcessingStage[AudioTask, AudioTask] directly — there is no intermediate base class.
CPU Stage (per-task processing)
from dataclasses import dataclass
from nemo_curator.stages.base import ProcessingStage
from nemo_curator.tasks import AudioTask
@dataclass
class DurationFilterStage(ProcessingStage[AudioTask, AudioTask]):
"""Filter audio by duration."""
name: str = "DurationFilter"
min_duration: float = 1.0
max_duration: float = 30.0
def inputs(self) -> tuple[list[str], list[str]]:
return ["data"], []
def outputs(self) -> tuple[list[str], list[str]]:
return ["data"], []
def process(self, task: AudioTask) -> AudioTask | None:
duration = task.data.get("duration", 0)
if self.min_duration <= duration <= self.max_duration:
return task
return None
Batch Stage (GPU/IO processing)
@dataclass
class MyGpuStage(ProcessingStage[AudioTask, AudioTask]):
"""GPU stage using process_batch."""
name: str = "MyGpuStage"
def process(self, task: AudioTask) -> AudioTask:
raise NotImplementedError("Use process_batch for GPU stages")
def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]:
# Batched GPU inference
paths = [t.data["audio_filepath"] for t in tasks]
results = self.model.infer(paths)
for task, result in zip(tasks, results):
task.data["pred_text"] = result
return tasks
Common Operations
ASR Transcription
def process(self, task: AudioTask) -> AudioTask:
audio_path = task.data["audio_filepath"]
task.data["pred_text"] = self.asr_model.transcribe(audio_path)
return task
Quality Scoring
def process(self, task: AudioTask) -> AudioTask:
if "text" in task.data and "pred_text" in task.data:
task.data["wer"] = compute_wer(task.data["text"], task.data["pred_text"])
return task