mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-14 17:38:06 -05:00
1746 lines
69 KiB
HTML
1746 lines
69 KiB
HTML
|
|
<!doctype html>
|
|
<html lang="en" class="no-js">
|
|
<head>
|
|
|
|
<meta charset="utf-8">
|
|
<meta name="viewport" content="width=device-width,initial-scale=1">
|
|
|
|
|
|
|
|
<link rel="canonical" href="https://docs.tinygrad.org/quickstart/">
|
|
|
|
|
|
<link rel="prev" href="..">
|
|
|
|
|
|
<link rel="next" href="../showcase/">
|
|
|
|
|
|
|
|
|
|
|
|
<link rel="icon" href="../favicon.svg">
|
|
<meta name="generator" content="mkdocs-1.6.1, mkdocs-material-9.7.1">
|
|
|
|
|
|
|
|
<title>Quickstart - tinygrad docs</title>
|
|
|
|
|
|
|
|
<link rel="stylesheet" href="../assets/stylesheets/main.484c7ddc.min.css">
|
|
|
|
|
|
<link rel="stylesheet" href="../assets/stylesheets/palette.ab4e12ef.min.css">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
|
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300i,400,400i,700,700i%7CRoboto+Mono:400,400i,700,700i&display=fallback">
|
|
<style>:root{--md-text-font:"Roboto";--md-code-font:"Roboto Mono"}</style>
|
|
|
|
|
|
|
|
<link rel="stylesheet" href="../assets/_markdown_exec_pyodide.css">
|
|
|
|
<link rel="stylesheet" href="../assets/_markdown_exec_ansi.css">
|
|
|
|
<link rel="stylesheet" href="../assets/_mkdocstrings.css">
|
|
|
|
<script>__md_scope=new URL("..",location),__md_hash=e=>[...e].reduce(((e,_)=>(e<<5)-e+_.charCodeAt(0)),0),__md_get=(e,_=localStorage,t=__md_scope)=>JSON.parse(_.getItem(t.pathname+"."+e)),__md_set=(e,_,t=localStorage,a=__md_scope)=>{try{t.setItem(a.pathname+"."+e,JSON.stringify(_))}catch(e){}}</script>
|
|
|
|
|
|
|
|
|
|
|
|
</head>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<body dir="ltr" data-md-color-scheme="default" data-md-color-primary="black" data-md-color-accent="lime">
|
|
|
|
|
|
<input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer" autocomplete="off">
|
|
<input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search" autocomplete="off">
|
|
<label class="md-overlay" for="__drawer"></label>
|
|
<div data-md-component="skip">
|
|
|
|
|
|
<a href="#quick-start-guide" class="md-skip">
|
|
Skip to content
|
|
</a>
|
|
|
|
</div>
|
|
<div data-md-component="announce">
|
|
|
|
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<header class="md-header md-header--shadow" data-md-component="header">
|
|
<nav class="md-header__inner md-grid" aria-label="Header">
|
|
<a href=".." title="tinygrad docs" class="md-header__button md-logo" aria-label="tinygrad docs" data-md-component="logo">
|
|
|
|
<img src="../logo_tiny_dark.svg" alt="logo">
|
|
|
|
</a>
|
|
<label class="md-header__button md-icon" for="__drawer">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M3 6h18v2H3zm0 5h18v2H3zm0 5h18v2H3z"/></svg>
|
|
</label>
|
|
<div class="md-header__title" data-md-component="header-title">
|
|
<div class="md-header__ellipsis">
|
|
<div class="md-header__topic">
|
|
<span class="md-ellipsis">
|
|
tinygrad docs
|
|
</span>
|
|
</div>
|
|
<div class="md-header__topic" data-md-component="header-topic">
|
|
<span class="md-ellipsis">
|
|
|
|
Quickstart
|
|
|
|
</span>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
|
|
|
|
<form class="md-header__option" data-md-component="palette">
|
|
|
|
|
|
|
|
|
|
<input class="md-option" data-md-color-media="(prefers-color-scheme)" data-md-color-scheme="default" data-md-color-primary="black" data-md-color-accent="lime" aria-label="Switch to light mode" type="radio" name="__palette" id="__palette_0">
|
|
|
|
<label class="md-header__button md-icon" title="Switch to light mode" for="__palette_1" hidden>
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="m14.3 16-.7-2h-3.2l-.7 2H7.8L11 7h2l3.2 9zM20 8.69V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12zm-9.15 3.96h2.3L12 9z"/></svg>
|
|
</label>
|
|
|
|
|
|
|
|
|
|
|
|
<input class="md-option" data-md-color-media="(prefers-color-scheme: light)" data-md-color-scheme="default" data-md-color-primary="black" data-md-color-accent="lime" aria-label="Switch to dark mode" type="radio" name="__palette" id="__palette_1">
|
|
|
|
<label class="md-header__button md-icon" title="Switch to dark mode" for="__palette_2" hidden>
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 8a4 4 0 0 0-4 4 4 4 0 0 0 4 4 4 4 0 0 0 4-4 4 4 0 0 0-4-4m0 10a6 6 0 0 1-6-6 6 6 0 0 1 6-6 6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
|
|
</label>
|
|
|
|
|
|
|
|
|
|
|
|
<input class="md-option" data-md-color-media="(prefers-color-scheme: dark)" data-md-color-scheme="slate" data-md-color-primary="black" data-md-color-accent="lime" aria-label="Switch to system preference" type="radio" name="__palette" id="__palette_2">
|
|
|
|
<label class="md-header__button md-icon" title="Switch to system preference" for="__palette_0" hidden>
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 18c-.89 0-1.74-.2-2.5-.55C11.56 16.5 13 14.42 13 12s-1.44-4.5-3.5-5.45C10.26 6.2 11.11 6 12 6a6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
|
|
</label>
|
|
|
|
|
|
</form>
|
|
|
|
|
|
|
|
<script>var palette=__md_get("__palette");if(palette&&palette.color){if("(prefers-color-scheme)"===palette.color.media){var media=matchMedia("(prefers-color-scheme: light)"),input=document.querySelector(media.matches?"[data-md-color-media='(prefers-color-scheme: light)']":"[data-md-color-media='(prefers-color-scheme: dark)']");palette.color.media=input.getAttribute("data-md-color-media"),palette.color.scheme=input.getAttribute("data-md-color-scheme"),palette.color.primary=input.getAttribute("data-md-color-primary"),palette.color.accent=input.getAttribute("data-md-color-accent")}for(var[key,value]of Object.entries(palette.color))document.body.setAttribute("data-md-color-"+key,value)}</script>
|
|
|
|
|
|
|
|
|
|
|
|
<label class="md-header__button md-icon" for="__search">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
|
|
</label>
|
|
<div class="md-search" data-md-component="search" role="dialog">
|
|
<label class="md-search__overlay" for="__search"></label>
|
|
<div class="md-search__inner" role="search">
|
|
<form class="md-search__form" name="search">
|
|
<input type="text" class="md-search__input" name="query" aria-label="Search" placeholder="Search" autocapitalize="off" autocorrect="off" autocomplete="off" spellcheck="false" data-md-component="search-query" required>
|
|
<label class="md-search__icon md-icon" for="__search">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
|
|
</label>
|
|
<nav class="md-search__options" aria-label="Search">
|
|
|
|
<button type="reset" class="md-search__icon md-icon" title="Clear" aria-label="Clear" tabindex="-1">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M19 6.41 17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12z"/></svg>
|
|
</button>
|
|
</nav>
|
|
|
|
<div class="md-search__suggest" data-md-component="search-suggest"></div>
|
|
|
|
</form>
|
|
<div class="md-search__output">
|
|
<div class="md-search__scrollwrap" tabindex="0" data-md-scrollfix>
|
|
<div class="md-search-result" data-md-component="search-result">
|
|
<div class="md-search-result__meta">
|
|
Initializing search
|
|
</div>
|
|
<ol class="md-search-result__list" role="presentation"></ol>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
|
|
|
|
|
|
<div class="md-header__source">
|
|
<a href="https://github.com/tinygrad/tinygrad/" title="Go to repository" class="md-source" data-md-component="source">
|
|
<div class="md-source__icon md-icon">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 7.1.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2025 Fonticons, Inc.--><path d="M439.6 236.1 244 40.5c-5.4-5.5-12.8-8.5-20.4-8.5s-15 3-20.4 8.4L162.5 81l51.5 51.5c27.1-9.1 52.7 16.8 43.4 43.7l49.7 49.7c34.2-11.8 61.2 31 35.5 56.7-26.5 26.5-70.2-2.9-56-37.3L240.3 199v121.9c25.3 12.5 22.3 41.8 9.1 55-6.4 6.4-15.2 10.1-24.3 10.1s-17.8-3.6-24.3-10.1c-17.6-17.6-11.1-46.9 11.2-56v-123c-20.8-8.5-24.6-30.7-18.6-45L142.6 101 8.5 235.1C3 240.6 0 247.9 0 255.5s3 15 8.5 20.4l195.6 195.7c5.4 5.4 12.7 8.4 20.4 8.4s15-3 20.4-8.4l194.7-194.7c5.4-5.4 8.4-12.8 8.4-20.4s-3-15-8.4-20.4"/></svg>
|
|
</div>
|
|
<div class="md-source__repository">
|
|
GitHub
|
|
</div>
|
|
</a>
|
|
</div>
|
|
|
|
</nav>
|
|
|
|
</header>
|
|
|
|
<div class="md-container" data-md-component="container">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<main class="md-main" data-md-component="main">
|
|
<div class="md-main__inner md-grid">
|
|
|
|
|
|
|
|
<div class="md-sidebar md-sidebar--primary" data-md-component="sidebar" data-md-type="navigation" >
|
|
<div class="md-sidebar__scrollwrap">
|
|
<div class="md-sidebar__inner">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<nav class="md-nav md-nav--primary md-nav--integrated" aria-label="Navigation" data-md-level="0">
|
|
<label class="md-nav__title" for="__drawer">
|
|
<a href=".." title="tinygrad docs" class="md-nav__button md-logo" aria-label="tinygrad docs" data-md-component="logo">
|
|
|
|
<img src="../logo_tiny_dark.svg" alt="logo">
|
|
|
|
</a>
|
|
tinygrad docs
|
|
</label>
|
|
|
|
<div class="md-nav__source">
|
|
<a href="https://github.com/tinygrad/tinygrad/" title="Go to repository" class="md-source" data-md-component="source">
|
|
<div class="md-source__icon md-icon">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 7.1.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2025 Fonticons, Inc.--><path d="M439.6 236.1 244 40.5c-5.4-5.5-12.8-8.5-20.4-8.5s-15 3-20.4 8.4L162.5 81l51.5 51.5c27.1-9.1 52.7 16.8 43.4 43.7l49.7 49.7c34.2-11.8 61.2 31 35.5 56.7-26.5 26.5-70.2-2.9-56-37.3L240.3 199v121.9c25.3 12.5 22.3 41.8 9.1 55-6.4 6.4-15.2 10.1-24.3 10.1s-17.8-3.6-24.3-10.1c-17.6-17.6-11.1-46.9 11.2-56v-123c-20.8-8.5-24.6-30.7-18.6-45L142.6 101 8.5 235.1C3 240.6 0 247.9 0 255.5s3 15 8.5 20.4l195.6 195.7c5.4 5.4 12.7 8.4 20.4 8.4s15-3 20.4-8.4l194.7-194.7c5.4-5.4 8.4-12.8 8.4-20.4s-3-15-8.4-20.4"/></svg>
|
|
</div>
|
|
<div class="md-source__repository">
|
|
GitHub
|
|
</div>
|
|
</a>
|
|
</div>
|
|
|
|
<ul class="md-nav__list" data-md-scrollfix>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item md-nav__item--active md-nav__item--section md-nav__item--nested">
|
|
|
|
|
|
|
|
<input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_1" checked>
|
|
|
|
|
|
<div class="md-nav__link md-nav__container">
|
|
<a href=".." class="md-nav__link ">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Home
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
|
|
|
|
<label class="md-nav__link " for="__nav_1" id="__nav_1_label" tabindex="">
|
|
<span class="md-nav__icon md-icon"></span>
|
|
</label>
|
|
|
|
</div>
|
|
|
|
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_1_label" aria-expanded="true">
|
|
<label class="md-nav__title" for="__nav_1">
|
|
<span class="md-nav__icon md-icon"></span>
|
|
|
|
|
|
Home
|
|
|
|
|
|
</label>
|
|
<ul class="md-nav__list" data-md-scrollfix>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item md-nav__item--active">
|
|
|
|
<input class="md-nav__toggle md-toggle" type="checkbox" id="__toc">
|
|
|
|
|
|
|
|
|
|
|
|
<label class="md-nav__link md-nav__link--active" for="__toc">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Quickstart
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
<span class="md-nav__icon md-icon"></span>
|
|
</label>
|
|
|
|
<a href="./" class="md-nav__link md-nav__link--active">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Quickstart
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
|
|
|
|
|
|
<nav class="md-nav md-nav--secondary" aria-label="Table of contents">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<label class="md-nav__title" for="__toc">
|
|
<span class="md-nav__icon md-icon"></span>
|
|
Table of contents
|
|
</label>
|
|
<ul class="md-nav__list" data-md-component="toc" data-md-scrollfix>
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#tensors" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
Tensors
|
|
|
|
</span>
|
|
</a>
|
|
|
|
</li>
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#models" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
Models
|
|
|
|
</span>
|
|
</a>
|
|
|
|
</li>
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#training" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
Training
|
|
|
|
</span>
|
|
</a>
|
|
|
|
</li>
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#evaluation" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
Evaluation
|
|
|
|
</span>
|
|
</a>
|
|
|
|
</li>
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#and-thats-it" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
And that's it
|
|
|
|
</span>
|
|
</a>
|
|
|
|
</li>
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#extras" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
Extras
|
|
|
|
</span>
|
|
</a>
|
|
|
|
<nav class="md-nav" aria-label="Extras">
|
|
<ul class="md-nav__list">
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#jit" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
JIT
|
|
|
|
</span>
|
|
</a>
|
|
|
|
</li>
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#saving-and-loading-models" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
Saving and Loading Models
|
|
|
|
</span>
|
|
</a>
|
|
|
|
</li>
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#environment-variables" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
Environment Variables
|
|
|
|
</span>
|
|
</a>
|
|
|
|
</li>
|
|
|
|
<li class="md-nav__item">
|
|
<a href="#visualizing-the-computation-graph" class="md-nav__link">
|
|
<span class="md-ellipsis">
|
|
|
|
Visualizing the Computation Graph
|
|
|
|
</span>
|
|
</a>
|
|
|
|
</li>
|
|
|
|
</ul>
|
|
</nav>
|
|
|
|
</li>
|
|
|
|
</ul>
|
|
|
|
</nav>
|
|
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../showcase/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Showcase
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../mnist/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
MNIST Tutorial
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item md-nav__item--nested">
|
|
|
|
|
|
|
|
|
|
|
|
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_5" >
|
|
|
|
|
|
<label class="md-nav__link" for="__nav_1_5" id="__nav_1_5_label" tabindex="0">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
API Reference
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
<span class="md-nav__icon md-icon"></span>
|
|
</label>
|
|
|
|
<nav class="md-nav" data-md-level="2" aria-labelledby="__nav_1_5_label" aria-expanded="false">
|
|
<label class="md-nav__title" for="__nav_1_5">
|
|
<span class="md-nav__icon md-icon"></span>
|
|
|
|
|
|
API Reference
|
|
|
|
|
|
</label>
|
|
<ul class="md-nav__list" data-md-scrollfix>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item md-nav__item--nested">
|
|
|
|
|
|
|
|
|
|
|
|
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_5_1" >
|
|
|
|
|
|
<div class="md-nav__link md-nav__container">
|
|
<a href="../tensor/" class="md-nav__link ">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Tensor
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
|
|
|
|
<label class="md-nav__link " for="__nav_1_5_1" id="__nav_1_5_1_label" tabindex="0">
|
|
<span class="md-nav__icon md-icon"></span>
|
|
</label>
|
|
|
|
</div>
|
|
|
|
<nav class="md-nav" data-md-level="3" aria-labelledby="__nav_1_5_1_label" aria-expanded="false">
|
|
<label class="md-nav__title" for="__nav_1_5_1">
|
|
<span class="md-nav__icon md-icon"></span>
|
|
|
|
|
|
Tensor
|
|
|
|
|
|
</label>
|
|
<ul class="md-nav__list" data-md-scrollfix>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../tensor/properties/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Properties
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../tensor/creation/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Creation
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../tensor/movement/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Movement
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../tensor/elementwise/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Elementwise
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../tensor/ops/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Complex Ops
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
</ul>
|
|
</nav>
|
|
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../dtypes/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
dtypes
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../nn/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
nn (Neural Networks)
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../env_vars/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Environment Variables
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../runtime/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Runtime
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
</ul>
|
|
</nav>
|
|
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item md-nav__item--nested">
|
|
|
|
|
|
|
|
|
|
|
|
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_6" >
|
|
|
|
|
|
<label class="md-nav__link" for="__nav_1_6" id="__nav_1_6_label" tabindex="0">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Developer
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
<span class="md-nav__icon md-icon"></span>
|
|
</label>
|
|
|
|
<nav class="md-nav" data-md-level="2" aria-labelledby="__nav_1_6_label" aria-expanded="false">
|
|
<label class="md-nav__title" for="__nav_1_6">
|
|
<span class="md-nav__icon md-icon"></span>
|
|
|
|
|
|
Developer
|
|
|
|
|
|
</label>
|
|
<ul class="md-nav__list" data-md-scrollfix>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../developer/developer/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Intro
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../developer/layout/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Layout
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../developer/speed/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Speed
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../developer/uop/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
UOp
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item md-nav__item--nested">
|
|
|
|
|
|
|
|
|
|
|
|
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_6_5" >
|
|
|
|
|
|
<label class="md-nav__link" for="__nav_1_6_5" id="__nav_1_6_5_label" tabindex="0">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Runtime
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
<span class="md-nav__icon md-icon"></span>
|
|
</label>
|
|
|
|
<nav class="md-nav" data-md-level="3" aria-labelledby="__nav_1_6_5_label" aria-expanded="false">
|
|
<label class="md-nav__title" for="__nav_1_6_5">
|
|
<span class="md-nav__icon md-icon"></span>
|
|
|
|
|
|
Runtime
|
|
|
|
|
|
</label>
|
|
<ul class="md-nav__list" data-md-scrollfix>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../developer/runtime/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
Runtime Overview
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../developer/hcq/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
HCQ
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../developer/am/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
AM Driver
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
</ul>
|
|
</nav>
|
|
|
|
</li>
|
|
|
|
|
|
|
|
|
|
</ul>
|
|
</nav>
|
|
|
|
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<li class="md-nav__item">
|
|
<a href="../tinybox/" class="md-nav__link">
|
|
|
|
|
|
|
|
<span class="md-ellipsis">
|
|
|
|
|
|
tinybox
|
|
|
|
|
|
|
|
</span>
|
|
|
|
|
|
|
|
</a>
|
|
</li>
|
|
|
|
|
|
|
|
|
|
</ul>
|
|
</nav>
|
|
|
|
</li>
|
|
|
|
|
|
|
|
</ul>
|
|
</nav>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
|
|
|
|
|
|
|
|
<div class="md-content" data-md-component="content">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<article class="md-content__inner md-typeset">
|
|
|
|
|
|
|
|
|
|
|
|
<a href="https://github.com/tinygrad/tinygrad/edit/master/docs/quickstart.md" title="Edit this page" class="md-content__button md-icon" rel="edit">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M10 20H6V4h7v5h5v3.1l2-2V8l-6-6H6c-1.1 0-2 .9-2 2v16c0 1.1.9 2 2 2h4zm10.2-7c.1 0 .3.1.4.2l1.3 1.3c.2.2.2.6 0 .8l-1 1-2.1-2.1 1-1c.1-.1.2-.2.4-.2m0 3.9L14.1 23H12v-2.1l6.1-6.1z"/></svg>
|
|
</a>
|
|
|
|
|
|
|
|
|
|
|
|
<a href="https://github.com/tinygrad/tinygrad/raw/master/docs/quickstart.md" title="View source of this page" class="md-content__button md-icon">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M17 18c.56 0 1 .44 1 1s-.44 1-1 1-1-.44-1-1 .44-1 1-1m0-3c-2.73 0-5.06 1.66-6 4 .94 2.34 3.27 4 6 4s5.06-1.66 6-4c-.94-2.34-3.27-4-6-4m0 6.5a2.5 2.5 0 0 1-2.5-2.5 2.5 2.5 0 0 1 2.5-2.5 2.5 2.5 0 0 1 2.5 2.5 2.5 2.5 0 0 1-2.5 2.5M9.27 20H6V4h7v5h5v4.07c.7.08 1.36.25 2 .49V8l-6-6H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h4.5a8.2 8.2 0 0 1-1.23-2"/></svg>
|
|
</a>
|
|
|
|
|
|
|
|
<h1 id="quick-start-guide">Quick Start Guide<a class="headerlink" href="#quick-start-guide" title="Permanent link">¤</a></h1>
|
|
<p>This guide assumes no prior knowledge of pytorch or any other deep learning framework, but does assume some basic knowledge of neural networks.
|
|
It is intended to be a very quick overview of the high level API that tinygrad provides.</p>
|
|
<p>This guide is also structured as a tutorial which at the end of it you will have a working model that can classify handwritten digits.</p>
|
|
<p>We need some imports to get started:</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad.helpers</span><span class="w"> </span><span class="kn">import</span> <span class="n">Timing</span>
|
|
</code></pre></div>
|
|
<h2 id="tensors">Tensors<a class="headerlink" href="#tensors" title="Permanent link">¤</a></h2>
|
|
<p>Tensors are the base data structure in tinygrad. They can be thought of as a multidimensional array of a specific data type.
|
|
All high level operations in tinygrad operate on these tensors.</p>
|
|
<p>The tensor class can be imported like so:</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad</span><span class="w"> </span><span class="kn">import</span> <span class="n">Tensor</span>
|
|
</code></pre></div>
|
|
<p>Tensors can be created from an existing data structure like a python list or numpy ndarray:</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="n">t1</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">])</span>
|
|
<span class="n">na</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">])</span>
|
|
<span class="n">t2</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">na</span><span class="p">)</span>
|
|
</code></pre></div>
|
|
<p>Tensors can also be created using one of the many factory methods:</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="n">full</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">full</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">fill_value</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span> <span class="c1"># create a tensor of shape (2, 3) filled with 5</span>
|
|
<span class="n">zeros</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="c1"># create a tensor of shape (2, 3) filled with 0</span>
|
|
<span class="n">ones</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="c1"># create a tensor of shape (2, 3) filled with 1</span>
|
|
|
|
<span class="n">full_like</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">full_like</span><span class="p">(</span><span class="n">full</span><span class="p">,</span> <span class="n">fill_value</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="c1"># create a tensor of the same shape as `full` filled with 2</span>
|
|
<span class="n">zeros_like</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">full</span><span class="p">)</span> <span class="c1"># create a tensor of the same shape as `full` filled with 0</span>
|
|
<span class="n">ones_like</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">full</span><span class="p">)</span> <span class="c1"># create a tensor of the same shape as `full` filled with 1</span>
|
|
|
|
<span class="n">eye</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span> <span class="c1"># create a 3x3 identity matrix</span>
|
|
<span class="n">arange</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">start</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">stop</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># create a tensor of shape (10,) filled with values from 0 to 9</span>
|
|
|
|
<span class="n">rand</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="c1"># create a tensor of shape (2, 3) filled with random values from a uniform distribution</span>
|
|
<span class="n">randn</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="c1"># create a tensor of shape (2, 3) filled with random values from a standard normal distribution</span>
|
|
<span class="n">uniform</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">low</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> <span class="c1"># create a tensor of shape (2, 3) filled with random values from a uniform distribution between 0 and 10</span>
|
|
</code></pre></div>
|
|
<p>There are even more of these factory methods, you can find them in the <a href="../tensor/creation/">Tensor Creation</a> file.</p>
|
|
<p>All the tensors creation methods can take a <code class="language-python highlight"><span class="n">dtype</span></code> argument to specify the data type of the tensor, find the supported <code class="language-python highlight"><span class="n">dtype</span></code> in <a href="../dtypes/">dtypes</a>.</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad</span><span class="w"> </span><span class="kn">import</span> <span class="n">dtypes</span>
|
|
|
|
<span class="n">t3</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtypes</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
</code></pre></div>
|
|
<p>Tensors allow you to perform operations on them like so:</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="n">t4</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">])</span>
|
|
<span class="n">t5</span> <span class="o">=</span> <span class="p">(</span><span class="n">t4</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span>
|
|
<span class="n">t6</span> <span class="o">=</span> <span class="p">(</span><span class="n">t5</span> <span class="o">*</span> <span class="n">t4</span><span class="p">)</span><span class="o">.</span><span class="n">relu</span><span class="p">()</span><span class="o">.</span><span class="n">log_softmax</span><span class="p">()</span>
|
|
</code></pre></div>
|
|
<p>All of these operations are lazy and are only executed when you realize the tensor using <code class="language-python highlight"><span class="o">.</span><span class="n">realize</span><span class="p">()</span></code> or <code class="language-python highlight"><span class="o">.</span><span class="n">numpy</span><span class="p">()</span></code>.</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="nb">print</span><span class="p">(</span><span class="n">t6</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
|
<span class="c1"># [-56. -48. -36. -20. 0.]</span>
|
|
</code></pre></div>
|
|
<p>There are a lot more operations that can be performed on tensors, you can find them in the <a href="../tensor/ops/">Tensor Ops</a> file.
|
|
Additionally reading through <a href="https://github.com/tinygrad/tinygrad/blob/master/docs/abstractions2.py">abstractions2.py</a> will help you understand how operations on these tensors make their way down to your hardware.</p>
|
|
<h2 id="models">Models<a class="headerlink" href="#models" title="Permanent link">¤</a></h2>
|
|
<p>Neural networks in tinygrad are really just represented by the operations performed on tensors.
|
|
These operations are commonly grouped into the <code class="language-python highlight"><span class="fm">__call__</span></code> method of a class which allows modularization and reuse of these groups of operations.
|
|
These classes do not need to inherit from any base class, in fact if they don't need any trainable parameters they don't even need to be a class!</p>
|
|
<p>An example of this would be the <code class="language-python highlight"><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span></code> class which represents a linear layer in a neural network.</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">Linear</span><span class="p">:</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_features</span><span class="p">,</span> <span class="n">out_features</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">initialization</span><span class="p">:</span> <span class="nb">str</span><span class="o">=</span><span class="s1">'kaiming_uniform'</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">initialization</span><span class="p">)(</span><span class="n">out_features</span><span class="p">,</span> <span class="n">in_features</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">out_features</span><span class="p">)</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="n">x</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">transpose</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">)</span>
|
|
</code></pre></div>
|
|
<p>There are more neural network modules already implemented in <a href="../nn/">nn</a>, and you can also implement your own.</p>
|
|
<p>We will be implementing a simple neural network that can classify handwritten digits from the MNIST dataset.
|
|
Our classifier will be a simple 2 layer neural network with a Leaky ReLU activation function.
|
|
It will use a hidden layer size of 128 and an output layer size of 10 (one for each digit) with no bias on either Linear layer.</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">TinyNet</span><span class="p">:</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">l1</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">l2</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
|
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">l1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
|
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">()</span>
|
|
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">l2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="n">x</span>
|
|
|
|
<span class="n">net</span> <span class="o">=</span> <span class="n">TinyNet</span><span class="p">()</span>
|
|
</code></pre></div>
|
|
<p>We can see that the forward pass of our neural network is just the sequence of operations performed on the input tensor <code class="language-python highlight"><span class="n">x</span></code>.
|
|
We can also see that functional operations like <code class="language-python highlight"><span class="n">leaky_relu</span></code> are not defined as classes and instead are just methods we can just call.
|
|
Finally, we just initialize an instance of our neural network, and we are ready to start training it.</p>
|
|
<h2 id="training">Training<a class="headerlink" href="#training" title="Permanent link">¤</a></h2>
|
|
<p>Now that we have our neural network defined we can start training it.
|
|
Training neural networks in tinygrad is super simple.
|
|
All we need to do is define our neural network, define our loss function, and then call <code class="language-python highlight"><span class="o">.</span><span class="n">backward</span><span class="p">()</span></code> on the loss function to compute the gradients.
|
|
They can then be used to update the parameters of our neural network using one of the many <a href="../nn/#optimizers">Optimizers</a>.</p>
|
|
<p>For our loss function we will be using sparse categorical cross entropy loss. The implementation below is taken from <a href="https://github.com/tinygrad/tinygrad/blob/master/tinygrad/tensor.py">tensor.py</a>, it's copied below to highlight an important detail of tinygrad.</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">sparse_categorical_crossentropy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
|
<span class="n">loss_mask</span> <span class="o">=</span> <span class="n">Y</span> <span class="o">!=</span> <span class="n">ignore_index</span>
|
|
<span class="n">y_counter</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtypes</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">Y</span><span class="o">.</span><span class="n">numel</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
|
<span class="n">y</span> <span class="o">=</span> <span class="p">((</span><span class="n">y_counter</span> <span class="o">==</span> <span class="n">Y</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="o">-</span><span class="mf">1.0</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="o">*</span> <span class="n">loss_mask</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">*</span><span class="n">Y</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_softmax</span><span class="p">()</span><span class="o">.</span><span class="n">mul</span><span class="p">(</span><span class="n">y</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">/</span> <span class="n">loss_mask</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
|
|
</code></pre></div>
|
|
<p>As we can see in this implementation of cross entropy loss, there are certain operations that tinygrad does not support natively.
|
|
Load/store ops are not supported in tinygrad natively because they add complexity when trying to port to different backends, 90% of the models out there don't use/need them, and they can be implemented like it's done above with an <code class="language-python highlight"><span class="n">arange</span></code> mask.</p>
|
|
<p>For our optimizer we will be using the traditional stochastic gradient descent optimizer with a learning rate of 3e-4.</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad.nn.optim</span><span class="w"> </span><span class="kn">import</span> <span class="n">SGD</span>
|
|
|
|
<span class="n">opt</span> <span class="o">=</span> <span class="n">SGD</span><span class="p">([</span><span class="n">net</span><span class="o">.</span><span class="n">l1</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="n">net</span><span class="o">.</span><span class="n">l2</span><span class="o">.</span><span class="n">weight</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">3e-4</span><span class="p">)</span>
|
|
</code></pre></div>
|
|
<p>We can see that we are passing in the parameters of our neural network to the optimizer.
|
|
This is due to the fact that the optimizer needs to know which parameters to update.
|
|
There is a simpler way to do this just by using <code class="language-python highlight"><span class="n">get_parameters</span><span class="p">(</span><span class="n">net</span><span class="p">)</span></code> from <code class="language-python highlight"><span class="n">tinygrad</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">state</span></code> which will return a list of all the parameters in the neural network.
|
|
The parameters are just listed out explicitly here for clarity.</p>
|
|
<p>Now that we have our network, loss function, and optimizer defined all we are missing is the data to train on!
|
|
There are a couple of dataset loaders in tinygrad located in <a href="https://github.com/tinygrad/tinygrad/blob/master/extra/datasets">/extra/datasets</a>.
|
|
We will be using the MNIST dataset loader.</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">extra.datasets</span><span class="w"> </span><span class="kn">import</span> <span class="n">fetch_mnist</span>
|
|
</code></pre></div>
|
|
<p>Now we have everything we need to start training our neural network.
|
|
We will be training for 1000 steps with a batch size of 64.</p>
|
|
<p>We use <code class="language-python highlight"><span class="k">with</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">train</span><span class="p">()</span></code> to set the internal flag <code class="language-python highlight"><span class="n">Tensor</span><span class="o">.</span><span class="n">training</span></code> to <code class="language-python highlight"><span class="kc">True</span></code> during training.
|
|
Upon exit, the flag is restored to its previous value by the context manager.</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="n">X_train</span><span class="p">,</span> <span class="n">Y_train</span><span class="p">,</span> <span class="n">X_test</span><span class="p">,</span> <span class="n">Y_test</span> <span class="o">=</span> <span class="n">fetch_mnist</span><span class="p">()</span>
|
|
|
|
<span class="k">with</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">train</span><span class="p">():</span>
|
|
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1000</span><span class="p">):</span>
|
|
<span class="c1"># random sample a batch</span>
|
|
<span class="n">samp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">X_train</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">64</span><span class="p">))</span>
|
|
<span class="n">batch</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">X_train</span><span class="p">[</span><span class="n">samp</span><span class="p">],</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
|
<span class="c1"># get the corresponding labels</span>
|
|
<span class="n">labels</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">Y_train</span><span class="p">[</span><span class="n">samp</span><span class="p">])</span>
|
|
|
|
<span class="c1"># forward pass</span>
|
|
<span class="n">out</span> <span class="o">=</span> <span class="n">net</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
|
|
|
|
<span class="c1"># compute loss</span>
|
|
<span class="n">loss</span> <span class="o">=</span> <span class="n">sparse_categorical_crossentropy</span><span class="p">(</span><span class="n">out</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span>
|
|
|
|
<span class="c1"># zero gradients</span>
|
|
<span class="n">opt</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
|
|
|
|
<span class="c1"># backward pass</span>
|
|
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
|
|
|
|
<span class="c1"># update parameters</span>
|
|
<span class="n">opt</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
|
|
|
|
<span class="c1"># calculate accuracy</span>
|
|
<span class="n">pred</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">acc</span> <span class="o">=</span> <span class="p">(</span><span class="n">pred</span> <span class="o">==</span> <span class="n">labels</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
|
|
|
|
<span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">100</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Step </span><span class="si">{</span><span class="n">step</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s2"> | Loss: </span><span class="si">{</span><span class="n">loss</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="si">}</span><span class="s2"> | Accuracy: </span><span class="si">{</span><span class="n">acc</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
</code></pre></div>
|
|
<h2 id="evaluation">Evaluation<a class="headerlink" href="#evaluation" title="Permanent link">¤</a></h2>
|
|
<p>Now that we have trained our neural network we can evaluate it on the test set.
|
|
We will be using the same batch size of 64 and will be evaluating for 1000 of those batches.</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="k">with</span> <span class="n">Timing</span><span class="p">(</span><span class="s2">"Time: "</span><span class="p">):</span>
|
|
<span class="n">avg_acc</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1000</span><span class="p">):</span>
|
|
<span class="c1"># random sample a batch</span>
|
|
<span class="n">samp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">X_test</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">64</span><span class="p">))</span>
|
|
<span class="n">batch</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">X_test</span><span class="p">[</span><span class="n">samp</span><span class="p">],</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
|
<span class="c1"># get the corresponding labels</span>
|
|
<span class="n">labels</span> <span class="o">=</span> <span class="n">Y_test</span><span class="p">[</span><span class="n">samp</span><span class="p">]</span>
|
|
|
|
<span class="c1"># forward pass</span>
|
|
<span class="n">out</span> <span class="o">=</span> <span class="n">net</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
|
|
|
|
<span class="c1"># calculate accuracy</span>
|
|
<span class="n">pred</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
|
|
<span class="n">avg_acc</span> <span class="o">+=</span> <span class="p">(</span><span class="n">pred</span> <span class="o">==</span> <span class="n">labels</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
|
|
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Test Accuracy: </span><span class="si">{</span><span class="n">avg_acc</span><span class="w"> </span><span class="o">/</span><span class="w"> </span><span class="mi">1000</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
</code></pre></div>
|
|
<h2 id="and-thats-it">And that's it<a class="headerlink" href="#and-thats-it" title="Permanent link">¤</a></h2>
|
|
<p>Highly recommend you check out the <a href="https://github.com/tinygrad/tinygrad/blob/master/examples">examples/</a> folder for more examples of using tinygrad.
|
|
Reading the source code of tinygrad is also a great way to learn how it works.
|
|
Specifically the tests in <a href="https://github.com/tinygrad/tinygrad/blob/master/test">test/</a> are a great place to see how to use and the semantics of the different operations.
|
|
There are also a bunch of models implemented in <a href="https://github.com/tinygrad/tinygrad/blob/master/extra/models">models/</a> that you can use as a reference.</p>
|
|
<p>Additionally, feel free to ask questions in the <code class="language-python highlight"><span class="c1">#learn-tinygrad</span></code> channel on the <a href="https://discord.gg/beYbxwxVdx">discord</a>. Don't ask to ask, just ask!</p>
|
|
<h2 id="extras">Extras<a class="headerlink" href="#extras" title="Permanent link">¤</a></h2>
|
|
<h3 id="jit">JIT<a class="headerlink" href="#jit" title="Permanent link">¤</a></h3>
|
|
<p>Additionally, it is possible to speed up the computation of certain neural networks by using the JIT.
|
|
Currently, this does not support models with varying input sizes and non tinygrad operations.</p>
|
|
<p>To use the JIT we just need to add a function decorator to the forward pass of our neural network and ensure that the input and output are realized tensors.
|
|
Or in this case we will create a wrapper function and decorate the wrapper function to speed up the evaluation of our neural network.</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad</span><span class="w"> </span><span class="kn">import</span> <span class="n">TinyJit</span>
|
|
|
|
<span class="nd">@TinyJit</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">jit</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="n">net</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">realize</span><span class="p">()</span>
|
|
|
|
<span class="k">with</span> <span class="n">Timing</span><span class="p">(</span><span class="s2">"Time: "</span><span class="p">):</span>
|
|
<span class="n">avg_acc</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1000</span><span class="p">):</span>
|
|
<span class="c1"># random sample a batch</span>
|
|
<span class="n">samp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">X_test</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">64</span><span class="p">))</span>
|
|
<span class="n">batch</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">X_test</span><span class="p">[</span><span class="n">samp</span><span class="p">],</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
|
<span class="c1"># get the corresponding labels</span>
|
|
<span class="n">labels</span> <span class="o">=</span> <span class="n">Y_test</span><span class="p">[</span><span class="n">samp</span><span class="p">]</span>
|
|
|
|
<span class="c1"># forward pass with jit</span>
|
|
<span class="n">out</span> <span class="o">=</span> <span class="n">jit</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
|
|
|
|
<span class="c1"># calculate accuracy</span>
|
|
<span class="n">pred</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
|
|
<span class="n">avg_acc</span> <span class="o">+=</span> <span class="p">(</span><span class="n">pred</span> <span class="o">==</span> <span class="n">labels</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
|
|
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Test Accuracy: </span><span class="si">{</span><span class="n">avg_acc</span><span class="w"> </span><span class="o">/</span><span class="w"> </span><span class="mi">1000</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
</code></pre></div>
|
|
<p>You will find that the evaluation time is much faster than before and that your accelerator utilization is much higher.</p>
|
|
<h3 id="saving-and-loading-models">Saving and Loading Models<a class="headerlink" href="#saving-and-loading-models" title="Permanent link">¤</a></h3>
|
|
<p>The standard weight format for tinygrad is <a href="https://github.com/huggingface/safetensors">safetensors</a>. This means that you can load the weights of any model also using safetensors into tinygrad.
|
|
There are functions in <a href="https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/state.py">state.py</a> to save and load models to and from this format.</p>
|
|
<div class="language-python highlight"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad.nn.state</span><span class="w"> </span><span class="kn">import</span> <span class="n">safe_save</span><span class="p">,</span> <span class="n">safe_load</span><span class="p">,</span> <span class="n">get_state_dict</span><span class="p">,</span> <span class="n">load_state_dict</span>
|
|
|
|
<span class="c1"># first we need the state dict of our model</span>
|
|
<span class="n">state_dict</span> <span class="o">=</span> <span class="n">get_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">)</span>
|
|
|
|
<span class="c1"># then we can just save it to a file</span>
|
|
<span class="n">safe_save</span><span class="p">(</span><span class="n">state_dict</span><span class="p">,</span> <span class="s2">"model.safetensors"</span><span class="p">)</span>
|
|
|
|
<span class="c1"># and load it back in</span>
|
|
<span class="n">state_dict</span> <span class="o">=</span> <span class="n">safe_load</span><span class="p">(</span><span class="s2">"model.safetensors"</span><span class="p">)</span>
|
|
<span class="n">load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">state_dict</span><span class="p">)</span>
|
|
</code></pre></div>
|
|
<p>Many of the models in the <a href="https://github.com/tinygrad/tinygrad/tree/master/extra/models">models/</a> folder have a <code class="language-python highlight"><span class="n">load_from_pretrained</span></code> method that will download and load the weights for you. These usually are pytorch weights meaning that you would need pytorch installed to load them.</p>
|
|
<h3 id="environment-variables">Environment Variables<a class="headerlink" href="#environment-variables" title="Permanent link">¤</a></h3>
|
|
<p>There exist a bunch of environment variables that control the runtime behavior of tinygrad.
|
|
Some of the commons ones are <code class="language-python highlight"><span class="n">DEBUG</span></code> and the different backend enablement variables.</p>
|
|
<p>You can find a full list and their descriptions in <a href="../env_vars/">env_vars.md</a>.</p>
|
|
<h3 id="visualizing-the-computation-graph">Visualizing the Computation Graph<a class="headerlink" href="#visualizing-the-computation-graph" title="Permanent link">¤</a></h3>
|
|
<p>It is possible to visualize the computation graph of a neural network using VIZ=1.</p>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
</article>
|
|
</div>
|
|
|
|
|
|
<script>var target=document.getElementById(location.hash.slice(1));target&&target.name&&(target.checked=target.name.startsWith("__tabbed_"))</script>
|
|
</div>
|
|
|
|
<button type="button" class="md-top md-icon" data-md-component="top" hidden>
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8z"/></svg>
|
|
Back to top
|
|
</button>
|
|
|
|
</main>
|
|
|
|
<footer class="md-footer">
|
|
|
|
|
|
|
|
<nav class="md-footer__inner md-grid" aria-label="Footer" >
|
|
|
|
|
|
<a href=".." class="md-footer__link md-footer__link--prev" aria-label="Previous: tinygrad documentation">
|
|
<div class="md-footer__button md-icon">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
|
|
</div>
|
|
<div class="md-footer__title">
|
|
<span class="md-footer__direction">
|
|
Previous
|
|
</span>
|
|
<div class="md-ellipsis">
|
|
tinygrad documentation
|
|
</div>
|
|
</div>
|
|
</a>
|
|
|
|
|
|
|
|
<a href="../showcase/" class="md-footer__link md-footer__link--next" aria-label="Next: Showcase">
|
|
<div class="md-footer__title">
|
|
<span class="md-footer__direction">
|
|
Next
|
|
</span>
|
|
<div class="md-ellipsis">
|
|
Showcase
|
|
</div>
|
|
</div>
|
|
<div class="md-footer__button md-icon">
|
|
|
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M4 11v2h12l-5.5 5.5 1.42 1.42L19.84 12l-7.92-7.92L10.5 5.5 16 11z"/></svg>
|
|
</div>
|
|
</a>
|
|
|
|
</nav>
|
|
|
|
|
|
<div class="md-footer-meta md-typeset">
|
|
<div class="md-footer-meta__inner md-grid">
|
|
<div class="md-copyright">
|
|
|
|
|
|
Made with
|
|
<a href="https://squidfunk.github.io/mkdocs-material/" target="_blank" rel="noopener">
|
|
Material for MkDocs
|
|
</a>
|
|
|
|
</div>
|
|
|
|
</div>
|
|
</div>
|
|
</footer>
|
|
|
|
</div>
|
|
<div class="md-dialog" data-md-component="dialog">
|
|
<div class="md-dialog__inner md-typeset"></div>
|
|
</div>
|
|
|
|
|
|
|
|
|
|
|
|
<script id="__config" type="application/json">{"annotate": null, "base": "..", "features": ["announce.dismiss", "content.action.edit", "content.action.view", "content.code.annotate", "content.code.copy", "content.tooltips", "navigation.footer", "navigation.indexes", "navigation.sections", "navigation.expand", "navigation.top", "navigation.path", "search.highlight", "search.suggest", "toc.follow", "toc.integrate"], "search": "../assets/javascripts/workers/search.2c215733.min.js", "tags": null, "translations": {"clipboard.copied": "Copied to clipboard", "clipboard.copy": "Copy to clipboard", "search.result.more.one": "1 more on this page", "search.result.more.other": "# more on this page", "search.result.none": "No matching documents", "search.result.one": "1 matching document", "search.result.other": "# matching documents", "search.result.placeholder": "Type to start searching", "search.result.term.missing": "Missing", "select.version": "Select version"}, "version": null}</script>
|
|
|
|
|
|
<script src="../assets/javascripts/bundle.79ae519e.min.js"></script>
|
|
|
|
<script src="../assets/_markdown_exec_pyodide.js"></script>
|
|
|
|
|
|
</body>
|
|
</html> |