DocumentBatch is the primary task type for text document processing in NeMo Curator.
Import
from nemo_curator.tasks import DocumentBatch
Class Definition
from dataclasses import dataclass
import pandas as pd
import pyarrow as pa
@dataclass
class DocumentBatch(Task[pa.Table | pd.DataFrame]):
"""Task containing a batch of text documents.
Attributes:
task_id: Unique identifier for this batch.
dataset_name: Name of the source dataset.
data: DataFrame or PyArrow Table containing documents.
"""
task_id: str
dataset_name: str
data: pa.Table | pd.DataFrame
Expected Data Schema
The data attribute typically contains:
| Column | Type | Description |
|---|---|---|
text | str | The document text content |
id | str | Optional document identifier |
url | str | Optional source URL |
| Additional columns | Various | Task-specific metadata |
Properties
num_items
Get the number of documents in the batch.
@property
def num_items(self) -> int:
"""Returns the number of documents in this batch."""
Methods
to_pyarrow()
Convert data to PyArrow table.
def to_pyarrow(self) -> pa.Table:
"""Convert data to PyArrow table."""
to_pandas()
Convert data to Pandas DataFrame.
def to_pandas(self) -> pd.DataFrame:
"""Convert data to Pandas DataFrame."""
get_columns()
Get column names from the data.
def get_columns(self) -> list[str]:
"""Get column names from the data."""
validate()
Validate the batch structure.
def validate(self) -> bool:
"""Validate that the batch has required structure.
Returns:
True if valid, False if empty or has no columns (logs warning).
"""
Creating DocumentBatch
import pandas as pd
from nemo_curator.tasks import DocumentBatch
# Create from DataFrame
df = pd.DataFrame({
"text": ["Document 1 content...", "Document 2 content..."],
"id": ["doc_001", "doc_002"],
"url": ["https://example.com/1", "https://example.com/2"],
})
batch = DocumentBatch(
task_id="batch_001",
dataset_name="my_dataset",
data=df,
)
print(f"Batch contains {batch.num_items} documents")
Usage in Stages
from dataclasses import dataclass
from nemo_curator.stages.base import ProcessingStage
from nemo_curator.tasks import DocumentBatch
@dataclass
class TextFilterStage(ProcessingStage[DocumentBatch, DocumentBatch]):
"""Filter documents based on text length."""
name: str = "TextFilter"
min_length: int = 100
def inputs(self) -> tuple[list[str], list[str]]:
return ["data"], ["text"]
def outputs(self) -> tuple[list[str], list[str]]:
return ["data"], ["text"]
def process(self, task: DocumentBatch) -> DocumentBatch | None:
df = task.data
# Filter by text length
mask = df["text"].str.len() >= self.min_length
filtered_df = df[mask]
if filtered_df.empty:
return None
return DocumentBatch(
task_id=f"{task.task_id}_filtered",
dataset_name=task.dataset_name,
data=filtered_df,
_metadata=task._metadata,
_stage_perf=task._stage_perf,
)
Common Patterns
Adding Columns
def process(self, task: DocumentBatch) -> DocumentBatch:
df = task.data.copy()
df["word_count"] = df["text"].str.split().str.len()
return DocumentBatch(
task_id=f"{task.task_id}_{self.name}",
dataset_name=task.dataset_name,
data=df,
_metadata=task._metadata,
_stage_perf=task._stage_perf,
)
Filtering Rows
def process(self, task: DocumentBatch) -> DocumentBatch | None:
df = task.data
filtered = df[df["score"] > self.threshold]
if filtered.empty:
return None # Filter out entire batch
return DocumentBatch(
task_id=f"{task.task_id}_{self.name}",
dataset_name=task.dataset_name,
data=filtered,
_metadata=task._metadata,
_stage_perf=task._stage_perf,
)