Files
tinygrad/nn/index.html
2026-01-05 03:41:19 +00:00

3965 lines
264 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
<!doctype html>
<html lang="en" class="no-js">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width,initial-scale=1">
<link rel="canonical" href="https://docs.tinygrad.org/nn/">
<link rel="prev" href="../dtypes/">
<link rel="next" href="../env_vars/">
<link rel="icon" href="../favicon.svg">
<meta name="generator" content="mkdocs-1.6.1, mkdocs-material-9.7.1">
<title>nn (Neural Networks) - tinygrad docs</title>
<link rel="stylesheet" href="../assets/stylesheets/main.484c7ddc.min.css">
<link rel="stylesheet" href="../assets/stylesheets/palette.ab4e12ef.min.css">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300i,400,400i,700,700i%7CRoboto+Mono:400,400i,700,700i&display=fallback">
<style>:root{--md-text-font:"Roboto";--md-code-font:"Roboto Mono"}</style>
<link rel="stylesheet" href="../assets/_markdown_exec_pyodide.css">
<link rel="stylesheet" href="../assets/_markdown_exec_ansi.css">
<link rel="stylesheet" href="../assets/_mkdocstrings.css">
<script>__md_scope=new URL("..",location),__md_hash=e=>[...e].reduce(((e,_)=>(e<<5)-e+_.charCodeAt(0)),0),__md_get=(e,_=localStorage,t=__md_scope)=>JSON.parse(_.getItem(t.pathname+"."+e)),__md_set=(e,_,t=localStorage,a=__md_scope)=>{try{t.setItem(a.pathname+"."+e,JSON.stringify(_))}catch(e){}}</script>
</head>
<body dir="ltr" data-md-color-scheme="default" data-md-color-primary="black" data-md-color-accent="lime">
<input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer" autocomplete="off">
<input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search" autocomplete="off">
<label class="md-overlay" for="__drawer"></label>
<div data-md-component="skip">
<a href="#neural-network-classes" class="md-skip">
Skip to content
</a>
</div>
<div data-md-component="announce">
</div>
<header class="md-header md-header--shadow" data-md-component="header">
<nav class="md-header__inner md-grid" aria-label="Header">
<a href=".." title="tinygrad docs" class="md-header__button md-logo" aria-label="tinygrad docs" data-md-component="logo">
<img src="../logo_tiny_dark.svg" alt="logo">
</a>
<label class="md-header__button md-icon" for="__drawer">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M3 6h18v2H3zm0 5h18v2H3zm0 5h18v2H3z"/></svg>
</label>
<div class="md-header__title" data-md-component="header-title">
<div class="md-header__ellipsis">
<div class="md-header__topic">
<span class="md-ellipsis">
tinygrad docs
</span>
</div>
<div class="md-header__topic" data-md-component="header-topic">
<span class="md-ellipsis">
nn (Neural Networks)
</span>
</div>
</div>
</div>
<form class="md-header__option" data-md-component="palette">
<input class="md-option" data-md-color-media="(prefers-color-scheme)" data-md-color-scheme="default" data-md-color-primary="black" data-md-color-accent="lime" aria-label="Switch to light mode" type="radio" name="__palette" id="__palette_0">
<label class="md-header__button md-icon" title="Switch to light mode" for="__palette_1" hidden>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="m14.3 16-.7-2h-3.2l-.7 2H7.8L11 7h2l3.2 9zM20 8.69V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12zm-9.15 3.96h2.3L12 9z"/></svg>
</label>
<input class="md-option" data-md-color-media="(prefers-color-scheme: light)" data-md-color-scheme="default" data-md-color-primary="black" data-md-color-accent="lime" aria-label="Switch to dark mode" type="radio" name="__palette" id="__palette_1">
<label class="md-header__button md-icon" title="Switch to dark mode" for="__palette_2" hidden>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 8a4 4 0 0 0-4 4 4 4 0 0 0 4 4 4 4 0 0 0 4-4 4 4 0 0 0-4-4m0 10a6 6 0 0 1-6-6 6 6 0 0 1 6-6 6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
</label>
<input class="md-option" data-md-color-media="(prefers-color-scheme: dark)" data-md-color-scheme="slate" data-md-color-primary="black" data-md-color-accent="lime" aria-label="Switch to system preference" type="radio" name="__palette" id="__palette_2">
<label class="md-header__button md-icon" title="Switch to system preference" for="__palette_0" hidden>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 18c-.89 0-1.74-.2-2.5-.55C11.56 16.5 13 14.42 13 12s-1.44-4.5-3.5-5.45C10.26 6.2 11.11 6 12 6a6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
</label>
</form>
<script>var palette=__md_get("__palette");if(palette&&palette.color){if("(prefers-color-scheme)"===palette.color.media){var media=matchMedia("(prefers-color-scheme: light)"),input=document.querySelector(media.matches?"[data-md-color-media='(prefers-color-scheme: light)']":"[data-md-color-media='(prefers-color-scheme: dark)']");palette.color.media=input.getAttribute("data-md-color-media"),palette.color.scheme=input.getAttribute("data-md-color-scheme"),palette.color.primary=input.getAttribute("data-md-color-primary"),palette.color.accent=input.getAttribute("data-md-color-accent")}for(var[key,value]of Object.entries(palette.color))document.body.setAttribute("data-md-color-"+key,value)}</script>
<label class="md-header__button md-icon" for="__search">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
</label>
<div class="md-search" data-md-component="search" role="dialog">
<label class="md-search__overlay" for="__search"></label>
<div class="md-search__inner" role="search">
<form class="md-search__form" name="search">
<input type="text" class="md-search__input" name="query" aria-label="Search" placeholder="Search" autocapitalize="off" autocorrect="off" autocomplete="off" spellcheck="false" data-md-component="search-query" required>
<label class="md-search__icon md-icon" for="__search">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
</label>
<nav class="md-search__options" aria-label="Search">
<button type="reset" class="md-search__icon md-icon" title="Clear" aria-label="Clear" tabindex="-1">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M19 6.41 17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12z"/></svg>
</button>
</nav>
<div class="md-search__suggest" data-md-component="search-suggest"></div>
</form>
<div class="md-search__output">
<div class="md-search__scrollwrap" tabindex="0" data-md-scrollfix>
<div class="md-search-result" data-md-component="search-result">
<div class="md-search-result__meta">
Initializing search
</div>
<ol class="md-search-result__list" role="presentation"></ol>
</div>
</div>
</div>
</div>
</div>
<div class="md-header__source">
<a href="https://github.com/tinygrad/tinygrad/" title="Go to repository" class="md-source" data-md-component="source">
<div class="md-source__icon md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 7.1.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2025 Fonticons, Inc.--><path d="M439.6 236.1 244 40.5c-5.4-5.5-12.8-8.5-20.4-8.5s-15 3-20.4 8.4L162.5 81l51.5 51.5c27.1-9.1 52.7 16.8 43.4 43.7l49.7 49.7c34.2-11.8 61.2 31 35.5 56.7-26.5 26.5-70.2-2.9-56-37.3L240.3 199v121.9c25.3 12.5 22.3 41.8 9.1 55-6.4 6.4-15.2 10.1-24.3 10.1s-17.8-3.6-24.3-10.1c-17.6-17.6-11.1-46.9 11.2-56v-123c-20.8-8.5-24.6-30.7-18.6-45L142.6 101 8.5 235.1C3 240.6 0 247.9 0 255.5s3 15 8.5 20.4l195.6 195.7c5.4 5.4 12.7 8.4 20.4 8.4s15-3 20.4-8.4l194.7-194.7c5.4-5.4 8.4-12.8 8.4-20.4s-3-15-8.4-20.4"/></svg>
</div>
<div class="md-source__repository">
GitHub
</div>
</a>
</div>
</nav>
</header>
<div class="md-container" data-md-component="container">
<main class="md-main" data-md-component="main">
<div class="md-main__inner md-grid">
<div class="md-sidebar md-sidebar--primary" data-md-component="sidebar" data-md-type="navigation" >
<div class="md-sidebar__scrollwrap">
<div class="md-sidebar__inner">
<nav class="md-nav md-nav--primary md-nav--integrated" aria-label="Navigation" data-md-level="0">
<label class="md-nav__title" for="__drawer">
<a href=".." title="tinygrad docs" class="md-nav__button md-logo" aria-label="tinygrad docs" data-md-component="logo">
<img src="../logo_tiny_dark.svg" alt="logo">
</a>
tinygrad docs
</label>
<div class="md-nav__source">
<a href="https://github.com/tinygrad/tinygrad/" title="Go to repository" class="md-source" data-md-component="source">
<div class="md-source__icon md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 7.1.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2025 Fonticons, Inc.--><path d="M439.6 236.1 244 40.5c-5.4-5.5-12.8-8.5-20.4-8.5s-15 3-20.4 8.4L162.5 81l51.5 51.5c27.1-9.1 52.7 16.8 43.4 43.7l49.7 49.7c34.2-11.8 61.2 31 35.5 56.7-26.5 26.5-70.2-2.9-56-37.3L240.3 199v121.9c25.3 12.5 22.3 41.8 9.1 55-6.4 6.4-15.2 10.1-24.3 10.1s-17.8-3.6-24.3-10.1c-17.6-17.6-11.1-46.9 11.2-56v-123c-20.8-8.5-24.6-30.7-18.6-45L142.6 101 8.5 235.1C3 240.6 0 247.9 0 255.5s3 15 8.5 20.4l195.6 195.7c5.4 5.4 12.7 8.4 20.4 8.4s15-3 20.4-8.4l194.7-194.7c5.4-5.4 8.4-12.8 8.4-20.4s-3-15-8.4-20.4"/></svg>
</div>
<div class="md-source__repository">
GitHub
</div>
</a>
</div>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item md-nav__item--active md-nav__item--section md-nav__item--nested">
<input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_1" checked>
<div class="md-nav__link md-nav__container">
<a href=".." class="md-nav__link ">
<span class="md-ellipsis">
Home
</span>
</a>
<label class="md-nav__link " for="__nav_1" id="__nav_1_label" tabindex="">
<span class="md-nav__icon md-icon"></span>
</label>
</div>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_1_label" aria-expanded="true">
<label class="md-nav__title" for="__nav_1">
<span class="md-nav__icon md-icon"></span>
Home
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../quickstart/" class="md-nav__link">
<span class="md-ellipsis">
Quickstart
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../showcase/" class="md-nav__link">
<span class="md-ellipsis">
Showcase
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../mnist/" class="md-nav__link">
<span class="md-ellipsis">
MNIST Tutorial
</span>
</a>
</li>
<li class="md-nav__item md-nav__item--active md-nav__item--nested">
<input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_1_5" checked>
<label class="md-nav__link" for="__nav_1_5" id="__nav_1_5_label" tabindex="0">
<span class="md-ellipsis">
API Reference
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="2" aria-labelledby="__nav_1_5_label" aria-expanded="true">
<label class="md-nav__title" for="__nav_1_5">
<span class="md-nav__icon md-icon"></span>
API Reference
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_5_1" >
<div class="md-nav__link md-nav__container">
<a href="../tensor/" class="md-nav__link ">
<span class="md-ellipsis">
Tensor
</span>
</a>
<label class="md-nav__link " for="__nav_1_5_1" id="__nav_1_5_1_label" tabindex="0">
<span class="md-nav__icon md-icon"></span>
</label>
</div>
<nav class="md-nav" data-md-level="3" aria-labelledby="__nav_1_5_1_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_1_5_1">
<span class="md-nav__icon md-icon"></span>
Tensor
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../tensor/properties/" class="md-nav__link">
<span class="md-ellipsis">
Properties
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../tensor/creation/" class="md-nav__link">
<span class="md-ellipsis">
Creation
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../tensor/movement/" class="md-nav__link">
<span class="md-ellipsis">
Movement
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../tensor/elementwise/" class="md-nav__link">
<span class="md-ellipsis">
Elementwise
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../tensor/ops/" class="md-nav__link">
<span class="md-ellipsis">
Complex Ops
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item">
<a href="../dtypes/" class="md-nav__link">
<span class="md-ellipsis">
dtypes
</span>
</a>
</li>
<li class="md-nav__item md-nav__item--active">
<input class="md-nav__toggle md-toggle" type="checkbox" id="__toc">
<label class="md-nav__link md-nav__link--active" for="__toc">
<span class="md-ellipsis">
nn (Neural Networks)
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<a href="./" class="md-nav__link md-nav__link--active">
<span class="md-ellipsis">
nn (Neural Networks)
</span>
</a>
<nav class="md-nav md-nav--secondary" aria-label="Table of contents">
<label class="md-nav__title" for="__toc">
<span class="md-nav__icon md-icon"></span>
Table of contents
</label>
<ul class="md-nav__list" data-md-component="toc" data-md-scrollfix>
<li class="md-nav__item">
<a href="#neural-network-classes" class="md-nav__link">
<span class="md-ellipsis">
Neural Network classes
</span>
</a>
<nav class="md-nav" aria-label="Neural Network classes">
<ul class="md-nav__list">
<li class="md-nav__item">
<a href="#tinygrad.nn.BatchNorm" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code>&nbsp;BatchNorm
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.Conv1d" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code>&nbsp;Conv1d
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.Conv2d" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code>&nbsp;Conv2d
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.ConvTranspose1d" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code>&nbsp;ConvTranspose1d
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.ConvTranspose2d" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code>&nbsp;ConvTranspose2d
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.Linear" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code>&nbsp;Linear
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.GroupNorm" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code>&nbsp;GroupNorm
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.InstanceNorm" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code>&nbsp;InstanceNorm
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.LayerNorm" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code>&nbsp;LayerNorm
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.LayerNorm2d" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code>&nbsp;LayerNorm2d
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.RMSNorm" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code>&nbsp;RMSNorm
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.Embedding" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code>&nbsp;Embedding
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.LSTMCell" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code>&nbsp;LSTMCell
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item">
<a href="#optimizers" class="md-nav__link">
<span class="md-ellipsis">
Optimizers
</span>
</a>
<nav class="md-nav" aria-label="Optimizers">
<ul class="md-nav__list">
<li class="md-nav__item">
<a href="#tinygrad.nn.optim.SGD" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code>&nbsp;SGD
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.optim.LARS" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code>&nbsp;LARS
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.optim.AdamW" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code>&nbsp;AdamW
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.optim.Adam" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code>&nbsp;Adam
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.optim.LAMB" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code>&nbsp;LAMB
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item">
<a href="#loadsave" class="md-nav__link">
<span class="md-ellipsis">
Load/Save
</span>
</a>
<nav class="md-nav" aria-label="Load/Save">
<ul class="md-nav__list">
<li class="md-nav__item">
<a href="#tinygrad.nn.state.safe_load" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code>&nbsp;safe_load
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.state.safe_save" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code>&nbsp;safe_save
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.state.get_state_dict" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code>&nbsp;get_state_dict
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.state.get_parameters" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code>&nbsp;get_parameters
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.state.load_state_dict" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code>&nbsp;load_state_dict
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.state.tar_extract" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code>&nbsp;tar_extract
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.state.torch_load" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code>&nbsp;torch_load
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#tinygrad.nn.state.gguf_load" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code>&nbsp;gguf_load
</span>
</a>
</li>
</ul>
</nav>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item">
<a href="../env_vars/" class="md-nav__link">
<span class="md-ellipsis">
Environment Variables
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../runtime/" class="md-nav__link">
<span class="md-ellipsis">
Runtime
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_6" >
<label class="md-nav__link" for="__nav_1_6" id="__nav_1_6_label" tabindex="0">
<span class="md-ellipsis">
Developer
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="2" aria-labelledby="__nav_1_6_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_1_6">
<span class="md-nav__icon md-icon"></span>
Developer
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../developer/developer/" class="md-nav__link">
<span class="md-ellipsis">
Intro
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../developer/layout/" class="md-nav__link">
<span class="md-ellipsis">
Layout
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../developer/speed/" class="md-nav__link">
<span class="md-ellipsis">
Speed
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../developer/uop/" class="md-nav__link">
<span class="md-ellipsis">
UOp
</span>
</a>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_6_5" >
<label class="md-nav__link" for="__nav_1_6_5" id="__nav_1_6_5_label" tabindex="0">
<span class="md-ellipsis">
Runtime
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="3" aria-labelledby="__nav_1_6_5_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_1_6_5">
<span class="md-nav__icon md-icon"></span>
Runtime
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../developer/runtime/" class="md-nav__link">
<span class="md-ellipsis">
Runtime Overview
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../developer/hcq/" class="md-nav__link">
<span class="md-ellipsis">
HCQ
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../developer/am/" class="md-nav__link">
<span class="md-ellipsis">
AM Driver
</span>
</a>
</li>
</ul>
</nav>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item">
<a href="../tinybox/" class="md-nav__link">
<span class="md-ellipsis">
tinybox
</span>
</a>
</li>
</ul>
</nav>
</li>
</ul>
</nav>
</div>
</div>
</div>
<div class="md-content" data-md-component="content">
<nav class="md-path" aria-label="Navigation" >
<ol class="md-path__list">
<li class="md-path__item">
<a href=".." class="md-path__link">
<span class="md-ellipsis">
Home
</span>
</a>
</li>
<li class="md-path__item">
<a href="../tensor/" class="md-path__link">
<span class="md-ellipsis">
API Reference
</span>
</a>
</li>
</ol>
</nav>
<article class="md-content__inner md-typeset">
<a href="https://github.com/tinygrad/tinygrad/edit/master/docs/nn.md" title="Edit this page" class="md-content__button md-icon" rel="edit">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M10 20H6V4h7v5h5v3.1l2-2V8l-6-6H6c-1.1 0-2 .9-2 2v16c0 1.1.9 2 2 2h4zm10.2-7c.1 0 .3.1.4.2l1.3 1.3c.2.2.2.6 0 .8l-1 1-2.1-2.1 1-1c.1-.1.2-.2.4-.2m0 3.9L14.1 23H12v-2.1l6.1-6.1z"/></svg>
</a>
<a href="https://github.com/tinygrad/tinygrad/raw/master/docs/nn.md" title="View source of this page" class="md-content__button md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M17 18c.56 0 1 .44 1 1s-.44 1-1 1-1-.44-1-1 .44-1 1-1m0-3c-2.73 0-5.06 1.66-6 4 .94 2.34 3.27 4 6 4s5.06-1.66 6-4c-.94-2.34-3.27-4-6-4m0 6.5a2.5 2.5 0 0 1-2.5-2.5 2.5 2.5 0 0 1 2.5-2.5 2.5 2.5 0 0 1 2.5 2.5 2.5 2.5 0 0 1-2.5 2.5M9.27 20H6V4h7v5h5v4.07c.7.08 1.36.25 2 .49V8l-6-6H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h4.5a8.2 8.2 0 0 1-1.23-2"/></svg>
</a>
<h1>nn (Neural Networks)</h1>
<h2 id="neural-network-classes">Neural Network classes<a class="headerlink" href="#neural-network-classes" title="Permanent link">¤</a></h2>
<div class="doc doc-object doc-class">
<h3 id="tinygrad.nn.BatchNorm" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">BatchNorm</span>
<a href="#tinygrad.nn.BatchNorm" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">BatchNorm</span><span class="p">(</span>
<span class="n">sz</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
<span class="n">eps</span><span class="o">=</span><span class="mf">1e-05</span><span class="p">,</span>
<span class="n">affine</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">track_running_stats</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">momentum</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
<span class="p">)</span>
</code></pre></div>
<div class="doc doc-contents first">
<p>Applies Batch Normalization over a 2D or 3D input.</p>
<ul>
<li>Paper: <a href="https://arxiv.org/abs/1502.03167v3">https://arxiv.org/abs/1502.03167v3</a></li>
</ul>
<p>See: <code class="language-python highlight"><span class="n">Tensor</span><span class="o">.</span><span class="n">batchnorm</span></code></p>
<p></p>
<p><div class="language-python highlight"><pre><span></span><code><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="mf">0.5005428791046143</span> <span class="mf">0.31660139560699463</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="mf">0.5005403757095337</span> <span class="mf">0.3165997266769409</span>
</code></pre></div></p>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">33</span>
<span class="normal">34</span>
<span class="normal">35</span>
<span class="normal">36</span>
<span class="normal">37</span>
<span class="normal">38</span>
<span class="normal">39</span>
<span class="normal">40</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sz</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">track_running_stats</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="n">eps</span><span class="p">,</span> <span class="n">track_running_stats</span><span class="p">,</span> <span class="n">momentum</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">sz</span><span class="p">)</span> <span class="k">if</span> <span class="n">affine</span> <span class="k">else</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">sz</span><span class="p">)</span> <span class="k">if</span> <span class="n">affine</span> <span class="k">else</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_batches_tracked</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;long&#39;</span> <span class="k">if</span> <span class="n">is_dtype_supported</span><span class="p">(</span><span class="n">dtypes</span><span class="o">.</span><span class="n">long</span><span class="p">)</span> <span class="k">else</span> <span class="s1">&#39;int&#39;</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">if</span> <span class="n">track_running_stats</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_mean</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_var</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">sz</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">sz</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
</code></pre></div></td></tr></table></div>
</details>
<div class="doc doc-children">
</div>
</div>
</div>
<div class="doc doc-object doc-function">
<h3 id="tinygrad.nn.Conv1d" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">Conv1d</span>
<a href="#tinygrad.nn.Conv1d" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">Conv1d</span><span class="p">(</span>
<span class="n">in_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
<span class="n">out_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
<span class="n">kernel_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
<span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">padding</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n"><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;Conv2d&lt;/span&gt; (&lt;code&gt;tinygrad.nn.Conv2d&lt;/code&gt;)" href="#tinygrad.nn.Conv2d">Conv2d</a></span>
</code></pre></div>
<div class="doc doc-contents first">
<p>Applies a 1D convolution over an input signal composed of several input planes.</p>
<p>See: <a href="https://pytorch.org/docs/stable/generated/torch.nn.Conv1d">https://pytorch.org/docs/stable/generated/torch.nn.Conv1d</a></p>
<p><div class="language-python highlight"><pre><span></span><code><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv1d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[</span><span class="mf">0.8887</span> <span class="mf">0.9379</span> <span class="mf">0.6705</span> <span class="mf">0.2281</span><span class="p">]]]</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[</span><span class="mf">0.4804</span> <span class="mf">0.141</span> <span class="p">]]]</span>
</code></pre></div></p>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">63</span>
<span class="normal">64</span>
<span class="normal">65</span>
<span class="normal">66</span>
<span class="normal">67</span>
<span class="normal">68</span>
<span class="normal">69</span>
<span class="normal">70</span>
<span class="normal">71</span>
<span class="normal">72</span>
<span class="normal">73</span>
<span class="normal">74</span>
<span class="normal">75</span>
<span class="normal">76</span>
<span class="normal">77</span>
<span class="normal">78</span>
<span class="normal">79</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">Conv1d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="p">:</span><span class="nb">int</span><span class="o">|</span><span class="nb">str</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Conv2d</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Applies a 1D convolution over an input signal composed of several input planes.</span>
<span class="sd"> See: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d</span>
<span class="sd"> ```python exec=&quot;true&quot; source=&quot;above&quot; session=&quot;tensor&quot; result=&quot;python&quot;</span>
<span class="sd"> conv = nn.Conv1d(1, 1, 3)</span>
<span class="sd"> t = Tensor.rand(1, 1, 4)</span>
<span class="sd"> print(t.numpy())</span>
<span class="sd"> ```</span>
<span class="sd"> ```python exec=&quot;true&quot; source=&quot;above&quot; session=&quot;tensor&quot; result=&quot;python&quot;</span>
<span class="sd"> t = conv(t)</span>
<span class="sd"> print(t.numpy())</span>
<span class="sd"> ```</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="p">(</span><span class="n">kernel_size</span><span class="p">,),</span> <span class="n">stride</span><span class="p">,</span> <span class="n">padding</span><span class="p">,</span> <span class="n">dilation</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
</code></pre></div></td></tr></table></div>
</details>
</div>
</div>
<div class="doc doc-object doc-class">
<h3 id="tinygrad.nn.Conv2d" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Conv2d</span>
<a href="#tinygrad.nn.Conv2d" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">Conv2d</span><span class="p">(</span>
<span class="n">in_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
<span class="n">out_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
<span class="n">kernel_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#tuple">tuple</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="o">...</span><span class="p">],</span>
<span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">padding</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#tuple">tuple</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="p">)</span>
</code></pre></div>
<div class="doc doc-contents first">
<p>Applies a 2D convolution over an input signal composed of several input planes.</p>
<p>See: <a href="https://pytorch.org/docs/stable/generated/torch.nn.Conv2d">https://pytorch.org/docs/stable/generated/torch.nn.Conv2d</a></p>
<p><div class="language-python highlight"><pre><span></span><code><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[[</span><span class="mf">0.4906</span> <span class="mf">0.2963</span> <span class="mf">0.0639</span> <span class="mf">0.0127</span><span class="p">]</span>
<span class="p">[</span><span class="mf">0.8699</span> <span class="mf">0.8575</span> <span class="mf">0.5628</span> <span class="mf">0.6926</span><span class="p">]</span>
<span class="p">[</span><span class="mf">0.7058</span> <span class="mf">0.1117</span> <span class="mf">0.634</span> <span class="mf">0.614</span> <span class="p">]</span>
<span class="p">[</span><span class="mf">0.4943</span> <span class="mf">0.7191</span> <span class="mf">0.0912</span> <span class="mf">0.9734</span><span class="p">]]]]</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[[</span><span class="o">-</span><span class="mf">0.5142</span> <span class="o">-</span><span class="mf">0.5431</span><span class="p">]</span>
<span class="p">[</span><span class="o">-</span><span class="mf">0.5355</span> <span class="o">-</span><span class="mf">0.1517</span><span class="p">]]]]</span>
</code></pre></div></p>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal"> 97</span>
<span class="normal"> 98</span>
<span class="normal"> 99</span>
<span class="normal">100</span>
<span class="normal">101</span>
<span class="normal">102</span>
<span class="normal">103</span>
<span class="normal">104</span>
<span class="normal">105</span>
<span class="normal">106</span>
<span class="normal">107</span>
<span class="normal">108</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">:</span><span class="nb">int</span><span class="o">|</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="p">:</span><span class="nb">int</span><span class="o">|</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span><span class="o">|</span><span class="nb">str</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span> <span class="o">=</span> <span class="n">make_tuple</span><span class="p">(</span><span class="n">kernel_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">padding</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
<span class="k">if</span> <span class="n">padding</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span> <span class="o">!=</span> <span class="s1">&#39;same&#39;</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid padding string </span><span class="si">{</span><span class="n">padding</span><span class="si">!r}</span><span class="s2">, only &#39;same&#39; is supported&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">stride</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;padding=&#39;same&#39; is not supported for strided convolutions&quot;</span><span class="p">)</span>
<span class="n">pad</span> <span class="o">=</span> <span class="p">[(</span><span class="n">d</span><span class="o">*</span><span class="p">(</span><span class="n">k</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">//</span><span class="mi">2</span><span class="p">,</span> <span class="n">d</span><span class="o">*</span><span class="p">(</span><span class="n">k</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="n">d</span><span class="o">*</span><span class="p">(</span><span class="n">k</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">//</span><span class="mi">2</span><span class="p">)</span> <span class="k">for</span> <span class="n">d</span><span class="p">,</span><span class="n">k</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">make_tuple</span><span class="p">(</span><span class="n">dilation</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">)),</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">[::</span><span class="o">-</span><span class="mi">1</span><span class="p">])]</span>
<span class="n">padding</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">flatten</span><span class="p">(</span><span class="n">pad</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">stride</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dilation</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">padding</span> <span class="o">=</span> <span class="n">stride</span><span class="p">,</span> <span class="n">dilation</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> <span class="n">padding</span>
<span class="n">scale</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">in_channels</span> <span class="o">*</span> <span class="n">prod</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">in_channels</span><span class="o">//</span><span class="n">groups</span><span class="p">,</span> <span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">scale</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scale</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">scale</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scale</span><span class="p">)</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
</code></pre></div></td></tr></table></div>
</details>
<div class="doc doc-children">
</div>
</div>
</div>
<div class="doc doc-object doc-function">
<h3 id="tinygrad.nn.ConvTranspose1d" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">ConvTranspose1d</span>
<a href="#tinygrad.nn.ConvTranspose1d" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">ConvTranspose1d</span><span class="p">(</span>
<span class="n">in_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
<span class="n">out_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
<span class="n">kernel_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
<span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">output_padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n"><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;ConvTranspose2d&lt;/span&gt; (&lt;code&gt;tinygrad.nn.ConvTranspose2d&lt;/code&gt;)" href="#tinygrad.nn.ConvTranspose2d">ConvTranspose2d</a></span>
</code></pre></div>
<div class="doc doc-contents first">
<p>Applies a 1D transposed convolution operator over an input signal composed of several input planes.</p>
<p>See: <a href="https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d">https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d</a></p>
<p><div class="language-python highlight"><pre><span></span><code><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvTranspose1d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[</span><span class="mf">0.832</span> <span class="mf">0.3402</span> <span class="mf">0.9718</span> <span class="mf">0.5209</span><span class="p">]]]</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[</span><span class="o">-</span><span class="mf">0.7385</span> <span class="o">-</span><span class="mf">0.0549</span> <span class="o">-</span><span class="mf">0.1784</span> <span class="mf">0.1086</span> <span class="mf">0.5088</span> <span class="o">-</span><span class="mf">0.0056</span><span class="p">]]]</span>
</code></pre></div></p>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">112</span>
<span class="normal">113</span>
<span class="normal">114</span>
<span class="normal">115</span>
<span class="normal">116</span>
<span class="normal">117</span>
<span class="normal">118</span>
<span class="normal">119</span>
<span class="normal">120</span>
<span class="normal">121</span>
<span class="normal">122</span>
<span class="normal">123</span>
<span class="normal">124</span>
<span class="normal">125</span>
<span class="normal">126</span>
<span class="normal">127</span>
<span class="normal">128</span>
<span class="normal">129</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">ConvTranspose1d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">output_padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">ConvTranspose2d</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Applies a 1D transposed convolution operator over an input signal composed of several input planes.</span>
<span class="sd"> See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d</span>
<span class="sd"> ```python exec=&quot;true&quot; source=&quot;above&quot; session=&quot;tensor&quot; result=&quot;python&quot;</span>
<span class="sd"> conv = nn.ConvTranspose1d(1, 1, 3)</span>
<span class="sd"> t = Tensor.rand(1, 1, 4)</span>
<span class="sd"> print(t.numpy())</span>
<span class="sd"> ```</span>
<span class="sd"> ```python exec=&quot;true&quot; source=&quot;above&quot; session=&quot;tensor&quot; result=&quot;python&quot;</span>
<span class="sd"> t = conv(t)</span>
<span class="sd"> print(t.numpy())</span>
<span class="sd"> ```</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="n">ConvTranspose2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="p">(</span><span class="n">kernel_size</span><span class="p">,),</span> <span class="n">stride</span><span class="p">,</span> <span class="n">padding</span><span class="p">,</span> <span class="n">output_padding</span><span class="p">,</span> <span class="n">dilation</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
</code></pre></div></td></tr></table></div>
</details>
</div>
</div>
<div class="doc doc-object doc-class">
<h3 id="tinygrad.nn.ConvTranspose2d" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">ConvTranspose2d</span>
<a href="#tinygrad.nn.ConvTranspose2d" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">ConvTranspose2d</span><span class="p">(</span>
<span class="n">in_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
<span class="n">out_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
<span class="n">kernel_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#tuple">tuple</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="o">...</span><span class="p">],</span>
<span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">output_padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="p">)</span>
</code></pre></div>
<div class="doc doc-contents first">
<p class="doc doc-class-bases">
Bases: <code><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;Conv2d&lt;/span&gt; (&lt;code&gt;tinygrad.nn.Conv2d&lt;/code&gt;)" href="#tinygrad.nn.Conv2d">Conv2d</a></code></p>
<p>Applies a 2D transposed convolution operator over an input image.</p>
<p>See: <a href="https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d">https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d</a></p>
<p><div class="language-python highlight"><pre><span></span><code><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvTranspose2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[[</span><span class="mf">0.1347</span> <span class="mf">0.9967</span> <span class="mf">0.0782</span> <span class="mf">0.1288</span><span class="p">]</span>
<span class="p">[</span><span class="mf">0.649</span> <span class="mf">0.9984</span> <span class="mf">0.827</span> <span class="mf">0.7346</span><span class="p">]</span>
<span class="p">[</span><span class="mf">0.5111</span> <span class="mf">0.4726</span> <span class="mf">0.7987</span> <span class="mf">0.8159</span><span class="p">]</span>
<span class="p">[</span><span class="mf">0.9239</span> <span class="mf">0.5857</span> <span class="mf">0.8851</span> <span class="mf">0.3674</span><span class="p">]]]]</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[[</span> <span class="mf">0.281</span> <span class="mf">0.4304</span> <span class="mf">0.0086</span> <span class="mf">0.5784</span> <span class="mf">0.2373</span> <span class="mf">0.2929</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">0.3606</span> <span class="mf">0.0571</span> <span class="mf">0.5456</span> <span class="mf">0.2427</span> <span class="mf">0.3103</span> <span class="mf">0.4593</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">0.2271</span> <span class="mf">0.2836</span> <span class="mf">0.1235</span> <span class="mf">0.032</span> <span class="mf">0.2217</span> <span class="mf">0.3312</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">0.4558</span> <span class="mf">0.1647</span> <span class="mf">0.2354</span> <span class="o">-</span><span class="mf">0.0112</span> <span class="mf">0.1913</span> <span class="mf">0.0977</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">0.12</span> <span class="mf">0.3204</span> <span class="o">-</span><span class="mf">0.0401</span> <span class="mf">0.1801</span> <span class="o">-</span><span class="mf">0.1336</span> <span class="mf">0.0728</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">0.4357</span> <span class="mf">0.1603</span> <span class="mf">0.19</span> <span class="mf">0.0581</span> <span class="mf">0.0668</span> <span class="mf">0.209</span> <span class="p">]]]]</span>
</code></pre></div></p>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">147</span>
<span class="normal">148</span>
<span class="normal">149</span>
<span class="normal">150</span>
<span class="normal">151</span>
<span class="normal">152</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">:</span><span class="nb">int</span><span class="o">|</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">output_padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">padding</span><span class="p">,</span> <span class="n">dilation</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
<span class="n">scale</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">in_channels</span> <span class="o">*</span> <span class="n">prod</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">//</span><span class="n">groups</span><span class="p">,</span> <span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">scale</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scale</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output_padding</span> <span class="o">=</span> <span class="n">output_padding</span>
</code></pre></div></td></tr></table></div>
</details>
<div class="doc doc-children">
</div>
</div>
</div>
<div class="doc doc-object doc-class">
<h3 id="tinygrad.nn.Linear" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Linear</span>
<a href="#tinygrad.nn.Linear" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="n">out_features</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</code></pre></div>
<div class="doc doc-contents first">
<p>Applies a linear transformation to the incoming data.</p>
<p>See: <a href="https://pytorch.org/docs/stable/generated/torch.nn.Linear">https://pytorch.org/docs/stable/generated/torch.nn.Linear</a></p>
<p><div class="language-python highlight"><pre><span></span><code><span class="n">lin</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="p">[[</span><span class="mf">0.6496</span> <span class="mf">0.6111</span> <span class="mf">0.5894</span><span class="p">]</span>
<span class="p">[</span><span class="mf">0.0162</span> <span class="mf">0.9957</span> <span class="mf">0.7491</span><span class="p">]]</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">lin</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="p">[[</span><span class="o">-</span><span class="mf">0.7194</span> <span class="mf">0.1132</span> <span class="mf">0.416</span> <span class="o">-</span><span class="mf">0.1307</span><span class="p">]</span>
<span class="p">[</span><span class="o">-</span><span class="mf">0.8916</span> <span class="mf">0.5362</span> <span class="mf">0.1431</span> <span class="mf">0.0856</span><span class="p">]]</span>
</code></pre></div></p>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">173</span>
<span class="normal">174</span>
<span class="normal">175</span>
<span class="normal">176</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_features</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">out_features</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="n">bound</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">in_features</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">out_features</span><span class="p">,</span> <span class="n">in_features</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">bound</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">bound</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">out_features</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">bound</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">bound</span><span class="p">)</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
</code></pre></div></td></tr></table></div>
</details>
<div class="doc doc-children">
</div>
</div>
</div>
<div class="doc doc-object doc-class">
<h3 id="tinygrad.nn.GroupNorm" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">GroupNorm</span>
<a href="#tinygrad.nn.GroupNorm" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">GroupNorm</span><span class="p">(</span>
<span class="n">num_groups</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
<span class="n">num_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
<span class="n">eps</span><span class="o">=</span><span class="mf">1e-05</span><span class="p">,</span>
<span class="n">affine</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="p">)</span>
</code></pre></div>
<div class="doc doc-contents first">
<p>Applies Group Normalization over a mini-batch of inputs.</p>
<ul>
<li>Paper: <a href="https://arxiv.org/abs/1803.08494v3">https://arxiv.org/abs/1803.08494v3</a></li>
</ul>
<p><div class="language-python highlight"><pre><span></span><code><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">12</span><span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="mf">1.9572356939315796</span> <span class="mf">0.5820862054824829</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="o">-</span><span class="mf">3.4862459585838224e-08</span> <span class="mf">1.0012898445129395</span>
</code></pre></div></p>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">196</span>
<span class="normal">197</span>
<span class="normal">198</span>
<span class="normal">199</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_groups</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">num_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_groups</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_channels</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">num_groups</span><span class="p">,</span> <span class="n">num_channels</span><span class="p">,</span> <span class="n">eps</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">num_channels</span><span class="p">)</span> <span class="k">if</span> <span class="n">affine</span> <span class="k">else</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">num_channels</span><span class="p">)</span> <span class="k">if</span> <span class="n">affine</span> <span class="k">else</span> <span class="kc">None</span>
</code></pre></div></td></tr></table></div>
</details>
<div class="doc doc-children">
</div>
</div>
</div>
<div class="doc doc-object doc-class">
<h3 id="tinygrad.nn.InstanceNorm" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">InstanceNorm</span>
<a href="#tinygrad.nn.InstanceNorm" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">InstanceNorm</span><span class="p">(</span>
<span class="n">num_features</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
<span class="n">eps</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#float">float</a></span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">,</span>
<span class="n">affine</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#bool">bool</a></span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="p">)</span>
</code></pre></div>
<div class="doc doc-contents first">
<p>Applies Instance Normalization over a mini-batch of inputs.</p>
<ul>
<li>Paper: <a href="https://arxiv.org/abs/1607.08022v3">https://arxiv.org/abs/1607.08022v3</a></li>
</ul>
<p><div class="language-python highlight"><pre><span></span><code><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">InstanceNorm</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="mf">2.0937702655792236</span> <span class="mf">0.5715298056602478</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="mf">1.1038510194794071e-07</span> <span class="mf">1.0052294731140137</span>
</code></pre></div></p>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">226</span>
<span class="normal">227</span>
<span class="normal">228</span>
<span class="normal">229</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_features</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">eps</span><span class="p">:</span><span class="nb">float</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span><span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_features</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">num_features</span><span class="p">,</span> <span class="n">eps</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">num_features</span><span class="p">)</span> <span class="k">if</span> <span class="n">affine</span> <span class="k">else</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">num_features</span><span class="p">)</span> <span class="k">if</span> <span class="n">affine</span> <span class="k">else</span> <span class="kc">None</span>
</code></pre></div></td></tr></table></div>
</details>
<div class="doc doc-children">
</div>
</div>
</div>
<div class="doc doc-object doc-class">
<h3 id="tinygrad.nn.LayerNorm" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">LayerNorm</span>
<a href="#tinygrad.nn.LayerNorm" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">LayerNorm</span><span class="p">(</span>
<span class="n">normalized_shape</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#tuple">tuple</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="o">...</span><span class="p">],</span>
<span class="n">eps</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#float">float</a></span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">,</span>
<span class="n">elementwise_affine</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#bool">bool</a></span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="p">)</span>
</code></pre></div>
<div class="doc doc-contents first">
<p>Applies Layer Normalization over a mini-batch of inputs.</p>
<ul>
<li>Paper: <a href="https://arxiv.org/abs/1607.06450v1">https://arxiv.org/abs/1607.06450v1</a></li>
</ul>
<p><div class="language-python highlight"><pre><span></span><code><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="mf">2.0335121154785156</span> <span class="mf">0.6475386023521423</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="o">-</span><span class="mf">1.8065918538923142e-07</span> <span class="mf">1.0170753002166748</span>
</code></pre></div></p>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">252</span>
<span class="normal">253</span>
<span class="normal">254</span>
<span class="normal">255</span>
<span class="normal">256</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">normalized_shape</span><span class="p">:</span><span class="nb">int</span><span class="o">|</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">eps</span><span class="p">:</span><span class="nb">float</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">elementwise_affine</span><span class="p">:</span><span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">=</span> <span class="n">make_tuple</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">axis</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="o">-</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">))),</span> <span class="n">eps</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">)</span> <span class="k">if</span> <span class="n">elementwise_affine</span> <span class="k">else</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">)</span> <span class="k">if</span> <span class="n">elementwise_affine</span> <span class="k">else</span> <span class="kc">None</span>
</code></pre></div></td></tr></table></div>
</details>
<div class="doc doc-children">
</div>
</div>
</div>
<div class="doc doc-object doc-class">
<h3 id="tinygrad.nn.LayerNorm2d" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">LayerNorm2d</span>
<a href="#tinygrad.nn.LayerNorm2d" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">LayerNorm2d</span><span class="p">(</span>
<span class="n">normalized_shape</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#tuple">tuple</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="o">...</span><span class="p">],</span>
<span class="n">eps</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#float">float</a></span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">,</span>
<span class="n">elementwise_affine</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#bool">bool</a></span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="p">)</span>
</code></pre></div>
<div class="doc doc-contents first">
<p class="doc doc-class-bases">
Bases: <code><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;LayerNorm&lt;/span&gt; (&lt;code&gt;tinygrad.nn.LayerNorm&lt;/code&gt;)" href="#tinygrad.nn.LayerNorm">LayerNorm</a></code></p>
<p>Applies Layer Normalization over a mini-batch of 2D inputs.</p>
<p>See: <code class="language-python highlight"><span class="n">LayerNorm</span></code></p>
<p><div class="language-python highlight"><pre><span></span><code><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm2d</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="mf">1.9161213636398315</span> <span class="mf">0.5450614094734192</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="o">-</span><span class="mf">1.4681984339404153e-07</span> <span class="mf">1.005200982093811</span>
</code></pre></div></p>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">252</span>
<span class="normal">253</span>
<span class="normal">254</span>
<span class="normal">255</span>
<span class="normal">256</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">normalized_shape</span><span class="p">:</span><span class="nb">int</span><span class="o">|</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">eps</span><span class="p">:</span><span class="nb">float</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">elementwise_affine</span><span class="p">:</span><span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">=</span> <span class="n">make_tuple</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">axis</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="o">-</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">))),</span> <span class="n">eps</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">)</span> <span class="k">if</span> <span class="n">elementwise_affine</span> <span class="k">else</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">)</span> <span class="k">if</span> <span class="n">elementwise_affine</span> <span class="k">else</span> <span class="kc">None</span>
</code></pre></div></td></tr></table></div>
</details>
<div class="doc doc-children">
</div>
</div>
</div>
<div class="doc doc-object doc-class">
<h3 id="tinygrad.nn.RMSNorm" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">RMSNorm</span>
<a href="#tinygrad.nn.RMSNorm" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">RMSNorm</span><span class="p">(</span><span class="n">dim</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-06</span><span class="p">,</span> <span class="n">elementwise_affine</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</code></pre></div>
<div class="doc doc-contents first">
<p>Applies Root Mean Square Normalization to input.</p>
<ul>
<li>Paper: <a href="https://arxiv.org/abs/1910.07467">https://arxiv.org/abs/1910.07467</a></li>
</ul>
<p><div class="language-python highlight"><pre><span></span><code><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">RMSNorm</span><span class="p">(</span><span class="mi">4</span><span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtypes</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="p">[[</span> <span class="mf">0.</span> <span class="mf">1.</span> <span class="mf">2.</span> <span class="mf">3.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">4.</span> <span class="mf">5.</span> <span class="mf">6.</span> <span class="mf">7.</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">8.</span> <span class="mf">9.</span> <span class="mf">10.</span> <span class="mf">11.</span><span class="p">]]</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="nb">print</span><span class="p">(</span><span class="n">norm</span><span class="p">(</span><span class="n">t</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="p">[[</span><span class="mf">0.</span> <span class="mf">0.5345</span> <span class="mf">1.069</span> <span class="mf">1.6036</span><span class="p">]</span>
<span class="p">[</span><span class="mf">0.7127</span> <span class="mf">0.8909</span> <span class="mf">1.069</span> <span class="mf">1.2472</span><span class="p">]</span>
<span class="p">[</span><span class="mf">0.8363</span> <span class="mf">0.9409</span> <span class="mf">1.0454</span> <span class="mf">1.15</span> <span class="p">]]</span>
</code></pre></div></p>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">297</span>
<span class="normal">298</span>
<span class="normal">299</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">,</span> <span class="n">elementwise_affine</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span> <span class="k">if</span> <span class="n">elementwise_affine</span> <span class="k">else</span> <span class="kc">None</span>
</code></pre></div></td></tr></table></div>
</details>
<div class="doc doc-children">
</div>
</div>
</div>
<div class="doc doc-object doc-class">
<h3 id="tinygrad.nn.Embedding" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Embedding</span>
<a href="#tinygrad.nn.Embedding" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">Embedding</span><span class="p">(</span><span class="n">vocab_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="n">embed_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">)</span>
</code></pre></div>
<div class="doc doc-contents first">
<p>A simple lookup table that stores embeddings of a fixed dictionary and size.</p>
<p>See: <a href="https://pytorch.org/docs/stable/generated/torch.nn.Embedding">https://pytorch.org/docs/stable/generated/torch.nn.Embedding</a></p>
<div class="language-python highlight"><pre><span></span><code><span class="n">emb</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">emb</span><span class="p">(</span><span class="n">Tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">]))</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="p">[[</span> <span class="mf">0.5594</span> <span class="mf">0.3912</span> <span class="mf">0.1444</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">0.072</span> <span class="mf">0.4649</span> <span class="mf">0.3679</span><span class="p">]</span>
<span class="p">[</span><span class="o">-</span><span class="mf">0.3457</span> <span class="o">-</span><span class="mf">0.2776</span> <span class="o">-</span><span class="mf">0.5988</span><span class="p">]</span>
<span class="p">[</span> <span class="mf">0.5594</span> <span class="mf">0.3912</span> <span class="mf">0.1444</span><span class="p">]]</span>
</code></pre></div>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">318</span>
<span class="normal">319</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">embed_size</span><span class="p">:</span><span class="nb">int</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_sz</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_sz</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">vocab_size</span><span class="p">,</span> <span class="n">embed_size</span><span class="p">,</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">glorot_uniform</span><span class="p">(</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">embed_size</span><span class="p">)</span>
</code></pre></div></td></tr></table></div>
</details>
<div class="doc doc-children">
</div>
</div>
</div>
<div class="doc doc-object doc-class">
<h3 id="tinygrad.nn.LSTMCell" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">LSTMCell</span>
<a href="#tinygrad.nn.LSTMCell" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">LSTMCell</span><span class="p">(</span>
<span class="n">input_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="n">bias</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#bool">bool</a></span> <span class="o">=</span> <span class="kc">True</span>
<span class="p">)</span>
</code></pre></div>
<div class="doc doc-contents first">
<p>A long short-term memory (LSTM) cell.</p>
<p><span class="doc-section-title">Parameters:</span></p>
<ul>
<li class="doc-section-item field-body">
<b><code>input_size</code></b>
(<code><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></code>)
<div class="doc-md-description">
<p>The number of expected features in the input <code class="language-python highlight"><span class="n">x</span></code></p>
</div>
</li>
<li class="doc-section-item field-body">
<b><code>hidden_size</code></b>
(<code><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></code>)
<div class="doc-md-description">
<p>The number of features in the hidden state <code class="language-python highlight"><span class="n">h</span></code></p>
</div>
</li>
<li class="doc-section-item field-body">
<b><code>bias</code></b>
(<code><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#bool">bool</a></code>, default:
<code>True</code>
)
<div class="doc-md-description">
<p>If <code class="language-python highlight"><span class="kc">False</span></code>, then the layer does not use bias weights <code class="language-python highlight"><span class="n">b_ih</span></code> and <code class="language-python highlight"><span class="n">b_hh</span></code></p>
</div>
</li>
</ul>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">337</span>
<span class="normal">338</span>
<span class="normal">339</span>
<span class="normal">340</span>
<span class="normal">341</span>
<span class="normal">342</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_size</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">bias</span><span class="p">:</span><span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="n">stdv</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight_ih</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">hidden_size</span><span class="o">*</span><span class="mi">4</span><span class="p">,</span> <span class="n">input_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">stdv</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">stdv</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight_hh</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">hidden_size</span><span class="o">*</span><span class="mi">4</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">stdv</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">stdv</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias_ih</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">hidden_size</span><span class="o">*</span><span class="mi">4</span><span class="p">)</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias_hh</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">hidden_size</span><span class="o">*</span><span class="mi">4</span><span class="p">)</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
</code></pre></div></td></tr></table></div>
</details>
<div class="doc doc-children">
</div>
</div>
</div><h2 id="optimizers">Optimizers<a class="headerlink" href="#optimizers" title="Permanent link">¤</a></h2>
<div class="doc doc-object doc-function">
<h3 id="tinygrad.nn.optim.SGD" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">SGD</span>
<a href="#tinygrad.nn.optim.SGD" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">SGD</span><span class="p">(</span>
<span class="n">params</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;Tensor&lt;/span&gt; (&lt;code&gt;tinygrad.tensor.Tensor&lt;/code&gt;)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
<span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
<span class="n">momentum</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
<span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
<span class="n">nesterov</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">classic</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">fused</span><span class="o">=</span><span class="n"><span title="tinygrad.helpers.FUSE_OPTIM">FUSE_OPTIM</span></span><span class="p">,</span>
<span class="p">)</span>
</code></pre></div>
<div class="doc doc-contents first">
<p>Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.</p>
<p><code class="language-python highlight"><span class="n">classic</span></code> is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.</p>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/optim.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">77</span>
<span class="normal">78</span>
<span class="normal">79</span>
<span class="normal">80</span>
<span class="normal">81</span>
<span class="normal">82</span>
<span class="normal">83</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">SGD</span><span class="p">(</span><span class="n">params</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">nesterov</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">classic</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">FUSE_OPTIM</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.</span>
<span class="sd"> `classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="n">LARS</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">momentum</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="n">nesterov</span><span class="p">,</span> <span class="n">classic</span><span class="o">=</span><span class="n">classic</span><span class="p">,</span> <span class="n">pre_wd</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">tcoef</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">fused</span><span class="p">)</span>
</code></pre></div></td></tr></table></div>
</details>
</div>
</div>
<div class="doc doc-object doc-class">
<h3 id="tinygrad.nn.optim.LARS" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">LARS</span>
<a href="#tinygrad.nn.optim.LARS" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">LARS</span><span class="p">(</span>
<span class="n">params</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;Tensor&lt;/span&gt; (&lt;code&gt;tinygrad.tensor.Tensor&lt;/code&gt;)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
<span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
<span class="n">momentum</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span>
<span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0001</span><span class="p">,</span>
<span class="n">ns_steps</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">ns_coefficients</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">nesterov</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">classic</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">pre_wd</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">tcoef</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
<span class="n">fused</span><span class="o">=</span><span class="n"><span title="tinygrad.helpers.FUSE_OPTIM">FUSE_OPTIM</span></span><span class="p">,</span>
<span class="p">)</span>
</code></pre></div>
<div class="doc doc-contents first">
<p class="doc doc-class-bases">
Bases: <code><span title="tinygrad.nn.optim.Optimizer">Optimizer</span></code></p>
<p>Layer-wise Adaptive Rate Scaling (LARS) optimizer with optional momentum and weight decay.</p>
<ul>
<li>Paper: <a href="https://arxiv.org/abs/1708.03888v3">https://arxiv.org/abs/1708.03888v3</a></li>
</ul>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/optim.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">103</span>
<span class="normal">104</span>
<span class="normal">105</span>
<span class="normal">106</span>
<span class="normal">107</span>
<span class="normal">108</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">:</span><span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">,</span> <span class="n">ns_steps</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">ns_coefficients</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">nesterov</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">classic</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">pre_wd</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">tcoef</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">FUSE_OPTIM</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">fused</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">wd</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ns_steps</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ns_coefficients</span> <span class="o">=</span> <span class="n">momentum</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="n">ns_steps</span><span class="p">,</span> <span class="n">ns_coefficients</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nesterov</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">classic</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_wd</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tcoef</span> <span class="o">=</span> <span class="n">nesterov</span><span class="p">,</span> <span class="n">classic</span><span class="p">,</span> <span class="n">pre_wd</span><span class="p">,</span> <span class="n">tcoef</span>
<span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_optim_param</span><span class="p">()</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="k">else</span> <span class="p">[]</span>
</code></pre></div></td></tr></table></div>
</details>
<div class="doc doc-children">
</div>
</div>
</div>
<div class="doc doc-object doc-function">
<h3 id="tinygrad.nn.optim.AdamW" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">AdamW</span>
<a href="#tinygrad.nn.optim.AdamW" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">AdamW</span><span class="p">(</span>
<span class="n">params</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;Tensor&lt;/span&gt; (&lt;code&gt;tinygrad.tensor.Tensor&lt;/code&gt;)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
<span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
<span class="n">b1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span>
<span class="n">b2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span>
<span class="n">eps</span><span class="o">=</span><span class="mf">1e-08</span><span class="p">,</span>
<span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span>
<span class="n">fused</span><span class="o">=</span><span class="n"><span title="tinygrad.helpers.FUSE_OPTIM">FUSE_OPTIM</span></span><span class="p">,</span>
<span class="p">)</span>
</code></pre></div>
<div class="doc doc-contents first">
<p>AdamW optimizer with optional weight decay.</p>
<ul>
<li>Paper: <a href="https://arxiv.org/abs/1711.05101v3">https://arxiv.org/abs/1711.05101v3</a></li>
</ul>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/optim.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">133</span>
<span class="normal">134</span>
<span class="normal">135</span>
<span class="normal">136</span>
<span class="normal">137</span>
<span class="normal">138</span>
<span class="normal">139</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">AdamW</span><span class="p">(</span><span class="n">params</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">b1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">b2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-8</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">FUSE_OPTIM</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> AdamW optimizer with optional weight decay.</span>
<span class="sd"> - Paper: https://arxiv.org/abs/1711.05101v3</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="n">LAMB</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">b1</span><span class="p">,</span> <span class="n">b2</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="n">adam</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">fused</span><span class="p">)</span>
</code></pre></div></td></tr></table></div>
</details>
</div>
</div>
<div class="doc doc-object doc-function">
<h3 id="tinygrad.nn.optim.Adam" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">Adam</span>
<a href="#tinygrad.nn.optim.Adam" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">Adam</span><span class="p">(</span>
<span class="n">params</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;Tensor&lt;/span&gt; (&lt;code&gt;tinygrad.tensor.Tensor&lt;/code&gt;)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
<span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
<span class="n">b1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span>
<span class="n">b2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span>
<span class="n">eps</span><span class="o">=</span><span class="mf">1e-08</span><span class="p">,</span>
<span class="n">fused</span><span class="o">=</span><span class="n"><span title="tinygrad.helpers.FUSE_OPTIM">FUSE_OPTIM</span></span><span class="p">,</span>
<span class="p">)</span>
</code></pre></div>
<div class="doc doc-contents first">
<p>Adam optimizer.</p>
<ul>
<li>Paper: <a href="https://arxiv.org/abs/1412.6980">https://arxiv.org/abs/1412.6980</a></li>
</ul>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/optim.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">140</span>
<span class="normal">141</span>
<span class="normal">142</span>
<span class="normal">143</span>
<span class="normal">144</span>
<span class="normal">145</span>
<span class="normal">146</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">Adam</span><span class="p">(</span><span class="n">params</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">b1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">b2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-8</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">FUSE_OPTIM</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Adam optimizer.</span>
<span class="sd"> - Paper: https://arxiv.org/abs/1412.6980</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="n">LAMB</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">b1</span><span class="p">,</span> <span class="n">b2</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">adam</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">fused</span><span class="p">)</span>
</code></pre></div></td></tr></table></div>
</details>
</div>
</div>
<div class="doc doc-object doc-class">
<h3 id="tinygrad.nn.optim.LAMB" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">LAMB</span>
<a href="#tinygrad.nn.optim.LAMB" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">LAMB</span><span class="p">(</span>
<span class="n">params</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;Tensor&lt;/span&gt; (&lt;code&gt;tinygrad.tensor.Tensor&lt;/code&gt;)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
<span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
<span class="n">b1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span>
<span class="n">b2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span>
<span class="n">eps</span><span class="o">=</span><span class="mf">1e-06</span><span class="p">,</span>
<span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
<span class="n">adam</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">fused</span><span class="o">=</span><span class="n"><span title="tinygrad.helpers.FUSE_OPTIM">FUSE_OPTIM</span></span><span class="p">,</span>
<span class="p">)</span>
</code></pre></div>
<div class="doc doc-contents first">
<p class="doc doc-class-bases">
Bases: <code><span title="tinygrad.nn.optim.Optimizer">Optimizer</span></code></p>
<p>LAMB optimizer with optional weight decay.</p>
<ul>
<li>Paper: <a href="https://arxiv.org/abs/1904.00962">https://arxiv.org/abs/1904.00962</a></li>
</ul>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/optim.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">154</span>
<span class="normal">155</span>
<span class="normal">156</span>
<span class="normal">157</span>
<span class="normal">158</span>
<span class="normal">159</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">b1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">b2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">adam</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">FUSE_OPTIM</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">fused</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">b1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">b2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">wd</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">adam</span> <span class="o">=</span> <span class="n">b1</span><span class="p">,</span> <span class="n">b2</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="n">adam</span>
<span class="bp">self</span><span class="o">.</span><span class="n">b1_t</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">b2_t</span> <span class="o">=</span> <span class="p">(</span><span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtypes</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="p">[</span><span class="n">b1</span><span class="p">,</span> <span class="n">b2</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">m</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_optim_param</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_optim_param</span><span class="p">()</span>
</code></pre></div></td></tr></table></div>
</details>
<div class="doc doc-children">
</div>
</div>
</div><h2 id="loadsave">Load/Save<a class="headerlink" href="#loadsave" title="Permanent link">¤</a></h2>
<div class="doc doc-object doc-function">
<h3 id="tinygrad.nn.state.safe_load" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">safe_load</span>
<a href="#tinygrad.nn.state.safe_load" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">safe_load</span><span class="p">(</span><span class="n">fn</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;Tensor&lt;/span&gt; (&lt;code&gt;tinygrad.tensor.Tensor&lt;/code&gt;)" href="../tensor/#tinygrad.Tensor">Tensor</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" title="&lt;code&gt;pathlib.Path&lt;/code&gt;" href="https://docs.python.org/3/library/pathlib.html#pathlib.Path">Path</a></span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;Tensor&lt;/span&gt; (&lt;code&gt;tinygrad.tensor.Tensor&lt;/code&gt;)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">]</span>
</code></pre></div>
<div class="doc doc-contents first">
<p>Loads a .safetensor file, returning the <code class="language-python highlight"><span class="n">state_dict</span></code>.</p>
<div class="language-python highlight"><pre><span></span><code><span class="n">state_dict</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">safe_load</span><span class="p">(</span><span class="s2">&quot;test.safetensor&quot;</span><span class="p">)</span>
</code></pre></div>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">50</span>
<span class="normal">51</span>
<span class="normal">52</span>
<span class="normal">53</span>
<span class="normal">54</span>
<span class="normal">55</span>
<span class="normal">56</span>
<span class="normal">57</span>
<span class="normal">58</span>
<span class="normal">59</span>
<span class="normal">60</span>
<span class="normal">61</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">safe_load</span><span class="p">(</span><span class="n">fn</span><span class="p">:</span><span class="n">Tensor</span><span class="o">|</span><span class="nb">str</span><span class="o">|</span><span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Loads a .safetensor file, returning the `state_dict`.</span>
<span class="sd"> ```python</span>
<span class="sd"> state_dict = nn.state.safe_load(&quot;test.safetensor&quot;)</span>
<span class="sd"> ```</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">t</span><span class="p">,</span> <span class="n">data_start</span><span class="p">,</span> <span class="n">metadata</span> <span class="o">=</span> <span class="n">safe_load_metadata</span><span class="p">(</span><span class="n">fn</span><span class="p">)</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">t</span><span class="p">[</span><span class="n">data_start</span><span class="p">:]</span>
<span class="k">return</span> <span class="p">{</span> <span class="n">k</span><span class="p">:</span> <span class="n">data</span><span class="p">[</span><span class="n">v</span><span class="p">[</span><span class="s1">&#39;data_offsets&#39;</span><span class="p">][</span><span class="mi">0</span><span class="p">]:</span><span class="n">v</span><span class="p">[</span><span class="s1">&#39;data_offsets&#39;</span><span class="p">][</span><span class="mi">1</span><span class="p">]]</span><span class="o">.</span><span class="n">bitcast</span><span class="p">(</span><span class="n">safe_dtypes</span><span class="p">[</span><span class="n">v</span><span class="p">[</span><span class="s1">&#39;dtype&#39;</span><span class="p">]])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">v</span><span class="p">[</span><span class="s1">&#39;shape&#39;</span><span class="p">])</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">metadata</span><span class="o">.</span><span class="n">items</span><span class="p">()</span> <span class="k">if</span> <span class="n">k</span> <span class="o">!=</span> <span class="s2">&quot;__metadata__&quot;</span> <span class="p">}</span>
</code></pre></div></td></tr></table></div>
</details>
</div>
</div>
<div class="doc doc-object doc-function">
<h3 id="tinygrad.nn.state.safe_save" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">safe_save</span>
<a href="#tinygrad.nn.state.safe_save" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">safe_save</span><span class="p">(</span>
<span class="n">tensors</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;Tensor&lt;/span&gt; (&lt;code&gt;tinygrad.tensor.Tensor&lt;/code&gt;)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
<span class="n">fn</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span>
<span class="n">metadata</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-external" title="&lt;code&gt;typing.Any&lt;/code&gt;" href="https://docs.python.org/3/library/typing.html#typing.Any">Any</a></span><span class="p">]</span> <span class="o">|</span> <span class="kc">None</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">)</span>
</code></pre></div>
<div class="doc doc-contents first">
<p>Saves a <code class="language-python highlight"><span class="n">state_dict</span></code> to disk in a .safetensor file with optional metadata.</p>
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span>
<span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">safe_save</span><span class="p">({</span><span class="s1">&#39;t&#39;</span><span class="p">:</span><span class="n">t</span><span class="p">},</span> <span class="s2">&quot;test.safetensor&quot;</span><span class="p">)</span>
</code></pre></div>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">63</span>
<span class="normal">64</span>
<span class="normal">65</span>
<span class="normal">66</span>
<span class="normal">67</span>
<span class="normal">68</span>
<span class="normal">69</span>
<span class="normal">70</span>
<span class="normal">71</span>
<span class="normal">72</span>
<span class="normal">73</span>
<span class="normal">74</span>
<span class="normal">75</span>
<span class="normal">76</span>
<span class="normal">77</span>
<span class="normal">78</span>
<span class="normal">79</span>
<span class="normal">80</span>
<span class="normal">81</span>
<span class="normal">82</span>
<span class="normal">83</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">safe_save</span><span class="p">(</span><span class="n">tensors</span><span class="p">:</span><span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">],</span> <span class="n">fn</span><span class="p">:</span><span class="nb">str</span><span class="p">,</span> <span class="n">metadata</span><span class="p">:</span><span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span><span class="o">|</span><span class="kc">None</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Saves a `state_dict` to disk in a .safetensor file with optional metadata.</span>
<span class="sd"> ```python</span>
<span class="sd"> t = Tensor([1, 2, 3])</span>
<span class="sd"> nn.state.safe_save({&#39;t&#39;:t}, &quot;test.safetensor&quot;)</span>
<span class="sd"> ```</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">headers</span><span class="p">,</span> <span class="n">offset</span> <span class="o">=</span> <span class="p">{},</span> <span class="mi">0</span>
<span class="k">if</span> <span class="n">metadata</span><span class="p">:</span> <span class="n">headers</span><span class="p">[</span><span class="s1">&#39;__metadata__&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">metadata</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span><span class="n">v</span> <span class="ow">in</span> <span class="n">tensors</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">headers</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;dtype&#39;</span><span class="p">:</span> <span class="n">inverse_safe_dtypes</span><span class="p">[</span><span class="n">v</span><span class="o">.</span><span class="n">dtype</span><span class="p">],</span> <span class="s1">&#39;shape&#39;</span><span class="p">:</span> <span class="nb">list</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> <span class="s1">&#39;data_offsets&#39;</span><span class="p">:[</span><span class="n">offset</span><span class="p">,</span> <span class="n">offset</span><span class="o">+</span><span class="n">v</span><span class="o">.</span><span class="n">nbytes</span><span class="p">()]}</span>
<span class="n">offset</span> <span class="o">+=</span> <span class="n">v</span><span class="o">.</span><span class="n">nbytes</span><span class="p">()</span>
<span class="n">j</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="n">headers</span><span class="p">,</span> <span class="n">separators</span><span class="o">=</span><span class="p">(</span><span class="s1">&#39;,&#39;</span><span class="p">,</span> <span class="s1">&#39;:&#39;</span><span class="p">))</span>
<span class="n">j</span> <span class="o">+=</span> <span class="s2">&quot;</span><span class="se">\x20</span><span class="s2">&quot;</span><span class="o">*</span><span class="p">(</span><span class="n">round_up</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">j</span><span class="p">),</span><span class="mi">8</span><span class="p">)</span><span class="o">-</span><span class="nb">len</span><span class="p">(</span><span class="n">j</span><span class="p">))</span>
<span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="n">fn</span><span class="p">)</span><span class="o">.</span><span class="n">unlink</span><span class="p">(</span><span class="n">missing_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="mi">8</span><span class="o">+</span><span class="nb">len</span><span class="p">(</span><span class="n">j</span><span class="p">)</span><span class="o">+</span><span class="n">offset</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtypes</span><span class="o">.</span><span class="n">uint8</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;disk:</span><span class="si">{</span><span class="n">fn</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">t</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">8</span><span class="p">]</span><span class="o">.</span><span class="n">bitcast</span><span class="p">(</span><span class="n">dtypes</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span><span class="o">.</span><span class="n">assign</span><span class="p">([</span><span class="nb">len</span><span class="p">(</span><span class="n">j</span><span class="p">)])</span>
<span class="n">t</span><span class="p">[</span><span class="mi">8</span><span class="p">:</span><span class="mi">8</span><span class="o">+</span><span class="nb">len</span><span class="p">(</span><span class="n">j</span><span class="p">)]</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">j</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="s1">&#39;utf-8&#39;</span><span class="p">)))</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span><span class="n">v</span> <span class="ow">in</span> <span class="n">safe_load</span><span class="p">(</span><span class="n">t</span><span class="p">)</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> <span class="n">v</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="n">tensors</span><span class="p">[</span><span class="n">k</span><span class="p">])</span>
</code></pre></div></td></tr></table></div>
</details>
</div>
</div>
<div class="doc doc-object doc-function">
<h3 id="tinygrad.nn.state.get_state_dict" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">get_state_dict</span>
<a href="#tinygrad.nn.state.get_state_dict" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">get_state_dict</span><span class="p">(</span>
<span class="n">obj</span><span class="p">,</span> <span class="n">prefix</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span> <span class="o">=</span> <span class="s2">&quot;&quot;</span><span class="p">,</span> <span class="n">tensor_type</span><span class="o">=</span><span class="n"><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;Tensor&lt;/span&gt; (&lt;code&gt;tinygrad.tensor.Tensor&lt;/code&gt;)" href="../tensor/#tinygrad.Tensor">Tensor</a></span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;Tensor&lt;/span&gt; (&lt;code&gt;tinygrad.tensor.Tensor&lt;/code&gt;)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">]</span>
</code></pre></div>
<div class="doc doc-contents first">
<p>Returns a <code class="language-python highlight"><span class="n">state_dict</span></code> of the object, with optional prefix.</p>
<div class="language-python highlight"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">Net</span><span class="p">:</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">l1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">l2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
<span class="n">net</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span>
<span class="nb">print</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">get_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">)</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="n">dict_keys</span><span class="p">([</span><span class="s1">&#39;l1.weight&#39;</span><span class="p">,</span> <span class="s1">&#39;l1.bias&#39;</span><span class="p">,</span> <span class="s1">&#39;l2.weight&#39;</span><span class="p">,</span> <span class="s1">&#39;l2.bias&#39;</span><span class="p">])</span>
</code></pre></div>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal"> 87</span>
<span class="normal"> 88</span>
<span class="normal"> 89</span>
<span class="normal"> 90</span>
<span class="normal"> 91</span>
<span class="normal"> 92</span>
<span class="normal"> 93</span>
<span class="normal"> 94</span>
<span class="normal"> 95</span>
<span class="normal"> 96</span>
<span class="normal"> 97</span>
<span class="normal"> 98</span>
<span class="normal"> 99</span>
<span class="normal">100</span>
<span class="normal">101</span>
<span class="normal">102</span>
<span class="normal">103</span>
<span class="normal">104</span>
<span class="normal">105</span>
<span class="normal">106</span>
<span class="normal">107</span>
<span class="normal">108</span>
<span class="normal">109</span>
<span class="normal">110</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">get_state_dict</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">prefix</span><span class="p">:</span><span class="nb">str</span><span class="o">=</span><span class="s1">&#39;&#39;</span><span class="p">,</span> <span class="n">tensor_type</span><span class="o">=</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns a `state_dict` of the object, with optional prefix.</span>
<span class="sd"> ```python exec=&quot;true&quot; source=&quot;above&quot; session=&quot;tensor&quot; result=&quot;python&quot;</span>
<span class="sd"> class Net:</span>
<span class="sd"> def __init__(self):</span>
<span class="sd"> self.l1 = nn.Linear(4, 5)</span>
<span class="sd"> self.l2 = nn.Linear(5, 6)</span>
<span class="sd"> net = Net()</span>
<span class="sd"> print(nn.state.get_state_dict(net).keys())</span>
<span class="sd"> ```</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">tensor_type</span><span class="p">):</span> <span class="k">return</span> <span class="p">{</span><span class="n">prefix</span><span class="o">.</span><span class="n">strip</span><span class="p">(</span><span class="s1">&#39;.&#39;</span><span class="p">):</span><span class="n">obj</span><span class="p">}</span>
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="s1">&#39;_asdict&#39;</span><span class="p">):</span> <span class="k">return</span> <span class="n">get_state_dict</span><span class="p">(</span><span class="n">obj</span><span class="o">.</span><span class="n">_asdict</span><span class="p">(),</span> <span class="n">prefix</span><span class="p">,</span> <span class="n">tensor_type</span><span class="p">)</span> <span class="c1"># namedtuple</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">OrderedDict</span><span class="p">):</span> <span class="k">return</span> <span class="n">get_state_dict</span><span class="p">(</span><span class="nb">dict</span><span class="p">(</span><span class="n">obj</span><span class="p">),</span> <span class="n">prefix</span><span class="p">,</span> <span class="n">tensor_type</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="s1">&#39;__dict__&#39;</span><span class="p">):</span> <span class="k">return</span> <span class="n">get_state_dict</span><span class="p">(</span><span class="n">obj</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">,</span> <span class="n">prefix</span><span class="p">,</span> <span class="n">tensor_type</span><span class="p">)</span>
<span class="n">state_dict</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span><span class="n">x</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">obj</span><span class="p">):</span> <span class="n">state_dict</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">get_state_dict</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">prefix</span><span class="si">}{</span><span class="nb">str</span><span class="p">(</span><span class="n">i</span><span class="p">)</span><span class="si">}</span><span class="s2">.&quot;</span><span class="p">,</span> <span class="n">tensor_type</span><span class="p">))</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span><span class="n">v</span> <span class="ow">in</span> <span class="n">obj</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> <span class="n">state_dict</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">get_state_dict</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">prefix</span><span class="si">}{</span><span class="nb">str</span><span class="p">(</span><span class="n">k</span><span class="p">)</span><span class="si">}</span><span class="s2">.&quot;</span><span class="p">,</span> <span class="n">tensor_type</span><span class="p">))</span>
<span class="k">return</span> <span class="n">state_dict</span>
</code></pre></div></td></tr></table></div>
</details>
</div>
</div>
<div class="doc doc-object doc-function">
<h3 id="tinygrad.nn.state.get_parameters" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">get_parameters</span>
<a href="#tinygrad.nn.state.get_parameters" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">get_parameters</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;Tensor&lt;/span&gt; (&lt;code&gt;tinygrad.tensor.Tensor&lt;/code&gt;)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">]</span>
</code></pre></div>
<div class="doc doc-contents first">
<div class="language-python highlight"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">Net</span><span class="p">:</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">l1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">l2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
<span class="n">net</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span>
<span class="nb">print</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">get_parameters</span><span class="p">(</span><span class="n">net</span><span class="p">)))</span>
</code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span class="mi">4</span>
</code></pre></div>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">112</span>
<span class="normal">113</span>
<span class="normal">114</span>
<span class="normal">115</span>
<span class="normal">116</span>
<span class="normal">117</span>
<span class="normal">118</span>
<span class="normal">119</span>
<span class="normal">120</span>
<span class="normal">121</span>
<span class="normal">122</span>
<span class="normal">123</span>
<span class="normal">124</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">get_parameters</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> ```python exec=&quot;true&quot; source=&quot;above&quot; session=&quot;tensor&quot; result=&quot;python&quot;</span>
<span class="sd"> class Net:</span>
<span class="sd"> def __init__(self):</span>
<span class="sd"> self.l1 = nn.Linear(4, 5)</span>
<span class="sd"> self.l2 = nn.Linear(5, 6)</span>
<span class="sd"> net = Net()</span>
<span class="sd"> print(len(nn.state.get_parameters(net)))</span>
<span class="sd"> ```</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="n">get_state_dict</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span><span class="o">.</span><span class="n">values</span><span class="p">())</span>
</code></pre></div></td></tr></table></div>
</details>
</div>
</div>
<div class="doc doc-object doc-function">
<h3 id="tinygrad.nn.state.load_state_dict" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">load_state_dict</span>
<a href="#tinygrad.nn.state.load_state_dict" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">load_state_dict</span><span class="p">(</span>
<span class="n">model</span><span class="p">,</span>
<span class="n">state_dict</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;Tensor&lt;/span&gt; (&lt;code&gt;tinygrad.tensor.Tensor&lt;/code&gt;)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
<span class="n">strict</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">verbose</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">consume</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">realize</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;Tensor&lt;/span&gt; (&lt;code&gt;tinygrad.tensor.Tensor&lt;/code&gt;)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">]</span>
</code></pre></div>
<div class="doc doc-contents first">
<p>Loads a <code class="language-python highlight"><span class="n">state_dict</span></code> into a model. Return the loaded Tensors.</p>
<div class="language-python highlight"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">Net</span><span class="p">:</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">l1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">l2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
<span class="n">net</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span>
<span class="n">state_dict</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">get_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">)</span>
<span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">state_dict</span><span class="p">)</span>
</code></pre></div>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">126</span>
<span class="normal">127</span>
<span class="normal">128</span>
<span class="normal">129</span>
<span class="normal">130</span>
<span class="normal">131</span>
<span class="normal">132</span>
<span class="normal">133</span>
<span class="normal">134</span>
<span class="normal">135</span>
<span class="normal">136</span>
<span class="normal">137</span>
<span class="normal">138</span>
<span class="normal">139</span>
<span class="normal">140</span>
<span class="normal">141</span>
<span class="normal">142</span>
<span class="normal">143</span>
<span class="normal">144</span>
<span class="normal">145</span>
<span class="normal">146</span>
<span class="normal">147</span>
<span class="normal">148</span>
<span class="normal">149</span>
<span class="normal">150</span>
<span class="normal">151</span>
<span class="normal">152</span>
<span class="normal">153</span>
<span class="normal">154</span>
<span class="normal">155</span>
<span class="normal">156</span>
<span class="normal">157</span>
<span class="normal">158</span>
<span class="normal">159</span>
<span class="normal">160</span>
<span class="normal">161</span>
<span class="normal">162</span>
<span class="normal">163</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">load_state_dict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">state_dict</span><span class="p">:</span><span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">],</span> <span class="n">strict</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">consume</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">realize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Loads a `state_dict` into a model. Return the loaded Tensors.</span>
<span class="sd"> ```python</span>
<span class="sd"> class Net:</span>
<span class="sd"> def __init__(self):</span>
<span class="sd"> self.l1 = nn.Linear(4, 5)</span>
<span class="sd"> self.l2 = nn.Linear(5, 6)</span>
<span class="sd"> net = Net()</span>
<span class="sd"> state_dict = nn.state.get_state_dict(net)</span>
<span class="sd"> nn.state.load_state_dict(net, state_dict)</span>
<span class="sd"> ```</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">start_mem_used</span> <span class="o">=</span> <span class="n">GlobalCounters</span><span class="o">.</span><span class="n">mem_used</span>
<span class="n">ret</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">with</span> <span class="n">Timing</span><span class="p">(</span><span class="s2">&quot;loaded weights in &quot;</span><span class="p">,</span>
<span class="k">lambda</span> <span class="n">et_ns</span><span class="p">:</span> <span class="sa">f</span><span class="s2">&quot;, </span><span class="si">{</span><span class="p">(</span><span class="n">B</span><span class="o">:=</span><span class="p">(</span><span class="n">GlobalCounters</span><span class="o">.</span><span class="n">mem_used</span><span class="o">-</span><span class="n">start_mem_used</span><span class="p">))</span><span class="o">/</span><span class="mf">1e9</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2"> GB loaded at </span><span class="si">{</span><span class="n">B</span><span class="o">/</span><span class="n">et_ns</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2"> GB/s&quot;</span><span class="p">,</span> <span class="n">enabled</span><span class="o">=</span><span class="n">verbose</span><span class="p">):</span>
<span class="n">model_state_dict</span> <span class="o">=</span> <span class="n">get_state_dict</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>
<span class="k">if</span> <span class="n">DEBUG</span> <span class="o">&gt;=</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">state_dict</span><span class="p">)</span> <span class="o">&gt;</span> <span class="nb">len</span><span class="p">(</span><span class="n">model_state_dict</span><span class="p">):</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;WARNING: unused weights in state_dict&quot;</span><span class="p">,</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">state_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> <span class="o">-</span> <span class="n">model_state_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">())))</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span><span class="n">v</span> <span class="ow">in</span> <span class="p">(</span><span class="n">t</span> <span class="o">:=</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">model_state_dict</span><span class="o">.</span><span class="n">items</span><span class="p">(),</span> <span class="n">disable</span><span class="o">=</span><span class="n">CI</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">verbose</span><span class="p">)):</span>
<span class="n">t</span><span class="o">.</span><span class="n">desc</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;ram used: </span><span class="si">{</span><span class="n">GlobalCounters</span><span class="o">.</span><span class="n">mem_used</span><span class="o">/</span><span class="mf">1e9</span><span class="si">:</span><span class="s2">5.2f</span><span class="si">}</span><span class="s2"> GB, </span><span class="si">{</span><span class="n">k</span><span class="si">:</span><span class="s2">50s</span><span class="si">}</span><span class="s2">: &quot;</span>
<span class="k">if</span> <span class="n">k</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">state_dict</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">strict</span><span class="p">:</span>
<span class="k">if</span> <span class="n">DEBUG</span> <span class="o">&gt;=</span> <span class="mi">1</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;WARNING: not loading </span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">continue</span>
<span class="k">if</span> <span class="n">v</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
<span class="k">if</span> <span class="p">{(),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,)}</span> <span class="o">==</span> <span class="p">{</span><span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">}:</span> <span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Shape mismatch in layer `</span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s1">`: Expected shape </span><span class="si">{</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s1">, but found </span><span class="si">{</span><span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s1"> in state dict.&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span> <span class="n">v</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">])</span>
<span class="k">else</span><span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">shard</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">v</span><span class="o">.</span><span class="n">uop</span><span class="o">.</span><span class="n">axis</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">device</span><span class="p">))</span>
<span class="k">if</span> <span class="n">realize</span><span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">realize</span><span class="p">()</span>
<span class="k">if</span> <span class="n">consume</span><span class="p">:</span> <span class="k">del</span> <span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span>
<span class="n">ret</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">v</span><span class="p">)</span>
<span class="k">return</span> <span class="n">ret</span>
</code></pre></div></td></tr></table></div>
</details>
</div>
</div>
<div class="doc doc-object doc-function">
<h3 id="tinygrad.nn.state.tar_extract" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <code class="highlight language-python"><span class="n">tar_extract</span></code>
<a href="#tinygrad.nn.state.tar_extract" class="headerlink" title="Permanent link">¤</a></h3>
<div class="doc doc-contents first">
<div class="language-python highlight"><pre><span></span><code><span class="n">tar_extract</span><span class="p">(</span><span class="n">fn</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">|</span> <span class="nb">str</span> <span class="o">|</span> <span class="n">Path</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]</span>
</code></pre></div>
<p>Extracts files from a tar archive and returns them as a dictionary of names (keys) and tensors (values).</p>
<div class="language-python highlight"><pre><span></span><code><span class="n">tensors</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">tar_extract</span><span class="p">(</span><span class="n">Tensor</span><span class="p">(</span><span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="s2">&quot;archive.tar&quot;</span><span class="p">)))</span>
</code></pre></div>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">186</span>
<span class="normal">187</span>
<span class="normal">188</span>
<span class="normal">189</span>
<span class="normal">190</span>
<span class="normal">191</span>
<span class="normal">192</span>
<span class="normal">193</span>
<span class="normal">194</span>
<span class="normal">195</span>
<span class="normal">196</span>
<span class="normal">197</span>
<span class="normal">198</span>
<span class="normal">199</span>
<span class="normal">200</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="nd">@accept_filename</span>
<span class="k">def</span><span class="w"> </span><span class="nf">tar_extract</span><span class="p">(</span><span class="n">t</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> ```python</span>
<span class="sd"> tar_extract(fn: Tensor | str | Path) -&gt; dict[str, Tensor]</span>
<span class="sd"> ```</span>
<span class="sd"> Extracts files from a tar archive and returns them as a dictionary of names (keys) and tensors (values).</span>
<span class="sd"> ```python</span>
<span class="sd"> tensors = nn.state.tar_extract(Tensor(pathlib.Path(&quot;archive.tar&quot;)))</span>
<span class="sd"> ```</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">with</span> <span class="n">tarfile</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">fileobj</span><span class="o">=</span><span class="n">TensorIO</span><span class="p">(</span><span class="n">t</span><span class="p">),</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;r&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">tar</span><span class="p">:</span>
<span class="k">return</span> <span class="p">{</span><span class="n">member</span><span class="o">.</span><span class="n">name</span><span class="p">:</span><span class="n">t</span><span class="p">[</span><span class="n">member</span><span class="o">.</span><span class="n">offset_data</span><span class="p">:</span><span class="n">member</span><span class="o">.</span><span class="n">offset_data</span><span class="o">+</span><span class="n">member</span><span class="o">.</span><span class="n">size</span><span class="p">]</span> <span class="k">for</span> <span class="n">member</span> <span class="ow">in</span> <span class="n">tar</span> <span class="k">if</span> <span class="n">member</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">tarfile</span><span class="o">.</span><span class="n">REGTYPE</span><span class="p">}</span>
</code></pre></div></td></tr></table></div>
</details>
</div>
</div>
<div class="doc doc-object doc-function">
<h3 id="tinygrad.nn.state.torch_load" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <code class="highlight language-python"><span class="n">torch_load</span></code>
<a href="#tinygrad.nn.state.torch_load" class="headerlink" title="Permanent link">¤</a></h3>
<div class="doc doc-contents first">
<div class="language-python highlight"><pre><span></span><code><span class="n">torch_load</span><span class="p">(</span><span class="n">fn</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">|</span> <span class="nb">str</span> <span class="o">|</span> <span class="n">Path</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]</span>
</code></pre></div>
<p>Loads a torch .pth file, returning the <code class="language-python highlight"><span class="n">state_dict</span></code>.</p>
<div class="language-python highlight"><pre><span></span><code><span class="n">state_dict</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">torch_load</span><span class="p">(</span><span class="s2">&quot;test.pth&quot;</span><span class="p">)</span>
</code></pre></div>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">205</span>
<span class="normal">206</span>
<span class="normal">207</span>
<span class="normal">208</span>
<span class="normal">209</span>
<span class="normal">210</span>
<span class="normal">211</span>
<span class="normal">212</span>
<span class="normal">213</span>
<span class="normal">214</span>
<span class="normal">215</span>
<span class="normal">216</span>
<span class="normal">217</span>
<span class="normal">218</span>
<span class="normal">219</span>
<span class="normal">220</span>
<span class="normal">221</span>
<span class="normal">222</span>
<span class="normal">223</span>
<span class="normal">224</span>
<span class="normal">225</span>
<span class="normal">226</span>
<span class="normal">227</span>
<span class="normal">228</span>
<span class="normal">229</span>
<span class="normal">230</span>
<span class="normal">231</span>
<span class="normal">232</span>
<span class="normal">233</span>
<span class="normal">234</span>
<span class="normal">235</span>
<span class="normal">236</span>
<span class="normal">237</span>
<span class="normal">238</span>
<span class="normal">239</span>
<span class="normal">240</span>
<span class="normal">241</span>
<span class="normal">242</span>
<span class="normal">243</span>
<span class="normal">244</span>
<span class="normal">245</span>
<span class="normal">246</span>
<span class="normal">247</span>
<span class="normal">248</span>
<span class="normal">249</span>
<span class="normal">250</span>
<span class="normal">251</span>
<span class="normal">252</span>
<span class="normal">253</span>
<span class="normal">254</span>
<span class="normal">255</span>
<span class="normal">256</span>
<span class="normal">257</span>
<span class="normal">258</span>
<span class="normal">259</span>
<span class="normal">260</span>
<span class="normal">261</span>
<span class="normal">262</span>
<span class="normal">263</span>
<span class="normal">264</span>
<span class="normal">265</span>
<span class="normal">266</span>
<span class="normal">267</span>
<span class="normal">268</span>
<span class="normal">269</span>
<span class="normal">270</span>
<span class="normal">271</span>
<span class="normal">272</span>
<span class="normal">273</span>
<span class="normal">274</span>
<span class="normal">275</span>
<span class="normal">276</span>
<span class="normal">277</span>
<span class="normal">278</span>
<span class="normal">279</span>
<span class="normal">280</span>
<span class="normal">281</span>
<span class="normal">282</span>
<span class="normal">283</span>
<span class="normal">284</span>
<span class="normal">285</span>
<span class="normal">286</span>
<span class="normal">287</span>
<span class="normal">288</span>
<span class="normal">289</span>
<span class="normal">290</span>
<span class="normal">291</span>
<span class="normal">292</span>
<span class="normal">293</span>
<span class="normal">294</span>
<span class="normal">295</span>
<span class="normal">296</span>
<span class="normal">297</span>
<span class="normal">298</span>
<span class="normal">299</span>
<span class="normal">300</span>
<span class="normal">301</span>
<span class="normal">302</span>
<span class="normal">303</span>
<span class="normal">304</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="nd">@accept_filename</span>
<span class="k">def</span><span class="w"> </span><span class="nf">torch_load</span><span class="p">(</span><span class="n">t</span><span class="p">:</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> ```python</span>
<span class="sd"> torch_load(fn: Tensor | str | Path) -&gt; dict[str, Tensor]</span>
<span class="sd"> ```</span>
<span class="sd"> Loads a torch .pth file, returning the `state_dict`.</span>
<span class="sd"> ```python</span>
<span class="sd"> state_dict = nn.state.torch_load(&quot;test.pth&quot;)</span>
<span class="sd"> ```</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">offsets</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="o">|</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
<span class="n">lens</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="o">|</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_rebuild_tensor</span><span class="p">(</span><span class="n">storage</span><span class="p">,</span> <span class="n">storage_offset</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">stride</span><span class="p">):</span>
<span class="k">return</span> <span class="n">_rebuild_tensor_v2</span><span class="p">(</span><span class="n">storage</span><span class="p">,</span> <span class="n">storage_offset</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">stride</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_rebuild_tensor_v2</span><span class="p">(</span><span class="n">storage</span><span class="p">,</span> <span class="n">storage_offset</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">backward_hooks</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">metadata</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="c1">#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)</span>
<span class="n">lens</span><span class="p">[</span><span class="n">storage</span><span class="p">[</span><span class="mi">2</span><span class="p">]]</span> <span class="o">=</span> <span class="n">storage</span><span class="p">[</span><span class="mi">4</span><span class="p">]</span> <span class="o">*</span> <span class="n">storage</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">itemsize</span>
<span class="k">if</span> <span class="n">storage</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">offsets</span><span class="p">:</span> <span class="k">return</span> <span class="kc">None</span>
<span class="n">byte_offset</span> <span class="o">=</span> <span class="n">offsets</span><span class="p">[</span><span class="n">storage</span><span class="p">[</span><span class="mi">2</span><span class="p">]]</span><span class="o">+</span><span class="n">storage_offset</span><span class="o">*</span><span class="n">storage</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">itemsize</span>
<span class="n">ret</span> <span class="o">=</span> <span class="n">t</span><span class="p">[</span><span class="n">byte_offset</span><span class="p">:</span><span class="n">byte_offset</span><span class="o">+</span><span class="n">prod</span><span class="p">(</span><span class="n">size</span><span class="p">)</span><span class="o">*</span><span class="n">storage</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">itemsize</span><span class="p">]</span><span class="o">.</span><span class="n">bitcast</span><span class="p">(</span><span class="n">storage</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="c1"># 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk</span>
<span class="n">shape_strides</span> <span class="o">=</span> <span class="p">[(</span><span class="n">s</span><span class="p">,</span> <span class="n">st</span><span class="p">)</span> <span class="k">for</span> <span class="n">s</span><span class="p">,</span><span class="n">st</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">stride</span><span class="p">)</span> <span class="k">if</span> <span class="n">s</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">]</span>
<span class="n">permute_indexes</span> <span class="o">=</span> <span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">shape_strides</span><span class="p">)</span><span class="o">-</span><span class="mi">1</span><span class="o">-</span><span class="n">y</span> <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">argsort</span><span class="p">([</span><span class="n">x</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">shape_strides</span><span class="p">])]</span>
<span class="k">if</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">permute_indexes</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">tuple</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">permute_indexes</span><span class="p">))):</span>
<span class="n">intermediate_shape</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="n">shape_strides</span><span class="p">[</span><span class="n">x</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">argsort</span><span class="p">(</span><span class="n">permute_indexes</span><span class="p">)])</span>
<span class="k">assert</span> <span class="nb">tuple</span><span class="p">([</span><span class="n">shape_strides</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">argsort</span><span class="p">(</span><span class="n">permute_indexes</span><span class="p">)])</span> <span class="o">==</span> <span class="n">strides_for_shape</span><span class="p">(</span><span class="n">intermediate_shape</span><span class="p">),</span> <span class="s2">&quot;nonpermutable strides&quot;</span>
<span class="k">if</span> <span class="n">DEBUG</span> <span class="o">&gt;=</span> <span class="mi">3</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;WARNING: this torch load is slow. to permute </span><span class="si">{</span><span class="n">intermediate_shape</span><span class="si">}</span><span class="s2"> with </span><span class="si">{</span><span class="n">permute_indexes</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">storage</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">!=</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">,</span> <span class="s2">&quot;can&#39;t permute BF16&quot;</span>
<span class="c1"># TODO: find a nice way to support all movement ops on disktensors</span>
<span class="n">ret</span> <span class="o">=</span> <span class="n">ret</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">intermediate_shape</span><span class="p">)</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="n">permute_indexes</span><span class="p">)</span>
<span class="k">return</span> <span class="n">ret</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">size</span><span class="p">)</span>
<span class="k">class</span><span class="w"> </span><span class="nc">Parameter</span><span class="p">:</span>
<span class="k">def</span><span class="w"> </span><span class="nf">__setstate__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">tensor</span> <span class="o">=</span> <span class="n">state</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">deserialized_objects</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
<span class="n">intercept</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;HalfStorage&quot;</span><span class="p">:</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="s2">&quot;FloatStorage&quot;</span><span class="p">:</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="s2">&quot;BFloat16Storage&quot;</span><span class="p">:</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">,</span>
<span class="s2">&quot;IntStorage&quot;</span><span class="p">:</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="s2">&quot;BoolStorage&quot;</span><span class="p">:</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span>
<span class="s2">&quot;LongStorage&quot;</span><span class="p">:</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">int64</span><span class="p">,</span> <span class="s2">&quot;_rebuild_tensor&quot;</span><span class="p">:</span> <span class="n">_rebuild_tensor</span><span class="p">,</span> <span class="s2">&quot;_rebuild_tensor_v2&quot;</span><span class="p">:</span> <span class="n">_rebuild_tensor_v2</span><span class="p">,</span>
<span class="s2">&quot;FloatTensor&quot;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;Parameter&quot;</span><span class="p">:</span> <span class="n">Parameter</span><span class="p">}</span>
<span class="n">whitelist</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;torch&quot;</span><span class="p">,</span> <span class="s2">&quot;collections&quot;</span><span class="p">,</span> <span class="s2">&quot;numpy&quot;</span><span class="p">,</span> <span class="s2">&quot;_codecs&quot;</span><span class="p">}</span> <span class="c1"># NOTE: this is not for security, only speed</span>
<span class="k">class</span><span class="w"> </span><span class="nc">Dummy</span><span class="p">:</span> <span class="k">pass</span>
<span class="k">class</span><span class="w"> </span><span class="nc">TorchPickle</span><span class="p">(</span><span class="n">pickle</span><span class="o">.</span><span class="n">Unpickler</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">find_class</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">module</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="n">module_root</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">if</span> <span class="n">module_root</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">whitelist</span><span class="p">:</span>
<span class="k">if</span> <span class="n">DEBUG</span> <span class="o">&gt;=</span> <span class="mi">2</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;WARNING: returning Dummy for </span><span class="si">{</span><span class="n">module</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">Dummy</span>
<span class="k">return</span> <span class="n">intercept</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="k">if</span> <span class="n">module_root</span> <span class="o">==</span> <span class="s2">&quot;torch&quot;</span> <span class="k">else</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">find_class</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="n">name</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">persistent_load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pid</span><span class="p">):</span> <span class="k">return</span> <span class="n">deserialized_objects</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">pid</span><span class="p">,</span> <span class="n">pid</span><span class="p">)</span>
<span class="n">fobj</span> <span class="o">=</span> <span class="n">io</span><span class="o">.</span><span class="n">BufferedReader</span><span class="p">(</span><span class="n">TensorIO</span><span class="p">(</span><span class="n">t</span><span class="p">))</span>
<span class="k">def</span><span class="w"> </span><span class="nf">passthrough_reset</span><span class="p">(</span><span class="n">v</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span> <span class="k">return</span> <span class="n">fobj</span><span class="o">.</span><span class="n">seek</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="ow">or</span> <span class="n">v</span>
<span class="k">if</span> <span class="n">passthrough_reset</span><span class="p">(</span><span class="n">zipfile</span><span class="o">.</span><span class="n">is_zipfile</span><span class="p">(</span><span class="n">fobj</span><span class="p">)):</span> <span class="c1"># NOTE: passthrough_reset required to support python &lt; 3.14</span>
<span class="n">myzip</span> <span class="o">=</span> <span class="n">zipfile</span><span class="o">.</span><span class="n">ZipFile</span><span class="p">(</span><span class="n">fobj</span><span class="p">,</span> <span class="s1">&#39;r&#39;</span><span class="p">)</span>
<span class="n">base_name</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">header_offsets</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">for</span> <span class="n">zi</span> <span class="ow">in</span> <span class="n">myzip</span><span class="o">.</span><span class="n">filelist</span><span class="p">:</span>
<span class="k">if</span> <span class="n">base_name</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> <span class="n">base_name</span> <span class="o">=</span> <span class="n">zi</span><span class="o">.</span><span class="n">filename</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">&#39;/&#39;</span><span class="p">,</span> <span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">if</span> <span class="n">zi</span><span class="o">.</span><span class="n">filename</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">base_name</span><span class="si">}</span><span class="s1">/data/&#39;</span><span class="p">):</span> <span class="n">header_offsets</span><span class="p">[</span><span class="n">zi</span><span class="o">.</span><span class="n">filename</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;/&quot;</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]]</span> <span class="o">=</span> <span class="n">zi</span><span class="o">.</span><span class="n">header_offset</span>
<span class="c1"># sadly there&#39;s no way to get the start of the file in the zip without reading the header</span>
<span class="c1"># at least here we read them in parallel</span>
<span class="n">header_contents</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span><span class="p">[</span><span class="n">v</span><span class="o">+</span><span class="mi">26</span><span class="p">:</span><span class="n">v</span><span class="o">+</span><span class="mi">30</span><span class="p">]</span><span class="o">.</span><span class="n">bitcast</span><span class="p">(</span><span class="n">dtypes</span><span class="o">.</span><span class="n">uint16</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">&#39;CPU&#39;</span><span class="p">)</span> <span class="k">for</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">header_offsets</span><span class="o">.</span><span class="n">values</span><span class="p">()]</span>
<span class="n">Tensor</span><span class="o">.</span><span class="n">realize</span><span class="p">(</span><span class="o">*</span><span class="n">header_contents</span><span class="p">)</span>
<span class="k">for</span> <span class="p">(</span><span class="n">n</span><span class="p">,</span><span class="n">o</span><span class="p">),</span><span class="n">c</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">header_offsets</span><span class="o">.</span><span class="n">items</span><span class="p">(),</span> <span class="n">header_contents</span><span class="p">):</span>
<span class="c1"># header_offset + sizeFileHeader + File name length + Extra field length : https://en.wikipedia.org/wiki/ZIP_(file_format)</span>
<span class="n">offsets</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="n">o</span><span class="o">+</span><span class="mi">30</span><span class="o">+</span><span class="nb">sum</span><span class="p">(</span><span class="n">cast</span><span class="p">(</span><span class="nb">list</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">c</span><span class="o">.</span><span class="n">tolist</span><span class="p">()))</span>
<span class="k">with</span> <span class="n">myzip</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">base_name</span><span class="si">}</span><span class="s1">/data.pkl&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">myfile</span><span class="p">:</span>
<span class="k">return</span> <span class="n">TorchPickle</span><span class="p">(</span><span class="n">myfile</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">()</span>
<span class="k">elif</span> <span class="n">passthrough_reset</span><span class="p">(</span><span class="n">tarfile</span><span class="o">.</span><span class="n">is_tarfile</span><span class="p">(</span><span class="n">fobj</span><span class="p">)):</span> <span class="c1"># NOTE: passthrough_reset required to support python &lt; 3.11</span>
<span class="k">with</span> <span class="n">tarfile</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">fileobj</span><span class="o">=</span><span class="n">fobj</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;r&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">tar</span><span class="p">:</span>
<span class="n">storages_offset</span> <span class="o">=</span> <span class="n">tar</span><span class="o">.</span><span class="n">getmember</span><span class="p">(</span><span class="s1">&#39;storages&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">offset_data</span>
<span class="n">f</span> <span class="o">=</span> <span class="n">unwrap</span><span class="p">(</span><span class="n">tar</span><span class="o">.</span><span class="n">extractfile</span><span class="p">(</span><span class="s1">&#39;storages&#39;</span><span class="p">))</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">TorchPickle</span><span class="p">(</span><span class="n">f</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">()):</span> <span class="c1"># num_storages</span>
<span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">storage_type</span><span class="p">),</span> <span class="n">sz</span> <span class="o">=</span> <span class="n">TorchPickle</span><span class="p">(</span><span class="n">f</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="s1">&#39;&lt;q&#39;</span><span class="p">,</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">8</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">offsets</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">storages_offset</span> <span class="o">+</span> <span class="n">f</span><span class="o">.</span><span class="n">tell</span><span class="p">()</span>
<span class="n">f</span><span class="o">.</span><span class="n">seek</span><span class="p">(</span><span class="n">sz</span><span class="o">*</span><span class="n">storage_type</span><span class="o">.</span><span class="n">itemsize</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">f</span> <span class="o">=</span> <span class="n">unwrap</span><span class="p">(</span><span class="n">tar</span><span class="o">.</span><span class="n">extractfile</span><span class="p">(</span><span class="s1">&#39;tensors&#39;</span><span class="p">))</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">TorchPickle</span><span class="p">(</span><span class="n">f</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">()):</span> <span class="c1"># num_tensors</span>
<span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">storage_id</span><span class="p">,</span> <span class="n">_</span><span class="p">),</span> <span class="n">ndim</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">TorchPickle</span><span class="p">(</span><span class="n">f</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="s1">&#39;&lt;i&#39;</span><span class="p">,</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">4</span><span class="p">))[</span><span class="mi">0</span><span class="p">],</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">4</span><span class="p">)</span>
<span class="n">size</span><span class="p">,</span> <span class="n">stride</span> <span class="o">=</span> <span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;&lt;</span><span class="si">{</span><span class="n">ndim</span><span class="si">}</span><span class="s1">q&#39;</span><span class="p">,</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">8</span> <span class="o">*</span> <span class="n">ndim</span><span class="p">)),</span> <span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;&lt;</span><span class="si">{</span><span class="n">ndim</span><span class="si">}</span><span class="s1">q&#39;</span><span class="p">,</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">8</span> <span class="o">*</span> <span class="n">ndim</span><span class="p">))</span>
<span class="n">storage_offset</span> <span class="o">=</span> <span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="s1">&#39;&lt;q&#39;</span><span class="p">,</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">8</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">deserialized_objects</span><span class="p">[</span><span class="nb">str</span><span class="p">(</span><span class="n">key</span><span class="p">)]</span> <span class="o">=</span> <span class="n">_rebuild_tensor_v2</span><span class="p">((</span><span class="kc">None</span><span class="p">,</span> <span class="n">storage_type</span><span class="p">,</span> <span class="n">storage_id</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">storage_offset</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">stride</span><span class="p">)</span>
<span class="k">return</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span><span class="n">v</span><span class="o">.</span><span class="n">tensor</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">Parameter</span><span class="p">)</span> <span class="k">else</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span><span class="n">v</span> <span class="ow">in</span> <span class="n">TorchPickle</span><span class="p">(</span><span class="n">unwrap</span><span class="p">(</span><span class="n">tar</span><span class="o">.</span><span class="n">extractfile</span><span class="p">(</span><span class="s1">&#39;pickle&#39;</span><span class="p">)))</span><span class="o">.</span><span class="n">load</span><span class="p">()</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">pkl</span> <span class="o">=</span> <span class="n">TorchPickle</span><span class="p">(</span><span class="n">fobj</span><span class="p">)</span>
<span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">rwd</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">ids</span><span class="p">,</span> <span class="n">base_offset</span> <span class="o">=</span> <span class="n">pkl</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">pkl</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">pkl</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">fobj</span><span class="o">.</span><span class="n">tell</span><span class="p">(),</span> <span class="n">pkl</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">pkl</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">fobj</span><span class="o">.</span><span class="n">tell</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">ids</span><span class="p">:</span>
<span class="n">offsets</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">base_offset</span> <span class="o">+</span> <span class="mi">8</span>
<span class="n">base_offset</span> <span class="o">+=</span> <span class="mi">8</span> <span class="o">+</span> <span class="n">lens</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">fobj</span><span class="o">.</span><span class="n">seek</span><span class="p">(</span><span class="n">rwd</span><span class="p">)</span>
<span class="k">return</span> <span class="n">TorchPickle</span><span class="p">(</span><span class="n">fobj</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">()</span>
</code></pre></div></td></tr></table></div>
</details>
</div>
</div>
<div class="doc doc-object doc-function">
<h3 id="tinygrad.nn.state.gguf_load" class="doc doc-heading">
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">gguf_load</span>
<a href="#tinygrad.nn.state.gguf_load" class="headerlink" title="Permanent link">¤</a></h3>
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">gguf_load</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;Tensor&lt;/span&gt; (&lt;code&gt;tinygrad.tensor.Tensor&lt;/code&gt;)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#tuple">tuple</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-internal" title="&lt;code class=&quot;doc-symbol doc-symbol-heading doc-symbol-class&quot;&gt;&lt;/code&gt; &lt;span class=&quot;doc doc-object-name doc-class-name&quot;&gt;Tensor&lt;/span&gt; (&lt;code&gt;tinygrad.tensor.Tensor&lt;/code&gt;)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">]]</span>
</code></pre></div>
<div class="doc doc-contents first">
<p>Loads a .gguf file, returning the <code class="language-python highlight"><span class="n">kv_data</span></code> and <code class="language-python highlight"><span class="n">state_dict</span></code>.</p>
<div class="language-python highlight"><pre><span></span><code><span class="n">gguf_tensor</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="s2">&quot;Meta-Llama-3-8B-Instruct.Q4_0.gguf&quot;</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">Device</span><span class="o">.</span><span class="n">DEFAULT</span><span class="p">)</span>
<span class="n">kv_data</span><span class="p">,</span> <span class="n">state_dict</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">gguf_load</span><span class="p">(</span><span class="n">gguf_tensor</span><span class="p">)</span>
</code></pre></div>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>The provided tensor must be on a device that supports execution.</p>
</div>
<details class="mkdocstrings-source">
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">357</span>
<span class="normal">358</span>
<span class="normal">359</span>
<span class="normal">360</span>
<span class="normal">361</span>
<span class="normal">362</span>
<span class="normal">363</span>
<span class="normal">364</span>
<span class="normal">365</span>
<span class="normal">366</span>
<span class="normal">367</span>
<span class="normal">368</span>
<span class="normal">369</span>
<span class="normal">370</span>
<span class="normal">371</span>
<span class="normal">372</span>
<span class="normal">373</span>
<span class="normal">374</span>
<span class="normal">375</span>
<span class="normal">376</span>
<span class="normal">377</span>
<span class="normal">378</span>
<span class="normal">379</span>
<span class="normal">380</span>
<span class="normal">381</span>
<span class="normal">382</span>
<span class="normal">383</span>
<span class="normal">384</span>
<span class="normal">385</span>
<span class="normal">386</span>
<span class="normal">387</span>
<span class="normal">388</span>
<span class="normal">389</span>
<span class="normal">390</span>
<span class="normal">391</span>
<span class="normal">392</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="nd">@accept_filename</span>
<span class="k">def</span><span class="w"> </span><span class="nf">gguf_load</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">tuple</span><span class="p">[</span><span class="nb">dict</span><span class="p">,</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Loads a .gguf file, returning the `kv_data` and `state_dict`.</span>
<span class="sd"> ```python</span>
<span class="sd"> gguf_tensor = Tensor(pathlib.Path(&quot;Meta-Llama-3-8B-Instruct.Q4_0.gguf&quot;)).to(Device.DEFAULT)</span>
<span class="sd"> kv_data, state_dict = nn.state.gguf_load(gguf_tensor)</span>
<span class="sd"> ```</span>
<span class="sd"> NOTE: The provided tensor must be on a device that supports execution.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">reader</span><span class="p">,</span> <span class="n">kv_data</span><span class="p">,</span> <span class="n">state_dict</span> <span class="o">=</span> <span class="n">io</span><span class="o">.</span><span class="n">BufferedReader</span><span class="p">(</span><span class="n">TensorIO</span><span class="p">(</span><span class="n">tensor</span><span class="p">),</span> <span class="mi">1_000_000</span><span class="p">),</span> <span class="p">{},</span> <span class="p">{}</span>
<span class="k">def</span><span class="w"> </span><span class="nf">read_unpack</span><span class="p">(</span><span class="n">fmt</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">n</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span> <span class="k">return</span> <span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="n">fmt</span><span class="p">,</span> <span class="n">reader</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="n">n</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">def</span><span class="w"> </span><span class="nf">read_str</span><span class="p">():</span> <span class="k">return</span> <span class="nb">str</span><span class="p">(</span><span class="n">reader</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="n">read_uint64</span><span class="p">()),</span> <span class="s2">&quot;utf-8&quot;</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">read_arr</span><span class="p">():</span>
<span class="n">reader</span><span class="p">,</span> <span class="n">n</span> <span class="o">=</span> <span class="n">readers</span><span class="p">[</span><span class="n">read_int32</span><span class="p">()],</span> <span class="n">read_uint64</span><span class="p">()</span>
<span class="k">return</span> <span class="p">[</span> <span class="n">reader</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="p">)</span> <span class="p">]</span>
<span class="n">readers</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Callable</span><span class="p">[[],</span> <span class="n">Any</span><span class="p">]]</span> <span class="o">=</span> <span class="p">{</span> <span class="mi">8</span><span class="p">:</span> <span class="n">read_str</span><span class="p">,</span> <span class="mi">9</span><span class="p">:</span> <span class="n">read_arr</span><span class="p">,</span> <span class="o">**</span><span class="p">{</span> <span class="n">t</span><span class="p">:</span> <span class="n">functools</span><span class="o">.</span><span class="n">partial</span><span class="p">(</span><span class="n">read_unpack</span><span class="p">,</span> <span class="s2">&quot;&lt;&quot;</span><span class="o">+</span><span class="n">f</span><span class="p">,</span> <span class="n">nb</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span><span class="p">,</span><span class="n">f</span><span class="p">,</span><span class="n">nb</span> <span class="ow">in</span> \
<span class="p">[</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span><span class="s2">&quot;c&quot;</span><span class="p">,</span><span class="mi">1</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="s2">&quot;b&quot;</span><span class="p">,</span><span class="mi">1</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="s2">&quot;H&quot;</span><span class="p">,</span><span class="mi">2</span><span class="p">),</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span><span class="s2">&quot;h&quot;</span><span class="p">,</span><span class="mi">2</span><span class="p">),</span> <span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="s2">&quot;I&quot;</span><span class="p">,</span><span class="mi">4</span><span class="p">),</span> <span class="p">(</span><span class="mi">5</span><span class="p">,</span><span class="s2">&quot;i&quot;</span><span class="p">,</span><span class="mi">4</span><span class="p">),</span> <span class="p">(</span><span class="mi">6</span><span class="p">,</span><span class="s2">&quot;f&quot;</span><span class="p">,</span><span class="mi">4</span><span class="p">),</span> <span class="p">(</span><span class="mi">7</span><span class="p">,</span><span class="s2">&quot;?&quot;</span><span class="p">,</span><span class="mi">1</span><span class="p">),</span> <span class="p">(</span><span class="mi">10</span><span class="p">,</span><span class="s2">&quot;Q&quot;</span><span class="p">,</span><span class="mi">8</span><span class="p">),</span> <span class="p">(</span><span class="mi">11</span><span class="p">,</span><span class="s2">&quot;q&quot;</span><span class="p">,</span><span class="mi">8</span><span class="p">),</span> <span class="p">(</span><span class="mi">12</span><span class="p">,</span><span class="s2">&quot;d&quot;</span><span class="p">,</span><span class="mi">8</span><span class="p">)</span> <span class="p">]</span> <span class="p">}</span> <span class="p">}</span>
<span class="n">read_uint32</span><span class="p">,</span> <span class="n">read_int32</span><span class="p">,</span> <span class="n">read_uint64</span><span class="p">,</span> <span class="n">read_int64</span> <span class="o">=</span> <span class="n">readers</span><span class="p">[</span><span class="mi">4</span><span class="p">],</span> <span class="n">readers</span><span class="p">[</span><span class="mi">5</span><span class="p">],</span> <span class="n">readers</span><span class="p">[</span><span class="mi">10</span><span class="p">],</span> <span class="n">readers</span><span class="p">[</span><span class="mi">11</span><span class="p">]</span>
<span class="n">magic</span><span class="p">,</span> <span class="n">version</span><span class="p">,</span> <span class="n">n_tensors</span><span class="p">,</span> <span class="n">n_kv</span> <span class="o">=</span> <span class="n">reader</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">4</span><span class="p">),</span> <span class="n">read_int32</span><span class="p">(),</span> <span class="n">read_int64</span><span class="p">(),</span> <span class="n">read_int64</span><span class="p">()</span>
<span class="k">if</span> <span class="n">magic</span> <span class="o">!=</span> <span class="sa">b</span><span class="s2">&quot;GGUF&quot;</span> <span class="ow">or</span> <span class="n">version</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">]:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Invalid GGUF format!&quot;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_kv</span><span class="p">):</span>
<span class="n">k</span><span class="p">,</span> <span class="n">typ</span> <span class="o">=</span> <span class="n">read_str</span><span class="p">(),</span> <span class="n">read_int32</span><span class="p">()</span>
<span class="n">kv_data</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">readers</span><span class="p">[</span><span class="n">typ</span><span class="p">]()</span>
<span class="n">t_infos</span> <span class="o">=</span> <span class="p">[</span> <span class="p">(</span><span class="n">read_str</span><span class="p">(),</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">read_uint64</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">read_uint32</span><span class="p">())),</span> <span class="n">read_int32</span><span class="p">(),</span> <span class="n">read_uint64</span><span class="p">())</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_tensors</span><span class="p">)</span> <span class="p">]</span>
<span class="n">alignment</span><span class="p">,</span> <span class="n">pos</span> <span class="o">=</span> <span class="n">kv_data</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;general.alignment&quot;</span><span class="p">,</span> <span class="mi">32</span><span class="p">),</span> <span class="n">reader</span><span class="o">.</span><span class="n">tell</span><span class="p">()</span>
<span class="n">data_start</span> <span class="o">=</span> <span class="n">round_up</span><span class="p">(</span><span class="n">pos</span><span class="p">,</span> <span class="n">alignment</span><span class="p">)</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="n">typ</span><span class="p">,</span> <span class="n">off</span> <span class="ow">in</span> <span class="n">t_infos</span><span class="p">:</span> <span class="n">state_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">ggml_data_to_tensor</span><span class="p">(</span><span class="n">tensor</span><span class="p">[</span><span class="n">data_start</span> <span class="o">+</span> <span class="n">off</span><span class="p">:],</span> <span class="n">prod</span><span class="p">(</span><span class="n">dims</span><span class="p">),</span> <span class="n">typ</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">*</span><span class="nb">reversed</span><span class="p">(</span><span class="n">dims</span><span class="p">))</span>
<span class="k">return</span> <span class="n">kv_data</span><span class="p">,</span> <span class="n">state_dict</span>
</code></pre></div></td></tr></table></div>
</details>
</div>
</div>
</article>
</div>
<script>var target=document.getElementById(location.hash.slice(1));target&&target.name&&(target.checked=target.name.startsWith("__tabbed_"))</script>
</div>
<button type="button" class="md-top md-icon" data-md-component="top" hidden>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8z"/></svg>
Back to top
</button>
</main>
<footer class="md-footer">
<nav class="md-footer__inner md-grid" aria-label="Footer" >
<a href="../dtypes/" class="md-footer__link md-footer__link--prev" aria-label="Previous: dtypes">
<div class="md-footer__button md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
</div>
<div class="md-footer__title">
<span class="md-footer__direction">
Previous
</span>
<div class="md-ellipsis">
dtypes
</div>
</div>
</a>
<a href="../env_vars/" class="md-footer__link md-footer__link--next" aria-label="Next: Environment Variables">
<div class="md-footer__title">
<span class="md-footer__direction">
Next
</span>
<div class="md-ellipsis">
Environment Variables
</div>
</div>
<div class="md-footer__button md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M4 11v2h12l-5.5 5.5 1.42 1.42L19.84 12l-7.92-7.92L10.5 5.5 16 11z"/></svg>
</div>
</a>
</nav>
<div class="md-footer-meta md-typeset">
<div class="md-footer-meta__inner md-grid">
<div class="md-copyright">
Made with
<a href="https://squidfunk.github.io/mkdocs-material/" target="_blank" rel="noopener">
Material for MkDocs
</a>
</div>
</div>
</div>
</footer>
</div>
<div class="md-dialog" data-md-component="dialog">
<div class="md-dialog__inner md-typeset"></div>
</div>
<script id="__config" type="application/json">{"annotate": null, "base": "..", "features": ["announce.dismiss", "content.action.edit", "content.action.view", "content.code.annotate", "content.code.copy", "content.tooltips", "navigation.footer", "navigation.indexes", "navigation.sections", "navigation.expand", "navigation.top", "navigation.path", "search.highlight", "search.suggest", "toc.follow", "toc.integrate"], "search": "../assets/javascripts/workers/search.2c215733.min.js", "tags": null, "translations": {"clipboard.copied": "Copied to clipboard", "clipboard.copy": "Copy to clipboard", "search.result.more.one": "1 more on this page", "search.result.more.other": "# more on this page", "search.result.none": "No matching documents", "search.result.one": "1 matching document", "search.result.other": "# matching documents", "search.result.placeholder": "Type to start searching", "search.result.term.missing": "Missing", "select.version": "Select version"}, "version": null}</script>
<script src="../assets/javascripts/bundle.79ae519e.min.js"></script>
<script src="../assets/_markdown_exec_pyodide.js"></script>
</body>
</html>