Loading and transforming HuggingFace datasets

Loading and transforming HuggingFace datasets#

Open in Colab

HuggingFace (HF) platform provides a wide variety of ML models, datasets, and transformers for the worldwide community. An easy access to these assets is guaranteed thanks to Python packages such as datasets or transformers, available on PyPI.

In this tutorial you will learn how to utilize HF datasets and tools with Grain: How to load HF datasets and how to use HF transformers in your Grain pipeline.

Setup#

To run the notebook you need to have a few packages installed in your environment: grain, numpy, and Two HF packages: datasets and transformers.

!pip install grain
!pip install -U numpy datasets transformers huggingface_hub fsspec
# Python standard library
from pprint import pprint
# HF imports
from datasets import load_dataset
from dateutil.parser import parse
import grain
import numpy as np
from transformers import AutoTokenizer

Loading dataset#

Let’s first import an HF dataset. For the sake of simplicity let’s proceed with lhoestq/demo1 - a minimal dataset comprised of five rows and six columns.

hf_dataset = load_dataset("lhoestq/demo1")
hf_train, hf_test = hf_dataset["train"], hf_dataset["test"]
hf_dataset
/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: 
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  warnings.warn(
DatasetDict({
    train: Dataset({
        features: ['id', 'package_name', 'review', 'date', 'star', 'version_id'],
        num_rows: 5
    })
    test: Dataset({
        features: ['id', 'package_name', 'review', 'date', 'star', 'version_id'],
        num_rows: 5
    })
})

Each sample is a Python dictionary with string or integer data.

hf_train[0]
{'id': '7bd227d9-afc9-11e6-aba1-c4b301cdf627',
 'package_name': 'com.mantz_it.rfanalyzer',
 'review': "Great app! The new version now works on my Bravia Android TV which is great as it's right by my rooftop aerial cable. The scan feature would be useful...any ETA on when this will be available? Also the option to import a list of bookmarks e.g. from a simple properties file would be useful.",
 'date': 'October 12 2016',
 'star': 4,
 'version_id': 1487}

Preprocessing#

Let’s assume that for our preprocessing pipeline we want the string date field to become a timestamp and the sample values - NumPy arrays.

def process_date(sample: dict) -> dict:
  sample["date"] = parse(sample["date"]).timestamp()
  return sample


def process_sample_to_np(sample: dict) -> np.ndarray:
  for name, value in sample.items():
    sample[name] = np.asarray(value)
  return sample

Building a pipeline is as simple as chaining map calls. HF dataset supports random access so we can pass it directly to a source method. The resulting object is of type grain.MapDataset with random access support.

dataset = (
    grain.MapDataset.source(hf_train)
    .shuffle(seed=42)  # shuffles globally
    .map(process_date)  # maps each element
    .map(process_sample_to_np)  # maps each element
)
list(dataset)
[{'id': array('7bd22aba-afc9-11e6-8293-c4b301cdf627', dtype='<U36'),
  'package_name': array('com.mantz_it.rfanalyzer', dtype='<U23'),
  'review': array('Works well with my Hackrf Hopefully new updates will arrive for extra functions',
        dtype='<U79'),
  'date': array(1.4691456e+09),
  'star': array(5),
  'version_id': array(1487)},
 {'id': array('7bd227d9-afc9-11e6-aba1-c4b301cdf627', dtype='<U36'),
  'package_name': array('com.mantz_it.rfanalyzer', dtype='<U23'),
  'review': array("Great app! The new version now works on my Bravia Android TV which is great as it's right by my rooftop aerial cable. The scan feature would be useful...any ETA on when this will be available? Also the option to import a list of bookmarks e.g. from a simple properties file would be useful.",
        dtype='<U290'),
  'date': array(1.4762304e+09),
  'star': array(4),
  'version_id': array(1487)},
 {'id': array('7bd22905-afc9-11e6-a5dc-c4b301cdf627', dtype='<U36'),
  'package_name': array('com.mantz_it.rfanalyzer', dtype='<U23'),
  'review': array("Great It's not fully optimised and has some issues with crashing but still a nice app  especially considering the price and it's open source.",
        dtype='<U141'),
  'date': array(1.4719104e+09),
  'star': array(4),
  'version_id': array(1487)},
 {'id': array('7bd22a26-afc9-11e6-9309-c4b301cdf627', dtype='<U36'),
  'package_name': array('com.mantz_it.rfanalyzer', dtype='<U23'),
  'review': array('The bandwidth seemed to be limited to maximum 2 MHz or so. I tried to increase the bandwidth but not possible. I purchased this is because one of the pictures in the advertisement showed the 2.4GHz band with around 10MHz or more bandwidth. Is it not possible to increase the bandwidth? If not  it is just the same performance as other free APPs.',
        dtype='<U345'),
  'date': array(1.4694048e+09),
  'star': array(3),
  'version_id': array(1487)},
 {'id': array('7bd2299c-afc9-11e6-85d6-c4b301cdf627', dtype='<U36'),
  'package_name': array('com.mantz_it.rfanalyzer', dtype='<U23'),
  'review': array("Works on a Nexus 6p I'm still messing around with my hackrf but it works with my Nexus 6p  Trond usb-c to usb host adapter. Thanks!",
        dtype='<U131'),
  'date': array(1.4702688e+09),
  'star': array(5),
  'version_id': array(1487)}]

Tokenizer#

Next we would like to tokenize the review field. LLM models operate on integers (encoded words) rather than raw strings. AutoTokenizer generic class ships from_pretrained method - accessor to models and tokenizers hosted on HF services.

Let’s use bert-base-uncased, a case-insensitive BERT-based transformers model.

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tokenizer
BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

Transforming a single review string yields a dictionary with three keys. We’re only interested in input_ids since that is the encoded review.

review = hf_train[0]["review"]
pprint(review)
print("\n", tokenizer(review).keys(), "\n")
pprint(np.asarray(tokenizer(review)["input_ids"]))
('Great app! The new version now works on my Bravia Android TV which is great '
 "as it's right by my rooftop aerial cable. The scan feature would be "
 'useful...any ETA on when this will be available? Also the option to import a '
 'list of bookmarks e.g. from a simple properties file would be useful.')

 dict_keys(['input_ids', 'token_type_ids', 'attention_mask']) 

array([  101,  2307, 10439,   999,  1996,  2047,  2544,  2085,  2573,
        2006,  2026, 11655,  9035, 11924,  2694,  2029,  2003,  2307,
        2004,  2009,  1005,  1055,  2157,  2011,  2026, 23308,  9682,
        5830,  1012,  1996, 13594,  3444,  2052,  2022,  6179,  1012,
        1012,  1012,  2151, 27859,  2006,  2043,  2023,  2097,  2022,
        2800,  1029,  2036,  1996,  5724,  2000, 12324,  1037,  2862,
        1997,  2338, 27373,  1041,  1012,  1043,  1012,  2013,  1037,
        3722,  5144,  5371,  2052,  2022,  6179,  1012,   102])

Plugging the selected transformer is as easy as before. We implement the process_transformer function and pass it to the map method.

Note that the tokenized reviews have different lengths, and accelerators such as GPUs and TPUs typically require static rectangular batch shapes. For simplicity in this tutorial we will pad them to the same length before batching. For advanced use cases please take a look at our example packing imlementations: first-fit and concat-and-split that allow to minimize padding or avoid it altogether.

target_length = 70

def process_transformer(sample: dict) -> dict:
  tokenized = tokenizer(sample["review"])["input_ids"][:target_length]
  sample["review"] = np.pad(
      tokenized, pad_width=(0, target_length-len(tokenized)))
  return sample


dataset = (
    grain.MapDataset.source(hf_train)
    .shuffle(seed=42)
    .map(process_date)
    .map(process_transformer)
)

Now samples are less human- but more machine-friendly.

dataset[2]
{'id': '7bd2299c-afc9-11e6-85d6-c4b301cdf627',
 'package_name': 'com.mantz_it.rfanalyzer',
 'review': array([  101,  2573,  2006,  1037, 26041,  1020,  2361,  1045,  1005,
         1049,  2145, 22308,  2105,  2007,  2026, 20578, 12881,  2021,
         2009,  2573,  2007,  2026, 26041,  1020,  2361, 19817, 15422,
        18833,  1011,  1039,  2000, 18833,  3677, 15581,  2121,  1012,
         4283,   999,   102,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0]),
 'date': 1470268800.0,
 'star': 5,
 'version_id': 1487}

Complete Pipeline#

Time to build our final pipeline! The pipeline doesn’t need to be restricted to shuffle and map. Grain has a rich API and hands us multiple functionalities such as: filter, random_map, repeat. Check out Grain API page to learn more.

On top of the transformer we want to discard reviews that are rated three stars or less. It’s crucial to mention that filtering changes the number of samples in the following steps so random access is no longer available. To perform batching as the final step we plug .to_iter_dataset() converting MapDataset to IterDataset - a dataset that gives us an iterator-like interface.

dataset = (
    grain.MapDataset.source(hf_train)
    .shuffle(seed=42)
    .filter(lambda x: x["star"] > 3)  # filters samples
    .map(process_date)
    .map(process_transformer)
    .map(process_sample_to_np)
    .to_iter_dataset()
    .batch(batch_size=2)  # batches consecutive elements
)

With IterDataset we can use Python built-ins, iter and next, to interact with the dataset.

ds_iter = iter(dataset)
next(ds_iter)
{'date': array([1.4691456e+09, 1.4719104e+09]),
 'id': array(['7bd22aba-afc9-11e6-8293-c4b301cdf627',
        '7bd22905-afc9-11e6-a5dc-c4b301cdf627'], dtype='<U36'),
 'package_name': array(['com.mantz_it.rfanalyzer', 'com.mantz_it.rfanalyzer'], dtype='<U23'),
 'review': array([[  101,  2573,  2092,  2007,  2026, 20578, 12881, 11504,  2047,
         14409,  2097,  7180,  2005,  4469,  4972,   102,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0],
        [  101,  2307,  2009,  1005,  1055,  2025,  3929, 23569, 27605,
          6924,  1998,  2038,  2070,  3314,  2007, 12894,  2021,  2145,
          1037,  3835, 10439,  2926,  6195,  1996,  3976,  1998,  2009,
          1005,  1055,  2330,  3120,  1012,   102,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0]]),
 'star': array([5, 4]),
 'version_id': array([1487, 1487])}

And that’s it! We ended up with a batch with processed date, tokenized review, and filtered rating.