grain0.2.11
grain0.2.11
Published
Grain: A library for loading and transforming data for ML training.
pip install grain
Package Downloads
Authors
Project URLs
Requires Python
>=3.10
Dependencies
- absl-py
- array-record
- cloudpickle
- dm-tree
- etils
[epath,epy]
- more-itertools
>=9.1.0
- numpy
- protobuf
>=3.20.3
- attrs
; extra == "testing"
- dill
; extra == "testing"
- jax
; extra == "testing"
- jaxlib
; extra == "testing"
- jaxtyping
; extra == "testing"
- pyarrow
; extra == "testing"
- tensorflow-datasets
; extra == "testing"
- pyarrow
; extra == "parquet"
Grain - Feeding JAX Models
Installation | Quickstart | Reference docs | Change logs
Grain is a Python library for reading and processing data for training and evaluating JAX models. It is flexible, fast and deterministic.
Grain allows to define data processing steps in a simple declarative way:
import grain
dataset = (
grain.MapDataset.source([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
.shuffle(seed=42) # Shuffles elements globally.
.map(lambda x: x+1) # Maps each element.
.batch(batch_size=2) # Batches consecutive elements.
)
for batch in dataset:
# Training step.
Grain is designed to work with JAX models but it does not require JAX to run and can be used with other frameworks as well.
Installation
Grain is available on PyPI and can be
installed with pip install grain
.
Supported platforms
Grain does not directly use GPU or TPU in its transformations, the processing within Grain will be done on the CPU by default.
Linux | Mac | Windows | |
---|---|---|---|
x86_64 | yes | no | no |
aarch64 | yes | yes | n/a |
Quickstart
Existing users
Grain is used by MaxText, Gemma, kauldron and multiple internal Google projects.