NVIDIANeMo Curator
Menu

DocumentBatch is the primary task type for text document processing in NeMo Curator.

Import

from nemo_curator.tasks import DocumentBatch

Class Definition

from dataclasses import dataclass
import pandas as pd
import pyarrow as pa

@dataclass
class DocumentBatch(Task[pa.Table | pd.DataFrame]):
    """Task containing a batch of text documents.

    Attributes:
        task_id: Unique identifier for this batch.
        dataset_name: Name of the source dataset.
        data: DataFrame or PyArrow Table containing documents.
    """

    task_id: str
    dataset_name: str
    data: pa.Table | pd.DataFrame

Expected Data Schema

The data attribute typically contains:

ColumnTypeDescription
textstrThe document text content
idstrOptional document identifier
urlstrOptional source URL
Additional columnsVariousTask-specific metadata

Properties

num_items

Get the number of documents in the batch.

@property
def num_items(self) -> int:
    """Returns the number of documents in this batch."""

Methods

to_pyarrow()

Convert data to PyArrow table.

def to_pyarrow(self) -> pa.Table:
    """Convert data to PyArrow table."""

to_pandas()

Convert data to Pandas DataFrame.

def to_pandas(self) -> pd.DataFrame:
    """Convert data to Pandas DataFrame."""

get_columns()

Get column names from the data.

def get_columns(self) -> list[str]:
    """Get column names from the data."""

validate()

Validate the batch structure.

def validate(self) -> bool:
    """Validate that the batch has required structure.

    Returns:
        True if valid, False if empty or has no columns (logs warning).
    """

Creating DocumentBatch

import pandas as pd
from nemo_curator.tasks import DocumentBatch

# Create from DataFrame
df = pd.DataFrame({
    "text": ["Document 1 content...", "Document 2 content..."],
    "id": ["doc_001", "doc_002"],
    "url": ["https://example.com/1", "https://example.com/2"],
})

batch = DocumentBatch(
    task_id="batch_001",
    dataset_name="my_dataset",
    data=df,
)

print(f"Batch contains {batch.num_items} documents")

Usage in Stages

from dataclasses import dataclass
from nemo_curator.stages.base import ProcessingStage
from nemo_curator.tasks import DocumentBatch

@dataclass
class TextFilterStage(ProcessingStage[DocumentBatch, DocumentBatch]):
    """Filter documents based on text length."""

    name: str = "TextFilter"
    min_length: int = 100

    def inputs(self) -> tuple[list[str], list[str]]:
        return ["data"], ["text"]

    def outputs(self) -> tuple[list[str], list[str]]:
        return ["data"], ["text"]

    def process(self, task: DocumentBatch) -> DocumentBatch | None:
        df = task.data

        # Filter by text length
        mask = df["text"].str.len() >= self.min_length
        filtered_df = df[mask]

        if filtered_df.empty:
            return None

        return DocumentBatch(
            task_id=f"{task.task_id}_filtered",
            dataset_name=task.dataset_name,
            data=filtered_df,
            _metadata=task._metadata,
            _stage_perf=task._stage_perf,
        )

Common Patterns

Adding Columns

def process(self, task: DocumentBatch) -> DocumentBatch:
    df = task.data.copy()
    df["word_count"] = df["text"].str.split().str.len()

    return DocumentBatch(
        task_id=f"{task.task_id}_{self.name}",
        dataset_name=task.dataset_name,
        data=df,
        _metadata=task._metadata,
        _stage_perf=task._stage_perf,
    )

Filtering Rows

def process(self, task: DocumentBatch) -> DocumentBatch | None:
    df = task.data
    filtered = df[df["score"] > self.threshold]

    if filtered.empty:
        return None  # Filter out entire batch

    return DocumentBatch(
        task_id=f"{task.task_id}_{self.name}",
        dataset_name=task.dataset_name,
        data=filtered,
        _metadata=task._metadata,
        _stage_perf=task._stage_perf,
    )

Source Code

View source on GitHub