The ProcessingStage class is the base class for all data processing stages in NeMo Curator. Each stage defines a single step in a data curation pipeline.
Import
from nemo_curator.stages.base import ProcessingStage
Class Definition
from dataclasses import dataclass
from typing import Generic, TypeVar
InputT = TypeVar("InputT", bound=Task)
OutputT = TypeVar("OutputT", bound=Task)
@dataclass
class ProcessingStage(Generic[InputT, OutputT]):
"""Base class for all processing stages.
Type Parameters:
InputT: The input task type this stage accepts.
OutputT: The output task type this stage produces.
Class Attributes:
name: String identifier for the stage.
resources: Resources configuration (CPUs, GPUs).
batch_size: Number of tasks to process per batch.
"""
name: str = "ProcessingStage"
resources: Resources = field(default_factory=lambda: Resources(cpus=1.0))
batch_size: int = 1
Abstract Methods
inputs()
Define stage input requirements.
def inputs(self) -> tuple[list[str], list[str]]:
"""Define required task and data attributes.
Returns:
Tuple of (required_task_attributes, required_data_attributes).
"""
outputs()
Define stage output requirements.
def outputs(self) -> tuple[list[str], list[str]]:
"""Define output task and data attributes.
Returns:
Tuple of (output_task_attributes, output_data_attributes).
"""
process()
Process a single task.
def process(self, task: InputT) -> OutputT | list[OutputT] | None:
"""Process a single task.
Args:
task: The input task to process.
Returns:
- Single task: For 1-to-1 transformations
- List of tasks: For splitting/reading operations
- None: To filter out the task
"""
Optional Lifecycle Methods
setup_on_node()
Node-level initialization (e.g., download models).
def setup_on_node(
self,
node_info: NodeInfo,
worker_metadata: dict[str, Any],
) -> None:
"""Initialize resources on a compute node.
Called once per node before any workers start.
"""
setup()
Worker-level initialization (e.g., load models).
def setup(self, worker_metadata: dict[str, Any]) -> None:
"""Initialize resources for a worker.
Called once per worker before processing begins.
"""
teardown()
Cleanup after processing.
def teardown(self) -> None:
"""Clean up resources after processing completes."""
process_batch()
Vectorized batch processing for better performance.
def process_batch(self, tasks: list[InputT]) -> list[OutputT | None]:
"""Process a batch of tasks.
Override for vectorized operations.
Args:
tasks: List of input tasks.
Returns:
List of output tasks (None entries are filtered out).
"""
Creating Custom Stages
from dataclasses import dataclass
from nemo_curator.stages.base import ProcessingStage
from nemo_curator.stages.resources import Resources
from nemo_curator.tasks import DocumentBatch
@dataclass
class MyCustomStage(ProcessingStage[DocumentBatch, DocumentBatch]):
"""Custom stage that processes documents."""
name: str = "MyCustomStage"
resources: Resources = field(default_factory=lambda: Resources(cpus=2.0))
# Custom parameters
threshold: float = 0.5
def inputs(self) -> tuple[list[str], list[str]]:
return ["data"], ["text"]
def outputs(self) -> tuple[list[str], list[str]]:
return ["data"], ["text", "score"]
def process(self, task: DocumentBatch) -> DocumentBatch | None:
# Process the task
df = task.data
df["score"] = df["text"].apply(self._compute_score)
# Filter based on threshold
if df["score"].mean() < self.threshold:
return None
return DocumentBatch(
task_id=f"{task.task_id}_{self.name}",
dataset_name=task.dataset_name,
data=df,
_metadata=task._metadata,
_stage_perf=task._stage_perf,
)
def _compute_score(self, text: str) -> float:
# Custom scoring logic
return len(text) / 1000.0
Per-Stage Runtime Environments
Stages can declare isolated Python dependencies using Ray’s native runtime_env. Set runtime_env as a class variable to specify packages that should be installed in an isolated virtualenv for that stage’s workers:
from typing import Any, ClassVar
class IsolatedStage(ProcessingStage[DocumentBatch, DocumentBatch]):
name = "isolated_stage"
runtime_env: ClassVar[dict[str, Any] | None] = {"pip": ["transformers==4.40.0"]}
def inputs(self):
return ["data"], []
def outputs(self):
return ["data"], []
def process(self, task):
import transformers # sees 4.40.0
...
You can also override runtime_env at instantiation time using with_():
stage = IsolatedStage().with_(runtime_env={"pip": ["transformers==4.45.0"]})
All three execution backends (XennaExecutor, RayDataExecutor, RayActorPoolExecutor) support per-stage runtime environments. See the Per-Stage Runtime Environments reference for details.
Configuration with with_()
Stages can be configured using the with_() method:
from nemo_curator.stages.resources import Resources
stage = MyCustomStage(threshold=0.7)
configured_stage = stage.with_(resources=Resources(cpus=4.0, gpus=1.0))