ImageBatch is the task type for image processing in NeMo Curator.
Import
from nemo_curator.tasks import ImageBatch
Class Definition
from dataclasses import dataclass
from nemo_curator.tasks.image import ImageObject
@dataclass
class ImageBatch(Task[list[ImageObject]]):
"""Task containing a batch of images.
Attributes:
task_id: Unique identifier for this batch.
dataset_name: Name of the source dataset.
data: List of ImageObject instances.
"""
task_id: str
dataset_name: str
data: list[ImageObject]
ImageObject
Each image in the batch is represented by an ImageObject:
@dataclass
class ImageObject:
"""Represents a single image with metadata.
Attributes:
path: Path to the image file.
caption: Optional text caption for the image.
metadata: Additional metadata dictionary.
embeddings: Optional embedding vector.
"""
path: str
caption: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
embeddings: np.ndarray | None = None
Properties
num_items
Get the number of images in the batch.
@property
def num_items(self) -> int:
"""Returns the number of images in this batch."""
Creating ImageBatch
from nemo_curator.tasks import ImageBatch
from nemo_curator.tasks.image import ImageObject
# Create image objects
images = [
ImageObject(
path="/data/images/image1.jpg",
caption="A cat sitting on a couch",
metadata={"source": "dataset_a"},
),
ImageObject(
path="/data/images/image2.jpg",
caption="A dog playing in the park",
metadata={"source": "dataset_a"},
),
]
# Create batch
batch = ImageBatch(
task_id="img_batch_001",
dataset_name="image_dataset",
data=images,
)
Usage in Stages
from dataclasses import dataclass
from nemo_curator.stages.base import ProcessingStage
from nemo_curator.tasks import ImageBatch
@dataclass
class ImageFilterStage(ProcessingStage[ImageBatch, ImageBatch]):
"""Filter images based on metadata."""
name: str = "ImageFilter"
min_resolution: int = 256
def inputs(self) -> tuple[list[str], list[str]]:
return ["data"], []
def outputs(self) -> tuple[list[str], list[str]]:
return ["data"], []
def process(self, task: ImageBatch) -> ImageBatch | None:
filtered = [
img for img in task.data
if img.metadata.get("width", 0) >= self.min_resolution
and img.metadata.get("height", 0) >= self.min_resolution
]
if not filtered:
return None
return ImageBatch(
task_id=f"{task.task_id}_filtered",
dataset_name=task.dataset_name,
data=filtered,
_metadata=task._metadata,
_stage_perf=task._stage_perf,
)
Common Operations
Adding Embeddings
def process(self, task: ImageBatch) -> ImageBatch:
for img in task.data:
img.embeddings = self.model.encode(img.path)
return ImageBatch(
task_id=f"{task.task_id}_{self.name}",
dataset_name=task.dataset_name,
data=task.data,
_metadata=task._metadata,
_stage_perf=task._stage_perf,
)
Filtering by Score
def process(self, task: ImageBatch) -> ImageBatch | None:
filtered = [
img for img in task.data
if img.metadata.get("aesthetic_score", 0) >= self.threshold
]
if not filtered:
return None
return ImageBatch(
task_id=f"{task.task_id}_{self.name}",
dataset_name=task.dataset_name,
data=filtered,
_metadata=task._metadata,
_stage_perf=task._stage_perf,
)