mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
1648 lines
51 KiB
HTML
1648 lines
51 KiB
HTML
|
|
<!doctype html>
|
|
<html lang="en" class="no-js">
|
|
<head>
|
|
|
|
<meta charset="utf-8">
|
|
<meta name="viewport" content="width=device-width,initial-scale=1">
|
|
|
|
|
|
|
|
<link rel="canonical" href="https://docs.tinygrad.org/mnist/">
|
|
|
|
|
|
<link rel="prev" href="../showcase/">
|
|
|
|
|
|
<link rel="next" href="../tensor/">
|
|
|
|
|
|
|
|
|
|
|
|
<link rel="icon" href="../favicon.svg">
|
|
<meta name="generator" content="mkdocs-1.6.1, mkdocs-material-9.7.1">
|
|
|
|
|
|
|
|
<title>MNIST Tutorial - tinygrad docs</title>
|
|
|
|
|
|
|
|
<link rel="stylesheet" href="../assets/stylesheets/main.484c7ddc.min.css">
|
|
|
|
|
|
<link rel="stylesheet" href="../assets/stylesheets/palette.ab4e12ef.min.css">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
|
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300i,400,400i,700,700i%7CRoboto+Mono:400,400i,700,700i&display=fallback">
|
|
<style>:root{--md-text-font:"Roboto";--md-code-font:"Roboto Mono"}</style>
|
|
|
|
|
|
|
|
<link rel="stylesheet" href="../assets/_markdown_exec_pyodide.css">
|
|
|
|
<link rel="stylesheet" href="../assets/_markdown_exec_ansi.css">
|
|
|
|
<link rel="stylesheet" href="../assets/_mkdocstrings.css">
|
|
|
|
<script>__md_scope=new URL("..",location),__md_hash=e=>[...e].reduce(((e,_)=>(e<<5)-e+_.charCodeAt(0)),0),__md_get=(e,_=localStorage,t=__md_scope)=>JSON.parse(_.getItem(t.pathname+"."+e)),__md_set=(e,_,t=localStorage,a=__md_scope)=>{try{t.setItem(a.pathname+"."+e,JSON.stringify(_))}catch(e){}}</script>
|
|
|
|
|
|
|
|
|
|
|
|
</head>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<body dir="ltr" data-md-color-scheme="default" data-md-color-primary="black" data-md-color-accent="lime">
|
|
|
|
|
|
<input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer" autocomplete="off">
|
|
<input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search" autocomplete="off">
|
|
<label class="md-overlay" for="__drawer"></label>
|
|
<div data-md-component="skip">
|
|
|
|
|
|
<a href="#mnist-tutorial" class="md-skip">
|
|
Skip to content
|
|
</a>
|
|
|
|
</div>
|
|
<div data-md-component="announce">
|
|
|
|
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<header class="md-header md-header--shadow" data-md-component="header">
|
|
<nav class="md-header__inner md-grid" aria-label="Header">
|
|
<a href=".." title="tinygrad docs" class="md-header__button md-logo" aria-label="tinygrad docs" data-md-component="logo">
|
|
|
|
<img src="../logo_tiny_dark.svg" alt="logo">
|
|
|
|
</a>
|
|
<label class="md-header__button md-icon" for="__drawer">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M3 6h18v2H3zm0 5h18v2H3zm0 5h18v2H3z"/></svg>
|
|
</label>
|
|
<div class="md-header__title" data-md-component="header-title">
|
|
<div class="md-header__ellipsis">
|
|
<div class="md-header__topic">
|
|
<span class="md-ellipsis">
|
|
tinygrad docs
|
|
</span>
|
|
</div>
|
|
<div class="md-header__topic" data-md-component="header-topic">
|
|
<span class="md-ellipsis">
|
|
|
|
MNIST Tutorial
|
|
|
|
</span>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
|
|
|
|
<form class="md-header__option" data-md-component="palette">
|
|
|
|
|
|
|
|
|
|
<input class="md-option" data-md-color-media="(prefers-color-scheme)" data-md-color-scheme="default" data-md-color-primary="black" data-md-color-accent="lime" aria-label="Switch to light mode" type="radio" name="__palette" id="__palette_0">
|
|
|
|
<label class="md-header__button md-icon" title="Switch to light mode" for="__palette_1" hidden>
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="m14.3 16-.7-2h-3.2l-.7 2H7.8L11 7h2l3.2 9zM20 8.69V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12zm-9.15 3.96h2.3L12 9z"/></svg>
|
|
</label>
|
|
|
|
|
|
|
|
|
|
|
|
<input class="md-option" data-md-color-media="(prefers-color-scheme: light)" data-md-color-scheme="default" data-md-color-primary="black" data-md-color-accent="lime" aria-label="Switch to dark mode" type="radio" name="__palette" id="__palette_1">
|
|
|
|
<label class="md-header__button md-icon" title="Switch to dark mode" for="__palette_2" hidden>
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 8a4 4 0 0 0-4 4 4 4 0 0 0 4 4 4 4 0 0 0 4-4 4 4 0 0 0-4-4m0 10a6 6 0 0 1-6-6 6 6 0 0 1 6-6 6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
|
|
</label>
|
|
|
|
|
|
|
|
|
|
|
|
<input class="md-option" data-md-color-media="(prefers-color-scheme: dark)" data-md-color-scheme="slate" data-md-color-primary="black" data-md-color-accent="lime" aria-label="Switch to system preference" type="radio" name="__palette" id="__palette_2">
|
|
|
|
<label class="md-header__button md-icon" title="Switch to system preference" for="__palette_0" hidden>
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 18c-.89 0-1.74-.2-2.5-.55C11.56 16.5 13 14.42 13 12s-1.44-4.5-3.5-5.45C10.26 6.2 11.11 6 12 6a6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
|
|
</label>
|
|
|
|
|
|
</form>
|
|
|
|
|
|
|
|
<script>var palette=__md_get("__palette");if(palette&&palette.color){if("(prefers-color-scheme)"===palette.color.media){var media=matchMedia("(prefers-color-scheme: light)"),input=document.querySelector(media.matches?"[data-md-color-media='(prefers-color-scheme: light)']":"[data-md-color-media='(prefers-color-scheme: dark)']");palette.color.media=input.getAttribute("data-md-color-media"),palette.color.scheme=input.getAttribute("data-md-color-scheme"),palette.color.primary=input.getAttribute("data-md-color-primary"),palette.color.accent=input.getAttribute("data-md-color-accent")}for(var[key,value]of Object.entries(palette.color))document.body.setAttribute("data-md-color-"+key,value)}</script>
|
|
|
|
|
|
|
|
|
|
|
|
<label class="md-header__button md-icon" for="__search">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
|
|
</label>
|
|
<div class="md-search" data-md-component="search" role="dialog">
|
|
<label class="md-search__overlay" for="__search"></label>
|
|
<div class="md-search__inner" role="search">
|
|
<form class="md-search__form" name="search">
|
|
<input type="text" class="md-search__input" name="query" aria-label="Search" placeholder="Search" autocapitalize="off" autocorrect="off" autocomplete="off" spellcheck="false" data-md-component="search-query" required>
|
|
<label class="md-search__icon md-icon" for="__search">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
|
|
</label>
|
|
<nav class="md-search__options" aria-label="Search">
|
|
|
|
<button type="reset" class="md-search__icon md-icon" title="Clear" aria-label="Clear" tabindex="-1">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M19 6.41 17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12z"/></svg>
|
|
</button>
|
|
</nav>
|
|
|
|
<div class="md-search__suggest" data-md-component="search-suggest"></div>
|
|
|
|
</form>
|
|
<div class="md-search__output">
|
|
<div class="md-search__scrollwrap" tabindex="0" data-md-scrollfix>
|
|
<div class="md-search-result" data-md-component="search-result">
|
|
<div class="md-search-result__meta">
|
|
Initializing search
|
|
</div>
|
|
<ol class="md-search-result__list" role="presentation"></ol>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
|
|
|
|
|
|
<div class="md-header__source">
|
|
<a href="https://github.com/tinygrad/tinygrad/" title="Go to repository" class="md-source" data-md-component="source">
|
|
<div class="md-source__icon md-icon">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 7.1.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2025 Fonticons, Inc.--><path d="M439.6 236.1 244 40.5c-5.4-5.5-12.8-8.5-20.4-8.5s-15 3-20.4 8.4L162.5 81l51.5 51.5c27.1-9.1 52.7 16.8 43.4 43.7l49.7 49.7c34.2-11.8 61.2 31 35.5 56.7-26.5 26.5-70.2-2.9-56-37.3L240.3 199v121.9c25.3 12.5 22.3 41.8 9.1 55-6.4 6.4-15.2 10.1-24.3 10.1s-17.8-3.6-24.3-10.1c-17.6-17.6-11.1-46.9 11.2-56v-123c-20.8-8.5-24.6-30.7-18.6-45L142.6 101 8.5 235.1C3 240.6 0 247.9 0 255.5s3 15 8.5 20.4l195.6 195.7c5.4 5.4 12.7 8.4 20.4 8.4s15-3 20.4-8.4l194.7-194.7c5.4-5.4 8.4-12.8 8.4-20.4s-3-15-8.4-20.4"/></svg>
|
|
</div>
|
|
<div class="md-source__repository">
|
|
GitHub
|
|
</div>
|
|
</a>
|
|
</div>
|
|
|
|
</nav>
|
|
|
|
</header>
|
|
|
|
<div class="md-container" data-md-component="container">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<main class="md-main" data-md-component="main">
|
|
<div class="md-main__inner md-grid">
|
|
|
|
|
|
|
|
<div class="md-sidebar md-sidebar--primary" data-md-component="sidebar" data-md-type="navigation" >
|
|
<div class="md-sidebar__scrollwrap">
|
|
<div class="md-sidebar__inner">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<nav class="md-nav md-nav--primary md-nav--integrated" aria-label="Navigation" data-md-level="0">
|
|
<label class="md-nav__title" for="__drawer">
|
|
<a href=".." title="tinygrad docs" class="md-nav__button md-logo" aria-label="tinygrad docs" data-md-component="logo">
|
|
|
|
<img src="../logo_tiny_dark.svg" alt="logo">
|
|
|
|
</a>
|
|
tinygrad docs
|
|
</label>
|
|
|
|
<div class="md-nav__source">
|
|
<a href="https://github.com/tinygrad/tinygrad/" title="Go to repository" class="md-source" data-md-component="source">
|
|
<div class="md-source__icon md-icon">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 7.1.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2025 Fonticons, Inc.--><path d="M439.6 236.1 244 40.5c-5.4-5.5-12.8-8.5-20.4-8.5s-15 3-20.4 8.4L162.5 81l51.5 51.5c27.1-9.1 52.7 16.8 43.4 43.7l49.7 49.7c34.2-11.8 61.2 31 35.5 56.7-26.5 26.5-70.2-2.9-56-37.3L240.3 199v121.9c25.3 12.5 22.3 41.8 9.1 55-6.4 6.4-15.2 10.1-24.3 10.1s-17.8-3.6-24.3-10.1c-17.6-17.6-11.1-46.9 11.2-56v-123c-20.8-8.5-24.6-30.7-18.6-45L142.6 101 8.5 235.1C3 240.6 0 247.9 0 255.5s3 15 8.5 20.4l195.6 195.7c5.4 5.4 12.7 8.4 20.4 8.4s15-3 20.4-8.4l194.7-194.7c5.4-5.4 8.4-12.8 8.4-20.4s-3-15-8.4-20.4"/></svg>
|
|
</div>
|
|
<div class="md-source__repository">
|
|
GitHub
|
|
</div>
|
|
</a>
|
|
</div>
|
|
|
|
<ul class="md-nav__list" data-md-scrollfix>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item md-nav__item--active md-nav__item--section md-nav__item--nested">
|
|
|
|
|
|
|
|
<input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_1" checked>
|
|
|
|
|
|
<div class="md-nav__link md-nav__container">
|
|
<a href=".." class="md-nav__link ">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Home
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
|
|
|
|
<label class="md-nav__link " for="__nav_1" id="__nav_1_label" tabindex="">
|
|
<span class="md-nav__icon md-icon"></span>
|
|
</label>
|
|
|
|
</div>
|
|
|
|
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_1_label" aria-expanded="true">
|
|
<label class="md-nav__title" for="__nav_1">
|
|
<span class="md-nav__icon md-icon"></span>
|
|
|
|
|
|
Home
|
|
|
|
|
|
</label>
|
|
<ul class="md-nav__list" data-md-scrollfix>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../quickstart/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Quickstart
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../showcase/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Showcase
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item md-nav__item--active">
|
|
|
|
<input class="md-nav__toggle md-toggle" type="checkbox" id="__toc">
|
|
|
|
|
|
|
|
|
|
|
|
<label class="md-nav__link md-nav__link--active" for="__toc">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
MNIST Tutorial
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
<span class="md-nav__icon md-icon"></span>
|
|
</label>
|
|
|
|
<a href="./" class="md-nav__link md-nav__link--active">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
MNIST Tutorial
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
|
|
|
|
|
|
<nav class="md-nav md-nav--secondary" aria-label="Table of contents">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<label class="md-nav__title" for="__toc">
|
|
<span class="md-nav__icon md-icon"></span>
|
|
Table of contents
|
|
</label>
|
|
<ul class="md-nav__list" data-md-component="toc" data-md-scrollfix>
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#one-liner-to-install-tinygrad-in-colab" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
One-liner to install tinygrad in colab
|
|
|
|
</span>
|
|
</a>
|
|
|
|
</li>
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#whats-the-default-device" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
What's the default device?
|
|
|
|
</span>
|
|
</a>
|
|
|
|
</li>
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#a-simple-model" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
A simple model
|
|
|
|
</span>
|
|
</a>
|
|
|
|
<nav class="md-nav" aria-label="A simple model">
|
|
<ul class="md-nav__list">
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#getting-the-dataset" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
Getting the dataset
|
|
|
|
</span>
|
|
</a>
|
|
|
|
</li>
|
|
|
|
</ul>
|
|
</nav>
|
|
|
|
</li>
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#using-the-model" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
Using the model
|
|
|
|
</span>
|
|
</a>
|
|
|
|
<nav class="md-nav" aria-label="Using the model">
|
|
<ul class="md-nav__list">
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#training-the-model" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
Training the model
|
|
|
|
</span>
|
|
</a>
|
|
|
|
</li>
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#why-so-slow" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
Why so slow?
|
|
|
|
</span>
|
|
</a>
|
|
|
|
</li>
|
|
|
|
</ul>
|
|
</nav>
|
|
|
|
</li>
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#putting-it-together" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
Putting it together
|
|
|
|
</span>
|
|
</a>
|
|
|
|
</li>
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#from-here" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
From here?
|
|
|
|
</span>
|
|
</a>
|
|
|
|
</li>
|
|
|
|
</ul>
|
|
|
|
</nav>
|
|
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item md-nav__item--nested">
|
|
|
|
|
|
|
|
|
|
|
|
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_5" >
|
|
|
|
|
|
<label class="md-nav__link" for="__nav_1_5" id="__nav_1_5_label" tabindex="0">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
API Reference
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
<span class="md-nav__icon md-icon"></span>
|
|
</label>
|
|
|
|
<nav class="md-nav" data-md-level="2" aria-labelledby="__nav_1_5_label" aria-expanded="false">
|
|
<label class="md-nav__title" for="__nav_1_5">
|
|
<span class="md-nav__icon md-icon"></span>
|
|
|
|
|
|
API Reference
|
|
|
|
|
|
</label>
|
|
<ul class="md-nav__list" data-md-scrollfix>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item md-nav__item--nested">
|
|
|
|
|
|
|
|
|
|
|
|
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_5_1" >
|
|
|
|
|
|
<div class="md-nav__link md-nav__container">
|
|
<a href="../tensor/" class="md-nav__link ">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Tensor
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
|
|
|
|
<label class="md-nav__link " for="__nav_1_5_1" id="__nav_1_5_1_label" tabindex="0">
|
|
<span class="md-nav__icon md-icon"></span>
|
|
</label>
|
|
|
|
</div>
|
|
|
|
<nav class="md-nav" data-md-level="3" aria-labelledby="__nav_1_5_1_label" aria-expanded="false">
|
|
<label class="md-nav__title" for="__nav_1_5_1">
|
|
<span class="md-nav__icon md-icon"></span>
|
|
|
|
|
|
Tensor
|
|
|
|
|
|
</label>
|
|
<ul class="md-nav__list" data-md-scrollfix>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../tensor/properties/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Properties
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../tensor/creation/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Creation
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../tensor/movement/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Movement
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../tensor/elementwise/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Elementwise
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../tensor/ops/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Complex Ops
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
</ul>
|
|
</nav>
|
|
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../dtypes/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
dtypes
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../nn/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
nn (Neural Networks)
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../env_vars/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Environment Variables
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../runtime/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Runtime
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
</ul>
|
|
</nav>
|
|
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item md-nav__item--nested">
|
|
|
|
|
|
|
|
|
|
|
|
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_6" >
|
|
|
|
|
|
<label class="md-nav__link" for="__nav_1_6" id="__nav_1_6_label" tabindex="0">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Developer
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
<span class="md-nav__icon md-icon"></span>
|
|
</label>
|
|
|
|
<nav class="md-nav" data-md-level="2" aria-labelledby="__nav_1_6_label" aria-expanded="false">
|
|
<label class="md-nav__title" for="__nav_1_6">
|
|
<span class="md-nav__icon md-icon"></span>
|
|
|
|
|
|
Developer
|
|
|
|
|
|
</label>
|
|
<ul class="md-nav__list" data-md-scrollfix>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../developer/developer/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Intro
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../developer/layout/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Layout
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../developer/speed/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Speed
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../developer/uop/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
UOp
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item md-nav__item--nested">
|
|
|
|
|
|
|
|
|
|
|
|
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_6_5" >
|
|
|
|
|
|
<label class="md-nav__link" for="__nav_1_6_5" id="__nav_1_6_5_label" tabindex="0">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Runtime
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
<span class="md-nav__icon md-icon"></span>
|
|
</label>
|
|
|
|
<nav class="md-nav" data-md-level="3" aria-labelledby="__nav_1_6_5_label" aria-expanded="false">
|
|
<label class="md-nav__title" for="__nav_1_6_5">
|
|
<span class="md-nav__icon md-icon"></span>
|
|
|
|
|
|
Runtime
|
|
|
|
|
|
</label>
|
|
<ul class="md-nav__list" data-md-scrollfix>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../developer/runtime/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Runtime Overview
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../developer/hcq/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
HCQ
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../developer/am/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
AM Driver
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
</ul>
|
|
</nav>
|
|
|
|
</li>
|
|
|
|
|
|
|
|
|
|
</ul>
|
|
</nav>
|
|
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../tinybox/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
tinybox
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
</ul>
|
|
</nav>
|
|
|
|
</li>
|
|
|
|
|
|
|
|
</ul>
|
|
</nav>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
|
|
|
|
|
|
|
|
<div class="md-content" data-md-component="content">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<article class="md-content__inner md-typeset">
|
|
|
|
|
|
|
|
|
|
|
|
<a href="https://github.com/tinygrad/tinygrad/edit/master/docs/mnist.md" title="Edit this page" class="md-content__button md-icon" rel="edit">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M10 20H6V4h7v5h5v3.1l2-2V8l-6-6H6c-1.1 0-2 .9-2 2v16c0 1.1.9 2 2 2h4zm10.2-7c.1 0 .3.1.4.2l1.3 1.3c.2.2.2.6 0 .8l-1 1-2.1-2.1 1-1c.1-.1.2-.2.4-.2m0 3.9L14.1 23H12v-2.1l6.1-6.1z"/></svg>
|
|
</a>
|
|
|
|
|
|
|
|
|
|
|
|
<a href="https://github.com/tinygrad/tinygrad/raw/master/docs/mnist.md" title="View source of this page" class="md-content__button md-icon">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M17 18c.56 0 1 .44 1 1s-.44 1-1 1-1-.44-1-1 .44-1 1-1m0-3c-2.73 0-5.06 1.66-6 4 .94 2.34 3.27 4 6 4s5.06-1.66 6-4c-.94-2.34-3.27-4-6-4m0 6.5a2.5 2.5 0 0 1-2.5-2.5 2.5 2.5 0 0 1 2.5-2.5 2.5 2.5 0 0 1 2.5 2.5 2.5 2.5 0 0 1-2.5 2.5M9.27 20H6V4h7v5h5v4.07c.7.08 1.36.25 2 .49V8l-6-6H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h4.5a8.2 8.2 0 0 1-1.23-2"/></svg>
|
|
</a>
|
|
|
|
|
|
|
|
<h1 id="mnist-tutorial">MNIST Tutorial<a class="headerlink" href="#mnist-tutorial" title="Permanent link">¤</a></h1>
|
|
<p>After you have installed tinygrad, this is a great first tutorial.</p>
|
|
<p>Start up a notebook locally, or use <a href="https://colab.research.google.com/">colab</a>. tinygrad is very lightweight, so it's easy to install anywhere and doesn't need a special colab image, but for speed we recommend a T4 GPU image.</p>
|
|
<h3 id="one-liner-to-install-tinygrad-in-colab">One-liner to install tinygrad in colab<a class="headerlink" href="#one-liner-to-install-tinygrad-in-colab" title="Permanent link">¤</a></h3>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="n">git</span><span class="o">+</span><span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">github</span><span class="o">.</span><span class="n">com</span><span class="o">/</span><span class="n">tinygrad</span><span class="o">/</span><span class="n">tinygrad</span><span class="o">.</span><span class="n">git</span>
|
|
</code></pre></div>
|
|
<h3 id="whats-the-default-device">What's the default device?<a class="headerlink" href="#whats-the-default-device" title="Permanent link">¤</a></h3>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad</span><span class="w"> </span><span class="kn">import</span> <span class="n">Device</span>
|
|
<span class="nb">print</span><span class="p">(</span><span class="n">Device</span><span class="o">.</span><span class="n">DEFAULT</span><span class="p">)</span>
|
|
</code></pre></div>
|
|
<p>You will see <code class="language-python highlight"><span class="n">CUDA</span></code> here on a GPU instance, or <code class="language-python highlight"><span class="n">CPU</span></code> here on a CPU instance.</p>
|
|
<h2 id="a-simple-model">A simple model<a class="headerlink" href="#a-simple-model" title="Permanent link">¤</a></h2>
|
|
<p>We'll use the model from <a href="https://keras.io/examples/vision/mnist_convnet/">the Keras tutorial</a>.</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad</span><span class="w"> </span><span class="kn">import</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">nn</span>
|
|
|
|
<span class="k">class</span><span class="w"> </span><span class="nc">Model</span><span class="p">:</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">l1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">))</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">l2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">))</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">l3</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">1600</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
|
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">l1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">relu</span><span class="p">()</span><span class="o">.</span><span class="n">max_pool2d</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">))</span>
|
|
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">l2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">relu</span><span class="p">()</span><span class="o">.</span><span class="n">max_pool2d</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">))</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">l3</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">))</span>
|
|
</code></pre></div>
|
|
<p>Two key differences from PyTorch:</p>
|
|
<ul>
|
|
<li>Only the stateful layers are declared in <code class="language-python highlight"><span class="fm">__init__</span></code></li>
|
|
<li>There's no <code class="language-python highlight"><span class="n">nn</span><span class="o">.</span><span class="n">Module</span></code> class or <code class="language-python highlight"><span class="n">forward</span></code> function, just a normal class and <code class="language-python highlight"><span class="fm">__call__</span></code></li>
|
|
</ul>
|
|
<h3 id="getting-the-dataset">Getting the dataset<a class="headerlink" href="#getting-the-dataset" title="Permanent link">¤</a></h3>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad.nn.datasets</span><span class="w"> </span><span class="kn">import</span> <span class="n">mnist</span>
|
|
<span class="n">X_train</span><span class="p">,</span> <span class="n">Y_train</span><span class="p">,</span> <span class="n">X_test</span><span class="p">,</span> <span class="n">Y_test</span> <span class="o">=</span> <span class="n">mnist</span><span class="p">()</span>
|
|
<span class="nb">print</span><span class="p">(</span><span class="n">X_train</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">X_train</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">Y_train</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">Y_train</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
|
<span class="c1"># (60000, 1, 28, 28) dtypes.uchar (60000,) dtypes.uchar</span>
|
|
</code></pre></div>
|
|
<p>tinygrad includes MNIST, it only adds four lines. Feel free to read the <a href="https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/datasets.py">function</a>.</p>
|
|
<h2 id="using-the-model">Using the model<a class="headerlink" href="#using-the-model" title="Permanent link">¤</a></h2>
|
|
<p>MNIST is small enough that the <code class="language-python highlight"><span class="n">mnist</span><span class="p">()</span></code> function copies the dataset to the default device.</p>
|
|
<p>So creating the model and evaluating it is a matter of:</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">()</span>
|
|
<span class="n">acc</span> <span class="o">=</span> <span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X_test</span><span class="p">)</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="n">Y_test</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
|
|
<span class="c1"># NOTE: tinygrad is lazy, and hasn't actually run anything by this point</span>
|
|
<span class="nb">print</span><span class="p">(</span><span class="n">acc</span><span class="o">.</span><span class="n">item</span><span class="p">())</span> <span class="c1"># ~10% accuracy, as expected from a random model</span>
|
|
</code></pre></div>
|
|
<h3 id="training-the-model">Training the model<a class="headerlink" href="#training-the-model" title="Permanent link">¤</a></h3>
|
|
<p>We'll use the Adam optimizer. The <code class="language-python highlight"><span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">get_parameters</span></code> will walk the model class and pull out the parameters for the optimizer. Also, in tinygrad, it's typical to write a function to do the training step so it can be jitted.</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="n">optim</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">get_parameters</span><span class="p">(</span><span class="n">model</span><span class="p">))</span>
|
|
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">128</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">step</span><span class="p">():</span>
|
|
<span class="n">Tensor</span><span class="o">.</span><span class="n">training</span> <span class="o">=</span> <span class="kc">True</span> <span class="c1"># makes dropout work</span>
|
|
<span class="n">samples</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">X_train</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
|
|
<span class="n">X</span><span class="p">,</span> <span class="n">Y</span> <span class="o">=</span> <span class="n">X_train</span><span class="p">[</span><span class="n">samples</span><span class="p">],</span> <span class="n">Y_train</span><span class="p">[</span><span class="n">samples</span><span class="p">]</span>
|
|
<span class="n">optim</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
|
|
<span class="n">loss</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">)</span><span class="o">.</span><span class="n">sparse_categorical_crossentropy</span><span class="p">(</span><span class="n">Y</span><span class="p">)</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
|
|
<span class="n">optim</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
|
|
<span class="k">return</span> <span class="n">loss</span>
|
|
</code></pre></div>
|
|
<p>You can time a step with:</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="kn">import</span><span class="w"> </span><span class="nn">timeit</span>
|
|
<span class="n">timeit</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">repeat</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">number</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="c1">#[0.08268719699981375,</span>
|
|
<span class="c1"># 0.07478952900009972,</span>
|
|
<span class="c1"># 0.07714716600003158,</span>
|
|
<span class="c1"># 0.07785399599970333,</span>
|
|
<span class="c1"># 0.07605237000007037]</span>
|
|
</code></pre></div>
|
|
<p>So around 75 ms on T4 colab.</p>
|
|
<p>If you want to see a breakdown of the time by kernel:</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad</span><span class="w"> </span><span class="kn">import</span> <span class="n">GlobalCounters</span><span class="p">,</span> <span class="n">Context</span>
|
|
<span class="n">GlobalCounters</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
|
|
<span class="k">with</span> <span class="n">Context</span><span class="p">(</span><span class="n">DEBUG</span><span class="o">=</span><span class="mi">2</span><span class="p">):</span> <span class="n">step</span><span class="p">()</span>
|
|
</code></pre></div>
|
|
<h3 id="why-so-slow">Why so slow?<a class="headerlink" href="#why-so-slow" title="Permanent link">¤</a></h3>
|
|
<p>Unlike PyTorch, tinygrad isn't designed to be fast like that. While 75 ms for one step is plenty fast for debugging, it's not great for training. Here, we introduce the first quintessentially tinygrad concept, the <code class="language-python highlight"><span class="n">TinyJit</span></code>.</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad</span><span class="w"> </span><span class="kn">import</span> <span class="n">TinyJit</span>
|
|
<span class="n">jit_step</span> <span class="o">=</span> <span class="n">TinyJit</span><span class="p">(</span><span class="n">step</span><span class="p">)</span>
|
|
</code></pre></div>
|
|
<div class="admonition note">
|
|
<p class="admonition-title">Note</p>
|
|
<p>It can also be used as a decorator <code class="language-python highlight"><span class="nd">@TinyJit</span></code></p>
|
|
</div>
|
|
<p>Now when we time it:</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="kn">import</span><span class="w"> </span><span class="nn">timeit</span>
|
|
<span class="n">timeit</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">jit_step</span><span class="p">,</span> <span class="n">repeat</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">number</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="c1"># [0.2596786549997887,</span>
|
|
<span class="c1"># 0.08989566299987928,</span>
|
|
<span class="c1"># 0.0012115650001760514,</span>
|
|
<span class="c1"># 0.001010227999813651,</span>
|
|
<span class="c1"># 0.0012164899999334011]</span>
|
|
</code></pre></div>
|
|
<p>1.0 ms is 75x faster! Note that we aren't syncing the GPU, so GPU time may be slower.</p>
|
|
<p>The first two runs of the function execute normally, with the JIT capturing the kernels. Starting from the third run, only the tinygrad operations are replayed, removing the overhead by skipping Python code execution. So be aware that any non-tinygrad Python values affecting the kernels will be "frozen" from the second run. Note that <code class="language-python highlight"><span class="n">Tensor</span></code> randomness functions work as expected.</p>
|
|
<p>Unlike other JITs, we JIT everything, including the optimizer. Think of it as a dumb replay on different data.</p>
|
|
<h2 id="putting-it-together">Putting it together<a class="headerlink" href="#putting-it-together" title="Permanent link">¤</a></h2>
|
|
<p>Since we are just randomly sampling from the dataset, there's no real concept of an epoch. We have a batch size of 128, so the Keras example is taking about 7000 steps.</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">7000</span><span class="p">):</span>
|
|
<span class="n">loss</span> <span class="o">=</span> <span class="n">jit_step</span><span class="p">()</span>
|
|
<span class="k">if</span> <span class="n">step</span><span class="o">%</span><span class="mi">100</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="n">Tensor</span><span class="o">.</span><span class="n">training</span> <span class="o">=</span> <span class="kc">False</span>
|
|
<span class="n">acc</span> <span class="o">=</span> <span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X_test</span><span class="p">)</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="n">Y_test</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
|
|
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"step </span><span class="si">{</span><span class="n">step</span><span class="si">:</span><span class="s2">4d</span><span class="si">}</span><span class="s2">, loss </span><span class="si">{</span><span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2">, acc </span><span class="si">{</span><span class="n">acc</span><span class="o">*</span><span class="mf">100.</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2">%"</span><span class="p">)</span>
|
|
</code></pre></div>
|
|
<p>It doesn't take long to reach 98%, and it usually reaches 99%.</p>
|
|
<div class="language-text highlight"><pre><span></span><code>step 0, loss 4.03, acc 71.43%
|
|
step 100, loss 0.34, acc 93.86%
|
|
step 200, loss 0.23, acc 95.97%
|
|
step 300, loss 0.18, acc 96.32%
|
|
step 400, loss 0.18, acc 96.76%
|
|
step 500, loss 0.13, acc 97.46%
|
|
step 600, loss 0.14, acc 97.45%
|
|
step 700, loss 0.10, acc 97.27%
|
|
step 800, loss 0.23, acc 97.49%
|
|
step 900, loss 0.13, acc 97.51%
|
|
step 1000, loss 0.13, acc 97.88%
|
|
step 1100, loss 0.11, acc 97.72%
|
|
step 1200, loss 0.14, acc 97.65%
|
|
step 1300, loss 0.12, acc 98.04%
|
|
step 1400, loss 0.25, acc 98.17%
|
|
step 1500, loss 0.11, acc 97.86%
|
|
step 1600, loss 0.21, acc 98.21%
|
|
step 1700, loss 0.14, acc 98.34%
|
|
...
|
|
</code></pre></div>
|
|
<h2 id="from-here">From here?<a class="headerlink" href="#from-here" title="Permanent link">¤</a></h2>
|
|
<p>tinygrad is yours to play with now. It's pure Python and short, so unlike PyTorch, fixing library bugs is well within your abilities.</p>
|
|
<ul>
|
|
<li>It's two lines to add multiGPU support to this example (can you find them?). You have to <code class="language-python highlight"><span class="o">.</span><span class="n">shard</span></code> the model to all GPUs, and <code class="language-python highlight"><span class="o">.</span><span class="n">shard</span></code> the dataset by batch.</li>
|
|
<li><code class="language-python highlight"><span class="k">with</span> <span class="n">Context</span><span class="p">(</span><span class="n">DEBUG</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></code> shows the running kernels, <code class="language-python highlight"><span class="n">DEBUG</span><span class="o">=</span><span class="mi">4</span></code> shows the code. All <code class="language-python highlight"><span class="n">Context</span></code> variables can also be environment variables.</li>
|
|
<li><code class="language-python highlight"><span class="k">with</span> <span class="n">Context</span><span class="p">(</span><span class="n">BEAM</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span></code> will do a BEAM search on the kernels, searching many possible implementations for what runs the fastest on your hardware. After this search, tinygrad is usually speed competitive with PyTorch, and the results are cached so you won't have to search next time.</li>
|
|
</ul>
|
|
<p><a href="https://discord.gg/ZjZadyC7PK">Join our Discord</a> for help, and if you want to be a tinygrad developer. Please read the Discord rules when you get there.</p>
|
|
<p><a href="https://twitter.com/__tinygrad__">Follow us on Twitter</a> to keep up with the project.</p>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
</article>
|
|
</div>
|
|
|
|
|
|
<script>var target=document.getElementById(location.hash.slice(1));target&&target.name&&(target.checked=target.name.startsWith("__tabbed_"))</script>
|
|
</div>
|
|
|
|
<button type="button" class="md-top md-icon" data-md-component="top" hidden>
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8z"/></svg>
|
|
Back to top
|
|
</button>
|
|
|
|
</main>
|
|
|
|
<footer class="md-footer">
|
|
|
|
|
|
|
|
<nav class="md-footer__inner md-grid" aria-label="Footer" >
|
|
|
|
|
|
<a href="../showcase/" class="md-footer__link md-footer__link--prev" aria-label="Previous: Showcase">
|
|
<div class="md-footer__button md-icon">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
|
|
</div>
|
|
<div class="md-footer__title">
|
|
<span class="md-footer__direction">
|
|
Previous
|
|
</span>
|
|
<div class="md-ellipsis">
|
|
Showcase
|
|
</div>
|
|
</div>
|
|
</a>
|
|
|
|
|
|
|
|
<a href="../tensor/" class="md-footer__link md-footer__link--next" aria-label="Next: Tensor">
|
|
<div class="md-footer__title">
|
|
<span class="md-footer__direction">
|
|
Next
|
|
</span>
|
|
<div class="md-ellipsis">
|
|
Tensor
|
|
</div>
|
|
</div>
|
|
<div class="md-footer__button md-icon">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M4 11v2h12l-5.5 5.5 1.42 1.42L19.84 12l-7.92-7.92L10.5 5.5 16 11z"/></svg>
|
|
</div>
|
|
</a>
|
|
|
|
</nav>
|
|
|
|
|
|
<div class="md-footer-meta md-typeset">
|
|
<div class="md-footer-meta__inner md-grid">
|
|
<div class="md-copyright">
|
|
|
|
|
|
Made with
|
|
<a href="https://squidfunk.github.io/mkdocs-material/" target="_blank" rel="noopener">
|
|
Material for MkDocs
|
|
</a>
|
|
|
|
</div>
|
|
|
|
</div>
|
|
</div>
|
|
</footer>
|
|
|
|
</div>
|
|
<div class="md-dialog" data-md-component="dialog">
|
|
<div class="md-dialog__inner md-typeset"></div>
|
|
</div>
|
|
|
|
|
|
|
|
|
|
|
|
<script id="__config" type="application/json">{"annotate": null, "base": "..", "features": ["announce.dismiss", "content.action.edit", "content.action.view", "content.code.annotate", "content.code.copy", "content.tooltips", "navigation.footer", "navigation.indexes", "navigation.sections", "navigation.expand", "navigation.top", "navigation.path", "search.highlight", "search.suggest", "toc.follow", "toc.integrate"], "search": "../assets/javascripts/workers/search.2c215733.min.js", "tags": null, "translations": {"clipboard.copied": "Copied to clipboard", "clipboard.copy": "Copy to clipboard", "search.result.more.one": "1 more on this page", "search.result.more.other": "# more on this page", "search.result.none": "No matching documents", "search.result.one": "1 matching document", "search.result.other": "# matching documents", "search.result.placeholder": "Type to start searching", "search.result.term.missing": "Missing", "select.version": "Select version"}, "version": null}</script>
|
|
|
|
|
|
<script src="../assets/javascripts/bundle.79ae519e.min.js"></script>
|
|
|
|
<script src="../assets/_markdown_exec_pyodide.js"></script>
|
|
|
|
|
|
</body>
|
|
</html> |