Francis Lata 2793cca9a6 RetinaNet MLPerf (#8385)
* add support for a custom BASEDIR for openimages download

* make export step faster

* add focal loss

* update model_eval with new dataloader

* generate_anchors in tinygrad

* update initializers for model

* small cleanup

* revert isin enhancements

* recursively go through backbone layers to freeze them

* add optimizer

* minor cleanup

* start dataloader work with input images

* add first transform for train set

* reuse existing prepare_target

* continue with dataloader implementation

* add dataloader

* separate out KiTS19 dataset test cases

* create mock data samples for test

* add dataloader + test

* cleanup dataloader test and revert shm path

* trim dataloader related code needed from ref

* got dataloader with normalize working

* update image to be float32

* add back normalization and negate it in test

* clean up reference dataset implementation + ruff changes

* add validation set test

* add proper training loop over the training dataset

* add LambdaLR support

* add LR scheduler and the start of training step

* get forward call to model work and setup multi-GPU

* already passed device

* return matches from dataloader

* hotfix for dataloader typo causing some hang

* start some work on classification loss

* update focal loss to support masking

* add missing test and cleanup focal loss

* cleanup unit tests

* remove masking support for sigmoid_focal_loss

* make ClassificationHead loss work

* cleanups + fix dataloader tests

* remove sigmoid when computing loss

* make anchors use Tensors

* simplify anchors batching

* revert anchors to use np

* implement regression loss

* fix regression loss

* cleanup losses

* move BoxCoder to MLPerf helpers

* revert helper changes

* fixes after helper refactor cleanup

* add tests for l1_loss

* start re-enabling training step

* minor cleanup

* add pycocotools to testing dependencies

* make training work

* adjust regression loss to mask after L1 loss is calculated

* reduce img and lbl sizes by half for KiTS19 dataset tests

* Revert "reduce img and lbl sizes by half for KiTS19 dataset tests"

This reverts commit d115b0c664.

* temporarily disable openimages dataset tests to debug CI

* enable openimages dataset test and create samples once

* temporarily disable openimages validation set test

* reenable test and add some debugging to the test

* add boto3 testing dependencies

* add pandas to testing dependencies

* This reverts commit 467704fec6.

* reenable test

* move sample creation to setup

* realize boxcoder's encoding

* add wandb

* fix wandb resuming feature

* move anchors as part of dataloader

* fix dtype for anchor inside dataloader and fix horizontal flip transformation

* add support for BENCHMARK

* set seed

* debug dataset test failuire

* Revert "debug dataset test failuire"

This reverts commit 1b2f9d7f50.

* fix dataloader script

* do not realize when sharding model weights

* setup openimages samples differently

* create the necessary samples per test case

* enable lr scheduler and fix benchmark timing

* add jit to the training loop

* add checkpointing and training resume capabilities

* refactor on training loop and start the work on val looop

* add debug logging for dataloader test

* debug test

* assert boxes again

* update validation dataloader and more cleanups

* fix validation test case

* add multi device support to retinanet eval

* fix issue with realized on dataloader

* remove optional disk tensors in dataloader

* remove verbose debugging on datasets test

* put back parallel testing and remove img_ids Tensor from dataloader

* cleanup train and validation dataloader

* return validation targets in dataloader

* cleanup boxes and labels in dataloader

* fix img_ids repeating its values

* remove unnecessary targets from validation dataloader

* add validation loop to training script

* adjust LR to be the ratio of the batch size

* minor cleanups

* remove frozen layers from optimizer's params

* hyperparameter adjustments and cleanups

* model init, hyperparam, and data preprocessing updates

* no need to return loaded keys for resnet

* fix train script

* update loss calculation for regresionhead and some cleanups

* add JIT reset support

* add nan check during training

* Revert "add nan check during training"

This reverts commit ddf1f0d5dd.

* Revert "Revert "add nan check during training""

This reverts commit b7b2943197.

* some typing cleanups

* update seeding on dataloader and the start of training script

* undo changse

* undo more changes

* more typing fixes

* minor cleanups

* update dataloader seed

* hotfix: log metric and move target metric check outside of CKPT

* check for CKPT when target metric is reached before saving

* add TRAIN_BEAM and EVAL_BEAM

* minor cleanup

* update hyperparams and add support for EVAL_BS

* add green coloring to metric reached statement

* initial work to support f16

* update model initializers to be monkeypatched

* update layers to support float32 weight loading + float16 training

* don't return loss that's scaled

* run eval on benchmark beam

* move BEAM to their respective steps

* update layers to be compatible with fp16

* end BENCHMARK after first eval

* cleanups and adjust learning rate for fp16

* remove duplicated files from test

* revert losses changes

* Revert "revert losses changes"

This reverts commit aebccf93ac.

* go back to old LR

* cast batchnorm to float32

* set new loss scaler default value for float16

* remove LambdaLRScheduler

* remove runner and use dataloader on eval

* fix retinanet eval with new dataloader

* remove unused import

* revert lr_scheduler updates

* use BS=96 with new learning rate

* rename module initializers

* more cleanups on training loop

* remove contig from optim.step

* simplify sum when computing loss
2025-04-12 22:11:51 -04:00
2025-04-12 22:11:51 -04:00
2024-10-14 22:40:56 +03:00
2025-02-08 17:28:52 +08:00
2024-10-24 15:38:47 +08:00
2025-02-20 18:03:09 -05:00
2025-04-09 03:47:03 -04:00
2025-01-28 09:15:29 +09:00

tiny corp logo

tinygrad: For something between PyTorch and karpathy/micrograd. Maintained by tiny corp.

Homepage | Documentation | Discord

GitHub Repo stars Unit Tests Discord


This may not be the best deep learning framework, but it is a deep learning framework.

Due to its extreme simplicity, it aims to be the easiest framework to add new accelerators to, with support for both inference and training. If XLA is CISC, tinygrad is RISC.

tinygrad is still alpha software, but we raised some money to make it good. Someday, we will tape out chips.

Features

LLaMA and Stable Diffusion

tinygrad can run LLaMA and Stable Diffusion!

Laziness

Try a matmul. See how, despite the style, it is fused into one kernel with the power of laziness.

DEBUG=3 python3 -c "from tinygrad import Tensor;
N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N);
c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2);
print((c.numpy() - (a.numpy() @ b.numpy())).mean())"

And we can change DEBUG to 4 to see the generated code.

Neural networks

As it turns out, 90% of what you need for neural networks are a decent autograd/tensor library. Throw in an optimizer, a data loader, and some compute, and you have all you need.

from tinygrad import Tensor, nn

class LinearNet:
  def __init__(self):
    self.l1 = Tensor.kaiming_uniform(784, 128)
    self.l2 = Tensor.kaiming_uniform(128, 10)
  def __call__(self, x:Tensor) -> Tensor:
    return x.flatten(1).dot(self.l1).relu().dot(self.l2)

model = LinearNet()
optim = nn.optim.Adam([model.l1, model.l2], lr=0.001)

x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7])  # replace with real mnist dataloader

with Tensor.train():
  for i in range(10):
    optim.zero_grad()
    loss = model(x).sparse_categorical_crossentropy(y).backward()
    optim.step()
    print(i, loss.item())

See examples/beautiful_mnist.py for the full version that gets 98% in ~5 seconds

Accelerators

tinygrad already supports numerous accelerators, including:

And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops.

To check default accelerator run: python3 -c "from tinygrad import Device; print(Device.DEFAULT)"

Installation

The current recommended way to install tinygrad is from source.

From source

git clone https://github.com/tinygrad/tinygrad.git
cd tinygrad
python3 -m pip install -e .

Direct (master)

python3 -m pip install git+https://github.com/tinygrad/tinygrad.git

Documentation

Documentation along with a quick start guide can be found on the docs website built from the docs/ directory.

Quick example comparing to PyTorch

from tinygrad import Tensor

x = Tensor.eye(3, requires_grad=True)
y = Tensor([[2.0,0,-2.0]], requires_grad=True)
z = y.matmul(x).sum()
z.backward()

print(x.grad.tolist())  # dz/dx
print(y.grad.tolist())  # dz/dy

The same thing but in PyTorch:

import torch

x = torch.eye(3, requires_grad=True)
y = torch.tensor([[2.0,0,-2.0]], requires_grad=True)
z = y.matmul(x).sum()
z.backward()

print(x.grad.tolist())  # dz/dx
print(y.grad.tolist())  # dz/dy

Contributing

There has been a lot of interest in tinygrad lately. Following these guidelines will help your PR get accepted.

We'll start with what will get your PR closed with a pointer to this section:

  • No code golf! While low line count is a guiding light of this project, anything that remotely looks like code golf will be closed. The true goal is reducing complexity and increasing readability, and deleting \ns does nothing to help with that.
  • All docs and whitespace changes will be closed unless you are a well-known contributor. The people writing the docs should be those who know the codebase the absolute best. People who have not demonstrated that shouldn't be messing with docs. Whitespace changes are both useless and carry a risk of introducing bugs.
  • Anything you claim is a "speedup" must be benchmarked. In general, the goal is simplicity, so even if your PR makes things marginally faster, you have to consider the tradeoff with maintainability and readability.
  • In general, the code outside the core tinygrad/ folder is not well tested, so unless the current code there is broken, you shouldn't be changing it.
  • If your PR looks "complex", is a big diff, or adds lots of lines, it won't be reviewed or merged. Consider breaking it up into smaller PRs that are individually clear wins. A common pattern I see is prerequisite refactors before adding new functionality. If you can (cleanly) refactor to the point that the feature is a 3 line change, this is great, and something easy for us to review.

Now, what we want:

  • Bug fixes (with a regression test) are great! This library isn't 1.0 yet, so if you stumble upon a bug, fix it, write a test, and submit a PR, this is valuable work.
  • Solving bounties! tinygrad offers cash bounties for certain improvements to the library. All new code should be high quality and well tested.
  • Features. However, if you are adding a feature, consider the line tradeoff. If it's 3 lines, there's less of a bar of usefulness it has to meet over something that's 30 or 300 lines. All features must have regression tests. In general with no other constraints, your feature's API should match torch or numpy.
  • Refactors that are clear wins. In general, if your refactor isn't a clear win it will be closed. But some refactors are amazing! Think about readability in a deep core sense. A whitespace change or moving a few functions around is useless, but if you realize that two 100 line functions can actually use the same 110 line function with arguments while also improving readability, this is a big win. Refactors should pass process replay.
  • Tests/fuzzers. If you can add tests that are non brittle, they are welcome. We have some fuzzers in here too, and there's a plethora of bugs that can be found with them and by improving them. Finding bugs, even writing broken tests (that should pass) with @unittest.expectedFailure is great. This is how we make progress.
  • Dead code removal from core tinygrad/ folder. We don't care about the code in extra, but removing dead code from the core library is great. Less for new people to read and be confused by.

Running tests

You should install the pre-commit hooks with pre-commit install. This will run the linter, mypy, and a subset of the tests on every commit.

For more examples on how to run the full test suite please refer to the CI workflow.

Some examples of running tests locally:

python3 -m pip install -e '.[testing]'  # install extra deps for testing
python3 test/test_ops.py                # just the ops tests
python3 -m pytest test/                 # whole test suite

Process replay tests

Process replay compares your PR's generated kernels against master. If your PR is a refactor or speedup without any expected behavior change, It should include [pr] in the pull request title.

Description
No description provided
Readme MIT 267 MiB
Languages
Python 67.6%
C 19.3%
Cuda 5.3%
Assembly 2.7%
Metal 2.3%
Other 2.7%