How to use a Torchvision dataset in the Grain pipeline

How to use a Torchvision dataset in the Grain pipeline#

Open in Colab

In this tutorial we’re going to learn how to process and extend Torchvision datasets using Grain on the example of Fashion-MNIST dataset.

Setup#

Here we require grain, torch, and torchvision dependencies installed but the last will be used primarily for dataset access. Additionally we want matplotlib in our environment for displaying samples.

%pip install grain
%pip install numpy matplotlib torch torchvision
import grain
import matplotlib.pyplot as plt
import numpy as np
from PIL.Image import Image
import torch
# PyTorch imports
from torchvision import datasets
from torchvision.transforms import Lambda, ToTensor

rng = np.random.default_rng(0)

Loading the dataset with Torchvision#

The Torchvision FashionMNIST function provides access to our dataset. There are 60k samples in the dataset, where each one of them is a Pillow image instance together with a label.

fashion_mnist = datasets.FashionMNIST(root="data", train=True, download=True)
print(fashion_mnist)
print(fashion_mnist[0])
fashion_mnist[0][0]
Dataset FashionMNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train
(<PIL.Image.Image image mode=L size=28x28 at 0x72BC7848F1F0>, 9)
100%|██████████| 26.4M/26.4M [00:00<00:00, 60.1MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 1.14MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 20.9MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 13.6MB/s]
../../_images/8f30ae58298b0dd6fd0df440fbbd69bd5ed0630d4b9d2364a3ac624c2751f29d.png

In this example, we have an ankle boot (label 9). First we should acknowledge that both sample and label require dedicated preprocessing:

  • The sample should be a PyTorch tensor;

  • The label needs to use one-hot encoding, so 9 becomes [0,0,0,0,0,0,0,0,0,1].

Let’s inspect a few more samples first.

nrows, ncols = 4, 8
fig, axs = plt.subplots(nrows, ncols)
for x in range(ncols):
  for y in range(nrows):
    axs[y, x].imshow(fashion_mnist.data[rng.integers(len(fashion_mnist))])
    axs[y, x].set_axis_off()
../../_images/edf9a4e05f668ac0592b03e17d4cbe9abf29afdc8bc6662c36407df891f42516.png

torchvision’s datasets function, which imports the dataset, already offers transform and target_transform arguments. Let’s confirm that ToTensor and Lambda do what we anticipate.

fashion_mnist_2 = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(
        lambda y: torch.zeros(10, dtype=torch.float).scatter_(
            0, torch.tensor(y), value=1
        )
    ),
)
(fashion_mnist_2[0][0].size(), fashion_mnist_2[0][1])
(torch.Size([1, 28, 28]), tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]))

The resulting contents are a PyTorch tensor and a one-hot encoded label, as expected.

Processing the dataset with Grain#

Now let’s move to Grain! Our goal is perform the same transformation on the dataset as before, but this time using Grain API. You might have noticed in the Dataset basics tutorial that defining a data source requires implementing a class that inherits from grain.sources.RandomAccessDataSource. With PyTorch datasets we don’t need it - the dataset is already compliant with the protocol.

Next let’s move to the preprocessing stage. We implement a custom class which inherits from grain.transforms.Map and implements a map method. In this method we instantiate a PyTorch tensor and devise a one-hot encoded label.

To make this rewrite more compelling, let’s use Grain’s capabilities beyond simple samples traversal. Assuming that we are interested in shoes only (labels: 5 - sandals, 7 - sneakers, and 9 - ankle boots) we implement a Filter class with a filter that allows us to discard unwanted samples.

class ToTensorAndOneHot(grain.transforms.Map):
  to_tensor = ToTensor()

  def map(
      self, element: tuple[Image, int]
  ) -> tuple[torch.Tensor, torch.Tensor]:
    data = self.to_tensor(element[0])
    target = torch.zeros(10, dtype=torch.float).scatter_(
        0, torch.tensor(element[1]), value=1
    )
    return (data, target)


class KeepShoesOnly(grain.transforms.Filter):
  shoes_only = {5, 7, 9}  # Sandal, Sneaker, Ankle boot

  def filter(self, element: tuple[Image, int]) -> bool:
    return element[1] in self.shoes_only

Complete Pipeline#

Now we combine all pieces into a single pipeline. Grain exposes an API that allows us to rely on a chain of method calls. This results in a straightforward flow from the source to the final batching stage.

dataset = (
    grain.MapDataset.source(fashion_mnist)
    .shuffle(seed=42)
    .filter(KeepShoesOnly())  # leave only shoes
    .map(ToTensorAndOneHot())  # construct tensors and label one-hot encoding
    .to_iter_dataset()
    .batch(  # batches consecutive elements
        batch_size=5,
        batch_fn=lambda ts: tuple(torch.stack(t) for t in zip(*ts)),
    )
)

The result is exactly what we expected - both samples and labels are now PyTorch tensors, and they’re batched in groups of five.

iterator = iter(dataset)
batch_0 = next(iterator)
print(type(batch_0[0]), batch_0[0].shape)
print(batch_0[1])
<class 'torch.Tensor'> torch.Size([5, 1, 28, 28])
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
nrows, ncols = 4, 8
fig, axs = plt.subplots(nrows, ncols)
for x in range(ncols):
  data, _ = next(iterator)
  for y in range(nrows):
    axs[y, x].imshow(data[y, 0])
    axs[y, x].set_axis_off()
../../_images/22718dd72d3e9797711fcc0e3e3ffb0900e24c6defc5258c1aaa4462a1fbb212.png

In the last cell we additionally confirm that all items are restricted to shoes only. The pipeline that we’ve built not only allowed us to ensure Torchvision dataset-loading features, but also extend with batching and filtering capabilities.