Reading ArrayRecord Files#
This tutorial provides an example of how to retrieve records from ArrayRecord files using grain.sources.ArrayRecordDataSource
, also covers how to process and transform the data with Grain.
Install and Load Dependencies#
!pip install grain array_record
import pickle
import grain
import tensorflow_datasets as tfds
from array_record.python import array_record_module
Write a temp ArrayRecord file#
# Load a public tensorflow dataset.
test_tfds = tfds.data_source("bool_q", split="train")
# Write the dataset into a test array_record file.
example_file_path = "./test.array_record"
writer = array_record_module.ArrayRecordWriter(
example_file_path, "group_size:1"
)
record_count = 0
for record in test_tfds:
writer.write(pickle.dumps(record))
record_count += 1
writer.close()
print(
f"Number of records written to array_record file {example_file_path} :"
f" {record_count}"
)
# @title Load Data Source
example_array_record_data_source = (grain.sources.ArrayRecordDataSource(
example_file_path
))
print(f"Number of records: {len(example_array_record_data_source)}")
print(example_array_record_data_source[0])
Define Transformation Function#
# Load a pre trained tokenizer
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_pretrained("bert-base-cased")
class ParseAndTokenizeText(grain.transforms.Map):
"""This function takes a serialized dict (as bytes), decodes it,
applies a tokenizer to a specified feature within the dict,
and returns the first 10 tokens from results.
"""
def __init__(self, tokenizer, feature_name):
self._tokenizer = tokenizer
self._feature_name = feature_name
def map(self, element: bytes) -> [str]:
parsed_element = pickle.loads(element)
# only pick the first 10 token IDs from the tokenized text for testing
return self._tokenizer.encode(
parsed_element[self._feature_name].decode('utf-8')
).tokens[:10]
Load and process data via the Dataset API#
# Example using Grain's MapDataset with ArrayRecord file source.
example_datasets = (
grain.MapDataset.source(example_array_record_data_source)
.shuffle(seed=42)
.map(ParseAndTokenizeText(tokenizer, "question"))
.batch(batch_size=10)
)
# Output a record at a random index
print(example_datasets[100])