mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
4031 lines
258 KiB
HTML
4031 lines
258 KiB
HTML
|
||
<!doctype html>
|
||
<html lang="en" class="no-js">
|
||
<head>
|
||
|
||
<meta charset="utf-8">
|
||
<meta name="viewport" content="width=device-width,initial-scale=1">
|
||
|
||
|
||
|
||
<link rel="canonical" href="https://docs.tinygrad.org/nn/">
|
||
|
||
|
||
<link rel="prev" href="../dtypes/">
|
||
|
||
|
||
<link rel="next" href="../env_vars/">
|
||
|
||
|
||
|
||
|
||
|
||
<link rel="icon" href="../favicon.svg">
|
||
<meta name="generator" content="mkdocs-1.6.1, mkdocs-material-9.7.6">
|
||
|
||
|
||
|
||
<title>nn (Neural Networks) - tinygrad docs</title>
|
||
|
||
|
||
|
||
<link rel="stylesheet" href="../assets/stylesheets/main.484c7ddc.min.css">
|
||
|
||
|
||
<link rel="stylesheet" href="../assets/stylesheets/palette.ab4e12ef.min.css">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300i,400,400i,700,700i%7CRoboto+Mono:400,400i,700,700i&display=fallback">
|
||
<style>:root{--md-text-font:"Roboto";--md-code-font:"Roboto Mono"}</style>
|
||
|
||
|
||
|
||
<link rel="stylesheet" href="../assets/_markdown_exec_pyodide.css">
|
||
|
||
<link rel="stylesheet" href="../assets/_markdown_exec_ansi.css">
|
||
|
||
<link rel="stylesheet" href="../assets/_mkdocstrings.css">
|
||
|
||
<script>__md_scope=new URL("..",location),__md_hash=e=>[...e].reduce(((e,_)=>(e<<5)-e+_.charCodeAt(0)),0),__md_get=(e,_=localStorage,t=__md_scope)=>JSON.parse(_.getItem(t.pathname+"."+e)),__md_set=(e,_,t=localStorage,a=__md_scope)=>{try{t.setItem(a.pathname+"."+e,JSON.stringify(_))}catch(e){}}</script>
|
||
|
||
|
||
|
||
|
||
|
||
</head>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<body dir="ltr" data-md-color-scheme="default" data-md-color-primary="black" data-md-color-accent="lime">
|
||
|
||
|
||
<input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer" autocomplete="off">
|
||
<input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search" autocomplete="off">
|
||
<label class="md-overlay" for="__drawer"></label>
|
||
<div data-md-component="skip">
|
||
|
||
|
||
<a href="#neural-network-classes" class="md-skip">
|
||
Skip to content
|
||
</a>
|
||
|
||
</div>
|
||
<div data-md-component="announce">
|
||
|
||
</div>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<header class="md-header md-header--shadow" data-md-component="header">
|
||
<nav class="md-header__inner md-grid" aria-label="Header">
|
||
<a href=".." title="tinygrad docs" class="md-header__button md-logo" aria-label="tinygrad docs" data-md-component="logo">
|
||
|
||
<img src="../logo_tiny_dark.svg" alt="logo">
|
||
|
||
</a>
|
||
<label class="md-header__button md-icon" for="__drawer">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M3 6h18v2H3zm0 5h18v2H3zm0 5h18v2H3z"/></svg>
|
||
</label>
|
||
<div class="md-header__title" data-md-component="header-title">
|
||
<div class="md-header__ellipsis">
|
||
<div class="md-header__topic">
|
||
<span class="md-ellipsis">
|
||
tinygrad docs
|
||
</span>
|
||
</div>
|
||
<div class="md-header__topic" data-md-component="header-topic">
|
||
<span class="md-ellipsis">
|
||
|
||
nn (Neural Networks)
|
||
|
||
</span>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
<form class="md-header__option" data-md-component="palette">
|
||
|
||
|
||
|
||
|
||
<input class="md-option" data-md-color-media="(prefers-color-scheme)" data-md-color-scheme="default" data-md-color-primary="black" data-md-color-accent="lime" aria-label="Switch to light mode" type="radio" name="__palette" id="__palette_0">
|
||
|
||
<label class="md-header__button md-icon" title="Switch to light mode" for="__palette_1" hidden>
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="m14.3 16-.7-2h-3.2l-.7 2H7.8L11 7h2l3.2 9zM20 8.69V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12zm-9.15 3.96h2.3L12 9z"/></svg>
|
||
</label>
|
||
|
||
|
||
|
||
|
||
|
||
<input class="md-option" data-md-color-media="(prefers-color-scheme: light)" data-md-color-scheme="default" data-md-color-primary="black" data-md-color-accent="lime" aria-label="Switch to dark mode" type="radio" name="__palette" id="__palette_1">
|
||
|
||
<label class="md-header__button md-icon" title="Switch to dark mode" for="__palette_2" hidden>
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 8a4 4 0 0 0-4 4 4 4 0 0 0 4 4 4 4 0 0 0 4-4 4 4 0 0 0-4-4m0 10a6 6 0 0 1-6-6 6 6 0 0 1 6-6 6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
|
||
</label>
|
||
|
||
|
||
|
||
|
||
|
||
<input class="md-option" data-md-color-media="(prefers-color-scheme: dark)" data-md-color-scheme="slate" data-md-color-primary="black" data-md-color-accent="lime" aria-label="Switch to system preference" type="radio" name="__palette" id="__palette_2">
|
||
|
||
<label class="md-header__button md-icon" title="Switch to system preference" for="__palette_0" hidden>
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 18c-.89 0-1.74-.2-2.5-.55C11.56 16.5 13 14.42 13 12s-1.44-4.5-3.5-5.45C10.26 6.2 11.11 6 12 6a6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
|
||
</label>
|
||
|
||
|
||
</form>
|
||
|
||
|
||
|
||
<script>var palette=__md_get("__palette");if(palette&&palette.color){if("(prefers-color-scheme)"===palette.color.media){var media=matchMedia("(prefers-color-scheme: light)"),input=document.querySelector(media.matches?"[data-md-color-media='(prefers-color-scheme: light)']":"[data-md-color-media='(prefers-color-scheme: dark)']");palette.color.media=input.getAttribute("data-md-color-media"),palette.color.scheme=input.getAttribute("data-md-color-scheme"),palette.color.primary=input.getAttribute("data-md-color-primary"),palette.color.accent=input.getAttribute("data-md-color-accent")}for(var[key,value]of Object.entries(palette.color))document.body.setAttribute("data-md-color-"+key,value)}</script>
|
||
|
||
|
||
|
||
|
||
|
||
<label class="md-header__button md-icon" for="__search">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
|
||
</label>
|
||
<div class="md-search" data-md-component="search" role="dialog">
|
||
<label class="md-search__overlay" for="__search"></label>
|
||
<div class="md-search__inner" role="search">
|
||
<form class="md-search__form" name="search">
|
||
<input type="text" class="md-search__input" name="query" aria-label="Search" placeholder="Search" autocapitalize="off" autocorrect="off" autocomplete="off" spellcheck="false" data-md-component="search-query" required>
|
||
<label class="md-search__icon md-icon" for="__search">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
|
||
</label>
|
||
<nav class="md-search__options" aria-label="Search">
|
||
|
||
<button type="reset" class="md-search__icon md-icon" title="Clear" aria-label="Clear" tabindex="-1">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M19 6.41 17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12z"/></svg>
|
||
</button>
|
||
</nav>
|
||
|
||
<div class="md-search__suggest" data-md-component="search-suggest"></div>
|
||
|
||
</form>
|
||
<div class="md-search__output">
|
||
<div class="md-search__scrollwrap" tabindex="0" data-md-scrollfix>
|
||
<div class="md-search-result" data-md-component="search-result">
|
||
<div class="md-search-result__meta">
|
||
Initializing search
|
||
</div>
|
||
<ol class="md-search-result__list" role="presentation"></ol>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="md-header__source">
|
||
<a href="https://github.com/tinygrad/tinygrad/" title="Go to repository" class="md-source" data-md-component="source">
|
||
<div class="md-source__icon md-icon">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 7.1.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2025 Fonticons, Inc.--><path d="M439.6 236.1 244 40.5c-5.4-5.5-12.8-8.5-20.4-8.5s-15 3-20.4 8.4L162.5 81l51.5 51.5c27.1-9.1 52.7 16.8 43.4 43.7l49.7 49.7c34.2-11.8 61.2 31 35.5 56.7-26.5 26.5-70.2-2.9-56-37.3L240.3 199v121.9c25.3 12.5 22.3 41.8 9.1 55-6.4 6.4-15.2 10.1-24.3 10.1s-17.8-3.6-24.3-10.1c-17.6-17.6-11.1-46.9 11.2-56v-123c-20.8-8.5-24.6-30.7-18.6-45L142.6 101 8.5 235.1C3 240.6 0 247.9 0 255.5s3 15 8.5 20.4l195.6 195.7c5.4 5.4 12.7 8.4 20.4 8.4s15-3 20.4-8.4l194.7-194.7c5.4-5.4 8.4-12.8 8.4-20.4s-3-15-8.4-20.4"/></svg>
|
||
</div>
|
||
<div class="md-source__repository">
|
||
GitHub
|
||
</div>
|
||
</a>
|
||
</div>
|
||
|
||
</nav>
|
||
|
||
</header>
|
||
|
||
<div class="md-container" data-md-component="container">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<main class="md-main" data-md-component="main">
|
||
<div class="md-main__inner md-grid">
|
||
|
||
|
||
|
||
<div class="md-sidebar md-sidebar--primary" data-md-component="sidebar" data-md-type="navigation" >
|
||
<div class="md-sidebar__scrollwrap">
|
||
<div class="md-sidebar__inner">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<nav class="md-nav md-nav--primary md-nav--integrated" aria-label="Navigation" data-md-level="0">
|
||
<label class="md-nav__title" for="__drawer">
|
||
<a href=".." title="tinygrad docs" class="md-nav__button md-logo" aria-label="tinygrad docs" data-md-component="logo">
|
||
|
||
<img src="../logo_tiny_dark.svg" alt="logo">
|
||
|
||
</a>
|
||
tinygrad docs
|
||
</label>
|
||
|
||
<div class="md-nav__source">
|
||
<a href="https://github.com/tinygrad/tinygrad/" title="Go to repository" class="md-source" data-md-component="source">
|
||
<div class="md-source__icon md-icon">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 7.1.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2025 Fonticons, Inc.--><path d="M439.6 236.1 244 40.5c-5.4-5.5-12.8-8.5-20.4-8.5s-15 3-20.4 8.4L162.5 81l51.5 51.5c27.1-9.1 52.7 16.8 43.4 43.7l49.7 49.7c34.2-11.8 61.2 31 35.5 56.7-26.5 26.5-70.2-2.9-56-37.3L240.3 199v121.9c25.3 12.5 22.3 41.8 9.1 55-6.4 6.4-15.2 10.1-24.3 10.1s-17.8-3.6-24.3-10.1c-17.6-17.6-11.1-46.9 11.2-56v-123c-20.8-8.5-24.6-30.7-18.6-45L142.6 101 8.5 235.1C3 240.6 0 247.9 0 255.5s3 15 8.5 20.4l195.6 195.7c5.4 5.4 12.7 8.4 20.4 8.4s15-3 20.4-8.4l194.7-194.7c5.4-5.4 8.4-12.8 8.4-20.4s-3-15-8.4-20.4"/></svg>
|
||
</div>
|
||
<div class="md-source__repository">
|
||
GitHub
|
||
</div>
|
||
</a>
|
||
</div>
|
||
|
||
<ul class="md-nav__list" data-md-scrollfix>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item md-nav__item--active md-nav__item--section md-nav__item--nested">
|
||
|
||
|
||
|
||
<input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_1" checked>
|
||
|
||
|
||
<div class="md-nav__link md-nav__container">
|
||
<a href=".." class="md-nav__link ">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Home
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
|
||
|
||
<label class="md-nav__link " for="__nav_1" id="__nav_1_label" tabindex="">
|
||
<span class="md-nav__icon md-icon"></span>
|
||
</label>
|
||
|
||
</div>
|
||
|
||
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_1_label" aria-expanded="true">
|
||
<label class="md-nav__title" for="__nav_1">
|
||
<span class="md-nav__icon md-icon"></span>
|
||
|
||
|
||
Home
|
||
|
||
|
||
</label>
|
||
<ul class="md-nav__list" data-md-scrollfix>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../quickstart/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Quickstart
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../showcase/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Showcase
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../mnist/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
MNIST Tutorial
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item md-nav__item--active md-nav__item--nested">
|
||
|
||
|
||
|
||
<input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_1_5" checked>
|
||
|
||
|
||
<label class="md-nav__link" for="__nav_1_5" id="__nav_1_5_label" tabindex="0">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
API Reference
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
<span class="md-nav__icon md-icon"></span>
|
||
</label>
|
||
|
||
<nav class="md-nav" data-md-level="2" aria-labelledby="__nav_1_5_label" aria-expanded="true">
|
||
<label class="md-nav__title" for="__nav_1_5">
|
||
<span class="md-nav__icon md-icon"></span>
|
||
|
||
|
||
API Reference
|
||
|
||
|
||
</label>
|
||
<ul class="md-nav__list" data-md-scrollfix>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item md-nav__item--nested">
|
||
|
||
|
||
|
||
|
||
|
||
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_5_1" >
|
||
|
||
|
||
<div class="md-nav__link md-nav__container">
|
||
<a href="../tensor/" class="md-nav__link ">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Tensor
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
|
||
|
||
<label class="md-nav__link " for="__nav_1_5_1" id="__nav_1_5_1_label" tabindex="0">
|
||
<span class="md-nav__icon md-icon"></span>
|
||
</label>
|
||
|
||
</div>
|
||
|
||
<nav class="md-nav" data-md-level="3" aria-labelledby="__nav_1_5_1_label" aria-expanded="false">
|
||
<label class="md-nav__title" for="__nav_1_5_1">
|
||
<span class="md-nav__icon md-icon"></span>
|
||
|
||
|
||
Tensor
|
||
|
||
|
||
</label>
|
||
<ul class="md-nav__list" data-md-scrollfix>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../tensor/properties/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Properties
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../tensor/creation/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Creation
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../tensor/movement/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Movement
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../tensor/elementwise/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Elementwise
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../tensor/ops/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Complex Ops
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
</ul>
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../dtypes/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
dtypes
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item md-nav__item--active">
|
||
|
||
<input class="md-nav__toggle md-toggle" type="checkbox" id="__toc">
|
||
|
||
|
||
|
||
<label class="md-nav__link md-nav__link--active" for="__toc">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
nn (Neural Networks)
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
<span class="md-nav__icon md-icon"></span>
|
||
</label>
|
||
|
||
<a href="./" class="md-nav__link md-nav__link--active">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
nn (Neural Networks)
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
|
||
|
||
|
||
<nav class="md-nav md-nav--secondary" aria-label="Table of contents">
|
||
|
||
|
||
|
||
|
||
<label class="md-nav__title" for="__toc">
|
||
<span class="md-nav__icon md-icon"></span>
|
||
Table of contents
|
||
</label>
|
||
<ul class="md-nav__list" data-md-component="toc" data-md-scrollfix>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#neural-network-classes" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
Neural Network classes
|
||
|
||
</span>
|
||
</a>
|
||
|
||
<nav class="md-nav" aria-label="Neural Network classes">
|
||
<ul class="md-nav__list">
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.BatchNorm" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> BatchNorm
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.Conv1d" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> Conv1d
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.Conv2d" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> Conv2d
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.ConvTranspose1d" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> ConvTranspose1d
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.ConvTranspose2d" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> ConvTranspose2d
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.Linear" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> Linear
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.GroupNorm" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> GroupNorm
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.InstanceNorm" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> InstanceNorm
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.LayerNorm" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> LayerNorm
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.LayerNorm2d" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> LayerNorm2d
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.RMSNorm" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> RMSNorm
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.Embedding" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> Embedding
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.LSTMCell" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> LSTMCell
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
</ul>
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#optimizers" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
Optimizers
|
||
|
||
</span>
|
||
</a>
|
||
|
||
<nav class="md-nav" aria-label="Optimizers">
|
||
<ul class="md-nav__list">
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.optim.SGD" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> SGD
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.optim.LARS" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> LARS
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.optim.AdamW" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> AdamW
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.optim.Adam" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> Adam
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.optim.LAMB" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> LAMB
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
</ul>
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#loadsave" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
Load/Save
|
||
|
||
</span>
|
||
</a>
|
||
|
||
<nav class="md-nav" aria-label="Load/Save">
|
||
<ul class="md-nav__list">
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.state.safe_load" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> safe_load
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.state.safe_save" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> safe_save
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.state.get_state_dict" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> get_state_dict
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.state.get_parameters" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> get_parameters
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.state.load_state_dict" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> load_state_dict
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.state.tar_extract" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> tar_extract
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.state.torch_load" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> torch_load
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.llm.gguf.gguf_load" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> gguf_load
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
</ul>
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
</ul>
|
||
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../env_vars/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Environment Variables
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item md-nav__item--nested">
|
||
|
||
|
||
|
||
|
||
|
||
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_5_5" >
|
||
|
||
|
||
<label class="md-nav__link" for="__nav_1_5_5" id="__nav_1_5_5_label" tabindex="0">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Runtime
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
<span class="md-nav__icon md-icon"></span>
|
||
</label>
|
||
|
||
<nav class="md-nav" data-md-level="3" aria-labelledby="__nav_1_5_5_label" aria-expanded="false">
|
||
<label class="md-nav__title" for="__nav_1_5_5">
|
||
<span class="md-nav__icon md-icon"></span>
|
||
|
||
|
||
Runtime
|
||
|
||
|
||
</label>
|
||
<ul class="md-nav__list" data-md-scrollfix>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../runtime/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Runtimes
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../tinygpu/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
egpu for mac
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
</ul>
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
|
||
|
||
|
||
</ul>
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item md-nav__item--nested">
|
||
|
||
|
||
|
||
|
||
|
||
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_6" >
|
||
|
||
|
||
<label class="md-nav__link" for="__nav_1_6" id="__nav_1_6_label" tabindex="0">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Developer
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
<span class="md-nav__icon md-icon"></span>
|
||
</label>
|
||
|
||
<nav class="md-nav" data-md-level="2" aria-labelledby="__nav_1_6_label" aria-expanded="false">
|
||
<label class="md-nav__title" for="__nav_1_6">
|
||
<span class="md-nav__icon md-icon"></span>
|
||
|
||
|
||
Developer
|
||
|
||
|
||
</label>
|
||
<ul class="md-nav__list" data-md-scrollfix>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../developer/developer/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Intro
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../developer/layout/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Layout
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../developer/speed/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Speed
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../developer/uop/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
UOp
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item md-nav__item--nested">
|
||
|
||
|
||
|
||
|
||
|
||
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_6_5" >
|
||
|
||
|
||
<label class="md-nav__link" for="__nav_1_6_5" id="__nav_1_6_5_label" tabindex="0">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Runtime
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
<span class="md-nav__icon md-icon"></span>
|
||
</label>
|
||
|
||
<nav class="md-nav" data-md-level="3" aria-labelledby="__nav_1_6_5_label" aria-expanded="false">
|
||
<label class="md-nav__title" for="__nav_1_6_5">
|
||
<span class="md-nav__icon md-icon"></span>
|
||
|
||
|
||
Runtime
|
||
|
||
|
||
</label>
|
||
<ul class="md-nav__list" data-md-scrollfix>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../developer/runtime/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Runtime Overview
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../developer/hcq/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
HCQ
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../developer/am/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
AM Driver
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
</ul>
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
|
||
|
||
|
||
</ul>
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../tinybox/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
tinybox
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
</ul>
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
|
||
|
||
</ul>
|
||
</nav>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
|
||
|
||
<div class="md-content" data-md-component="content">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<nav class="md-path" aria-label="Navigation" >
|
||
<ol class="md-path__list">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-path__item">
|
||
<a href=".." class="md-path__link">
|
||
|
||
<span class="md-ellipsis">
|
||
Home
|
||
</span>
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-path__item">
|
||
<a href="../tensor/" class="md-path__link">
|
||
|
||
<span class="md-ellipsis">
|
||
API Reference
|
||
</span>
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</ol>
|
||
</nav>
|
||
|
||
|
||
<article class="md-content__inner md-typeset">
|
||
|
||
|
||
|
||
|
||
|
||
<a href="https://github.com/tinygrad/tinygrad/edit/master/docs/nn.md" title="Edit this page" class="md-content__button md-icon" rel="edit">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M10 20H6V4h7v5h5v3.1l2-2V8l-6-6H6c-1.1 0-2 .9-2 2v16c0 1.1.9 2 2 2h4zm10.2-7c.1 0 .3.1.4.2l1.3 1.3c.2.2.2.6 0 .8l-1 1-2.1-2.1 1-1c.1-.1.2-.2.4-.2m0 3.9L14.1 23H12v-2.1l6.1-6.1z"/></svg>
|
||
</a>
|
||
|
||
|
||
|
||
|
||
|
||
<a href="https://github.com/tinygrad/tinygrad/raw/master/docs/nn.md" title="View source of this page" class="md-content__button md-icon">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M17 18c.56 0 1 .44 1 1s-.44 1-1 1-1-.44-1-1 .44-1 1-1m0-3c-2.73 0-5.06 1.66-6 4 .94 2.34 3.27 4 6 4s5.06-1.66 6-4c-.94-2.34-3.27-4-6-4m0 6.5a2.5 2.5 0 0 1-2.5-2.5 2.5 2.5 0 0 1 2.5-2.5 2.5 2.5 0 0 1 2.5 2.5 2.5 2.5 0 0 1-2.5 2.5M9.27 20H6V4h7v5h5v4.07c.7.08 1.36.25 2 .49V8l-6-6H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h4.5a8.2 8.2 0 0 1-1.23-2"/></svg>
|
||
</a>
|
||
|
||
|
||
|
||
<h1>nn (Neural Networks)</h1>
|
||
|
||
<h2 id="neural-network-classes">Neural Network classes<a class="headerlink" href="#neural-network-classes" title="Permanent link">¤</a></h2>
|
||
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.BatchNorm" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">BatchNorm</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.BatchNorm" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">BatchNorm</span><span class="p">(</span>
|
||
<span class="n">sz</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="o">=</span><span class="mf">1e-05</span><span class="p">,</span>
|
||
<span class="n">affine</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">track_running_stats</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">momentum</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
|
||
|
||
<p>Applies Batch Normalization over a 2D or 3D input.</p>
|
||
<ul>
|
||
<li>Paper: <a href="https://arxiv.org/abs/1502.03167v3">https://arxiv.org/abs/1502.03167v3</a></li>
|
||
</ul>
|
||
<p>See: <code class="language-python highlight"><span class="n">Tensor</span><span class="o">.</span><span class="n">batchnorm</span></code></p>
|
||
<p></p>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="mf">0.48187804222106934</span> <span class="mf">0.31661197543144226</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="mf">0.4818756580352783</span> <span class="mf">0.3166103959083557</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">32</span>
|
||
<span class="normal">33</span>
|
||
<span class="normal">34</span>
|
||
<span class="normal">35</span>
|
||
<span class="normal">36</span>
|
||
<span class="normal">37</span>
|
||
<span class="normal">38</span>
|
||
<span class="normal">39</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sz</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">track_running_stats</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="n">eps</span><span class="p">,</span> <span class="n">track_running_stats</span><span class="p">,</span> <span class="n">momentum</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">sz</span><span class="p">)</span> <span class="k">if</span> <span class="n">affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">sz</span><span class="p">)</span> <span class="k">if</span> <span class="n">affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_batches_tracked</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="s1">'long'</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">track_running_stats</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_mean</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_var</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">sz</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">sz</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.Conv1d" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">Conv1d</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.Conv1d" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">Conv1d</span><span class="p">(</span>
|
||
<span class="n">in_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">out_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">padding</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="p">)</span> <span class="o">-></span> <span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Conv2d</span> (<code>tinygrad.nn.Conv2d</code>)" href="#tinygrad.nn.Conv2d">Conv2d</a></span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>Applies a 1D convolution over an input signal composed of several input planes.</p>
|
||
<p>See: <a href="https://pytorch.org/docs/stable/generated/torch.nn.Conv1d">https://pytorch.org/docs/stable/generated/torch.nn.Conv1d</a></p>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv1d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[</span><span class="mf">0.9444</span> <span class="mf">0.3987</span> <span class="mf">0.3087</span> <span class="mf">0.0496</span><span class="p">]]]</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[</span><span class="mf">0.2584</span> <span class="mf">0.1335</span><span class="p">]]]</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">62</span>
|
||
<span class="normal">63</span>
|
||
<span class="normal">64</span>
|
||
<span class="normal">65</span>
|
||
<span class="normal">66</span>
|
||
<span class="normal">67</span>
|
||
<span class="normal">68</span>
|
||
<span class="normal">69</span>
|
||
<span class="normal">70</span>
|
||
<span class="normal">71</span>
|
||
<span class="normal">72</span>
|
||
<span class="normal">73</span>
|
||
<span class="normal">74</span>
|
||
<span class="normal">75</span>
|
||
<span class="normal">76</span>
|
||
<span class="normal">77</span>
|
||
<span class="normal">78</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">Conv1d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="p">:</span><span class="nb">int</span><span class="o">|</span><span class="nb">str</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Conv2d</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Applies a 1D convolution over an input signal composed of several input planes.</span>
|
||
|
||
<span class="sd"> See: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d</span>
|
||
|
||
<span class="sd"> ```python exec="true" source="above" session="tensor" result="python"</span>
|
||
<span class="sd"> conv = nn.Conv1d(1, 1, 3)</span>
|
||
<span class="sd"> t = Tensor.rand(1, 1, 4)</span>
|
||
<span class="sd"> print(t.numpy())</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> ```python exec="true" source="above" session="tensor" result="python"</span>
|
||
<span class="sd"> t = conv(t)</span>
|
||
<span class="sd"> print(t.numpy())</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">return</span> <span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="p">(</span><span class="n">kernel_size</span><span class="p">,),</span> <span class="n">stride</span><span class="p">,</span> <span class="n">padding</span><span class="p">,</span> <span class="n">dilation</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.Conv2d" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Conv2d</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.Conv2d" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">Conv2d</span><span class="p">(</span>
|
||
<span class="n">in_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">out_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#tuple">tuple</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="o">...</span><span class="p">],</span>
|
||
<span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">padding</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#tuple">tuple</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
|
||
|
||
<p>Applies a 2D convolution over an input signal composed of several input planes.</p>
|
||
<p>See: <a href="https://pytorch.org/docs/stable/generated/torch.nn.Conv2d">https://pytorch.org/docs/stable/generated/torch.nn.Conv2d</a></p>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[[</span><span class="mf">0.9785</span> <span class="mf">0.4781</span> <span class="mf">0.9548</span> <span class="mf">0.6781</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="mf">0.2831</span> <span class="mf">0.524</span> <span class="mf">0.4712</span> <span class="mf">0.8209</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="mf">0.7255</span> <span class="mf">0.4627</span> <span class="mf">0.025</span> <span class="mf">0.064</span> <span class="p">]</span>
|
||
<span class="p">[</span><span class="mf">0.7955</span> <span class="mf">0.2812</span> <span class="mf">0.492</span> <span class="mf">0.3316</span><span class="p">]]]]</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[[</span><span class="o">-</span><span class="mf">0.1055</span> <span class="mf">0.1569</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="o">-</span><span class="mf">0.0542</span> <span class="mf">0.1154</span><span class="p">]]]]</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal"> 96</span>
|
||
<span class="normal"> 97</span>
|
||
<span class="normal"> 98</span>
|
||
<span class="normal"> 99</span>
|
||
<span class="normal">100</span>
|
||
<span class="normal">101</span>
|
||
<span class="normal">102</span>
|
||
<span class="normal">103</span>
|
||
<span class="normal">104</span>
|
||
<span class="normal">105</span>
|
||
<span class="normal">106</span>
|
||
<span class="normal">107</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">:</span><span class="nb">int</span><span class="o">|</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="p">:</span><span class="nb">int</span><span class="o">|</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span><span class="o">|</span><span class="nb">str</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span> <span class="o">=</span> <span class="n">make_tuple</span><span class="p">(</span><span class="n">kernel_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">padding</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">padding</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span> <span class="o">!=</span> <span class="s1">'same'</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid padding string </span><span class="si">{</span><span class="n">padding</span><span class="si">!r}</span><span class="s2">, only 'same' is supported"</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">stride</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"padding='same' is not supported for strided convolutions"</span><span class="p">)</span>
|
||
<span class="n">pad</span> <span class="o">=</span> <span class="p">[(</span><span class="n">d</span><span class="o">*</span><span class="p">(</span><span class="n">k</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">//</span><span class="mi">2</span><span class="p">,</span> <span class="n">d</span><span class="o">*</span><span class="p">(</span><span class="n">k</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="n">d</span><span class="o">*</span><span class="p">(</span><span class="n">k</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">//</span><span class="mi">2</span><span class="p">)</span> <span class="k">for</span> <span class="n">d</span><span class="p">,</span><span class="n">k</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">make_tuple</span><span class="p">(</span><span class="n">dilation</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">)),</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">[::</span><span class="o">-</span><span class="mi">1</span><span class="p">])]</span>
|
||
<span class="n">padding</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">flatten</span><span class="p">(</span><span class="n">pad</span><span class="p">))</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">stride</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dilation</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">padding</span> <span class="o">=</span> <span class="n">stride</span><span class="p">,</span> <span class="n">dilation</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> <span class="n">padding</span>
|
||
<span class="n">scale</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">in_channels</span> <span class="o">*</span> <span class="n">prod</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">))</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">in_channels</span><span class="o">//</span><span class="n">groups</span><span class="p">,</span> <span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">scale</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scale</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">scale</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scale</span><span class="p">)</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.ConvTranspose1d" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">ConvTranspose1d</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.ConvTranspose1d" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">ConvTranspose1d</span><span class="p">(</span>
|
||
<span class="n">in_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">out_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">output_padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="p">)</span> <span class="o">-></span> <span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">ConvTranspose2d</span> (<code>tinygrad.nn.ConvTranspose2d</code>)" href="#tinygrad.nn.ConvTranspose2d">ConvTranspose2d</a></span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>Applies a 1D transposed convolution operator over an input signal composed of several input planes.</p>
|
||
<p>See: <a href="https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d">https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d</a></p>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvTranspose1d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[</span><span class="mf">0.8341</span> <span class="mf">0.8731</span> <span class="mf">0.7338</span> <span class="mf">0.5785</span><span class="p">]]]</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[</span> <span class="mf">0.0298</span> <span class="mf">0.1353</span> <span class="mf">0.1504</span> <span class="mf">0.1228</span> <span class="mf">0.0576</span> <span class="o">-</span><span class="mf">0.0174</span><span class="p">]]]</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">111</span>
|
||
<span class="normal">112</span>
|
||
<span class="normal">113</span>
|
||
<span class="normal">114</span>
|
||
<span class="normal">115</span>
|
||
<span class="normal">116</span>
|
||
<span class="normal">117</span>
|
||
<span class="normal">118</span>
|
||
<span class="normal">119</span>
|
||
<span class="normal">120</span>
|
||
<span class="normal">121</span>
|
||
<span class="normal">122</span>
|
||
<span class="normal">123</span>
|
||
<span class="normal">124</span>
|
||
<span class="normal">125</span>
|
||
<span class="normal">126</span>
|
||
<span class="normal">127</span>
|
||
<span class="normal">128</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">ConvTranspose1d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">output_padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">ConvTranspose2d</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Applies a 1D transposed convolution operator over an input signal composed of several input planes.</span>
|
||
|
||
<span class="sd"> See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d</span>
|
||
|
||
<span class="sd"> ```python exec="true" source="above" session="tensor" result="python"</span>
|
||
<span class="sd"> conv = nn.ConvTranspose1d(1, 1, 3)</span>
|
||
<span class="sd"> t = Tensor.rand(1, 1, 4)</span>
|
||
<span class="sd"> print(t.numpy())</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> ```python exec="true" source="above" session="tensor" result="python"</span>
|
||
<span class="sd"> t = conv(t)</span>
|
||
<span class="sd"> print(t.numpy())</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">return</span> <span class="n">ConvTranspose2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="p">(</span><span class="n">kernel_size</span><span class="p">,),</span> <span class="n">stride</span><span class="p">,</span> <span class="n">padding</span><span class="p">,</span> <span class="n">output_padding</span><span class="p">,</span> <span class="n">dilation</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.ConvTranspose2d" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">ConvTranspose2d</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.ConvTranspose2d" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">ConvTranspose2d</span><span class="p">(</span>
|
||
<span class="n">in_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">out_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#tuple">tuple</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="o">...</span><span class="p">],</span>
|
||
<span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">output_padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
<p class="doc doc-class-bases">
|
||
Bases: <code><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Conv2d</span> (<code>tinygrad.nn.Conv2d</code>)" href="#tinygrad.nn.Conv2d">Conv2d</a></code></p>
|
||
|
||
|
||
|
||
<p>Applies a 2D transposed convolution operator over an input image.</p>
|
||
<p>See: <a href="https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d">https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d</a></p>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvTranspose2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[[</span><span class="mf">0.5424</span> <span class="mf">0.6224</span> <span class="mf">0.2948</span> <span class="mf">0.2591</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="mf">0.3831</span> <span class="mf">0.2783</span> <span class="mf">0.4932</span> <span class="mf">0.7726</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="mf">0.7356</span> <span class="mf">0.0175</span> <span class="mf">0.3564</span> <span class="mf">0.9384</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="mf">0.9356</span> <span class="mf">0.363</span> <span class="mf">0.3902</span> <span class="mf">0.3312</span><span class="p">]]]]</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[[</span><span class="o">-</span><span class="mf">0.342</span> <span class="o">-</span><span class="mf">0.2702</span> <span class="o">-</span><span class="mf">0.2942</span> <span class="o">-</span><span class="mf">0.3583</span> <span class="o">-</span><span class="mf">0.2273</span> <span class="o">-</span><span class="mf">0.2633</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="o">-</span><span class="mf">0.4465</span> <span class="o">-</span><span class="mf">0.3751</span> <span class="o">-</span><span class="mf">0.6162</span> <span class="o">-</span><span class="mf">0.6435</span> <span class="o">-</span><span class="mf">0.2771</span> <span class="o">-</span><span class="mf">0.4588</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="o">-</span><span class="mf">0.564</span> <span class="o">-</span><span class="mf">0.2991</span> <span class="o">-</span><span class="mf">0.9048</span> <span class="o">-</span><span class="mf">0.8494</span> <span class="o">-</span><span class="mf">0.3585</span> <span class="o">-</span><span class="mf">0.7028</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="o">-</span><span class="mf">0.6868</span> <span class="o">-</span><span class="mf">0.219</span> <span class="o">-</span><span class="mf">0.9375</span> <span class="o">-</span><span class="mf">0.7595</span> <span class="o">-</span><span class="mf">0.5135</span> <span class="o">-</span><span class="mf">0.7095</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="o">-</span><span class="mf">0.5523</span> <span class="o">-</span><span class="mf">0.381</span> <span class="o">-</span><span class="mf">0.7811</span> <span class="o">-</span><span class="mf">0.5736</span> <span class="o">-</span><span class="mf">0.4904</span> <span class="o">-</span><span class="mf">0.4718</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="o">-</span><span class="mf">0.3298</span> <span class="o">-</span><span class="mf">0.354</span> <span class="o">-</span><span class="mf">0.4562</span> <span class="o">-</span><span class="mf">0.3536</span> <span class="o">-</span><span class="mf">0.3082</span> <span class="o">-</span><span class="mf">0.2628</span><span class="p">]]]]</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">146</span>
|
||
<span class="normal">147</span>
|
||
<span class="normal">148</span>
|
||
<span class="normal">149</span>
|
||
<span class="normal">150</span>
|
||
<span class="normal">151</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">:</span><span class="nb">int</span><span class="o">|</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">output_padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">padding</span><span class="p">,</span> <span class="n">dilation</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
|
||
<span class="n">scale</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">in_channels</span> <span class="o">*</span> <span class="n">prod</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">))</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">//</span><span class="n">groups</span><span class="p">,</span> <span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">scale</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scale</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">output_padding</span> <span class="o">=</span> <span class="n">output_padding</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.Linear" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Linear</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.Linear" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="n">out_features</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
|
||
|
||
<p>Applies a linear transformation to the incoming data.</p>
|
||
<p>See: <a href="https://pytorch.org/docs/stable/generated/torch.nn.Linear">https://pytorch.org/docs/stable/generated/torch.nn.Linear</a></p>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">lin</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[</span><span class="mf">0.2814</span> <span class="mf">0.4195</span> <span class="mf">0.6016</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="mf">0.7782</span> <span class="mf">0.4539</span> <span class="mf">0.9588</span><span class="p">]]</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">lin</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[</span><span class="o">-</span><span class="mf">0.1343</span> <span class="mf">0.5919</span> <span class="o">-</span><span class="mf">0.1691</span> <span class="o">-</span><span class="mf">0.4668</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="o">-</span><span class="mf">0.0287</span> <span class="mf">0.7865</span> <span class="o">-</span><span class="mf">0.2679</span> <span class="o">-</span><span class="mf">0.8907</span><span class="p">]]</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">172</span>
|
||
<span class="normal">173</span>
|
||
<span class="normal">174</span>
|
||
<span class="normal">175</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_features</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">out_features</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="n">bound</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">in_features</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">out_features</span><span class="p">,</span> <span class="n">in_features</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">bound</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">bound</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">out_features</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">bound</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">bound</span><span class="p">)</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.GroupNorm" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">GroupNorm</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.GroupNorm" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">GroupNorm</span><span class="p">(</span>
|
||
<span class="n">num_groups</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">num_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="o">=</span><span class="mf">1e-05</span><span class="p">,</span>
|
||
<span class="n">affine</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
|
||
|
||
<p>Applies Group Normalization over a mini-batch of inputs.</p>
|
||
<ul>
|
||
<li>Paper: <a href="https://arxiv.org/abs/1803.08494v3">https://arxiv.org/abs/1803.08494v3</a></li>
|
||
</ul>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">12</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="mf">2.0116562843322754</span> <span class="mf">0.5751477479934692</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="o">-</span><span class="mf">3.1152862334238307e-07</span> <span class="mf">1.0012893676757812</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">195</span>
|
||
<span class="normal">196</span>
|
||
<span class="normal">197</span>
|
||
<span class="normal">198</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_groups</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">num_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_groups</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_channels</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">num_groups</span><span class="p">,</span> <span class="n">num_channels</span><span class="p">,</span> <span class="n">eps</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">num_channels</span><span class="p">)</span> <span class="k">if</span> <span class="n">affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">num_channels</span><span class="p">)</span> <span class="k">if</span> <span class="n">affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.InstanceNorm" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">InstanceNorm</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.InstanceNorm" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">InstanceNorm</span><span class="p">(</span>
|
||
<span class="n">num_features</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#float">float</a></span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">,</span>
|
||
<span class="n">affine</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#bool">bool</a></span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
|
||
|
||
<p>Applies Instance Normalization over a mini-batch of inputs.</p>
|
||
<ul>
|
||
<li>Paper: <a href="https://arxiv.org/abs/1607.08022v3">https://arxiv.org/abs/1607.08022v3</a></li>
|
||
</ul>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">InstanceNorm</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="mf">2.0323781967163086</span> <span class="mf">0.5779063105583191</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="mf">1.389967536624681e-07</span> <span class="mf">1.0052329301834106</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">225</span>
|
||
<span class="normal">226</span>
|
||
<span class="normal">227</span>
|
||
<span class="normal">228</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_features</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">eps</span><span class="p">:</span><span class="nb">float</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span><span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_features</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">num_features</span><span class="p">,</span> <span class="n">eps</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">num_features</span><span class="p">)</span> <span class="k">if</span> <span class="n">affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">num_features</span><span class="p">)</span> <span class="k">if</span> <span class="n">affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.LayerNorm" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">LayerNorm</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.LayerNorm" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">LayerNorm</span><span class="p">(</span>
|
||
<span class="n">normalized_shape</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#tuple">tuple</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="o">...</span><span class="p">],</span>
|
||
<span class="n">eps</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#float">float</a></span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">,</span>
|
||
<span class="n">elementwise_affine</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#bool">bool</a></span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
|
||
|
||
<p>Applies Layer Normalization over a mini-batch of inputs.</p>
|
||
<ul>
|
||
<li>Paper: <a href="https://arxiv.org/abs/1607.06450v1">https://arxiv.org/abs/1607.06450v1</a></li>
|
||
</ul>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="mf">1.8958020210266113</span> <span class="mf">0.5616000294685364</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="o">-</span><span class="mf">1.4226594657884561e-07</span> <span class="mf">1.0170583724975586</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">251</span>
|
||
<span class="normal">252</span>
|
||
<span class="normal">253</span>
|
||
<span class="normal">254</span>
|
||
<span class="normal">255</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">normalized_shape</span><span class="p">:</span><span class="nb">int</span><span class="o">|</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">eps</span><span class="p">:</span><span class="nb">float</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">elementwise_affine</span><span class="p">:</span><span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">=</span> <span class="n">make_tuple</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">axis</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="o">-</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">))),</span> <span class="n">eps</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">)</span> <span class="k">if</span> <span class="n">elementwise_affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">)</span> <span class="k">if</span> <span class="n">elementwise_affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.LayerNorm2d" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">LayerNorm2d</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.LayerNorm2d" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">LayerNorm2d</span><span class="p">(</span>
|
||
<span class="n">normalized_shape</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#tuple">tuple</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="o">...</span><span class="p">],</span>
|
||
<span class="n">eps</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#float">float</a></span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">,</span>
|
||
<span class="n">elementwise_affine</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#bool">bool</a></span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
<p class="doc doc-class-bases">
|
||
Bases: <code><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">LayerNorm</span> (<code>tinygrad.nn.LayerNorm</code>)" href="#tinygrad.nn.LayerNorm">LayerNorm</a></code></p>
|
||
|
||
|
||
|
||
<p>Applies Layer Normalization over a mini-batch of 2D inputs.</p>
|
||
<p>See: <code class="language-python highlight"><span class="n">LayerNorm</span></code></p>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm2d</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="mf">2.0446934700012207</span> <span class="mf">0.5648013353347778</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="o">-</span><span class="mf">2.532809730837471e-07</span> <span class="mf">1.0050703287124634</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">251</span>
|
||
<span class="normal">252</span>
|
||
<span class="normal">253</span>
|
||
<span class="normal">254</span>
|
||
<span class="normal">255</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">normalized_shape</span><span class="p">:</span><span class="nb">int</span><span class="o">|</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">eps</span><span class="p">:</span><span class="nb">float</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">elementwise_affine</span><span class="p">:</span><span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">=</span> <span class="n">make_tuple</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">axis</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="o">-</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">))),</span> <span class="n">eps</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">)</span> <span class="k">if</span> <span class="n">elementwise_affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">)</span> <span class="k">if</span> <span class="n">elementwise_affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.RMSNorm" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">RMSNorm</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.RMSNorm" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">RMSNorm</span><span class="p">(</span><span class="n">dim</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-06</span><span class="p">,</span> <span class="n">elementwise_affine</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
|
||
|
||
<p>Applies Root Mean Square Normalization to input.</p>
|
||
<ul>
|
||
<li>Paper: <a href="https://arxiv.org/abs/1910.07467">https://arxiv.org/abs/1910.07467</a></li>
|
||
</ul>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">RMSNorm</span><span class="p">(</span><span class="mi">4</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtypes</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[</span> <span class="mf">0.</span> <span class="mf">1.</span> <span class="mf">2.</span> <span class="mf">3.</span><span class="p">]</span>
|
||
<span class="p">[</span> <span class="mf">4.</span> <span class="mf">5.</span> <span class="mf">6.</span> <span class="mf">7.</span><span class="p">]</span>
|
||
<span class="p">[</span> <span class="mf">8.</span> <span class="mf">9.</span> <span class="mf">10.</span> <span class="mf">11.</span><span class="p">]]</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="nb">print</span><span class="p">(</span><span class="n">norm</span><span class="p">(</span><span class="n">t</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[</span><span class="mf">0.</span> <span class="mf">0.5345</span> <span class="mf">1.069</span> <span class="mf">1.6036</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="mf">0.7127</span> <span class="mf">0.8909</span> <span class="mf">1.069</span> <span class="mf">1.2472</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="mf">0.8363</span> <span class="mf">0.9409</span> <span class="mf">1.0454</span> <span class="mf">1.15</span> <span class="p">]]</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">296</span>
|
||
<span class="normal">297</span>
|
||
<span class="normal">298</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">,</span> <span class="n">elementwise_affine</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span> <span class="k">if</span> <span class="n">elementwise_affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.Embedding" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Embedding</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.Embedding" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">Embedding</span><span class="p">(</span><span class="n">vocab_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="n">embed_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
|
||
|
||
<p>A simple lookup table that stores embeddings of a fixed dictionary and size.</p>
|
||
<p>See: <a href="https://pytorch.org/docs/stable/generated/torch.nn.Embedding">https://pytorch.org/docs/stable/generated/torch.nn.Embedding</a></p>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">emb</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">emb</span><span class="p">(</span><span class="n">Tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">]))</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[</span><span class="o">-</span><span class="mf">0.0409</span> <span class="o">-</span><span class="mf">0.6594</span> <span class="mf">0.3081</span><span class="p">]</span>
|
||
<span class="p">[</span> <span class="mf">0.3218</span> <span class="o">-</span><span class="mf">0.2385</span> <span class="mf">0.3235</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="o">-</span><span class="mf">0.2679</span> <span class="mf">0.5131</span> <span class="o">-</span><span class="mf">0.0909</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="o">-</span><span class="mf">0.0409</span> <span class="o">-</span><span class="mf">0.6594</span> <span class="mf">0.3081</span><span class="p">]]</span>
|
||
</code></pre></div>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">387</span>
|
||
<span class="normal">388</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">embed_size</span><span class="p">:</span><span class="nb">int</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">glorot_uniform</span><span class="p">(</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">embed_size</span><span class="p">)</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.LSTMCell" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">LSTMCell</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.LSTMCell" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">LSTMCell</span><span class="p">(</span>
|
||
<span class="n">input_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="n">bias</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#bool">bool</a></span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
|
||
|
||
<p>A long short-term memory (LSTM) cell.</p>
|
||
|
||
|
||
<p><span class="doc-section-title">Parameters:</span></p>
|
||
<ul>
|
||
<li class="doc-section-item field-body">
|
||
<b><code>input_size</code></b>
|
||
(<code><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></code>)
|
||
–
|
||
<div class="doc-md-description">
|
||
<p>The number of expected features in the input <code class="language-python highlight"><span class="n">x</span></code></p>
|
||
</div>
|
||
</li>
|
||
<li class="doc-section-item field-body">
|
||
<b><code>hidden_size</code></b>
|
||
(<code><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></code>)
|
||
–
|
||
<div class="doc-md-description">
|
||
<p>The number of features in the hidden state <code class="language-python highlight"><span class="n">h</span></code></p>
|
||
</div>
|
||
</li>
|
||
<li class="doc-section-item field-body">
|
||
<b><code>bias</code></b>
|
||
(<code><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#bool">bool</a></code>, default:
|
||
<code>True</code>
|
||
)
|
||
–
|
||
<div class="doc-md-description">
|
||
<p>If <code class="language-python highlight"><span class="kc">False</span></code>, then the layer does not use bias weights <code class="language-python highlight"><span class="n">b_ih</span></code> and <code class="language-python highlight"><span class="n">b_hh</span></code></p>
|
||
</div>
|
||
</li>
|
||
</ul>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">406</span>
|
||
<span class="normal">407</span>
|
||
<span class="normal">408</span>
|
||
<span class="normal">409</span>
|
||
<span class="normal">410</span>
|
||
<span class="normal">411</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_size</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">bias</span><span class="p">:</span><span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="n">stdv</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight_ih</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">hidden_size</span><span class="o">*</span><span class="mi">4</span><span class="p">,</span> <span class="n">input_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">stdv</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">stdv</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight_hh</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">hidden_size</span><span class="o">*</span><span class="mi">4</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">stdv</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">stdv</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias_ih</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">hidden_size</span><span class="o">*</span><span class="mi">4</span><span class="p">)</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias_hh</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">hidden_size</span><span class="o">*</span><span class="mi">4</span><span class="p">)</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div><h2 id="optimizers">Optimizers<a class="headerlink" href="#optimizers" title="Permanent link">¤</a></h2>
|
||
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.optim.SGD" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">SGD</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.optim.SGD" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">SGD</span><span class="p">(</span>
|
||
<span class="n">params</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
|
||
<span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
|
||
<span class="n">momentum</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
|
||
<span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
|
||
<span class="n">nesterov</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">classic</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">fused</span><span class="o">=</span><span class="n"><span title="tinygrad.helpers.FUSE_OPTIM">FUSE_OPTIM</span></span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.</p>
|
||
<p><code class="language-python highlight"><span class="n">classic</span></code> is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.</p>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/optim.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">78</span>
|
||
<span class="normal">79</span>
|
||
<span class="normal">80</span>
|
||
<span class="normal">81</span>
|
||
<span class="normal">82</span>
|
||
<span class="normal">83</span>
|
||
<span class="normal">84</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">SGD</span><span class="p">(</span><span class="n">params</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">nesterov</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">classic</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">FUSE_OPTIM</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.</span>
|
||
|
||
<span class="sd"> `classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">return</span> <span class="n">LARS</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">momentum</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="n">nesterov</span><span class="p">,</span> <span class="n">classic</span><span class="o">=</span><span class="n">classic</span><span class="p">,</span> <span class="n">pre_wd</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">tcoef</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">fused</span><span class="p">)</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.optim.LARS" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">LARS</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.optim.LARS" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">LARS</span><span class="p">(</span>
|
||
<span class="n">params</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
|
||
<span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
|
||
<span class="n">momentum</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span>
|
||
<span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0001</span><span class="p">,</span>
|
||
<span class="n">ns_steps</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">ns_coefficients</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">nesterov</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">classic</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">pre_wd</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">tcoef</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">fused</span><span class="o">=</span><span class="n"><span title="tinygrad.helpers.FUSE_OPTIM">FUSE_OPTIM</span></span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
<p class="doc doc-class-bases">
|
||
Bases: <code><span title="tinygrad.nn.optim.Optimizer">Optimizer</span></code></p>
|
||
|
||
|
||
|
||
<p>Layer-wise Adaptive Rate Scaling (LARS) optimizer with optional momentum and weight decay.</p>
|
||
<ul>
|
||
<li>Paper: <a href="https://arxiv.org/abs/1708.03888v3">https://arxiv.org/abs/1708.03888v3</a></li>
|
||
</ul>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/optim.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">105</span>
|
||
<span class="normal">106</span>
|
||
<span class="normal">107</span>
|
||
<span class="normal">108</span>
|
||
<span class="normal">109</span>
|
||
<span class="normal">110</span>
|
||
<span class="normal">111</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">:</span><span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">,</span> <span class="n">ns_steps</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">ns_coefficients</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">nesterov</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">classic</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">pre_wd</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">tcoef</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">FUSE_OPTIM</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">momentum</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid momentum value: </span><span class="si">{</span><span class="n">momentum</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">fused</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">wd</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ns_steps</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ns_coefficients</span> <span class="o">=</span> <span class="n">momentum</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="n">ns_steps</span><span class="p">,</span> <span class="n">ns_coefficients</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">nesterov</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">classic</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_wd</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tcoef</span> <span class="o">=</span> <span class="n">nesterov</span><span class="p">,</span> <span class="n">classic</span><span class="p">,</span> <span class="n">pre_wd</span><span class="p">,</span> <span class="n">tcoef</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_optim_param</span><span class="p">()</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="k">else</span> <span class="p">[]</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.optim.AdamW" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">AdamW</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.optim.AdamW" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">AdamW</span><span class="p">(</span>
|
||
<span class="n">params</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
|
||
<span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
|
||
<span class="n">b1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span>
|
||
<span class="n">b2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="o">=</span><span class="mf">1e-08</span><span class="p">,</span>
|
||
<span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">fused</span><span class="o">=</span><span class="n"><span title="tinygrad.helpers.FUSE_OPTIM">FUSE_OPTIM</span></span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>AdamW optimizer with optional weight decay.</p>
|
||
<ul>
|
||
<li>Paper: <a href="https://arxiv.org/abs/1711.05101v3">https://arxiv.org/abs/1711.05101v3</a></li>
|
||
</ul>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/optim.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">136</span>
|
||
<span class="normal">137</span>
|
||
<span class="normal">138</span>
|
||
<span class="normal">139</span>
|
||
<span class="normal">140</span>
|
||
<span class="normal">141</span>
|
||
<span class="normal">142</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">AdamW</span><span class="p">(</span><span class="n">params</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">b1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">b2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-8</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">FUSE_OPTIM</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> AdamW optimizer with optional weight decay.</span>
|
||
|
||
<span class="sd"> - Paper: https://arxiv.org/abs/1711.05101v3</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">return</span> <span class="n">LAMB</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">b1</span><span class="p">,</span> <span class="n">b2</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="n">adam</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">fused</span><span class="p">)</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.optim.Adam" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">Adam</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.optim.Adam" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">Adam</span><span class="p">(</span>
|
||
<span class="n">params</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
|
||
<span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
|
||
<span class="n">b1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span>
|
||
<span class="n">b2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="o">=</span><span class="mf">1e-08</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">fused</span><span class="o">=</span><span class="n"><span title="tinygrad.helpers.FUSE_OPTIM">FUSE_OPTIM</span></span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>Adam optimizer.</p>
|
||
<ul>
|
||
<li>Paper: <a href="https://arxiv.org/abs/1412.6980">https://arxiv.org/abs/1412.6980</a></li>
|
||
</ul>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/optim.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">143</span>
|
||
<span class="normal">144</span>
|
||
<span class="normal">145</span>
|
||
<span class="normal">146</span>
|
||
<span class="normal">147</span>
|
||
<span class="normal">148</span>
|
||
<span class="normal">149</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">Adam</span><span class="p">(</span><span class="n">params</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">b1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">b2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-8</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">FUSE_OPTIM</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Adam optimizer.</span>
|
||
|
||
<span class="sd"> - Paper: https://arxiv.org/abs/1412.6980</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">return</span> <span class="n">LAMB</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">b1</span><span class="p">,</span> <span class="n">b2</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">adam</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">fused</span><span class="p">)</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.optim.LAMB" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">LAMB</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.optim.LAMB" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">LAMB</span><span class="p">(</span>
|
||
<span class="n">params</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
|
||
<span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
|
||
<span class="n">b1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span>
|
||
<span class="n">b2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="o">=</span><span class="mf">1e-06</span><span class="p">,</span>
|
||
<span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
|
||
<span class="n">adam</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">fused</span><span class="o">=</span><span class="n"><span title="tinygrad.helpers.FUSE_OPTIM">FUSE_OPTIM</span></span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
<p class="doc doc-class-bases">
|
||
Bases: <code><span title="tinygrad.nn.optim.Optimizer">Optimizer</span></code></p>
|
||
|
||
|
||
|
||
<p>LAMB optimizer with optional weight decay.</p>
|
||
<ul>
|
||
<li>Paper: <a href="https://arxiv.org/abs/1904.00962">https://arxiv.org/abs/1904.00962</a></li>
|
||
</ul>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/optim.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">157</span>
|
||
<span class="normal">158</span>
|
||
<span class="normal">159</span>
|
||
<span class="normal">160</span>
|
||
<span class="normal">161</span>
|
||
<span class="normal">162</span>
|
||
<span class="normal">163</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">b1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">b2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">adam</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">FUSE_OPTIM</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">weight_decay</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid weight_decay value: </span><span class="si">{</span><span class="n">weight_decay</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">fused</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">b1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">b2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">wd</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">adam</span> <span class="o">=</span> <span class="n">b1</span><span class="p">,</span> <span class="n">b2</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="n">adam</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">b1_t</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">b2_t</span> <span class="o">=</span> <span class="p">(</span><span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtypes</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="p">[</span><span class="n">b1</span><span class="p">,</span> <span class="n">b2</span><span class="p">])</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">m</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_optim_param</span><span class="p">()</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_optim_param</span><span class="p">()</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div><h2 id="loadsave">Load/Save<a class="headerlink" href="#loadsave" title="Permanent link">¤</a></h2>
|
||
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.state.safe_load" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">safe_load</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.state.safe_load" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">safe_load</span><span class="p">(</span><span class="n">fn</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" title="<code>pathlib.Path</code>" href="https://docs.python.org/3/library/pathlib.html#pathlib.Path">Path</a></span><span class="p">)</span> <span class="o">-></span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">]</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>Loads a .safetensor file, returning the <code class="language-python highlight"><span class="n">state_dict</span></code>.</p>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">state_dict</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">safe_load</span><span class="p">(</span><span class="s2">"test.safetensor"</span><span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">50</span>
|
||
<span class="normal">51</span>
|
||
<span class="normal">52</span>
|
||
<span class="normal">53</span>
|
||
<span class="normal">54</span>
|
||
<span class="normal">55</span>
|
||
<span class="normal">56</span>
|
||
<span class="normal">57</span>
|
||
<span class="normal">58</span>
|
||
<span class="normal">59</span>
|
||
<span class="normal">60</span>
|
||
<span class="normal">61</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">safe_load</span><span class="p">(</span><span class="n">fn</span><span class="p">:</span><span class="n">Tensor</span><span class="o">|</span><span class="nb">str</span><span class="o">|</span><span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">)</span> <span class="o">-></span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Loads a .safetensor file, returning the `state_dict`.</span>
|
||
|
||
<span class="sd"> ```python</span>
|
||
<span class="sd"> state_dict = nn.state.safe_load("test.safetensor")</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">t</span><span class="p">,</span> <span class="n">data_start</span><span class="p">,</span> <span class="n">metadata</span> <span class="o">=</span> <span class="n">safe_load_metadata</span><span class="p">(</span><span class="n">fn</span><span class="p">)</span>
|
||
<span class="n">data</span> <span class="o">=</span> <span class="n">t</span><span class="p">[</span><span class="n">data_start</span><span class="p">:]</span>
|
||
<span class="k">return</span> <span class="p">{</span> <span class="n">k</span><span class="p">:</span> <span class="n">data</span><span class="p">[</span><span class="n">v</span><span class="p">[</span><span class="s1">'data_offsets'</span><span class="p">][</span><span class="mi">0</span><span class="p">]:</span><span class="n">v</span><span class="p">[</span><span class="s1">'data_offsets'</span><span class="p">][</span><span class="mi">1</span><span class="p">]]</span><span class="o">.</span><span class="n">bitcast</span><span class="p">(</span><span class="n">safe_dtypes</span><span class="p">[</span><span class="n">v</span><span class="p">[</span><span class="s1">'dtype'</span><span class="p">]])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">v</span><span class="p">[</span><span class="s1">'shape'</span><span class="p">])</span>
|
||
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">metadata</span><span class="o">.</span><span class="n">items</span><span class="p">()</span> <span class="k">if</span> <span class="n">k</span> <span class="o">!=</span> <span class="s2">"__metadata__"</span> <span class="p">}</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.state.safe_save" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">safe_save</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.state.safe_save" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">safe_save</span><span class="p">(</span>
|
||
<span class="n">tensors</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
|
||
<span class="n">fn</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span>
|
||
<span class="n">metadata</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-external" title="<code>typing.Any</code>" href="https://docs.python.org/3/library/typing.html#typing.Any">Any</a></span><span class="p">]</span> <span class="o">|</span> <span class="kc">None</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>Saves a <code class="language-python highlight"><span class="n">state_dict</span></code> to disk in a .safetensor file with optional metadata.</p>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span>
|
||
<span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">safe_save</span><span class="p">({</span><span class="s1">'t'</span><span class="p">:</span><span class="n">t</span><span class="p">},</span> <span class="s2">"test.safetensor"</span><span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">63</span>
|
||
<span class="normal">64</span>
|
||
<span class="normal">65</span>
|
||
<span class="normal">66</span>
|
||
<span class="normal">67</span>
|
||
<span class="normal">68</span>
|
||
<span class="normal">69</span>
|
||
<span class="normal">70</span>
|
||
<span class="normal">71</span>
|
||
<span class="normal">72</span>
|
||
<span class="normal">73</span>
|
||
<span class="normal">74</span>
|
||
<span class="normal">75</span>
|
||
<span class="normal">76</span>
|
||
<span class="normal">77</span>
|
||
<span class="normal">78</span>
|
||
<span class="normal">79</span>
|
||
<span class="normal">80</span>
|
||
<span class="normal">81</span>
|
||
<span class="normal">82</span>
|
||
<span class="normal">83</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">safe_save</span><span class="p">(</span><span class="n">tensors</span><span class="p">:</span><span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">],</span> <span class="n">fn</span><span class="p">:</span><span class="nb">str</span><span class="p">,</span> <span class="n">metadata</span><span class="p">:</span><span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span><span class="o">|</span><span class="kc">None</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Saves a `state_dict` to disk in a .safetensor file with optional metadata.</span>
|
||
|
||
<span class="sd"> ```python</span>
|
||
<span class="sd"> t = Tensor([1, 2, 3])</span>
|
||
<span class="sd"> nn.state.safe_save({'t':t}, "test.safetensor")</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">headers</span><span class="p">,</span> <span class="n">offset</span> <span class="o">=</span> <span class="p">{},</span> <span class="mi">0</span>
|
||
<span class="k">if</span> <span class="n">metadata</span><span class="p">:</span> <span class="n">headers</span><span class="p">[</span><span class="s1">'__metadata__'</span><span class="p">]</span> <span class="o">=</span> <span class="n">metadata</span>
|
||
<span class="k">for</span> <span class="n">k</span><span class="p">,</span><span class="n">v</span> <span class="ow">in</span> <span class="n">tensors</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||
<span class="n">headers</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'dtype'</span><span class="p">:</span> <span class="n">inverse_safe_dtypes</span><span class="p">[</span><span class="n">v</span><span class="o">.</span><span class="n">dtype</span><span class="p">],</span> <span class="s1">'shape'</span><span class="p">:</span> <span class="nb">list</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> <span class="s1">'data_offsets'</span><span class="p">:[</span><span class="n">offset</span><span class="p">,</span> <span class="n">offset</span><span class="o">+</span><span class="n">v</span><span class="o">.</span><span class="n">nbytes</span><span class="p">()]}</span>
|
||
<span class="n">offset</span> <span class="o">+=</span> <span class="n">v</span><span class="o">.</span><span class="n">nbytes</span><span class="p">()</span>
|
||
<span class="n">j</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="n">headers</span><span class="p">,</span> <span class="n">separators</span><span class="o">=</span><span class="p">(</span><span class="s1">','</span><span class="p">,</span> <span class="s1">':'</span><span class="p">))</span>
|
||
<span class="n">j</span> <span class="o">+=</span> <span class="s2">"</span><span class="se">\x20</span><span class="s2">"</span><span class="o">*</span><span class="p">(</span><span class="n">round_up</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">j</span><span class="p">),</span><span class="mi">8</span><span class="p">)</span><span class="o">-</span><span class="nb">len</span><span class="p">(</span><span class="n">j</span><span class="p">))</span>
|
||
<span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="n">fn</span><span class="p">)</span><span class="o">.</span><span class="n">unlink</span><span class="p">(</span><span class="n">missing_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="mi">8</span><span class="o">+</span><span class="nb">len</span><span class="p">(</span><span class="n">j</span><span class="p">)</span><span class="o">+</span><span class="n">offset</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtypes</span><span class="o">.</span><span class="n">uint8</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="sa">f</span><span class="s2">"disk:</span><span class="si">{</span><span class="n">fn</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="n">t</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">8</span><span class="p">]</span><span class="o">.</span><span class="n">bitcast</span><span class="p">(</span><span class="n">dtypes</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span><span class="o">.</span><span class="n">assign</span><span class="p">([</span><span class="nb">len</span><span class="p">(</span><span class="n">j</span><span class="p">)])</span>
|
||
<span class="n">t</span><span class="p">[</span><span class="mi">8</span><span class="p">:</span><span class="mi">8</span><span class="o">+</span><span class="nb">len</span><span class="p">(</span><span class="n">j</span><span class="p">)]</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">j</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="s1">'utf-8'</span><span class="p">)))</span>
|
||
<span class="k">for</span> <span class="n">k</span><span class="p">,</span><span class="n">v</span> <span class="ow">in</span> <span class="n">safe_load</span><span class="p">(</span><span class="n">t</span><span class="p">)</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> <span class="n">v</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="n">tensors</span><span class="p">[</span><span class="n">k</span><span class="p">])</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.state.get_state_dict" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">get_state_dict</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.state.get_state_dict" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">get_state_dict</span><span class="p">(</span>
|
||
<span class="n">obj</span><span class="p">,</span> <span class="n">prefix</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span> <span class="o">=</span> <span class="s2">""</span><span class="p">,</span> <span class="n">tensor_type</span><span class="o">=</span><span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span>
|
||
<span class="p">)</span> <span class="o">-></span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">]</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>Returns a <code class="language-python highlight"><span class="n">state_dict</span></code> of the object, with optional prefix.</p>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">Net</span><span class="p">:</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">l1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">l2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
|
||
|
||
<span class="n">net</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">get_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">)</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">dict_keys</span><span class="p">([</span><span class="s1">'l1.weight'</span><span class="p">,</span> <span class="s1">'l1.bias'</span><span class="p">,</span> <span class="s1">'l2.weight'</span><span class="p">,</span> <span class="s1">'l2.bias'</span><span class="p">])</span>
|
||
</code></pre></div>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal"> 87</span>
|
||
<span class="normal"> 88</span>
|
||
<span class="normal"> 89</span>
|
||
<span class="normal"> 90</span>
|
||
<span class="normal"> 91</span>
|
||
<span class="normal"> 92</span>
|
||
<span class="normal"> 93</span>
|
||
<span class="normal"> 94</span>
|
||
<span class="normal"> 95</span>
|
||
<span class="normal"> 96</span>
|
||
<span class="normal"> 97</span>
|
||
<span class="normal"> 98</span>
|
||
<span class="normal"> 99</span>
|
||
<span class="normal">100</span>
|
||
<span class="normal">101</span>
|
||
<span class="normal">102</span>
|
||
<span class="normal">103</span>
|
||
<span class="normal">104</span>
|
||
<span class="normal">105</span>
|
||
<span class="normal">106</span>
|
||
<span class="normal">107</span>
|
||
<span class="normal">108</span>
|
||
<span class="normal">109</span>
|
||
<span class="normal">110</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">get_state_dict</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">prefix</span><span class="p">:</span><span class="nb">str</span><span class="o">=</span><span class="s1">''</span><span class="p">,</span> <span class="n">tensor_type</span><span class="o">=</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Returns a `state_dict` of the object, with optional prefix.</span>
|
||
|
||
<span class="sd"> ```python exec="true" source="above" session="tensor" result="python"</span>
|
||
<span class="sd"> class Net:</span>
|
||
<span class="sd"> def __init__(self):</span>
|
||
<span class="sd"> self.l1 = nn.Linear(4, 5)</span>
|
||
<span class="sd"> self.l2 = nn.Linear(5, 6)</span>
|
||
|
||
<span class="sd"> net = Net()</span>
|
||
<span class="sd"> print(nn.state.get_state_dict(net).keys())</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">tensor_type</span><span class="p">):</span> <span class="k">return</span> <span class="p">{</span><span class="n">prefix</span><span class="o">.</span><span class="n">strip</span><span class="p">(</span><span class="s1">'.'</span><span class="p">):</span><span class="n">obj</span><span class="p">}</span>
|
||
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="s1">'_asdict'</span><span class="p">):</span> <span class="k">return</span> <span class="n">get_state_dict</span><span class="p">(</span><span class="n">obj</span><span class="o">.</span><span class="n">_asdict</span><span class="p">(),</span> <span class="n">prefix</span><span class="p">,</span> <span class="n">tensor_type</span><span class="p">)</span> <span class="c1"># namedtuple</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">OrderedDict</span><span class="p">):</span> <span class="k">return</span> <span class="n">get_state_dict</span><span class="p">(</span><span class="nb">dict</span><span class="p">(</span><span class="n">obj</span><span class="p">),</span> <span class="n">prefix</span><span class="p">,</span> <span class="n">tensor_type</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="s1">'__dict__'</span><span class="p">):</span> <span class="k">return</span> <span class="n">get_state_dict</span><span class="p">(</span><span class="n">obj</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">,</span> <span class="n">prefix</span><span class="p">,</span> <span class="n">tensor_type</span><span class="p">)</span>
|
||
<span class="n">state_dict</span> <span class="o">=</span> <span class="p">{}</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
|
||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span><span class="n">x</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">obj</span><span class="p">):</span> <span class="n">state_dict</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">get_state_dict</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">prefix</span><span class="si">}{</span><span class="nb">str</span><span class="p">(</span><span class="n">i</span><span class="p">)</span><span class="si">}</span><span class="s2">."</span><span class="p">,</span> <span class="n">tensor_type</span><span class="p">))</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="k">for</span> <span class="n">k</span><span class="p">,</span><span class="n">v</span> <span class="ow">in</span> <span class="n">obj</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> <span class="n">state_dict</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">get_state_dict</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">prefix</span><span class="si">}{</span><span class="nb">str</span><span class="p">(</span><span class="n">k</span><span class="p">)</span><span class="si">}</span><span class="s2">."</span><span class="p">,</span> <span class="n">tensor_type</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="n">state_dict</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.state.get_parameters" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">get_parameters</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.state.get_parameters" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">get_parameters</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span> <span class="o">-></span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">]</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<div class="language-python highlight"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">Net</span><span class="p">:</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">l1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">l2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
|
||
|
||
<span class="n">net</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">get_parameters</span><span class="p">(</span><span class="n">net</span><span class="p">)))</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="mi">4</span>
|
||
</code></pre></div>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">112</span>
|
||
<span class="normal">113</span>
|
||
<span class="normal">114</span>
|
||
<span class="normal">115</span>
|
||
<span class="normal">116</span>
|
||
<span class="normal">117</span>
|
||
<span class="normal">118</span>
|
||
<span class="normal">119</span>
|
||
<span class="normal">120</span>
|
||
<span class="normal">121</span>
|
||
<span class="normal">122</span>
|
||
<span class="normal">123</span>
|
||
<span class="normal">124</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">get_parameters</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span> <span class="o">-></span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> ```python exec="true" source="above" session="tensor" result="python"</span>
|
||
<span class="sd"> class Net:</span>
|
||
<span class="sd"> def __init__(self):</span>
|
||
<span class="sd"> self.l1 = nn.Linear(4, 5)</span>
|
||
<span class="sd"> self.l2 = nn.Linear(5, 6)</span>
|
||
|
||
<span class="sd"> net = Net()</span>
|
||
<span class="sd"> print(len(nn.state.get_parameters(net)))</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="n">get_state_dict</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span><span class="o">.</span><span class="n">values</span><span class="p">())</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.state.load_state_dict" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">load_state_dict</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.state.load_state_dict" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">load_state_dict</span><span class="p">(</span>
|
||
<span class="n">model</span><span class="p">,</span>
|
||
<span class="n">state_dict</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
|
||
<span class="n">strict</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">verbose</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">consume</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">realize</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="p">)</span> <span class="o">-></span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">]</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>Loads a <code class="language-python highlight"><span class="n">state_dict</span></code> into a model. Return the loaded Tensors.</p>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">Net</span><span class="p">:</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">l1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">l2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
|
||
|
||
<span class="n">net</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span>
|
||
<span class="n">state_dict</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">get_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">)</span>
|
||
<span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">state_dict</span><span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">126</span>
|
||
<span class="normal">127</span>
|
||
<span class="normal">128</span>
|
||
<span class="normal">129</span>
|
||
<span class="normal">130</span>
|
||
<span class="normal">131</span>
|
||
<span class="normal">132</span>
|
||
<span class="normal">133</span>
|
||
<span class="normal">134</span>
|
||
<span class="normal">135</span>
|
||
<span class="normal">136</span>
|
||
<span class="normal">137</span>
|
||
<span class="normal">138</span>
|
||
<span class="normal">139</span>
|
||
<span class="normal">140</span>
|
||
<span class="normal">141</span>
|
||
<span class="normal">142</span>
|
||
<span class="normal">143</span>
|
||
<span class="normal">144</span>
|
||
<span class="normal">145</span>
|
||
<span class="normal">146</span>
|
||
<span class="normal">147</span>
|
||
<span class="normal">148</span>
|
||
<span class="normal">149</span>
|
||
<span class="normal">150</span>
|
||
<span class="normal">151</span>
|
||
<span class="normal">152</span>
|
||
<span class="normal">153</span>
|
||
<span class="normal">154</span>
|
||
<span class="normal">155</span>
|
||
<span class="normal">156</span>
|
||
<span class="normal">157</span>
|
||
<span class="normal">158</span>
|
||
<span class="normal">159</span>
|
||
<span class="normal">160</span>
|
||
<span class="normal">161</span>
|
||
<span class="normal">162</span>
|
||
<span class="normal">163</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">load_state_dict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">state_dict</span><span class="p">:</span><span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">],</span> <span class="n">strict</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">consume</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">realize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Loads a `state_dict` into a model. Return the loaded Tensors.</span>
|
||
|
||
<span class="sd"> ```python</span>
|
||
<span class="sd"> class Net:</span>
|
||
<span class="sd"> def __init__(self):</span>
|
||
<span class="sd"> self.l1 = nn.Linear(4, 5)</span>
|
||
<span class="sd"> self.l2 = nn.Linear(5, 6)</span>
|
||
|
||
<span class="sd"> net = Net()</span>
|
||
<span class="sd"> state_dict = nn.state.get_state_dict(net)</span>
|
||
<span class="sd"> nn.state.load_state_dict(net, state_dict)</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">start_mem_used</span> <span class="o">=</span> <span class="n">GlobalCounters</span><span class="o">.</span><span class="n">mem_used</span>
|
||
<span class="n">ret</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">with</span> <span class="n">Timing</span><span class="p">(</span><span class="s2">"loaded weights in "</span><span class="p">,</span>
|
||
<span class="k">lambda</span> <span class="n">et_ns</span><span class="p">:</span> <span class="sa">f</span><span class="s2">", </span><span class="si">{</span><span class="p">(</span><span class="n">B</span><span class="o">:=</span><span class="p">(</span><span class="n">GlobalCounters</span><span class="o">.</span><span class="n">mem_used</span><span class="o">-</span><span class="n">start_mem_used</span><span class="p">))</span><span class="o">/</span><span class="mf">1e9</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2"> GB loaded at </span><span class="si">{</span><span class="n">B</span><span class="o">/</span><span class="n">et_ns</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2"> GB/s"</span><span class="p">,</span> <span class="n">enabled</span><span class="o">=</span><span class="n">verbose</span><span class="p">):</span>
|
||
<span class="n">model_state_dict</span> <span class="o">=</span> <span class="n">get_state_dict</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">DEBUG</span> <span class="o">>=</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">state_dict</span><span class="p">)</span> <span class="o">></span> <span class="nb">len</span><span class="p">(</span><span class="n">model_state_dict</span><span class="p">):</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="s2">"WARNING: unused weights in state_dict"</span><span class="p">,</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">state_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> <span class="o">-</span> <span class="n">model_state_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">())))</span>
|
||
<span class="k">for</span> <span class="n">k</span><span class="p">,</span><span class="n">v</span> <span class="ow">in</span> <span class="p">(</span><span class="n">t</span> <span class="o">:=</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">model_state_dict</span><span class="o">.</span><span class="n">items</span><span class="p">(),</span> <span class="n">disable</span><span class="o">=</span><span class="n">CI</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">verbose</span><span class="p">)):</span>
|
||
<span class="n">t</span><span class="o">.</span><span class="n">desc</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">"ram used: </span><span class="si">{</span><span class="n">GlobalCounters</span><span class="o">.</span><span class="n">mem_used</span><span class="o">/</span><span class="mf">1e9</span><span class="si">:</span><span class="s2">5.2f</span><span class="si">}</span><span class="s2"> GB, </span><span class="si">{</span><span class="n">k</span><span class="si">:</span><span class="s2">50s</span><span class="si">}</span><span class="s2">: "</span>
|
||
<span class="k">if</span> <span class="n">k</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">state_dict</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">strict</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">DEBUG</span> <span class="o">>=</span> <span class="mi">1</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"WARNING: not loading </span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="k">continue</span>
|
||
<span class="k">if</span> <span class="n">v</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="p">{(),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,)}</span> <span class="o">==</span> <span class="p">{</span><span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">}:</span> <span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Shape mismatch in layer `</span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s1">`: Expected shape </span><span class="si">{</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s1">, but found </span><span class="si">{</span><span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s1"> in state dict.'</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span> <span class="n">v</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">])</span>
|
||
<span class="k">else</span><span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">shard</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">v</span><span class="o">.</span><span class="n">uop</span><span class="o">.</span><span class="n">axis</span><span class="p">))</span>
|
||
<span class="k">else</span><span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">device</span><span class="p">))</span>
|
||
<span class="k">if</span> <span class="n">realize</span><span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">realize</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">consume</span><span class="p">:</span> <span class="k">del</span> <span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span>
|
||
<span class="n">ret</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">v</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">ret</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.state.tar_extract" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <code class="highlight language-python"><span class="n">tar_extract</span></code>
|
||
|
||
<a href="#tinygrad.nn.state.tar_extract" class="headerlink" title="Permanent link">¤</a></h3>
|
||
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">tar_extract</span><span class="p">(</span><span class="n">fn</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">|</span> <span class="nb">str</span> <span class="o">|</span> <span class="n">Path</span><span class="p">)</span> <span class="o">-></span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]</span>
|
||
</code></pre></div>
|
||
<p>Extracts files from a tar archive and returns them as a dictionary of names (keys) and tensors (values).</p>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">tensors</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">tar_extract</span><span class="p">(</span><span class="n">Tensor</span><span class="p">(</span><span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="s2">"archive.tar"</span><span class="p">)))</span>
|
||
</code></pre></div>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">184</span>
|
||
<span class="normal">185</span>
|
||
<span class="normal">186</span>
|
||
<span class="normal">187</span>
|
||
<span class="normal">188</span>
|
||
<span class="normal">189</span>
|
||
<span class="normal">190</span>
|
||
<span class="normal">191</span>
|
||
<span class="normal">192</span>
|
||
<span class="normal">193</span>
|
||
<span class="normal">194</span>
|
||
<span class="normal">195</span>
|
||
<span class="normal">196</span>
|
||
<span class="normal">197</span>
|
||
<span class="normal">198</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="nd">@accept_filename</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">tar_extract</span><span class="p">(</span><span class="n">t</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> ```python</span>
|
||
<span class="sd"> tar_extract(fn: Tensor | str | Path) -> dict[str, Tensor]</span>
|
||
<span class="sd"> ```</span>
|
||
|
||
<span class="sd"> Extracts files from a tar archive and returns them as a dictionary of names (keys) and tensors (values).</span>
|
||
|
||
<span class="sd"> ```python</span>
|
||
<span class="sd"> tensors = nn.state.tar_extract(Tensor(pathlib.Path("archive.tar")))</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">with</span> <span class="n">tarfile</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">fileobj</span><span class="o">=</span><span class="n">TensorIO</span><span class="p">(</span><span class="n">t</span><span class="p">),</span> <span class="n">mode</span><span class="o">=</span><span class="s2">"r"</span><span class="p">)</span> <span class="k">as</span> <span class="n">tar</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="p">{</span><span class="n">member</span><span class="o">.</span><span class="n">name</span><span class="p">:</span><span class="n">t</span><span class="p">[</span><span class="n">member</span><span class="o">.</span><span class="n">offset_data</span><span class="p">:</span><span class="n">member</span><span class="o">.</span><span class="n">offset_data</span><span class="o">+</span><span class="n">member</span><span class="o">.</span><span class="n">size</span><span class="p">]</span> <span class="k">for</span> <span class="n">member</span> <span class="ow">in</span> <span class="n">tar</span> <span class="k">if</span> <span class="n">member</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">tarfile</span><span class="o">.</span><span class="n">REGTYPE</span><span class="p">}</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.state.torch_load" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <code class="highlight language-python"><span class="n">torch_load</span></code>
|
||
|
||
<a href="#tinygrad.nn.state.torch_load" class="headerlink" title="Permanent link">¤</a></h3>
|
||
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">torch_load</span><span class="p">(</span><span class="n">fn</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">|</span> <span class="nb">str</span> <span class="o">|</span> <span class="n">Path</span><span class="p">)</span> <span class="o">-></span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]</span>
|
||
</code></pre></div>
|
||
<p>Loads a torch .pth file, returning the <code class="language-python highlight"><span class="n">state_dict</span></code>.</p>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">state_dict</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">torch_load</span><span class="p">(</span><span class="s2">"test.pth"</span><span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">202</span>
|
||
<span class="normal">203</span>
|
||
<span class="normal">204</span>
|
||
<span class="normal">205</span>
|
||
<span class="normal">206</span>
|
||
<span class="normal">207</span>
|
||
<span class="normal">208</span>
|
||
<span class="normal">209</span>
|
||
<span class="normal">210</span>
|
||
<span class="normal">211</span>
|
||
<span class="normal">212</span>
|
||
<span class="normal">213</span>
|
||
<span class="normal">214</span>
|
||
<span class="normal">215</span>
|
||
<span class="normal">216</span>
|
||
<span class="normal">217</span>
|
||
<span class="normal">218</span>
|
||
<span class="normal">219</span>
|
||
<span class="normal">220</span>
|
||
<span class="normal">221</span>
|
||
<span class="normal">222</span>
|
||
<span class="normal">223</span>
|
||
<span class="normal">224</span>
|
||
<span class="normal">225</span>
|
||
<span class="normal">226</span>
|
||
<span class="normal">227</span>
|
||
<span class="normal">228</span>
|
||
<span class="normal">229</span>
|
||
<span class="normal">230</span>
|
||
<span class="normal">231</span>
|
||
<span class="normal">232</span>
|
||
<span class="normal">233</span>
|
||
<span class="normal">234</span>
|
||
<span class="normal">235</span>
|
||
<span class="normal">236</span>
|
||
<span class="normal">237</span>
|
||
<span class="normal">238</span>
|
||
<span class="normal">239</span>
|
||
<span class="normal">240</span>
|
||
<span class="normal">241</span>
|
||
<span class="normal">242</span>
|
||
<span class="normal">243</span>
|
||
<span class="normal">244</span>
|
||
<span class="normal">245</span>
|
||
<span class="normal">246</span>
|
||
<span class="normal">247</span>
|
||
<span class="normal">248</span>
|
||
<span class="normal">249</span>
|
||
<span class="normal">250</span>
|
||
<span class="normal">251</span>
|
||
<span class="normal">252</span>
|
||
<span class="normal">253</span>
|
||
<span class="normal">254</span>
|
||
<span class="normal">255</span>
|
||
<span class="normal">256</span>
|
||
<span class="normal">257</span>
|
||
<span class="normal">258</span>
|
||
<span class="normal">259</span>
|
||
<span class="normal">260</span>
|
||
<span class="normal">261</span>
|
||
<span class="normal">262</span>
|
||
<span class="normal">263</span>
|
||
<span class="normal">264</span>
|
||
<span class="normal">265</span>
|
||
<span class="normal">266</span>
|
||
<span class="normal">267</span>
|
||
<span class="normal">268</span>
|
||
<span class="normal">269</span>
|
||
<span class="normal">270</span>
|
||
<span class="normal">271</span>
|
||
<span class="normal">272</span>
|
||
<span class="normal">273</span>
|
||
<span class="normal">274</span>
|
||
<span class="normal">275</span>
|
||
<span class="normal">276</span>
|
||
<span class="normal">277</span>
|
||
<span class="normal">278</span>
|
||
<span class="normal">279</span>
|
||
<span class="normal">280</span>
|
||
<span class="normal">281</span>
|
||
<span class="normal">282</span>
|
||
<span class="normal">283</span>
|
||
<span class="normal">284</span>
|
||
<span class="normal">285</span>
|
||
<span class="normal">286</span>
|
||
<span class="normal">287</span>
|
||
<span class="normal">288</span>
|
||
<span class="normal">289</span>
|
||
<span class="normal">290</span>
|
||
<span class="normal">291</span>
|
||
<span class="normal">292</span>
|
||
<span class="normal">293</span>
|
||
<span class="normal">294</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="nd">@accept_filename</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">torch_load</span><span class="p">(</span><span class="n">t</span><span class="p">:</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> ```python</span>
|
||
<span class="sd"> torch_load(fn: Tensor | str | Path) -> dict[str, Tensor]</span>
|
||
<span class="sd"> ```</span>
|
||
|
||
<span class="sd"> Loads a torch .pth file, returning the `state_dict`.</span>
|
||
|
||
<span class="sd"> ```python</span>
|
||
<span class="sd"> state_dict = nn.state.torch_load("test.pth")</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">storage_source</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="o">|</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
|
||
<span class="n">lens</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="o">|</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_rebuild_tensor</span><span class="p">(</span><span class="n">storage</span><span class="p">,</span> <span class="n">storage_offset</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">stride</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">_rebuild_tensor_v2</span><span class="p">(</span><span class="n">storage</span><span class="p">,</span> <span class="n">storage_offset</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">stride</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_rebuild_tensor_v2</span><span class="p">(</span><span class="n">storage</span><span class="p">,</span> <span class="n">storage_offset</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">backward_hooks</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">metadata</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||
<span class="c1">#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)</span>
|
||
<span class="n">lens</span><span class="p">[</span><span class="n">storage</span><span class="p">[</span><span class="mi">2</span><span class="p">]]</span> <span class="o">=</span> <span class="n">storage</span><span class="p">[</span><span class="mi">4</span><span class="p">]</span> <span class="o">*</span> <span class="n">storage</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">itemsize</span>
|
||
<span class="k">if</span> <span class="n">storage</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">storage_source</span><span class="p">:</span> <span class="k">return</span> <span class="kc">None</span>
|
||
<span class="n">byte_start</span><span class="p">,</span> <span class="n">byte_end</span> <span class="o">=</span> <span class="n">storage_offset</span><span class="o">*</span><span class="n">storage</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">itemsize</span><span class="p">,</span> <span class="p">(</span><span class="n">storage_offset</span> <span class="o">+</span> <span class="n">prod</span><span class="p">(</span><span class="n">size</span><span class="p">))</span><span class="o">*</span><span class="n">storage</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">itemsize</span>
|
||
<span class="n">ret</span> <span class="o">=</span> <span class="n">storage_source</span><span class="p">[</span><span class="n">storage</span><span class="p">[</span><span class="mi">2</span><span class="p">]][</span><span class="n">byte_start</span><span class="p">:</span><span class="n">byte_end</span><span class="p">]</span><span class="o">.</span><span class="n">bitcast</span><span class="p">(</span><span class="n">storage</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
|
||
|
||
<span class="c1"># 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk</span>
|
||
<span class="n">shape_strides</span> <span class="o">=</span> <span class="p">[(</span><span class="n">s</span><span class="p">,</span> <span class="n">st</span><span class="p">)</span> <span class="k">for</span> <span class="n">s</span><span class="p">,</span><span class="n">st</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">stride</span><span class="p">)</span> <span class="k">if</span> <span class="n">s</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">]</span>
|
||
<span class="n">permute_indexes</span> <span class="o">=</span> <span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">shape_strides</span><span class="p">)</span><span class="o">-</span><span class="mi">1</span><span class="o">-</span><span class="n">y</span> <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">argsort</span><span class="p">([</span><span class="n">x</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">shape_strides</span><span class="p">])]</span>
|
||
<span class="k">if</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">permute_indexes</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">tuple</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">permute_indexes</span><span class="p">))):</span>
|
||
<span class="n">intermediate_shape</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="n">shape_strides</span><span class="p">[</span><span class="n">x</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">argsort</span><span class="p">(</span><span class="n">permute_indexes</span><span class="p">)])</span>
|
||
<span class="k">assert</span> <span class="nb">tuple</span><span class="p">([</span><span class="n">shape_strides</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">argsort</span><span class="p">(</span><span class="n">permute_indexes</span><span class="p">)])</span> <span class="o">==</span> <span class="n">strides_for_shape</span><span class="p">(</span><span class="n">intermediate_shape</span><span class="p">),</span> <span class="s2">"nonpermutable strides"</span>
|
||
<span class="k">if</span> <span class="n">DEBUG</span> <span class="o">>=</span> <span class="mi">3</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"WARNING: this torch load is slow. to permute </span><span class="si">{</span><span class="n">intermediate_shape</span><span class="si">}</span><span class="s2"> with </span><span class="si">{</span><span class="n">permute_indexes</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">storage</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">!=</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">,</span> <span class="s2">"can't permute BF16"</span>
|
||
<span class="c1"># TODO: find a nice way to support all movement ops on disktensors</span>
|
||
<span class="n">ret</span> <span class="o">=</span> <span class="n">ret</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">intermediate_shape</span><span class="p">)</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="n">permute_indexes</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">ret</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">size</span><span class="p">)</span>
|
||
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">Parameter</span><span class="p">:</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">__setstate__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">tensor</span> <span class="o">=</span> <span class="n">state</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||
|
||
<span class="n">deserialized_objects</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
|
||
<span class="n">intercept</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"HalfStorage"</span><span class="p">:</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="s2">"FloatStorage"</span><span class="p">:</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="s2">"BFloat16Storage"</span><span class="p">:</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">,</span>
|
||
<span class="s2">"IntStorage"</span><span class="p">:</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="s2">"BoolStorage"</span><span class="p">:</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span>
|
||
<span class="s2">"LongStorage"</span><span class="p">:</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">int64</span><span class="p">,</span> <span class="s2">"_rebuild_tensor"</span><span class="p">:</span> <span class="n">_rebuild_tensor</span><span class="p">,</span> <span class="s2">"_rebuild_tensor_v2"</span><span class="p">:</span> <span class="n">_rebuild_tensor_v2</span><span class="p">,</span>
|
||
<span class="s2">"FloatTensor"</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"Parameter"</span><span class="p">:</span> <span class="n">Parameter</span><span class="p">}</span>
|
||
<span class="n">whitelist</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"torch"</span><span class="p">,</span> <span class="s2">"collections"</span><span class="p">,</span> <span class="s2">"numpy"</span><span class="p">,</span> <span class="s2">"_codecs"</span><span class="p">}</span> <span class="c1"># NOTE: this is not for security, only speed</span>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">Dummy</span><span class="p">:</span> <span class="k">pass</span>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">TorchPickle</span><span class="p">(</span><span class="n">pickle</span><span class="o">.</span><span class="n">Unpickler</span><span class="p">):</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">find_class</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">module</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
||
<span class="n">module_root</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"."</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">module_root</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">whitelist</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">DEBUG</span> <span class="o">>=</span> <span class="mi">2</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"WARNING: returning Dummy for </span><span class="si">{</span><span class="n">module</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">Dummy</span>
|
||
<span class="k">return</span> <span class="n">intercept</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="k">if</span> <span class="n">module_root</span> <span class="o">==</span> <span class="s2">"torch"</span> <span class="k">else</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">find_class</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="n">name</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">persistent_load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pid</span><span class="p">):</span> <span class="k">return</span> <span class="n">deserialized_objects</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">pid</span><span class="p">,</span> <span class="n">pid</span><span class="p">)</span>
|
||
|
||
<span class="n">fobj</span> <span class="o">=</span> <span class="n">io</span><span class="o">.</span><span class="n">BufferedReader</span><span class="p">(</span><span class="n">TensorIO</span><span class="p">(</span><span class="n">t</span><span class="p">))</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">passthrough_reset</span><span class="p">(</span><span class="n">v</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span> <span class="k">return</span> <span class="n">fobj</span><span class="o">.</span><span class="n">seek</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="ow">or</span> <span class="n">v</span>
|
||
<span class="k">if</span> <span class="n">passthrough_reset</span><span class="p">(</span><span class="n">zipfile</span><span class="o">.</span><span class="n">is_zipfile</span><span class="p">(</span><span class="n">fobj</span><span class="p">)):</span> <span class="c1"># NOTE: passthrough_reset required to support python < 3.14</span>
|
||
<span class="n">files</span> <span class="o">=</span> <span class="n">zip_extract</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="n">base_name</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">files</span><span class="p">))</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">'/'</span><span class="p">,</span> <span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="c1"># keyed by persistent_id in pickle file</span>
|
||
<span class="n">storage_source</span> <span class="o">=</span> <span class="p">{</span><span class="n">fn</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"/"</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span> <span class="n">data</span> <span class="k">for</span> <span class="n">fn</span><span class="p">,</span> <span class="n">data</span> <span class="ow">in</span> <span class="n">files</span><span class="o">.</span><span class="n">items</span><span class="p">()</span> <span class="k">if</span> <span class="n">fn</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">base_name</span><span class="si">}</span><span class="s2">/data/"</span><span class="p">)</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">fn</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s2">".pkl"</span><span class="p">)}</span>
|
||
<span class="k">return</span> <span class="n">TorchPickle</span><span class="p">(</span><span class="n">io</span><span class="o">.</span><span class="n">BufferedReader</span><span class="p">(</span><span class="n">TensorIO</span><span class="p">(</span><span class="n">files</span><span class="p">[</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">base_name</span><span class="si">}</span><span class="s2">/data.pkl"</span><span class="p">]),</span> <span class="mi">1_000_000</span><span class="p">))</span><span class="o">.</span><span class="n">load</span><span class="p">()</span>
|
||
<span class="k">elif</span> <span class="n">passthrough_reset</span><span class="p">(</span><span class="n">tarfile</span><span class="o">.</span><span class="n">is_tarfile</span><span class="p">(</span><span class="n">fobj</span><span class="p">)):</span> <span class="c1"># NOTE: passthrough_reset required to support python < 3.11</span>
|
||
<span class="n">files</span> <span class="o">=</span> <span class="n">tar_extract</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="n">f</span> <span class="o">=</span> <span class="n">io</span><span class="o">.</span><span class="n">BufferedReader</span><span class="p">(</span><span class="n">TensorIO</span><span class="p">(</span><span class="n">files</span><span class="p">[</span><span class="s2">"storages"</span><span class="p">]),</span> <span class="mi">1_000_000</span><span class="p">)</span>
|
||
<span class="c1"># slice source tensor t</span>
|
||
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">TorchPickle</span><span class="p">(</span><span class="n">f</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">()):</span>
|
||
<span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">storage_type</span><span class="p">),</span> <span class="n">sz</span> <span class="o">=</span> <span class="n">TorchPickle</span><span class="p">(</span><span class="n">f</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="s1">'<q'</span><span class="p">,</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">8</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="n">byte_offset</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">tell</span><span class="p">()</span>
|
||
<span class="n">storage_source</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">files</span><span class="p">[</span><span class="s2">"storages"</span><span class="p">][</span><span class="n">byte_offset</span><span class="p">:</span><span class="n">byte_offset</span> <span class="o">+</span> <span class="n">sz</span> <span class="o">*</span> <span class="n">storage_type</span><span class="o">.</span><span class="n">itemsize</span><span class="p">]</span>
|
||
<span class="n">f</span><span class="o">.</span><span class="n">seek</span><span class="p">(</span><span class="n">sz</span> <span class="o">*</span> <span class="n">storage_type</span><span class="o">.</span><span class="n">itemsize</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">f</span> <span class="o">=</span> <span class="n">io</span><span class="o">.</span><span class="n">BufferedReader</span><span class="p">(</span><span class="n">TensorIO</span><span class="p">(</span><span class="n">files</span><span class="p">[</span><span class="s2">"tensors"</span><span class="p">]),</span> <span class="mi">1_000_000</span><span class="p">)</span>
|
||
<span class="c1"># get tensor metadata</span>
|
||
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">TorchPickle</span><span class="p">(</span><span class="n">f</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">()):</span>
|
||
<span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">storage_id</span><span class="p">,</span> <span class="n">_</span><span class="p">),</span> <span class="n">ndim</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">TorchPickle</span><span class="p">(</span><span class="n">f</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="s1">'<i'</span><span class="p">,</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">4</span><span class="p">))[</span><span class="mi">0</span><span class="p">],</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">4</span><span class="p">)</span>
|
||
<span class="n">size</span><span class="p">,</span> <span class="n">stride</span> <span class="o">=</span> <span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="sa">f</span><span class="s1">'<</span><span class="si">{</span><span class="n">ndim</span><span class="si">}</span><span class="s1">q'</span><span class="p">,</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">8</span> <span class="o">*</span> <span class="n">ndim</span><span class="p">)),</span> <span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="sa">f</span><span class="s1">'<</span><span class="si">{</span><span class="n">ndim</span><span class="si">}</span><span class="s1">q'</span><span class="p">,</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">8</span> <span class="o">*</span> <span class="n">ndim</span><span class="p">))</span>
|
||
<span class="n">storage_offset</span> <span class="o">=</span> <span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="s1">'<q'</span><span class="p">,</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">8</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="n">deserialized_objects</span><span class="p">[</span><span class="nb">str</span><span class="p">(</span><span class="n">key</span><span class="p">)]</span> <span class="o">=</span> <span class="n">_rebuild_tensor_v2</span><span class="p">((</span><span class="kc">None</span><span class="p">,</span> <span class="n">storage_type</span><span class="p">,</span> <span class="n">storage_id</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">storage_offset</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">stride</span><span class="p">)</span>
|
||
<span class="n">pkl_data</span> <span class="o">=</span> <span class="n">TorchPickle</span><span class="p">(</span><span class="n">io</span><span class="o">.</span><span class="n">BufferedReader</span><span class="p">(</span><span class="n">TensorIO</span><span class="p">(</span><span class="n">files</span><span class="p">[</span><span class="s2">"pickle"</span><span class="p">]),</span> <span class="mi">1_000_000</span><span class="p">))</span><span class="o">.</span><span class="n">load</span><span class="p">()</span>
|
||
<span class="k">return</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">tensor</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">Parameter</span><span class="p">)</span> <span class="k">else</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">pkl_data</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">pkl</span> <span class="o">=</span> <span class="n">TorchPickle</span><span class="p">(</span><span class="n">fobj</span><span class="p">)</span>
|
||
<span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">rwd</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">ids</span><span class="p">,</span> <span class="n">base_offset</span> <span class="o">=</span> <span class="n">pkl</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">pkl</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">pkl</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">fobj</span><span class="o">.</span><span class="n">tell</span><span class="p">(),</span> <span class="n">pkl</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">pkl</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">fobj</span><span class="o">.</span><span class="n">tell</span><span class="p">()</span>
|
||
<span class="c1"># slice source tensor t</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">ids</span><span class="p">:</span>
|
||
<span class="n">storage_source</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">t</span><span class="p">[</span><span class="n">base_offset</span> <span class="o">+</span> <span class="mi">8</span><span class="p">:</span><span class="n">base_offset</span> <span class="o">+</span> <span class="mi">8</span> <span class="o">+</span> <span class="n">lens</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span>
|
||
<span class="n">base_offset</span> <span class="o">+=</span> <span class="mi">8</span> <span class="o">+</span> <span class="n">lens</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||
<span class="n">fobj</span><span class="o">.</span><span class="n">seek</span><span class="p">(</span><span class="n">rwd</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">TorchPickle</span><span class="p">(</span><span class="n">fobj</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">()</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.llm.gguf.gguf_load" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">gguf_load</span>
|
||
|
||
|
||
<a href="#tinygrad.llm.gguf.gguf_load" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">gguf_load</span><span class="p">(</span>
|
||
<span class="n">fn</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" title="<code>pathlib.Path</code>" href="https://docs.python.org/3/library/pathlib.html#pathlib.Path">Path</a></span><span class="p">,</span>
|
||
<span class="p">)</span> <span class="o">-></span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#tuple">tuple</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">]]</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>Loads a .gguf file, returning the <code class="language-python highlight"><span class="n">kv_data</span></code> and <code class="language-python highlight"><span class="n">state_dict</span></code>. Multi-part splits are auto-merged when loaded by path.</p>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="kn">import</span><span class="w"> </span><span class="nn">pathlib</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad</span><span class="w"> </span><span class="kn">import</span> <span class="n">Device</span><span class="p">,</span> <span class="n">Tensor</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">tinygrad.llm.gguf</span><span class="w"> </span><span class="kn">import</span> <span class="n">gguf_load</span>
|
||
|
||
<span class="n">gguf_tensor</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="s2">"Meta-Llama-3-8B-Instruct.Q4_0.gguf"</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">Device</span><span class="o">.</span><span class="n">DEFAULT</span><span class="p">)</span>
|
||
<span class="n">kv_data</span><span class="p">,</span> <span class="n">state_dict</span> <span class="o">=</span> <span class="n">gguf_load</span><span class="p">(</span><span class="n">gguf_tensor</span><span class="p">)</span>
|
||
</code></pre></div>
|
||
<div class="admonition note">
|
||
<p class="admonition-title">Note</p>
|
||
<p>The provided tensor must be on a device that supports execution.</p>
|
||
</div>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/llm/gguf.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">157</span>
|
||
<span class="normal">158</span>
|
||
<span class="normal">159</span>
|
||
<span class="normal">160</span>
|
||
<span class="normal">161</span>
|
||
<span class="normal">162</span>
|
||
<span class="normal">163</span>
|
||
<span class="normal">164</span>
|
||
<span class="normal">165</span>
|
||
<span class="normal">166</span>
|
||
<span class="normal">167</span>
|
||
<span class="normal">168</span>
|
||
<span class="normal">169</span>
|
||
<span class="normal">170</span>
|
||
<span class="normal">171</span>
|
||
<span class="normal">172</span>
|
||
<span class="normal">173</span>
|
||
<span class="normal">174</span>
|
||
<span class="normal">175</span>
|
||
<span class="normal">176</span>
|
||
<span class="normal">177</span>
|
||
<span class="normal">178</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">gguf_load</span><span class="p">(</span><span class="n">fn</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="nb">str</span><span class="o">|</span><span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">)</span> <span class="o">-></span> <span class="nb">tuple</span><span class="p">[</span><span class="nb">dict</span><span class="p">,</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]]:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Loads a .gguf file, returning the `kv_data` and `state_dict`. Multi-part splits are auto-merged when loaded by path.</span>
|
||
|
||
<span class="sd"> ```python</span>
|
||
<span class="sd"> import pathlib</span>
|
||
<span class="sd"> from tinygrad import Device, Tensor</span>
|
||
<span class="sd"> from tinygrad.llm.gguf import gguf_load</span>
|
||
|
||
<span class="sd"> gguf_tensor = Tensor(pathlib.Path("Meta-Llama-3-8B-Instruct.Q4_0.gguf")).to(Device.DEFAULT)</span>
|
||
<span class="sd"> kv_data, state_dict = gguf_load(gguf_tensor)</span>
|
||
<span class="sd"> ```</span>
|
||
|
||
<span class="sd"> NOTE: The provided tensor must be on a device that supports execution.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="c1"># TODO: remove the need for copy to default device</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">load</span><span class="p">(</span><span class="n">p</span><span class="p">):</span> <span class="k">return</span> <span class="n">_gguf_parse</span><span class="p">(</span><span class="n">p</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">p</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span><span class="o">.</span><span class="n">realize</span><span class="p">())</span>
|
||
<span class="n">kv</span><span class="p">,</span> <span class="n">sd</span> <span class="o">=</span> <span class="n">load</span><span class="p">(</span><span class="n">fn</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">kv</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'split.count'</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o"><=</span> <span class="mi">1</span><span class="p">:</span> <span class="k">return</span> <span class="n">kv</span><span class="p">,</span> <span class="n">sd</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">fn</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"multi-part GGUF requires a path argument (got Tensor)"</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">pp</span> <span class="ow">in</span> <span class="n">_gguf_split_paths</span><span class="p">(</span><span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="n">fn</span><span class="p">),</span> <span class="n">kv</span><span class="p">)[</span><span class="mi">1</span><span class="p">:]:</span> <span class="n">sd</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">load</span><span class="p">(</span><span class="n">pp</span><span class="p">)[</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="k">return</span> <span class="n">kv</span><span class="p">,</span> <span class="n">sd</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</article>
|
||
</div>
|
||
|
||
|
||
<script>var target=document.getElementById(location.hash.slice(1));target&&target.name&&(target.checked=target.name.startsWith("__tabbed_"))</script>
|
||
</div>
|
||
|
||
<button type="button" class="md-top md-icon" data-md-component="top" hidden>
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8z"/></svg>
|
||
Back to top
|
||
</button>
|
||
|
||
</main>
|
||
|
||
<footer class="md-footer">
|
||
|
||
|
||
|
||
<nav class="md-footer__inner md-grid" aria-label="Footer" >
|
||
|
||
|
||
<a href="../dtypes/" class="md-footer__link md-footer__link--prev" aria-label="Previous: dtypes">
|
||
<div class="md-footer__button md-icon">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
|
||
</div>
|
||
<div class="md-footer__title">
|
||
<span class="md-footer__direction">
|
||
Previous
|
||
</span>
|
||
<div class="md-ellipsis">
|
||
dtypes
|
||
</div>
|
||
</div>
|
||
</a>
|
||
|
||
|
||
|
||
<a href="../env_vars/" class="md-footer__link md-footer__link--next" aria-label="Next: Environment Variables">
|
||
<div class="md-footer__title">
|
||
<span class="md-footer__direction">
|
||
Next
|
||
</span>
|
||
<div class="md-ellipsis">
|
||
Environment Variables
|
||
</div>
|
||
</div>
|
||
<div class="md-footer__button md-icon">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M4 11v2h12l-5.5 5.5 1.42 1.42L19.84 12l-7.92-7.92L10.5 5.5 16 11z"/></svg>
|
||
</div>
|
||
</a>
|
||
|
||
</nav>
|
||
|
||
|
||
<div class="md-footer-meta md-typeset">
|
||
<div class="md-footer-meta__inner md-grid">
|
||
<div class="md-copyright">
|
||
|
||
|
||
Made with
|
||
<a href="https://squidfunk.github.io/mkdocs-material/" target="_blank" rel="noopener">
|
||
Material for MkDocs
|
||
</a>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
</div>
|
||
</footer>
|
||
|
||
</div>
|
||
<div class="md-dialog" data-md-component="dialog">
|
||
<div class="md-dialog__inner md-typeset"></div>
|
||
</div>
|
||
|
||
|
||
|
||
|
||
|
||
<script id="__config" type="application/json">{"annotate": null, "base": "..", "features": ["announce.dismiss", "content.action.edit", "content.action.view", "content.code.annotate", "content.code.copy", "content.tooltips", "navigation.footer", "navigation.indexes", "navigation.sections", "navigation.expand", "navigation.top", "navigation.path", "search.highlight", "search.suggest", "toc.follow", "toc.integrate"], "search": "../assets/javascripts/workers/search.2c215733.min.js", "tags": null, "translations": {"clipboard.copied": "Copied to clipboard", "clipboard.copy": "Copy to clipboard", "search.result.more.one": "1 more on this page", "search.result.more.other": "# more on this page", "search.result.none": "No matching documents", "search.result.one": "1 matching document", "search.result.other": "# matching documents", "search.result.placeholder": "Type to start searching", "search.result.term.missing": "Missing", "select.version": "Select version"}, "version": null}</script>
|
||
|
||
|
||
<script src="../assets/javascripts/bundle.79ae519e.min.js"></script>
|
||
|
||
<script src="../assets/_markdown_exec_pyodide.js"></script>
|
||
|
||
|
||
</body>
|
||
</html> |