Files
tinygrad/quickstart/index.html
2026-01-14 22:00:43 +00:00

1746 lines
69 KiB
HTML

<!doctype html>
<html lang="en" class="no-js">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width,initial-scale=1">
<link rel="canonical" href="https://docs.tinygrad.org/quickstart/">
<link rel="prev" href="..">
<link rel="next" href="../showcase/">
<link rel="icon" href="../favicon.svg">
<meta name="generator" content="mkdocs-1.6.1, mkdocs-material-9.7.1">
<title>Quickstart - tinygrad docs</title>
<link rel="stylesheet" href="../assets/stylesheets/main.484c7ddc.min.css">
<link rel="stylesheet" href="../assets/stylesheets/palette.ab4e12ef.min.css">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300i,400,400i,700,700i%7CRoboto+Mono:400,400i,700,700i&display=fallback">
<style>:root{--md-text-font:"Roboto";--md-code-font:"Roboto Mono"}</style>
<link rel="stylesheet" href="../assets/_markdown_exec_pyodide.css">
<link rel="stylesheet" href="../assets/_markdown_exec_ansi.css">
<link rel="stylesheet" href="../assets/_mkdocstrings.css">
<script>__md_scope=new URL("..",location),__md_hash=e=>[...e].reduce(((e,_)=>(e<<5)-e+_.charCodeAt(0)),0),__md_get=(e,_=localStorage,t=__md_scope)=>JSON.parse(_.getItem(t.pathname+"."+e)),__md_set=(e,_,t=localStorage,a=__md_scope)=>{try{t.setItem(a.pathname+"."+e,JSON.stringify(_))}catch(e){}}</script>
</head>
<body dir="ltr" data-md-color-scheme="default" data-md-color-primary="black" data-md-color-accent="lime">
<input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer" autocomplete="off">
<input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search" autocomplete="off">
<label class="md-overlay" for="__drawer"></label>
<div data-md-component="skip">
<a href="#quick-start-guide" class="md-skip">
Skip to content
</a>
</div>
<div data-md-component="announce">
</div>
<header class="md-header md-header--shadow" data-md-component="header">
<nav class="md-header__inner md-grid" aria-label="Header">
<a href=".." title="tinygrad docs" class="md-header__button md-logo" aria-label="tinygrad docs" data-md-component="logo">
<img src="../logo_tiny_dark.svg" alt="logo">
</a>
<label class="md-header__button md-icon" for="__drawer">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M3 6h18v2H3zm0 5h18v2H3zm0 5h18v2H3z"/></svg>
</label>
<div class="md-header__title" data-md-component="header-title">
<div class="md-header__ellipsis">
<div class="md-header__topic">
<span class="md-ellipsis">
tinygrad docs
</span>
</div>
<div class="md-header__topic" data-md-component="header-topic">
<span class="md-ellipsis">
Quickstart
</span>
</div>
</div>
</div>
<form class="md-header__option" data-md-component="palette">
<input class="md-option" data-md-color-media="(prefers-color-scheme)" data-md-color-scheme="default" data-md-color-primary="black" data-md-color-accent="lime" aria-label="Switch to light mode" type="radio" name="__palette" id="__palette_0">
<label class="md-header__button md-icon" title="Switch to light mode" for="__palette_1" hidden>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="m14.3 16-.7-2h-3.2l-.7 2H7.8L11 7h2l3.2 9zM20 8.69V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12zm-9.15 3.96h2.3L12 9z"/></svg>
</label>
<input class="md-option" data-md-color-media="(prefers-color-scheme: light)" data-md-color-scheme="default" data-md-color-primary="black" data-md-color-accent="lime" aria-label="Switch to dark mode" type="radio" name="__palette" id="__palette_1">
<label class="md-header__button md-icon" title="Switch to dark mode" for="__palette_2" hidden>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 8a4 4 0 0 0-4 4 4 4 0 0 0 4 4 4 4 0 0 0 4-4 4 4 0 0 0-4-4m0 10a6 6 0 0 1-6-6 6 6 0 0 1 6-6 6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
</label>
<input class="md-option" data-md-color-media="(prefers-color-scheme: dark)" data-md-color-scheme="slate" data-md-color-primary="black" data-md-color-accent="lime" aria-label="Switch to system preference" type="radio" name="__palette" id="__palette_2">
<label class="md-header__button md-icon" title="Switch to system preference" for="__palette_0" hidden>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 18c-.89 0-1.74-.2-2.5-.55C11.56 16.5 13 14.42 13 12s-1.44-4.5-3.5-5.45C10.26 6.2 11.11 6 12 6a6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
</label>
</form>
<script>var palette=__md_get("__palette");if(palette&&palette.color){if("(prefers-color-scheme)"===palette.color.media){var media=matchMedia("(prefers-color-scheme: light)"),input=document.querySelector(media.matches?"[data-md-color-media='(prefers-color-scheme: light)']":"[data-md-color-media='(prefers-color-scheme: dark)']");palette.color.media=input.getAttribute("data-md-color-media"),palette.color.scheme=input.getAttribute("data-md-color-scheme"),palette.color.primary=input.getAttribute("data-md-color-primary"),palette.color.accent=input.getAttribute("data-md-color-accent")}for(var[key,value]of Object.entries(palette.color))document.body.setAttribute("data-md-color-"+key,value)}</script>
<label class="md-header__button md-icon" for="__search">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
</label>
<div class="md-search" data-md-component="search" role="dialog">
<label class="md-search__overlay" for="__search"></label>
<div class="md-search__inner" role="search">
<form class="md-search__form" name="search">
<input type="text" class="md-search__input" name="query" aria-label="Search" placeholder="Search" autocapitalize="off" autocorrect="off" autocomplete="off" spellcheck="false" data-md-component="search-query" required>
<label class="md-search__icon md-icon" for="__search">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
</label>
<nav class="md-search__options" aria-label="Search">
<button type="reset" class="md-search__icon md-icon" title="Clear" aria-label="Clear" tabindex="-1">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M19 6.41 17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12z"/></svg>
</button>
</nav>
<div class="md-search__suggest" data-md-component="search-suggest"></div>
</form>
<div class="md-search__output">
<div class="md-search__scrollwrap" tabindex="0" data-md-scrollfix>
<div class="md-search-result" data-md-component="search-result">
<div class="md-search-result__meta">
Initializing search
</div>
<ol class="md-search-result__list" role="presentation"></ol>
</div>
</div>
</div>
</div>
</div>
<div class="md-header__source">
<a href="https://github.com/tinygrad/tinygrad/" title="Go to repository" class="md-source" data-md-component="source">
<div class="md-source__icon md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 7.1.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2025 Fonticons, Inc.--><path d="M439.6 236.1 244 40.5c-5.4-5.5-12.8-8.5-20.4-8.5s-15 3-20.4 8.4L162.5 81l51.5 51.5c27.1-9.1 52.7 16.8 43.4 43.7l49.7 49.7c34.2-11.8 61.2 31 35.5 56.7-26.5 26.5-70.2-2.9-56-37.3L240.3 199v121.9c25.3 12.5 22.3 41.8 9.1 55-6.4 6.4-15.2 10.1-24.3 10.1s-17.8-3.6-24.3-10.1c-17.6-17.6-11.1-46.9 11.2-56v-123c-20.8-8.5-24.6-30.7-18.6-45L142.6 101 8.5 235.1C3 240.6 0 247.9 0 255.5s3 15 8.5 20.4l195.6 195.7c5.4 5.4 12.7 8.4 20.4 8.4s15-3 20.4-8.4l194.7-194.7c5.4-5.4 8.4-12.8 8.4-20.4s-3-15-8.4-20.4"/></svg>
</div>
<div class="md-source__repository">
GitHub
</div>
</a>
</div>
</nav>
</header>
<div class="md-container" data-md-component="container">
<main class="md-main" data-md-component="main">
<div class="md-main__inner md-grid">
<div class="md-sidebar md-sidebar--primary" data-md-component="sidebar" data-md-type="navigation" >
<div class="md-sidebar__scrollwrap">
<div class="md-sidebar__inner">
<nav class="md-nav md-nav--primary md-nav--integrated" aria-label="Navigation" data-md-level="0">
<label class="md-nav__title" for="__drawer">
<a href=".." title="tinygrad docs" class="md-nav__button md-logo" aria-label="tinygrad docs" data-md-component="logo">
<img src="../logo_tiny_dark.svg" alt="logo">
</a>
tinygrad docs
</label>
<div class="md-nav__source">
<a href="https://github.com/tinygrad/tinygrad/" title="Go to repository" class="md-source" data-md-component="source">
<div class="md-source__icon md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 7.1.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2025 Fonticons, Inc.--><path d="M439.6 236.1 244 40.5c-5.4-5.5-12.8-8.5-20.4-8.5s-15 3-20.4 8.4L162.5 81l51.5 51.5c27.1-9.1 52.7 16.8 43.4 43.7l49.7 49.7c34.2-11.8 61.2 31 35.5 56.7-26.5 26.5-70.2-2.9-56-37.3L240.3 199v121.9c25.3 12.5 22.3 41.8 9.1 55-6.4 6.4-15.2 10.1-24.3 10.1s-17.8-3.6-24.3-10.1c-17.6-17.6-11.1-46.9 11.2-56v-123c-20.8-8.5-24.6-30.7-18.6-45L142.6 101 8.5 235.1C3 240.6 0 247.9 0 255.5s3 15 8.5 20.4l195.6 195.7c5.4 5.4 12.7 8.4 20.4 8.4s15-3 20.4-8.4l194.7-194.7c5.4-5.4 8.4-12.8 8.4-20.4s-3-15-8.4-20.4"/></svg>
</div>
<div class="md-source__repository">
GitHub
</div>
</a>
</div>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item md-nav__item--active md-nav__item--section md-nav__item--nested">
<input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_1" checked>
<div class="md-nav__link md-nav__container">
<a href=".." class="md-nav__link ">
<span class="md-ellipsis">
Home
</span>
</a>
<label class="md-nav__link " for="__nav_1" id="__nav_1_label" tabindex="">
<span class="md-nav__icon md-icon"></span>
</label>
</div>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_1_label" aria-expanded="true">
<label class="md-nav__title" for="__nav_1">
<span class="md-nav__icon md-icon"></span>
Home
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item md-nav__item--active">
<input class="md-nav__toggle md-toggle" type="checkbox" id="__toc">
<label class="md-nav__link md-nav__link--active" for="__toc">
<span class="md-ellipsis">
Quickstart
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<a href="./" class="md-nav__link md-nav__link--active">
<span class="md-ellipsis">
Quickstart
</span>
</a>
<nav class="md-nav md-nav--secondary" aria-label="Table of contents">
<label class="md-nav__title" for="__toc">
<span class="md-nav__icon md-icon"></span>
Table of contents
</label>
<ul class="md-nav__list" data-md-component="toc" data-md-scrollfix>
<li class="md-nav__item">
<a href="#tensors" class="md-nav__link">
<span class="md-ellipsis">
Tensors
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#models" class="md-nav__link">
<span class="md-ellipsis">
Models
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#training" class="md-nav__link">
<span class="md-ellipsis">
Training
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#evaluation" class="md-nav__link">
<span class="md-ellipsis">
Evaluation
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#and-thats-it" class="md-nav__link">
<span class="md-ellipsis">
And that's it
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#extras" class="md-nav__link">
<span class="md-ellipsis">
Extras
</span>
</a>
<nav class="md-nav" aria-label="Extras">
<ul class="md-nav__list">
<li class="md-nav__item">
<a href="#jit" class="md-nav__link">
<span class="md-ellipsis">
JIT
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#saving-and-loading-models" class="md-nav__link">
<span class="md-ellipsis">
Saving and Loading Models
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#environment-variables" class="md-nav__link">
<span class="md-ellipsis">
Environment Variables
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#visualizing-the-computation-graph" class="md-nav__link">
<span class="md-ellipsis">
Visualizing the Computation Graph
</span>
</a>
</li>
</ul>
</nav>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item">
<a href="../showcase/" class="md-nav__link">
<span class="md-ellipsis">
Showcase
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../mnist/" class="md-nav__link">
<span class="md-ellipsis">
MNIST Tutorial
</span>
</a>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_5" >
<label class="md-nav__link" for="__nav_1_5" id="__nav_1_5_label" tabindex="0">
<span class="md-ellipsis">
API Reference
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="2" aria-labelledby="__nav_1_5_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_1_5">
<span class="md-nav__icon md-icon"></span>
API Reference
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_5_1" >
<div class="md-nav__link md-nav__container">
<a href="../tensor/" class="md-nav__link ">
<span class="md-ellipsis">
Tensor
</span>
</a>
<label class="md-nav__link " for="__nav_1_5_1" id="__nav_1_5_1_label" tabindex="0">
<span class="md-nav__icon md-icon"></span>
</label>
</div>
<nav class="md-nav" data-md-level="3" aria-labelledby="__nav_1_5_1_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_1_5_1">
<span class="md-nav__icon md-icon"></span>
Tensor
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../tensor/properties/" class="md-nav__link">
<span class="md-ellipsis">
Properties
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../tensor/creation/" class="md-nav__link">
<span class="md-ellipsis">
Creation
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../tensor/movement/" class="md-nav__link">
<span class="md-ellipsis">
Movement
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../tensor/elementwise/" class="md-nav__link">
<span class="md-ellipsis">
Elementwise
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../tensor/ops/" class="md-nav__link">
<span class="md-ellipsis">
Complex Ops
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item">
<a href="../dtypes/" class="md-nav__link">
<span class="md-ellipsis">
dtypes
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../nn/" class="md-nav__link">
<span class="md-ellipsis">
nn (Neural Networks)
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../env_vars/" class="md-nav__link">
<span class="md-ellipsis">
Environment Variables
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../runtime/" class="md-nav__link">
<span class="md-ellipsis">
Runtime
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_6" >
<label class="md-nav__link" for="__nav_1_6" id="__nav_1_6_label" tabindex="0">
<span class="md-ellipsis">
Developer
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="2" aria-labelledby="__nav_1_6_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_1_6">
<span class="md-nav__icon md-icon"></span>
Developer
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../developer/developer/" class="md-nav__link">
<span class="md-ellipsis">
Intro
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../developer/layout/" class="md-nav__link">
<span class="md-ellipsis">
Layout
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../developer/speed/" class="md-nav__link">
<span class="md-ellipsis">
Speed
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../developer/uop/" class="md-nav__link">
<span class="md-ellipsis">
UOp
</span>
</a>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_6_5" >
<label class="md-nav__link" for="__nav_1_6_5" id="__nav_1_6_5_label" tabindex="0">
<span class="md-ellipsis">
Runtime
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="3" aria-labelledby="__nav_1_6_5_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_1_6_5">
<span class="md-nav__icon md-icon"></span>
Runtime
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../developer/runtime/" class="md-nav__link">
<span class="md-ellipsis">
Runtime Overview
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../developer/hcq/" class="md-nav__link">
<span class="md-ellipsis">
HCQ
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../developer/am/" class="md-nav__link">
<span class="md-ellipsis">
AM Driver
</span>
</a>
</li>
</ul>
</nav>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item">
<a href="../tinybox/" class="md-nav__link">
<span class="md-ellipsis">
tinybox
</span>
</a>
</li>
</ul>
</nav>
</li>
</ul>
</nav>
</div>
</div>
</div>
<div class="md-content" data-md-component="content">
<article class="md-content__inner md-typeset">
<a href="https://github.com/tinygrad/tinygrad/edit/master/docs/quickstart.md" title="Edit this page" class="md-content__button md-icon" rel="edit">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M10 20H6V4h7v5h5v3.1l2-2V8l-6-6H6c-1.1 0-2 .9-2 2v16c0 1.1.9 2 2 2h4zm10.2-7c.1 0 .3.1.4.2l1.3 1.3c.2.2.2.6 0 .8l-1 1-2.1-2.1 1-1c.1-.1.2-.2.4-.2m0 3.9L14.1 23H12v-2.1l6.1-6.1z"/></svg>
</a>
<a href="https://github.com/tinygrad/tinygrad/raw/master/docs/quickstart.md" title="View source of this page" class="md-content__button md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M17 18c.56 0 1 .44 1 1s-.44 1-1 1-1-.44-1-1 .44-1 1-1m0-3c-2.73 0-5.06 1.66-6 4 .94 2.34 3.27 4 6 4s5.06-1.66 6-4c-.94-2.34-3.27-4-6-4m0 6.5a2.5 2.5 0 0 1-2.5-2.5 2.5 2.5 0 0 1 2.5-2.5 2.5 2.5 0 0 1 2.5 2.5 2.5 2.5 0 0 1-2.5 2.5M9.27 20H6V4h7v5h5v4.07c.7.08 1.36.25 2 .49V8l-6-6H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h4.5a8.2 8.2 0 0 1-1.23-2"/></svg>
</a>
<h1 id="quick-start-guide">Quick Start Guide<a class="headerlink" href="#quick-start-guide" title="Permanent link">¤</a></h1>
<p>This guide assumes no prior knowledge of pytorch or any other deep learning framework, but does assume some basic knowledge of neural networks.
It is intended to be a very quick overview of the high level API that tinygrad provides.</p>
<p>This guide is also structured as a tutorial which at the end of it you will have a working model that can classify handwritten digits.</p>
<p>We need some imports to get started:</p>
<div class="language-python highlight"><pre><span></span><code><span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad.helpers</span><span class="w"> </span><span class="kn">import</span> <span class="n">Timing</span>
</code></pre></div>
<h2 id="tensors">Tensors<a class="headerlink" href="#tensors" title="Permanent link">¤</a></h2>
<p>Tensors are the base data structure in tinygrad. They can be thought of as a multidimensional array of a specific data type.
All high level operations in tinygrad operate on these tensors.</p>
<p>The tensor class can be imported like so:</p>
<div class="language-python highlight"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad</span><span class="w"> </span><span class="kn">import</span> <span class="n">Tensor</span>
</code></pre></div>
<p>Tensors can be created from an existing data structure like a python list or numpy ndarray:</p>
<div class="language-python highlight"><pre><span></span><code><span class="n">t1</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">])</span>
<span class="n">na</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">])</span>
<span class="n">t2</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">na</span><span class="p">)</span>
</code></pre></div>
<p>Tensors can also be created using one of the many factory methods:</p>
<div class="language-python highlight"><pre><span></span><code><span class="n">full</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">full</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">fill_value</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span> <span class="c1"># create a tensor of shape (2, 3) filled with 5</span>
<span class="n">zeros</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="c1"># create a tensor of shape (2, 3) filled with 0</span>
<span class="n">ones</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="c1"># create a tensor of shape (2, 3) filled with 1</span>
<span class="n">full_like</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">full_like</span><span class="p">(</span><span class="n">full</span><span class="p">,</span> <span class="n">fill_value</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="c1"># create a tensor of the same shape as `full` filled with 2</span>
<span class="n">zeros_like</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">full</span><span class="p">)</span> <span class="c1"># create a tensor of the same shape as `full` filled with 0</span>
<span class="n">ones_like</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">full</span><span class="p">)</span> <span class="c1"># create a tensor of the same shape as `full` filled with 1</span>
<span class="n">eye</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span> <span class="c1"># create a 3x3 identity matrix</span>
<span class="n">arange</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">start</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">stop</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># create a tensor of shape (10,) filled with values from 0 to 9</span>
<span class="n">rand</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="c1"># create a tensor of shape (2, 3) filled with random values from a uniform distribution</span>
<span class="n">randn</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="c1"># create a tensor of shape (2, 3) filled with random values from a standard normal distribution</span>
<span class="n">uniform</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">low</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> <span class="c1"># create a tensor of shape (2, 3) filled with random values from a uniform distribution between 0 and 10</span>
</code></pre></div>
<p>There are even more of these factory methods, you can find them in the <a href="../tensor/creation/">Tensor Creation</a> file.</p>
<p>All the tensors creation methods can take a <code class="language-python highlight"><span class="n">dtype</span></code> argument to specify the data type of the tensor, find the supported <code class="language-python highlight"><span class="n">dtype</span></code> in <a href="../dtypes/">dtypes</a>.</p>
<div class="language-python highlight"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad</span><span class="w"> </span><span class="kn">import</span> <span class="n">dtypes</span>
<span class="n">t3</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtypes</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
</code></pre></div>
<p>Tensors allow you to perform operations on them like so:</p>
<div class="language-python highlight"><pre><span></span><code><span class="n">t4</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">])</span>
<span class="n">t5</span> <span class="o">=</span> <span class="p">(</span><span class="n">t4</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span>
<span class="n">t6</span> <span class="o">=</span> <span class="p">(</span><span class="n">t5</span> <span class="o">*</span> <span class="n">t4</span><span class="p">)</span><span class="o">.</span><span class="n">relu</span><span class="p">()</span><span class="o">.</span><span class="n">log_softmax</span><span class="p">()</span>
</code></pre></div>
<p>All of these operations are lazy and are only executed when you realize the tensor using <code class="language-python highlight"><span class="o">.</span><span class="n">realize</span><span class="p">()</span></code> or <code class="language-python highlight"><span class="o">.</span><span class="n">numpy</span><span class="p">()</span></code>.</p>
<div class="language-python highlight"><pre><span></span><code><span class="nb">print</span><span class="p">(</span><span class="n">t6</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
<span class="c1"># [-56. -48. -36. -20. 0.]</span>
</code></pre></div>
<p>There are a lot more operations that can be performed on tensors, you can find them in the <a href="../tensor/ops/">Tensor Ops</a> file.
Additionally reading through <a href="https://github.com/tinygrad/tinygrad/blob/master/docs/abstractions2.py">abstractions2.py</a> will help you understand how operations on these tensors make their way down to your hardware.</p>
<h2 id="models">Models<a class="headerlink" href="#models" title="Permanent link">¤</a></h2>
<p>Neural networks in tinygrad are really just represented by the operations performed on tensors.
These operations are commonly grouped into the <code class="language-python highlight"><span class="fm">__call__</span></code> method of a class which allows modularization and reuse of these groups of operations.
These classes do not need to inherit from any base class, in fact if they don't need any trainable parameters they don't even need to be a class!</p>
<p>An example of this would be the <code class="language-python highlight"><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span></code> class which represents a linear layer in a neural network.</p>
<div class="language-python highlight"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">Linear</span><span class="p">:</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_features</span><span class="p">,</span> <span class="n">out_features</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">initialization</span><span class="p">:</span> <span class="nb">str</span><span class="o">=</span><span class="s1">&#39;kaiming_uniform&#39;</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">initialization</span><span class="p">)(</span><span class="n">out_features</span><span class="p">,</span> <span class="n">in_features</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">out_features</span><span class="p">)</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="n">x</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">transpose</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">)</span>
</code></pre></div>
<p>There are more neural network modules already implemented in <a href="../nn/">nn</a>, and you can also implement your own.</p>
<p>We will be implementing a simple neural network that can classify handwritten digits from the MNIST dataset.
Our classifier will be a simple 2 layer neural network with a Leaky ReLU activation function.
It will use a hidden layer size of 128 and an output layer size of 10 (one for each digit) with no bias on either Linear layer.</p>
<div class="language-python highlight"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">TinyNet</span><span class="p">:</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">l1</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">l2</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">l1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">()</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">l2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x</span>
<span class="n">net</span> <span class="o">=</span> <span class="n">TinyNet</span><span class="p">()</span>
</code></pre></div>
<p>We can see that the forward pass of our neural network is just the sequence of operations performed on the input tensor <code class="language-python highlight"><span class="n">x</span></code>.
We can also see that functional operations like <code class="language-python highlight"><span class="n">leaky_relu</span></code> are not defined as classes and instead are just methods we can just call.
Finally, we just initialize an instance of our neural network, and we are ready to start training it.</p>
<h2 id="training">Training<a class="headerlink" href="#training" title="Permanent link">¤</a></h2>
<p>Now that we have our neural network defined we can start training it.
Training neural networks in tinygrad is super simple.
All we need to do is define our neural network, define our loss function, and then call <code class="language-python highlight"><span class="o">.</span><span class="n">backward</span><span class="p">()</span></code> on the loss function to compute the gradients.
They can then be used to update the parameters of our neural network using one of the many <a href="../nn/#optimizers">Optimizers</a>.</p>
<p>For our loss function we will be using sparse categorical cross entropy loss. The implementation below is taken from <a href="https://github.com/tinygrad/tinygrad/blob/master/tinygrad/tensor.py">tensor.py</a>, it's copied below to highlight an important detail of tinygrad.</p>
<div class="language-python highlight"><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">sparse_categorical_crossentropy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="n">loss_mask</span> <span class="o">=</span> <span class="n">Y</span> <span class="o">!=</span> <span class="n">ignore_index</span>
<span class="n">y_counter</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtypes</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">Y</span><span class="o">.</span><span class="n">numel</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">y</span> <span class="o">=</span> <span class="p">((</span><span class="n">y_counter</span> <span class="o">==</span> <span class="n">Y</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="o">-</span><span class="mf">1.0</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="o">*</span> <span class="n">loss_mask</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">*</span><span class="n">Y</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_softmax</span><span class="p">()</span><span class="o">.</span><span class="n">mul</span><span class="p">(</span><span class="n">y</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">/</span> <span class="n">loss_mask</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
</code></pre></div>
<p>As we can see in this implementation of cross entropy loss, there are certain operations that tinygrad does not support natively.
Load/store ops are not supported in tinygrad natively because they add complexity when trying to port to different backends, 90% of the models out there don't use/need them, and they can be implemented like it's done above with an <code class="language-python highlight"><span class="n">arange</span></code> mask.</p>
<p>For our optimizer we will be using the traditional stochastic gradient descent optimizer with a learning rate of 3e-4.</p>
<div class="language-python highlight"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad.nn.optim</span><span class="w"> </span><span class="kn">import</span> <span class="n">SGD</span>
<span class="n">opt</span> <span class="o">=</span> <span class="n">SGD</span><span class="p">([</span><span class="n">net</span><span class="o">.</span><span class="n">l1</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="n">net</span><span class="o">.</span><span class="n">l2</span><span class="o">.</span><span class="n">weight</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">3e-4</span><span class="p">)</span>
</code></pre></div>
<p>We can see that we are passing in the parameters of our neural network to the optimizer.
This is due to the fact that the optimizer needs to know which parameters to update.
There is a simpler way to do this just by using <code class="language-python highlight"><span class="n">get_parameters</span><span class="p">(</span><span class="n">net</span><span class="p">)</span></code> from <code class="language-python highlight"><span class="n">tinygrad</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">state</span></code> which will return a list of all the parameters in the neural network.
The parameters are just listed out explicitly here for clarity.</p>
<p>Now that we have our network, loss function, and optimizer defined all we are missing is the data to train on!
There are a couple of dataset loaders in tinygrad located in <a href="https://github.com/tinygrad/tinygrad/blob/master/extra/datasets">/extra/datasets</a>.
We will be using the MNIST dataset loader.</p>
<div class="language-python highlight"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">extra.datasets</span><span class="w"> </span><span class="kn">import</span> <span class="n">fetch_mnist</span>
</code></pre></div>
<p>Now we have everything we need to start training our neural network.
We will be training for 1000 steps with a batch size of 64.</p>
<p>We use <code class="language-python highlight"><span class="k">with</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">train</span><span class="p">()</span></code> to set the internal flag <code class="language-python highlight"><span class="n">Tensor</span><span class="o">.</span><span class="n">training</span></code> to <code class="language-python highlight"><span class="kc">True</span></code> during training.
Upon exit, the flag is restored to its previous value by the context manager.</p>
<div class="language-python highlight"><pre><span></span><code><span class="n">X_train</span><span class="p">,</span> <span class="n">Y_train</span><span class="p">,</span> <span class="n">X_test</span><span class="p">,</span> <span class="n">Y_test</span> <span class="o">=</span> <span class="n">fetch_mnist</span><span class="p">()</span>
<span class="k">with</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">train</span><span class="p">():</span>
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1000</span><span class="p">):</span>
<span class="c1"># random sample a batch</span>
<span class="n">samp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">X_train</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">64</span><span class="p">))</span>
<span class="n">batch</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">X_train</span><span class="p">[</span><span class="n">samp</span><span class="p">],</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="c1"># get the corresponding labels</span>
<span class="n">labels</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">Y_train</span><span class="p">[</span><span class="n">samp</span><span class="p">])</span>
<span class="c1"># forward pass</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">net</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
<span class="c1"># compute loss</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">sparse_categorical_crossentropy</span><span class="p">(</span><span class="n">out</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span>
<span class="c1"># zero gradients</span>
<span class="n">opt</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="c1"># backward pass</span>
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="c1"># update parameters</span>
<span class="n">opt</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="c1"># calculate accuracy</span>
<span class="n">pred</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">acc</span> <span class="o">=</span> <span class="p">(</span><span class="n">pred</span> <span class="o">==</span> <span class="n">labels</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">100</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Step </span><span class="si">{</span><span class="n">step</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s2"> | Loss: </span><span class="si">{</span><span class="n">loss</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="si">}</span><span class="s2"> | Accuracy: </span><span class="si">{</span><span class="n">acc</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
</code></pre></div>
<h2 id="evaluation">Evaluation<a class="headerlink" href="#evaluation" title="Permanent link">¤</a></h2>
<p>Now that we have trained our neural network we can evaluate it on the test set.
We will be using the same batch size of 64 and will be evaluating for 1000 of those batches.</p>
<div class="language-python highlight"><pre><span></span><code><span class="k">with</span> <span class="n">Timing</span><span class="p">(</span><span class="s2">&quot;Time: &quot;</span><span class="p">):</span>
<span class="n">avg_acc</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1000</span><span class="p">):</span>
<span class="c1"># random sample a batch</span>
<span class="n">samp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">X_test</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">64</span><span class="p">))</span>
<span class="n">batch</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">X_test</span><span class="p">[</span><span class="n">samp</span><span class="p">],</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="c1"># get the corresponding labels</span>
<span class="n">labels</span> <span class="o">=</span> <span class="n">Y_test</span><span class="p">[</span><span class="n">samp</span><span class="p">]</span>
<span class="c1"># forward pass</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">net</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
<span class="c1"># calculate accuracy</span>
<span class="n">pred</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
<span class="n">avg_acc</span> <span class="o">+=</span> <span class="p">(</span><span class="n">pred</span> <span class="o">==</span> <span class="n">labels</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Test Accuracy: </span><span class="si">{</span><span class="n">avg_acc</span><span class="w"> </span><span class="o">/</span><span class="w"> </span><span class="mi">1000</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
</code></pre></div>
<h2 id="and-thats-it">And that's it<a class="headerlink" href="#and-thats-it" title="Permanent link">¤</a></h2>
<p>Highly recommend you check out the <a href="https://github.com/tinygrad/tinygrad/blob/master/examples">examples/</a> folder for more examples of using tinygrad.
Reading the source code of tinygrad is also a great way to learn how it works.
Specifically the tests in <a href="https://github.com/tinygrad/tinygrad/blob/master/test">test/</a> are a great place to see how to use and the semantics of the different operations.
There are also a bunch of models implemented in <a href="https://github.com/tinygrad/tinygrad/blob/master/extra/models">models/</a> that you can use as a reference.</p>
<p>Additionally, feel free to ask questions in the <code class="language-python highlight"><span class="c1">#learn-tinygrad</span></code> channel on the <a href="https://discord.gg/beYbxwxVdx">discord</a>. Don't ask to ask, just ask!</p>
<h2 id="extras">Extras<a class="headerlink" href="#extras" title="Permanent link">¤</a></h2>
<h3 id="jit">JIT<a class="headerlink" href="#jit" title="Permanent link">¤</a></h3>
<p>Additionally, it is possible to speed up the computation of certain neural networks by using the JIT.
Currently, this does not support models with varying input sizes and non tinygrad operations.</p>
<p>To use the JIT we just need to add a function decorator to the forward pass of our neural network and ensure that the input and output are realized tensors.
Or in this case we will create a wrapper function and decorate the wrapper function to speed up the evaluation of our neural network.</p>
<div class="language-python highlight"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad</span><span class="w"> </span><span class="kn">import</span> <span class="n">TinyJit</span>
<span class="nd">@TinyJit</span>
<span class="k">def</span><span class="w"> </span><span class="nf">jit</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="n">net</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">realize</span><span class="p">()</span>
<span class="k">with</span> <span class="n">Timing</span><span class="p">(</span><span class="s2">&quot;Time: &quot;</span><span class="p">):</span>
<span class="n">avg_acc</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1000</span><span class="p">):</span>
<span class="c1"># random sample a batch</span>
<span class="n">samp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">X_test</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">64</span><span class="p">))</span>
<span class="n">batch</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">X_test</span><span class="p">[</span><span class="n">samp</span><span class="p">],</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="c1"># get the corresponding labels</span>
<span class="n">labels</span> <span class="o">=</span> <span class="n">Y_test</span><span class="p">[</span><span class="n">samp</span><span class="p">]</span>
<span class="c1"># forward pass with jit</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">jit</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
<span class="c1"># calculate accuracy</span>
<span class="n">pred</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
<span class="n">avg_acc</span> <span class="o">+=</span> <span class="p">(</span><span class="n">pred</span> <span class="o">==</span> <span class="n">labels</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Test Accuracy: </span><span class="si">{</span><span class="n">avg_acc</span><span class="w"> </span><span class="o">/</span><span class="w"> </span><span class="mi">1000</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
</code></pre></div>
<p>You will find that the evaluation time is much faster than before and that your accelerator utilization is much higher.</p>
<h3 id="saving-and-loading-models">Saving and Loading Models<a class="headerlink" href="#saving-and-loading-models" title="Permanent link">¤</a></h3>
<p>The standard weight format for tinygrad is <a href="https://github.com/huggingface/safetensors">safetensors</a>. This means that you can load the weights of any model also using safetensors into tinygrad.
There are functions in <a href="https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/state.py">state.py</a> to save and load models to and from this format.</p>
<div class="language-python highlight"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad.nn.state</span><span class="w"> </span><span class="kn">import</span> <span class="n">safe_save</span><span class="p">,</span> <span class="n">safe_load</span><span class="p">,</span> <span class="n">get_state_dict</span><span class="p">,</span> <span class="n">load_state_dict</span>
<span class="c1"># first we need the state dict of our model</span>
<span class="n">state_dict</span> <span class="o">=</span> <span class="n">get_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">)</span>
<span class="c1"># then we can just save it to a file</span>
<span class="n">safe_save</span><span class="p">(</span><span class="n">state_dict</span><span class="p">,</span> <span class="s2">&quot;model.safetensors&quot;</span><span class="p">)</span>
<span class="c1"># and load it back in</span>
<span class="n">state_dict</span> <span class="o">=</span> <span class="n">safe_load</span><span class="p">(</span><span class="s2">&quot;model.safetensors&quot;</span><span class="p">)</span>
<span class="n">load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">state_dict</span><span class="p">)</span>
</code></pre></div>
<p>Many of the models in the <a href="https://github.com/tinygrad/tinygrad/tree/master/extra/models">models/</a> folder have a <code class="language-python highlight"><span class="n">load_from_pretrained</span></code> method that will download and load the weights for you. These usually are pytorch weights meaning that you would need pytorch installed to load them.</p>
<h3 id="environment-variables">Environment Variables<a class="headerlink" href="#environment-variables" title="Permanent link">¤</a></h3>
<p>There exist a bunch of environment variables that control the runtime behavior of tinygrad.
Some of the commons ones are <code class="language-python highlight"><span class="n">DEBUG</span></code> and the different backend enablement variables.</p>
<p>You can find a full list and their descriptions in <a href="../env_vars/">env_vars.md</a>.</p>
<h3 id="visualizing-the-computation-graph">Visualizing the Computation Graph<a class="headerlink" href="#visualizing-the-computation-graph" title="Permanent link">¤</a></h3>
<p>It is possible to visualize the computation graph of a neural network using VIZ=1.</p>
</article>
</div>
<script>var target=document.getElementById(location.hash.slice(1));target&&target.name&&(target.checked=target.name.startsWith("__tabbed_"))</script>
</div>
<button type="button" class="md-top md-icon" data-md-component="top" hidden>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8z"/></svg>
Back to top
</button>
</main>
<footer class="md-footer">
<nav class="md-footer__inner md-grid" aria-label="Footer" >
<a href=".." class="md-footer__link md-footer__link--prev" aria-label="Previous: tinygrad documentation">
<div class="md-footer__button md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
</div>
<div class="md-footer__title">
<span class="md-footer__direction">
Previous
</span>
<div class="md-ellipsis">
tinygrad documentation
</div>
</div>
</a>
<a href="../showcase/" class="md-footer__link md-footer__link--next" aria-label="Next: Showcase">
<div class="md-footer__title">
<span class="md-footer__direction">
Next
</span>
<div class="md-ellipsis">
Showcase
</div>
</div>
<div class="md-footer__button md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M4 11v2h12l-5.5 5.5 1.42 1.42L19.84 12l-7.92-7.92L10.5 5.5 16 11z"/></svg>
</div>
</a>
</nav>
<div class="md-footer-meta md-typeset">
<div class="md-footer-meta__inner md-grid">
<div class="md-copyright">
Made with
<a href="https://squidfunk.github.io/mkdocs-material/" target="_blank" rel="noopener">
Material for MkDocs
</a>
</div>
</div>
</div>
</footer>
</div>
<div class="md-dialog" data-md-component="dialog">
<div class="md-dialog__inner md-typeset"></div>
</div>
<script id="__config" type="application/json">{"annotate": null, "base": "..", "features": ["announce.dismiss", "content.action.edit", "content.action.view", "content.code.annotate", "content.code.copy", "content.tooltips", "navigation.footer", "navigation.indexes", "navigation.sections", "navigation.expand", "navigation.top", "navigation.path", "search.highlight", "search.suggest", "toc.follow", "toc.integrate"], "search": "../assets/javascripts/workers/search.2c215733.min.js", "tags": null, "translations": {"clipboard.copied": "Copied to clipboard", "clipboard.copy": "Copy to clipboard", "search.result.more.one": "1 more on this page", "search.result.more.other": "# more on this page", "search.result.none": "No matching documents", "search.result.one": "1 matching document", "search.result.other": "# matching documents", "search.result.placeholder": "Type to start searching", "search.result.term.missing": "Missing", "select.version": "Select version"}, "version": null}</script>
<script src="../assets/javascripts/bundle.79ae519e.min.js"></script>
<script src="../assets/_markdown_exec_pyodide.js"></script>
</body>
</html>