mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
3965 lines
264 KiB
HTML
3965 lines
264 KiB
HTML
|
||
<!doctype html>
|
||
<html lang="en" class="no-js">
|
||
<head>
|
||
|
||
<meta charset="utf-8">
|
||
<meta name="viewport" content="width=device-width,initial-scale=1">
|
||
|
||
|
||
|
||
<link rel="canonical" href="https://docs.tinygrad.org/nn/">
|
||
|
||
|
||
<link rel="prev" href="../dtypes/">
|
||
|
||
|
||
<link rel="next" href="../env_vars/">
|
||
|
||
|
||
|
||
|
||
|
||
<link rel="icon" href="../favicon.svg">
|
||
<meta name="generator" content="mkdocs-1.6.1, mkdocs-material-9.7.1">
|
||
|
||
|
||
|
||
<title>nn (Neural Networks) - tinygrad docs</title>
|
||
|
||
|
||
|
||
<link rel="stylesheet" href="../assets/stylesheets/main.484c7ddc.min.css">
|
||
|
||
|
||
<link rel="stylesheet" href="../assets/stylesheets/palette.ab4e12ef.min.css">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300i,400,400i,700,700i%7CRoboto+Mono:400,400i,700,700i&display=fallback">
|
||
<style>:root{--md-text-font:"Roboto";--md-code-font:"Roboto Mono"}</style>
|
||
|
||
|
||
|
||
<link rel="stylesheet" href="../assets/_markdown_exec_pyodide.css">
|
||
|
||
<link rel="stylesheet" href="../assets/_markdown_exec_ansi.css">
|
||
|
||
<link rel="stylesheet" href="../assets/_mkdocstrings.css">
|
||
|
||
<script>__md_scope=new URL("..",location),__md_hash=e=>[...e].reduce(((e,_)=>(e<<5)-e+_.charCodeAt(0)),0),__md_get=(e,_=localStorage,t=__md_scope)=>JSON.parse(_.getItem(t.pathname+"."+e)),__md_set=(e,_,t=localStorage,a=__md_scope)=>{try{t.setItem(a.pathname+"."+e,JSON.stringify(_))}catch(e){}}</script>
|
||
|
||
|
||
|
||
|
||
|
||
</head>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<body dir="ltr" data-md-color-scheme="default" data-md-color-primary="black" data-md-color-accent="lime">
|
||
|
||
|
||
<input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer" autocomplete="off">
|
||
<input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search" autocomplete="off">
|
||
<label class="md-overlay" for="__drawer"></label>
|
||
<div data-md-component="skip">
|
||
|
||
|
||
<a href="#neural-network-classes" class="md-skip">
|
||
Skip to content
|
||
</a>
|
||
|
||
</div>
|
||
<div data-md-component="announce">
|
||
|
||
</div>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<header class="md-header md-header--shadow" data-md-component="header">
|
||
<nav class="md-header__inner md-grid" aria-label="Header">
|
||
<a href=".." title="tinygrad docs" class="md-header__button md-logo" aria-label="tinygrad docs" data-md-component="logo">
|
||
|
||
<img src="../logo_tiny_dark.svg" alt="logo">
|
||
|
||
</a>
|
||
<label class="md-header__button md-icon" for="__drawer">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M3 6h18v2H3zm0 5h18v2H3zm0 5h18v2H3z"/></svg>
|
||
</label>
|
||
<div class="md-header__title" data-md-component="header-title">
|
||
<div class="md-header__ellipsis">
|
||
<div class="md-header__topic">
|
||
<span class="md-ellipsis">
|
||
tinygrad docs
|
||
</span>
|
||
</div>
|
||
<div class="md-header__topic" data-md-component="header-topic">
|
||
<span class="md-ellipsis">
|
||
|
||
nn (Neural Networks)
|
||
|
||
</span>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
<form class="md-header__option" data-md-component="palette">
|
||
|
||
|
||
|
||
|
||
<input class="md-option" data-md-color-media="(prefers-color-scheme)" data-md-color-scheme="default" data-md-color-primary="black" data-md-color-accent="lime" aria-label="Switch to light mode" type="radio" name="__palette" id="__palette_0">
|
||
|
||
<label class="md-header__button md-icon" title="Switch to light mode" for="__palette_1" hidden>
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="m14.3 16-.7-2h-3.2l-.7 2H7.8L11 7h2l3.2 9zM20 8.69V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12zm-9.15 3.96h2.3L12 9z"/></svg>
|
||
</label>
|
||
|
||
|
||
|
||
|
||
|
||
<input class="md-option" data-md-color-media="(prefers-color-scheme: light)" data-md-color-scheme="default" data-md-color-primary="black" data-md-color-accent="lime" aria-label="Switch to dark mode" type="radio" name="__palette" id="__palette_1">
|
||
|
||
<label class="md-header__button md-icon" title="Switch to dark mode" for="__palette_2" hidden>
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 8a4 4 0 0 0-4 4 4 4 0 0 0 4 4 4 4 0 0 0 4-4 4 4 0 0 0-4-4m0 10a6 6 0 0 1-6-6 6 6 0 0 1 6-6 6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
|
||
</label>
|
||
|
||
|
||
|
||
|
||
|
||
<input class="md-option" data-md-color-media="(prefers-color-scheme: dark)" data-md-color-scheme="slate" data-md-color-primary="black" data-md-color-accent="lime" aria-label="Switch to system preference" type="radio" name="__palette" id="__palette_2">
|
||
|
||
<label class="md-header__button md-icon" title="Switch to system preference" for="__palette_0" hidden>
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 18c-.89 0-1.74-.2-2.5-.55C11.56 16.5 13 14.42 13 12s-1.44-4.5-3.5-5.45C10.26 6.2 11.11 6 12 6a6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
|
||
</label>
|
||
|
||
|
||
</form>
|
||
|
||
|
||
|
||
<script>var palette=__md_get("__palette");if(palette&&palette.color){if("(prefers-color-scheme)"===palette.color.media){var media=matchMedia("(prefers-color-scheme: light)"),input=document.querySelector(media.matches?"[data-md-color-media='(prefers-color-scheme: light)']":"[data-md-color-media='(prefers-color-scheme: dark)']");palette.color.media=input.getAttribute("data-md-color-media"),palette.color.scheme=input.getAttribute("data-md-color-scheme"),palette.color.primary=input.getAttribute("data-md-color-primary"),palette.color.accent=input.getAttribute("data-md-color-accent")}for(var[key,value]of Object.entries(palette.color))document.body.setAttribute("data-md-color-"+key,value)}</script>
|
||
|
||
|
||
|
||
|
||
|
||
<label class="md-header__button md-icon" for="__search">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
|
||
</label>
|
||
<div class="md-search" data-md-component="search" role="dialog">
|
||
<label class="md-search__overlay" for="__search"></label>
|
||
<div class="md-search__inner" role="search">
|
||
<form class="md-search__form" name="search">
|
||
<input type="text" class="md-search__input" name="query" aria-label="Search" placeholder="Search" autocapitalize="off" autocorrect="off" autocomplete="off" spellcheck="false" data-md-component="search-query" required>
|
||
<label class="md-search__icon md-icon" for="__search">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
|
||
</label>
|
||
<nav class="md-search__options" aria-label="Search">
|
||
|
||
<button type="reset" class="md-search__icon md-icon" title="Clear" aria-label="Clear" tabindex="-1">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M19 6.41 17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12z"/></svg>
|
||
</button>
|
||
</nav>
|
||
|
||
<div class="md-search__suggest" data-md-component="search-suggest"></div>
|
||
|
||
</form>
|
||
<div class="md-search__output">
|
||
<div class="md-search__scrollwrap" tabindex="0" data-md-scrollfix>
|
||
<div class="md-search-result" data-md-component="search-result">
|
||
<div class="md-search-result__meta">
|
||
Initializing search
|
||
</div>
|
||
<ol class="md-search-result__list" role="presentation"></ol>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="md-header__source">
|
||
<a href="https://github.com/tinygrad/tinygrad/" title="Go to repository" class="md-source" data-md-component="source">
|
||
<div class="md-source__icon md-icon">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 7.1.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2025 Fonticons, Inc.--><path d="M439.6 236.1 244 40.5c-5.4-5.5-12.8-8.5-20.4-8.5s-15 3-20.4 8.4L162.5 81l51.5 51.5c27.1-9.1 52.7 16.8 43.4 43.7l49.7 49.7c34.2-11.8 61.2 31 35.5 56.7-26.5 26.5-70.2-2.9-56-37.3L240.3 199v121.9c25.3 12.5 22.3 41.8 9.1 55-6.4 6.4-15.2 10.1-24.3 10.1s-17.8-3.6-24.3-10.1c-17.6-17.6-11.1-46.9 11.2-56v-123c-20.8-8.5-24.6-30.7-18.6-45L142.6 101 8.5 235.1C3 240.6 0 247.9 0 255.5s3 15 8.5 20.4l195.6 195.7c5.4 5.4 12.7 8.4 20.4 8.4s15-3 20.4-8.4l194.7-194.7c5.4-5.4 8.4-12.8 8.4-20.4s-3-15-8.4-20.4"/></svg>
|
||
</div>
|
||
<div class="md-source__repository">
|
||
GitHub
|
||
</div>
|
||
</a>
|
||
</div>
|
||
|
||
</nav>
|
||
|
||
</header>
|
||
|
||
<div class="md-container" data-md-component="container">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<main class="md-main" data-md-component="main">
|
||
<div class="md-main__inner md-grid">
|
||
|
||
|
||
|
||
<div class="md-sidebar md-sidebar--primary" data-md-component="sidebar" data-md-type="navigation" >
|
||
<div class="md-sidebar__scrollwrap">
|
||
<div class="md-sidebar__inner">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<nav class="md-nav md-nav--primary md-nav--integrated" aria-label="Navigation" data-md-level="0">
|
||
<label class="md-nav__title" for="__drawer">
|
||
<a href=".." title="tinygrad docs" class="md-nav__button md-logo" aria-label="tinygrad docs" data-md-component="logo">
|
||
|
||
<img src="../logo_tiny_dark.svg" alt="logo">
|
||
|
||
</a>
|
||
tinygrad docs
|
||
</label>
|
||
|
||
<div class="md-nav__source">
|
||
<a href="https://github.com/tinygrad/tinygrad/" title="Go to repository" class="md-source" data-md-component="source">
|
||
<div class="md-source__icon md-icon">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 7.1.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2025 Fonticons, Inc.--><path d="M439.6 236.1 244 40.5c-5.4-5.5-12.8-8.5-20.4-8.5s-15 3-20.4 8.4L162.5 81l51.5 51.5c27.1-9.1 52.7 16.8 43.4 43.7l49.7 49.7c34.2-11.8 61.2 31 35.5 56.7-26.5 26.5-70.2-2.9-56-37.3L240.3 199v121.9c25.3 12.5 22.3 41.8 9.1 55-6.4 6.4-15.2 10.1-24.3 10.1s-17.8-3.6-24.3-10.1c-17.6-17.6-11.1-46.9 11.2-56v-123c-20.8-8.5-24.6-30.7-18.6-45L142.6 101 8.5 235.1C3 240.6 0 247.9 0 255.5s3 15 8.5 20.4l195.6 195.7c5.4 5.4 12.7 8.4 20.4 8.4s15-3 20.4-8.4l194.7-194.7c5.4-5.4 8.4-12.8 8.4-20.4s-3-15-8.4-20.4"/></svg>
|
||
</div>
|
||
<div class="md-source__repository">
|
||
GitHub
|
||
</div>
|
||
</a>
|
||
</div>
|
||
|
||
<ul class="md-nav__list" data-md-scrollfix>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item md-nav__item--active md-nav__item--section md-nav__item--nested">
|
||
|
||
|
||
|
||
<input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_1" checked>
|
||
|
||
|
||
<div class="md-nav__link md-nav__container">
|
||
<a href=".." class="md-nav__link ">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Home
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
|
||
|
||
<label class="md-nav__link " for="__nav_1" id="__nav_1_label" tabindex="">
|
||
<span class="md-nav__icon md-icon"></span>
|
||
</label>
|
||
|
||
</div>
|
||
|
||
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_1_label" aria-expanded="true">
|
||
<label class="md-nav__title" for="__nav_1">
|
||
<span class="md-nav__icon md-icon"></span>
|
||
|
||
|
||
Home
|
||
|
||
|
||
</label>
|
||
<ul class="md-nav__list" data-md-scrollfix>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../quickstart/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Quickstart
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../showcase/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Showcase
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../mnist/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
MNIST Tutorial
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item md-nav__item--active md-nav__item--nested">
|
||
|
||
|
||
|
||
<input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_1_5" checked>
|
||
|
||
|
||
<label class="md-nav__link" for="__nav_1_5" id="__nav_1_5_label" tabindex="0">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
API Reference
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
<span class="md-nav__icon md-icon"></span>
|
||
</label>
|
||
|
||
<nav class="md-nav" data-md-level="2" aria-labelledby="__nav_1_5_label" aria-expanded="true">
|
||
<label class="md-nav__title" for="__nav_1_5">
|
||
<span class="md-nav__icon md-icon"></span>
|
||
|
||
|
||
API Reference
|
||
|
||
|
||
</label>
|
||
<ul class="md-nav__list" data-md-scrollfix>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item md-nav__item--nested">
|
||
|
||
|
||
|
||
|
||
|
||
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_5_1" >
|
||
|
||
|
||
<div class="md-nav__link md-nav__container">
|
||
<a href="../tensor/" class="md-nav__link ">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Tensor
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
|
||
|
||
<label class="md-nav__link " for="__nav_1_5_1" id="__nav_1_5_1_label" tabindex="0">
|
||
<span class="md-nav__icon md-icon"></span>
|
||
</label>
|
||
|
||
</div>
|
||
|
||
<nav class="md-nav" data-md-level="3" aria-labelledby="__nav_1_5_1_label" aria-expanded="false">
|
||
<label class="md-nav__title" for="__nav_1_5_1">
|
||
<span class="md-nav__icon md-icon"></span>
|
||
|
||
|
||
Tensor
|
||
|
||
|
||
</label>
|
||
<ul class="md-nav__list" data-md-scrollfix>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../tensor/properties/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Properties
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../tensor/creation/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Creation
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../tensor/movement/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Movement
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../tensor/elementwise/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Elementwise
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../tensor/ops/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Complex Ops
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
</ul>
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../dtypes/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
dtypes
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item md-nav__item--active">
|
||
|
||
<input class="md-nav__toggle md-toggle" type="checkbox" id="__toc">
|
||
|
||
|
||
|
||
<label class="md-nav__link md-nav__link--active" for="__toc">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
nn (Neural Networks)
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
<span class="md-nav__icon md-icon"></span>
|
||
</label>
|
||
|
||
<a href="./" class="md-nav__link md-nav__link--active">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
nn (Neural Networks)
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
|
||
|
||
|
||
<nav class="md-nav md-nav--secondary" aria-label="Table of contents">
|
||
|
||
|
||
|
||
|
||
<label class="md-nav__title" for="__toc">
|
||
<span class="md-nav__icon md-icon"></span>
|
||
Table of contents
|
||
</label>
|
||
<ul class="md-nav__list" data-md-component="toc" data-md-scrollfix>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#neural-network-classes" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
Neural Network classes
|
||
|
||
</span>
|
||
</a>
|
||
|
||
<nav class="md-nav" aria-label="Neural Network classes">
|
||
<ul class="md-nav__list">
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.BatchNorm" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> BatchNorm
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.Conv1d" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> Conv1d
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.Conv2d" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> Conv2d
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.ConvTranspose1d" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> ConvTranspose1d
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.ConvTranspose2d" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> ConvTranspose2d
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.Linear" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> Linear
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.GroupNorm" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> GroupNorm
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.InstanceNorm" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> InstanceNorm
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.LayerNorm" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> LayerNorm
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.LayerNorm2d" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> LayerNorm2d
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.RMSNorm" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> RMSNorm
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.Embedding" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> Embedding
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.LSTMCell" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> LSTMCell
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
</ul>
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#optimizers" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
Optimizers
|
||
|
||
</span>
|
||
</a>
|
||
|
||
<nav class="md-nav" aria-label="Optimizers">
|
||
<ul class="md-nav__list">
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.optim.SGD" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> SGD
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.optim.LARS" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> LARS
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.optim.AdamW" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> AdamW
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.optim.Adam" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> Adam
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.optim.LAMB" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-class"></code> LAMB
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
</ul>
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#loadsave" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
Load/Save
|
||
|
||
</span>
|
||
</a>
|
||
|
||
<nav class="md-nav" aria-label="Load/Save">
|
||
<ul class="md-nav__list">
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.state.safe_load" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> safe_load
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.state.safe_save" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> safe_save
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.state.get_state_dict" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> get_state_dict
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.state.get_parameters" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> get_parameters
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.state.load_state_dict" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> load_state_dict
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.state.tar_extract" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> tar_extract
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.state.torch_load" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> torch_load
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
<li class="md-nav__item">
|
||
<a href="#tinygrad.nn.state.gguf_load" class="md-nav__link">
|
||
<span class="md-ellipsis">
|
||
|
||
<code class="doc-symbol doc-symbol-toc doc-symbol-function"></code> gguf_load
|
||
|
||
</span>
|
||
</a>
|
||
|
||
</li>
|
||
|
||
</ul>
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
</ul>
|
||
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../env_vars/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Environment Variables
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../runtime/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Runtime
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
</ul>
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item md-nav__item--nested">
|
||
|
||
|
||
|
||
|
||
|
||
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_6" >
|
||
|
||
|
||
<label class="md-nav__link" for="__nav_1_6" id="__nav_1_6_label" tabindex="0">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Developer
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
<span class="md-nav__icon md-icon"></span>
|
||
</label>
|
||
|
||
<nav class="md-nav" data-md-level="2" aria-labelledby="__nav_1_6_label" aria-expanded="false">
|
||
<label class="md-nav__title" for="__nav_1_6">
|
||
<span class="md-nav__icon md-icon"></span>
|
||
|
||
|
||
Developer
|
||
|
||
|
||
</label>
|
||
<ul class="md-nav__list" data-md-scrollfix>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../developer/developer/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Intro
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../developer/layout/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Layout
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../developer/speed/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Speed
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../developer/uop/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
UOp
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item md-nav__item--nested">
|
||
|
||
|
||
|
||
|
||
|
||
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1_6_5" >
|
||
|
||
|
||
<label class="md-nav__link" for="__nav_1_6_5" id="__nav_1_6_5_label" tabindex="0">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Runtime
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
<span class="md-nav__icon md-icon"></span>
|
||
</label>
|
||
|
||
<nav class="md-nav" data-md-level="3" aria-labelledby="__nav_1_6_5_label" aria-expanded="false">
|
||
<label class="md-nav__title" for="__nav_1_6_5">
|
||
<span class="md-nav__icon md-icon"></span>
|
||
|
||
|
||
Runtime
|
||
|
||
|
||
</label>
|
||
<ul class="md-nav__list" data-md-scrollfix>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../developer/runtime/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
Runtime Overview
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../developer/hcq/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
HCQ
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../developer/am/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
AM Driver
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
</ul>
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
|
||
|
||
|
||
</ul>
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-nav__item">
|
||
<a href="../tinybox/" class="md-nav__link">
|
||
|
||
|
||
|
||
<span class="md-ellipsis">
|
||
|
||
|
||
tinybox
|
||
|
||
|
||
|
||
</span>
|
||
|
||
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
</ul>
|
||
</nav>
|
||
|
||
</li>
|
||
|
||
|
||
|
||
</ul>
|
||
</nav>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
|
||
|
||
<div class="md-content" data-md-component="content">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<nav class="md-path" aria-label="Navigation" >
|
||
<ol class="md-path__list">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-path__item">
|
||
<a href=".." class="md-path__link">
|
||
|
||
<span class="md-ellipsis">
|
||
Home
|
||
</span>
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<li class="md-path__item">
|
||
<a href="../tensor/" class="md-path__link">
|
||
|
||
<span class="md-ellipsis">
|
||
API Reference
|
||
</span>
|
||
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</ol>
|
||
</nav>
|
||
|
||
|
||
<article class="md-content__inner md-typeset">
|
||
|
||
|
||
|
||
|
||
|
||
<a href="https://github.com/tinygrad/tinygrad/edit/master/docs/nn.md" title="Edit this page" class="md-content__button md-icon" rel="edit">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M10 20H6V4h7v5h5v3.1l2-2V8l-6-6H6c-1.1 0-2 .9-2 2v16c0 1.1.9 2 2 2h4zm10.2-7c.1 0 .3.1.4.2l1.3 1.3c.2.2.2.6 0 .8l-1 1-2.1-2.1 1-1c.1-.1.2-.2.4-.2m0 3.9L14.1 23H12v-2.1l6.1-6.1z"/></svg>
|
||
</a>
|
||
|
||
|
||
|
||
|
||
|
||
<a href="https://github.com/tinygrad/tinygrad/raw/master/docs/nn.md" title="View source of this page" class="md-content__button md-icon">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M17 18c.56 0 1 .44 1 1s-.44 1-1 1-1-.44-1-1 .44-1 1-1m0-3c-2.73 0-5.06 1.66-6 4 .94 2.34 3.27 4 6 4s5.06-1.66 6-4c-.94-2.34-3.27-4-6-4m0 6.5a2.5 2.5 0 0 1-2.5-2.5 2.5 2.5 0 0 1 2.5-2.5 2.5 2.5 0 0 1 2.5 2.5 2.5 2.5 0 0 1-2.5 2.5M9.27 20H6V4h7v5h5v4.07c.7.08 1.36.25 2 .49V8l-6-6H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h4.5a8.2 8.2 0 0 1-1.23-2"/></svg>
|
||
</a>
|
||
|
||
|
||
|
||
<h1>nn (Neural Networks)</h1>
|
||
|
||
<h2 id="neural-network-classes">Neural Network classes<a class="headerlink" href="#neural-network-classes" title="Permanent link">¤</a></h2>
|
||
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.BatchNorm" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">BatchNorm</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.BatchNorm" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">BatchNorm</span><span class="p">(</span>
|
||
<span class="n">sz</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="o">=</span><span class="mf">1e-05</span><span class="p">,</span>
|
||
<span class="n">affine</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">track_running_stats</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">momentum</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
|
||
|
||
<p>Applies Batch Normalization over a 2D or 3D input.</p>
|
||
<ul>
|
||
<li>Paper: <a href="https://arxiv.org/abs/1502.03167v3">https://arxiv.org/abs/1502.03167v3</a></li>
|
||
</ul>
|
||
<p>See: <code class="language-python highlight"><span class="n">Tensor</span><span class="o">.</span><span class="n">batchnorm</span></code></p>
|
||
<p></p>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="mf">0.5005428791046143</span> <span class="mf">0.31660139560699463</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="mf">0.5005403757095337</span> <span class="mf">0.3165997266769409</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">33</span>
|
||
<span class="normal">34</span>
|
||
<span class="normal">35</span>
|
||
<span class="normal">36</span>
|
||
<span class="normal">37</span>
|
||
<span class="normal">38</span>
|
||
<span class="normal">39</span>
|
||
<span class="normal">40</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sz</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">track_running_stats</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">track_running_stats</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="n">eps</span><span class="p">,</span> <span class="n">track_running_stats</span><span class="p">,</span> <span class="n">momentum</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">sz</span><span class="p">)</span> <span class="k">if</span> <span class="n">affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">sz</span><span class="p">)</span> <span class="k">if</span> <span class="n">affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_batches_tracked</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="s1">'long'</span> <span class="k">if</span> <span class="n">is_dtype_supported</span><span class="p">(</span><span class="n">dtypes</span><span class="o">.</span><span class="n">long</span><span class="p">)</span> <span class="k">else</span> <span class="s1">'int'</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">track_running_stats</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_mean</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">running_var</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">sz</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">sz</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.Conv1d" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">Conv1d</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.Conv1d" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">Conv1d</span><span class="p">(</span>
|
||
<span class="n">in_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">out_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">padding</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="p">)</span> <span class="o">-></span> <span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Conv2d</span> (<code>tinygrad.nn.Conv2d</code>)" href="#tinygrad.nn.Conv2d">Conv2d</a></span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>Applies a 1D convolution over an input signal composed of several input planes.</p>
|
||
<p>See: <a href="https://pytorch.org/docs/stable/generated/torch.nn.Conv1d">https://pytorch.org/docs/stable/generated/torch.nn.Conv1d</a></p>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv1d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[</span><span class="mf">0.8887</span> <span class="mf">0.9379</span> <span class="mf">0.6705</span> <span class="mf">0.2281</span><span class="p">]]]</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[</span><span class="mf">0.4804</span> <span class="mf">0.141</span> <span class="p">]]]</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">63</span>
|
||
<span class="normal">64</span>
|
||
<span class="normal">65</span>
|
||
<span class="normal">66</span>
|
||
<span class="normal">67</span>
|
||
<span class="normal">68</span>
|
||
<span class="normal">69</span>
|
||
<span class="normal">70</span>
|
||
<span class="normal">71</span>
|
||
<span class="normal">72</span>
|
||
<span class="normal">73</span>
|
||
<span class="normal">74</span>
|
||
<span class="normal">75</span>
|
||
<span class="normal">76</span>
|
||
<span class="normal">77</span>
|
||
<span class="normal">78</span>
|
||
<span class="normal">79</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">Conv1d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="p">:</span><span class="nb">int</span><span class="o">|</span><span class="nb">str</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Conv2d</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Applies a 1D convolution over an input signal composed of several input planes.</span>
|
||
|
||
<span class="sd"> See: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d</span>
|
||
|
||
<span class="sd"> ```python exec="true" source="above" session="tensor" result="python"</span>
|
||
<span class="sd"> conv = nn.Conv1d(1, 1, 3)</span>
|
||
<span class="sd"> t = Tensor.rand(1, 1, 4)</span>
|
||
<span class="sd"> print(t.numpy())</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> ```python exec="true" source="above" session="tensor" result="python"</span>
|
||
<span class="sd"> t = conv(t)</span>
|
||
<span class="sd"> print(t.numpy())</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">return</span> <span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="p">(</span><span class="n">kernel_size</span><span class="p">,),</span> <span class="n">stride</span><span class="p">,</span> <span class="n">padding</span><span class="p">,</span> <span class="n">dilation</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.Conv2d" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Conv2d</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.Conv2d" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">Conv2d</span><span class="p">(</span>
|
||
<span class="n">in_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">out_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#tuple">tuple</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="o">...</span><span class="p">],</span>
|
||
<span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">padding</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#tuple">tuple</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
|
||
|
||
<p>Applies a 2D convolution over an input signal composed of several input planes.</p>
|
||
<p>See: <a href="https://pytorch.org/docs/stable/generated/torch.nn.Conv2d">https://pytorch.org/docs/stable/generated/torch.nn.Conv2d</a></p>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[[</span><span class="mf">0.4906</span> <span class="mf">0.2963</span> <span class="mf">0.0639</span> <span class="mf">0.0127</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="mf">0.8699</span> <span class="mf">0.8575</span> <span class="mf">0.5628</span> <span class="mf">0.6926</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="mf">0.7058</span> <span class="mf">0.1117</span> <span class="mf">0.634</span> <span class="mf">0.614</span> <span class="p">]</span>
|
||
<span class="p">[</span><span class="mf">0.4943</span> <span class="mf">0.7191</span> <span class="mf">0.0912</span> <span class="mf">0.9734</span><span class="p">]]]]</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[[</span><span class="o">-</span><span class="mf">0.5142</span> <span class="o">-</span><span class="mf">0.5431</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="o">-</span><span class="mf">0.5355</span> <span class="o">-</span><span class="mf">0.1517</span><span class="p">]]]]</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal"> 97</span>
|
||
<span class="normal"> 98</span>
|
||
<span class="normal"> 99</span>
|
||
<span class="normal">100</span>
|
||
<span class="normal">101</span>
|
||
<span class="normal">102</span>
|
||
<span class="normal">103</span>
|
||
<span class="normal">104</span>
|
||
<span class="normal">105</span>
|
||
<span class="normal">106</span>
|
||
<span class="normal">107</span>
|
||
<span class="normal">108</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">:</span><span class="nb">int</span><span class="o">|</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="p">:</span><span class="nb">int</span><span class="o">|</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span><span class="o">|</span><span class="nb">str</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span> <span class="o">=</span> <span class="n">make_tuple</span><span class="p">(</span><span class="n">kernel_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">padding</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">padding</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span> <span class="o">!=</span> <span class="s1">'same'</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid padding string </span><span class="si">{</span><span class="n">padding</span><span class="si">!r}</span><span class="s2">, only 'same' is supported"</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">stride</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"padding='same' is not supported for strided convolutions"</span><span class="p">)</span>
|
||
<span class="n">pad</span> <span class="o">=</span> <span class="p">[(</span><span class="n">d</span><span class="o">*</span><span class="p">(</span><span class="n">k</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">//</span><span class="mi">2</span><span class="p">,</span> <span class="n">d</span><span class="o">*</span><span class="p">(</span><span class="n">k</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="n">d</span><span class="o">*</span><span class="p">(</span><span class="n">k</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">//</span><span class="mi">2</span><span class="p">)</span> <span class="k">for</span> <span class="n">d</span><span class="p">,</span><span class="n">k</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">make_tuple</span><span class="p">(</span><span class="n">dilation</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">)),</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">[::</span><span class="o">-</span><span class="mi">1</span><span class="p">])]</span>
|
||
<span class="n">padding</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">flatten</span><span class="p">(</span><span class="n">pad</span><span class="p">))</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">stride</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dilation</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">padding</span> <span class="o">=</span> <span class="n">stride</span><span class="p">,</span> <span class="n">dilation</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> <span class="n">padding</span>
|
||
<span class="n">scale</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">in_channels</span> <span class="o">*</span> <span class="n">prod</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">))</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">in_channels</span><span class="o">//</span><span class="n">groups</span><span class="p">,</span> <span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">scale</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scale</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">scale</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scale</span><span class="p">)</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.ConvTranspose1d" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">ConvTranspose1d</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.ConvTranspose1d" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">ConvTranspose1d</span><span class="p">(</span>
|
||
<span class="n">in_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">out_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">output_padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="p">)</span> <span class="o">-></span> <span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">ConvTranspose2d</span> (<code>tinygrad.nn.ConvTranspose2d</code>)" href="#tinygrad.nn.ConvTranspose2d">ConvTranspose2d</a></span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>Applies a 1D transposed convolution operator over an input signal composed of several input planes.</p>
|
||
<p>See: <a href="https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d">https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d</a></p>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvTranspose1d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[</span><span class="mf">0.832</span> <span class="mf">0.3402</span> <span class="mf">0.9718</span> <span class="mf">0.5209</span><span class="p">]]]</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[</span><span class="o">-</span><span class="mf">0.7385</span> <span class="o">-</span><span class="mf">0.0549</span> <span class="o">-</span><span class="mf">0.1784</span> <span class="mf">0.1086</span> <span class="mf">0.5088</span> <span class="o">-</span><span class="mf">0.0056</span><span class="p">]]]</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">112</span>
|
||
<span class="normal">113</span>
|
||
<span class="normal">114</span>
|
||
<span class="normal">115</span>
|
||
<span class="normal">116</span>
|
||
<span class="normal">117</span>
|
||
<span class="normal">118</span>
|
||
<span class="normal">119</span>
|
||
<span class="normal">120</span>
|
||
<span class="normal">121</span>
|
||
<span class="normal">122</span>
|
||
<span class="normal">123</span>
|
||
<span class="normal">124</span>
|
||
<span class="normal">125</span>
|
||
<span class="normal">126</span>
|
||
<span class="normal">127</span>
|
||
<span class="normal">128</span>
|
||
<span class="normal">129</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">ConvTranspose1d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">output_padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">ConvTranspose2d</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Applies a 1D transposed convolution operator over an input signal composed of several input planes.</span>
|
||
|
||
<span class="sd"> See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d</span>
|
||
|
||
<span class="sd"> ```python exec="true" source="above" session="tensor" result="python"</span>
|
||
<span class="sd"> conv = nn.ConvTranspose1d(1, 1, 3)</span>
|
||
<span class="sd"> t = Tensor.rand(1, 1, 4)</span>
|
||
<span class="sd"> print(t.numpy())</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> ```python exec="true" source="above" session="tensor" result="python"</span>
|
||
<span class="sd"> t = conv(t)</span>
|
||
<span class="sd"> print(t.numpy())</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">return</span> <span class="n">ConvTranspose2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="p">(</span><span class="n">kernel_size</span><span class="p">,),</span> <span class="n">stride</span><span class="p">,</span> <span class="n">padding</span><span class="p">,</span> <span class="n">output_padding</span><span class="p">,</span> <span class="n">dilation</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.ConvTranspose2d" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">ConvTranspose2d</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.ConvTranspose2d" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">ConvTranspose2d</span><span class="p">(</span>
|
||
<span class="n">in_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">out_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#tuple">tuple</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="o">...</span><span class="p">],</span>
|
||
<span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">output_padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
<p class="doc doc-class-bases">
|
||
Bases: <code><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Conv2d</span> (<code>tinygrad.nn.Conv2d</code>)" href="#tinygrad.nn.Conv2d">Conv2d</a></code></p>
|
||
|
||
|
||
|
||
<p>Applies a 2D transposed convolution operator over an input image.</p>
|
||
<p>See: <a href="https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d">https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d</a></p>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvTranspose2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[[</span><span class="mf">0.1347</span> <span class="mf">0.9967</span> <span class="mf">0.0782</span> <span class="mf">0.1288</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="mf">0.649</span> <span class="mf">0.9984</span> <span class="mf">0.827</span> <span class="mf">0.7346</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="mf">0.5111</span> <span class="mf">0.4726</span> <span class="mf">0.7987</span> <span class="mf">0.8159</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="mf">0.9239</span> <span class="mf">0.5857</span> <span class="mf">0.8851</span> <span class="mf">0.3674</span><span class="p">]]]]</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[[[</span> <span class="mf">0.281</span> <span class="mf">0.4304</span> <span class="mf">0.0086</span> <span class="mf">0.5784</span> <span class="mf">0.2373</span> <span class="mf">0.2929</span><span class="p">]</span>
|
||
<span class="p">[</span> <span class="mf">0.3606</span> <span class="mf">0.0571</span> <span class="mf">0.5456</span> <span class="mf">0.2427</span> <span class="mf">0.3103</span> <span class="mf">0.4593</span><span class="p">]</span>
|
||
<span class="p">[</span> <span class="mf">0.2271</span> <span class="mf">0.2836</span> <span class="mf">0.1235</span> <span class="mf">0.032</span> <span class="mf">0.2217</span> <span class="mf">0.3312</span><span class="p">]</span>
|
||
<span class="p">[</span> <span class="mf">0.4558</span> <span class="mf">0.1647</span> <span class="mf">0.2354</span> <span class="o">-</span><span class="mf">0.0112</span> <span class="mf">0.1913</span> <span class="mf">0.0977</span><span class="p">]</span>
|
||
<span class="p">[</span> <span class="mf">0.12</span> <span class="mf">0.3204</span> <span class="o">-</span><span class="mf">0.0401</span> <span class="mf">0.1801</span> <span class="o">-</span><span class="mf">0.1336</span> <span class="mf">0.0728</span><span class="p">]</span>
|
||
<span class="p">[</span> <span class="mf">0.4357</span> <span class="mf">0.1603</span> <span class="mf">0.19</span> <span class="mf">0.0581</span> <span class="mf">0.0668</span> <span class="mf">0.209</span> <span class="p">]]]]</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">147</span>
|
||
<span class="normal">148</span>
|
||
<span class="normal">149</span>
|
||
<span class="normal">150</span>
|
||
<span class="normal">151</span>
|
||
<span class="normal">152</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">:</span><span class="nb">int</span><span class="o">|</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">output_padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">dilation</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">padding</span><span class="p">,</span> <span class="n">dilation</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
|
||
<span class="n">scale</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">in_channels</span> <span class="o">*</span> <span class="n">prod</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">))</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">//</span><span class="n">groups</span><span class="p">,</span> <span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">scale</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">scale</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">output_padding</span> <span class="o">=</span> <span class="n">output_padding</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.Linear" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Linear</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.Linear" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="n">out_features</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
|
||
|
||
<p>Applies a linear transformation to the incoming data.</p>
|
||
<p>See: <a href="https://pytorch.org/docs/stable/generated/torch.nn.Linear">https://pytorch.org/docs/stable/generated/torch.nn.Linear</a></p>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">lin</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[</span><span class="mf">0.6496</span> <span class="mf">0.6111</span> <span class="mf">0.5894</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="mf">0.0162</span> <span class="mf">0.9957</span> <span class="mf">0.7491</span><span class="p">]]</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">lin</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[</span><span class="o">-</span><span class="mf">0.7194</span> <span class="mf">0.1132</span> <span class="mf">0.416</span> <span class="o">-</span><span class="mf">0.1307</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="o">-</span><span class="mf">0.8916</span> <span class="mf">0.5362</span> <span class="mf">0.1431</span> <span class="mf">0.0856</span><span class="p">]]</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">173</span>
|
||
<span class="normal">174</span>
|
||
<span class="normal">175</span>
|
||
<span class="normal">176</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_features</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">out_features</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="n">bound</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">in_features</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">out_features</span><span class="p">,</span> <span class="n">in_features</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">bound</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">bound</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">out_features</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">bound</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">bound</span><span class="p">)</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.GroupNorm" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">GroupNorm</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.GroupNorm" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">GroupNorm</span><span class="p">(</span>
|
||
<span class="n">num_groups</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">num_channels</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="o">=</span><span class="mf">1e-05</span><span class="p">,</span>
|
||
<span class="n">affine</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
|
||
|
||
<p>Applies Group Normalization over a mini-batch of inputs.</p>
|
||
<ul>
|
||
<li>Paper: <a href="https://arxiv.org/abs/1803.08494v3">https://arxiv.org/abs/1803.08494v3</a></li>
|
||
</ul>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">12</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="mf">1.9572356939315796</span> <span class="mf">0.5820862054824829</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="o">-</span><span class="mf">3.4862459585838224e-08</span> <span class="mf">1.0012898445129395</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">196</span>
|
||
<span class="normal">197</span>
|
||
<span class="normal">198</span>
|
||
<span class="normal">199</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_groups</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">num_channels</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_groups</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_channels</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">num_groups</span><span class="p">,</span> <span class="n">num_channels</span><span class="p">,</span> <span class="n">eps</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">num_channels</span><span class="p">)</span> <span class="k">if</span> <span class="n">affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">num_channels</span><span class="p">)</span> <span class="k">if</span> <span class="n">affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.InstanceNorm" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">InstanceNorm</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.InstanceNorm" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">InstanceNorm</span><span class="p">(</span>
|
||
<span class="n">num_features</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#float">float</a></span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">,</span>
|
||
<span class="n">affine</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#bool">bool</a></span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
|
||
|
||
<p>Applies Instance Normalization over a mini-batch of inputs.</p>
|
||
<ul>
|
||
<li>Paper: <a href="https://arxiv.org/abs/1607.08022v3">https://arxiv.org/abs/1607.08022v3</a></li>
|
||
</ul>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">InstanceNorm</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="mf">2.0937702655792236</span> <span class="mf">0.5715298056602478</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="mf">1.1038510194794071e-07</span> <span class="mf">1.0052294731140137</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">226</span>
|
||
<span class="normal">227</span>
|
||
<span class="normal">228</span>
|
||
<span class="normal">229</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_features</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">eps</span><span class="p">:</span><span class="nb">float</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">affine</span><span class="p">:</span><span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_features</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">num_features</span><span class="p">,</span> <span class="n">eps</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">num_features</span><span class="p">)</span> <span class="k">if</span> <span class="n">affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">num_features</span><span class="p">)</span> <span class="k">if</span> <span class="n">affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.LayerNorm" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">LayerNorm</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.LayerNorm" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">LayerNorm</span><span class="p">(</span>
|
||
<span class="n">normalized_shape</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#tuple">tuple</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="o">...</span><span class="p">],</span>
|
||
<span class="n">eps</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#float">float</a></span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">,</span>
|
||
<span class="n">elementwise_affine</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#bool">bool</a></span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
|
||
|
||
<p>Applies Layer Normalization over a mini-batch of inputs.</p>
|
||
<ul>
|
||
<li>Paper: <a href="https://arxiv.org/abs/1607.06450v1">https://arxiv.org/abs/1607.06450v1</a></li>
|
||
</ul>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="mf">2.0335121154785156</span> <span class="mf">0.6475386023521423</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="o">-</span><span class="mf">1.8065918538923142e-07</span> <span class="mf">1.0170753002166748</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">252</span>
|
||
<span class="normal">253</span>
|
||
<span class="normal">254</span>
|
||
<span class="normal">255</span>
|
||
<span class="normal">256</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">normalized_shape</span><span class="p">:</span><span class="nb">int</span><span class="o">|</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">eps</span><span class="p">:</span><span class="nb">float</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">elementwise_affine</span><span class="p">:</span><span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">=</span> <span class="n">make_tuple</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">axis</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="o">-</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">))),</span> <span class="n">eps</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">)</span> <span class="k">if</span> <span class="n">elementwise_affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">)</span> <span class="k">if</span> <span class="n">elementwise_affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.LayerNorm2d" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">LayerNorm2d</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.LayerNorm2d" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">LayerNorm2d</span><span class="p">(</span>
|
||
<span class="n">normalized_shape</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#tuple">tuple</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="o">...</span><span class="p">],</span>
|
||
<span class="n">eps</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#float">float</a></span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">,</span>
|
||
<span class="n">elementwise_affine</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#bool">bool</a></span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
<p class="doc doc-class-bases">
|
||
Bases: <code><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">LayerNorm</span> (<code>tinygrad.nn.LayerNorm</code>)" href="#tinygrad.nn.LayerNorm">LayerNorm</a></code></p>
|
||
|
||
|
||
|
||
<p>Applies Layer Normalization over a mini-batch of 2D inputs.</p>
|
||
<p>See: <code class="language-python highlight"><span class="n">LayerNorm</span></code></p>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm2d</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="mf">1.9161213636398315</span> <span class="mf">0.5450614094734192</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="o">-</span><span class="mf">1.4681984339404153e-07</span> <span class="mf">1.005200982093811</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">252</span>
|
||
<span class="normal">253</span>
|
||
<span class="normal">254</span>
|
||
<span class="normal">255</span>
|
||
<span class="normal">256</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">normalized_shape</span><span class="p">:</span><span class="nb">int</span><span class="o">|</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">eps</span><span class="p">:</span><span class="nb">float</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">elementwise_affine</span><span class="p">:</span><span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">=</span> <span class="n">make_tuple</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">axis</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="o">-</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">))),</span> <span class="n">eps</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">)</span> <span class="k">if</span> <span class="n">elementwise_affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">normalized_shape</span><span class="p">)</span> <span class="k">if</span> <span class="n">elementwise_affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.RMSNorm" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">RMSNorm</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.RMSNorm" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">RMSNorm</span><span class="p">(</span><span class="n">dim</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-06</span><span class="p">,</span> <span class="n">elementwise_affine</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
|
||
|
||
<p>Applies Root Mean Square Normalization to input.</p>
|
||
<ul>
|
||
<li>Paper: <a href="https://arxiv.org/abs/1910.07467">https://arxiv.org/abs/1910.07467</a></li>
|
||
</ul>
|
||
<p><div class="language-python highlight"><pre><span></span><code><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">RMSNorm</span><span class="p">(</span><span class="mi">4</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtypes</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[</span> <span class="mf">0.</span> <span class="mf">1.</span> <span class="mf">2.</span> <span class="mf">3.</span><span class="p">]</span>
|
||
<span class="p">[</span> <span class="mf">4.</span> <span class="mf">5.</span> <span class="mf">6.</span> <span class="mf">7.</span><span class="p">]</span>
|
||
<span class="p">[</span> <span class="mf">8.</span> <span class="mf">9.</span> <span class="mf">10.</span> <span class="mf">11.</span><span class="p">]]</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="nb">print</span><span class="p">(</span><span class="n">norm</span><span class="p">(</span><span class="n">t</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[</span><span class="mf">0.</span> <span class="mf">0.5345</span> <span class="mf">1.069</span> <span class="mf">1.6036</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="mf">0.7127</span> <span class="mf">0.8909</span> <span class="mf">1.069</span> <span class="mf">1.2472</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="mf">0.8363</span> <span class="mf">0.9409</span> <span class="mf">1.0454</span> <span class="mf">1.15</span> <span class="p">]]</span>
|
||
</code></pre></div></p>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">297</span>
|
||
<span class="normal">298</span>
|
||
<span class="normal">299</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">,</span> <span class="n">elementwise_affine</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span> <span class="k">if</span> <span class="n">elementwise_affine</span> <span class="k">else</span> <span class="kc">None</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.Embedding" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Embedding</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.Embedding" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">Embedding</span><span class="p">(</span><span class="n">vocab_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="n">embed_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
|
||
|
||
<p>A simple lookup table that stores embeddings of a fixed dictionary and size.</p>
|
||
<p>See: <a href="https://pytorch.org/docs/stable/generated/torch.nn.Embedding">https://pytorch.org/docs/stable/generated/torch.nn.Embedding</a></p>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">emb</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">emb</span><span class="p">(</span><span class="n">Tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">]))</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="p">[[</span> <span class="mf">0.5594</span> <span class="mf">0.3912</span> <span class="mf">0.1444</span><span class="p">]</span>
|
||
<span class="p">[</span> <span class="mf">0.072</span> <span class="mf">0.4649</span> <span class="mf">0.3679</span><span class="p">]</span>
|
||
<span class="p">[</span><span class="o">-</span><span class="mf">0.3457</span> <span class="o">-</span><span class="mf">0.2776</span> <span class="o">-</span><span class="mf">0.5988</span><span class="p">]</span>
|
||
<span class="p">[</span> <span class="mf">0.5594</span> <span class="mf">0.3912</span> <span class="mf">0.1444</span><span class="p">]]</span>
|
||
</code></pre></div>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">318</span>
|
||
<span class="normal">319</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">embed_size</span><span class="p">:</span><span class="nb">int</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_sz</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_sz</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">vocab_size</span><span class="p">,</span> <span class="n">embed_size</span><span class="p">,</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">glorot_uniform</span><span class="p">(</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">embed_size</span><span class="p">)</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.LSTMCell" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">LSTMCell</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.LSTMCell" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">LSTMCell</span><span class="p">(</span>
|
||
<span class="n">input_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></span><span class="p">,</span> <span class="n">bias</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#bool">bool</a></span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
|
||
|
||
<p>A long short-term memory (LSTM) cell.</p>
|
||
|
||
|
||
<p><span class="doc-section-title">Parameters:</span></p>
|
||
<ul>
|
||
<li class="doc-section-item field-body">
|
||
<b><code>input_size</code></b>
|
||
(<code><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></code>)
|
||
–
|
||
<div class="doc-md-description">
|
||
<p>The number of expected features in the input <code class="language-python highlight"><span class="n">x</span></code></p>
|
||
</div>
|
||
</li>
|
||
<li class="doc-section-item field-body">
|
||
<b><code>hidden_size</code></b>
|
||
(<code><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#int">int</a></code>)
|
||
–
|
||
<div class="doc-md-description">
|
||
<p>The number of features in the hidden state <code class="language-python highlight"><span class="n">h</span></code></p>
|
||
</div>
|
||
</li>
|
||
<li class="doc-section-item field-body">
|
||
<b><code>bias</code></b>
|
||
(<code><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/functions.html#bool">bool</a></code>, default:
|
||
<code>True</code>
|
||
)
|
||
–
|
||
<div class="doc-md-description">
|
||
<p>If <code class="language-python highlight"><span class="kc">False</span></code>, then the layer does not use bias weights <code class="language-python highlight"><span class="n">b_ih</span></code> and <code class="language-python highlight"><span class="n">b_hh</span></code></p>
|
||
</div>
|
||
</li>
|
||
</ul>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/__init__.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">337</span>
|
||
<span class="normal">338</span>
|
||
<span class="normal">339</span>
|
||
<span class="normal">340</span>
|
||
<span class="normal">341</span>
|
||
<span class="normal">342</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_size</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">:</span><span class="nb">int</span><span class="p">,</span> <span class="n">bias</span><span class="p">:</span><span class="nb">bool</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="n">stdv</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight_ih</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">hidden_size</span><span class="o">*</span><span class="mi">4</span><span class="p">,</span> <span class="n">input_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">stdv</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">stdv</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight_hh</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">hidden_size</span><span class="o">*</span><span class="mi">4</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">low</span><span class="o">=-</span><span class="n">stdv</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">stdv</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias_ih</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">hidden_size</span><span class="o">*</span><span class="mi">4</span><span class="p">)</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias_hh</span><span class="p">:</span> <span class="n">Tensor</span><span class="o">|</span><span class="kc">None</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">hidden_size</span><span class="o">*</span><span class="mi">4</span><span class="p">)</span> <span class="k">if</span> <span class="n">bias</span> <span class="k">else</span> <span class="kc">None</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div><h2 id="optimizers">Optimizers<a class="headerlink" href="#optimizers" title="Permanent link">¤</a></h2>
|
||
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.optim.SGD" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">SGD</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.optim.SGD" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">SGD</span><span class="p">(</span>
|
||
<span class="n">params</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
|
||
<span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
|
||
<span class="n">momentum</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
|
||
<span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
|
||
<span class="n">nesterov</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">classic</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">fused</span><span class="o">=</span><span class="n"><span title="tinygrad.helpers.FUSE_OPTIM">FUSE_OPTIM</span></span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.</p>
|
||
<p><code class="language-python highlight"><span class="n">classic</span></code> is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.</p>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/optim.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">77</span>
|
||
<span class="normal">78</span>
|
||
<span class="normal">79</span>
|
||
<span class="normal">80</span>
|
||
<span class="normal">81</span>
|
||
<span class="normal">82</span>
|
||
<span class="normal">83</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">SGD</span><span class="p">(</span><span class="n">params</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">nesterov</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">classic</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">FUSE_OPTIM</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.</span>
|
||
|
||
<span class="sd"> `classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">return</span> <span class="n">LARS</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">momentum</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="n">nesterov</span><span class="p">,</span> <span class="n">classic</span><span class="o">=</span><span class="n">classic</span><span class="p">,</span> <span class="n">pre_wd</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">tcoef</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">fused</span><span class="p">)</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.optim.LARS" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">LARS</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.optim.LARS" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">LARS</span><span class="p">(</span>
|
||
<span class="n">params</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
|
||
<span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
|
||
<span class="n">momentum</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span>
|
||
<span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0001</span><span class="p">,</span>
|
||
<span class="n">ns_steps</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">ns_coefficients</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">nesterov</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">classic</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">pre_wd</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">tcoef</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
|
||
<span class="n">fused</span><span class="o">=</span><span class="n"><span title="tinygrad.helpers.FUSE_OPTIM">FUSE_OPTIM</span></span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
<p class="doc doc-class-bases">
|
||
Bases: <code><span title="tinygrad.nn.optim.Optimizer">Optimizer</span></code></p>
|
||
|
||
|
||
|
||
<p>Layer-wise Adaptive Rate Scaling (LARS) optimizer with optional momentum and weight decay.</p>
|
||
<ul>
|
||
<li>Paper: <a href="https://arxiv.org/abs/1708.03888v3">https://arxiv.org/abs/1708.03888v3</a></li>
|
||
</ul>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/optim.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">103</span>
|
||
<span class="normal">104</span>
|
||
<span class="normal">105</span>
|
||
<span class="normal">106</span>
|
||
<span class="normal">107</span>
|
||
<span class="normal">108</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">:</span><span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">,</span> <span class="n">ns_steps</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">ns_coefficients</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">nesterov</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">classic</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">pre_wd</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">tcoef</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">FUSE_OPTIM</span><span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">fused</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">momentum</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">wd</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ns_steps</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ns_coefficients</span> <span class="o">=</span> <span class="n">momentum</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="n">ns_steps</span><span class="p">,</span> <span class="n">ns_coefficients</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">nesterov</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">classic</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_wd</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tcoef</span> <span class="o">=</span> <span class="n">nesterov</span><span class="p">,</span> <span class="n">classic</span><span class="p">,</span> <span class="n">pre_wd</span><span class="p">,</span> <span class="n">tcoef</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_optim_param</span><span class="p">()</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="k">else</span> <span class="p">[]</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.optim.AdamW" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">AdamW</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.optim.AdamW" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">AdamW</span><span class="p">(</span>
|
||
<span class="n">params</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
|
||
<span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
|
||
<span class="n">b1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span>
|
||
<span class="n">b2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="o">=</span><span class="mf">1e-08</span><span class="p">,</span>
|
||
<span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span>
|
||
<span class="n">fused</span><span class="o">=</span><span class="n"><span title="tinygrad.helpers.FUSE_OPTIM">FUSE_OPTIM</span></span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>AdamW optimizer with optional weight decay.</p>
|
||
<ul>
|
||
<li>Paper: <a href="https://arxiv.org/abs/1711.05101v3">https://arxiv.org/abs/1711.05101v3</a></li>
|
||
</ul>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/optim.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">133</span>
|
||
<span class="normal">134</span>
|
||
<span class="normal">135</span>
|
||
<span class="normal">136</span>
|
||
<span class="normal">137</span>
|
||
<span class="normal">138</span>
|
||
<span class="normal">139</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">AdamW</span><span class="p">(</span><span class="n">params</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">b1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">b2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-8</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">FUSE_OPTIM</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> AdamW optimizer with optional weight decay.</span>
|
||
|
||
<span class="sd"> - Paper: https://arxiv.org/abs/1711.05101v3</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">return</span> <span class="n">LAMB</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">b1</span><span class="p">,</span> <span class="n">b2</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="n">adam</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">fused</span><span class="p">)</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.optim.Adam" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">Adam</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.optim.Adam" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">Adam</span><span class="p">(</span>
|
||
<span class="n">params</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
|
||
<span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
|
||
<span class="n">b1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span>
|
||
<span class="n">b2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="o">=</span><span class="mf">1e-08</span><span class="p">,</span>
|
||
<span class="n">fused</span><span class="o">=</span><span class="n"><span title="tinygrad.helpers.FUSE_OPTIM">FUSE_OPTIM</span></span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>Adam optimizer.</p>
|
||
<ul>
|
||
<li>Paper: <a href="https://arxiv.org/abs/1412.6980">https://arxiv.org/abs/1412.6980</a></li>
|
||
</ul>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/optim.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">140</span>
|
||
<span class="normal">141</span>
|
||
<span class="normal">142</span>
|
||
<span class="normal">143</span>
|
||
<span class="normal">144</span>
|
||
<span class="normal">145</span>
|
||
<span class="normal">146</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">Adam</span><span class="p">(</span><span class="n">params</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">b1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">b2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-8</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">FUSE_OPTIM</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Adam optimizer.</span>
|
||
|
||
<span class="sd"> - Paper: https://arxiv.org/abs/1412.6980</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">return</span> <span class="n">LAMB</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">b1</span><span class="p">,</span> <span class="n">b2</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">adam</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">fused</span><span class="p">)</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-class">
|
||
|
||
|
||
|
||
<h3 id="tinygrad.nn.optim.LAMB" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">LAMB</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.optim.LAMB" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">LAMB</span><span class="p">(</span>
|
||
<span class="n">params</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
|
||
<span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span>
|
||
<span class="n">b1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span>
|
||
<span class="n">b2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="o">=</span><span class="mf">1e-06</span><span class="p">,</span>
|
||
<span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
|
||
<span class="n">adam</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">fused</span><span class="o">=</span><span class="n"><span title="tinygrad.helpers.FUSE_OPTIM">FUSE_OPTIM</span></span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
<p class="doc doc-class-bases">
|
||
Bases: <code><span title="tinygrad.nn.optim.Optimizer">Optimizer</span></code></p>
|
||
|
||
|
||
|
||
<p>LAMB optimizer with optional weight decay.</p>
|
||
<ul>
|
||
<li>Paper: <a href="https://arxiv.org/abs/1904.00962">https://arxiv.org/abs/1904.00962</a></li>
|
||
</ul>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/optim.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">154</span>
|
||
<span class="normal">155</span>
|
||
<span class="normal">156</span>
|
||
<span class="normal">157</span>
|
||
<span class="normal">158</span>
|
||
<span class="normal">159</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">b1</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">b2</span><span class="o">=</span><span class="mf">0.999</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">adam</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">FUSE_OPTIM</span><span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">fused</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">b1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">b2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">wd</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">adam</span> <span class="o">=</span> <span class="n">b1</span><span class="p">,</span> <span class="n">b2</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="n">adam</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">b1_t</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">b2_t</span> <span class="o">=</span> <span class="p">(</span><span class="n">Tensor</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtypes</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="p">[</span><span class="n">b1</span><span class="p">,</span> <span class="n">b2</span><span class="p">])</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">m</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_optim_param</span><span class="p">()</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_optim_param</span><span class="p">()</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
|
||
|
||
|
||
<div class="doc doc-children">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div><h2 id="loadsave">Load/Save<a class="headerlink" href="#loadsave" title="Permanent link">¤</a></h2>
|
||
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.state.safe_load" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">safe_load</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.state.safe_load" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">safe_load</span><span class="p">(</span><span class="n">fn</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span> <span class="o">|</span> <span class="n"><a class="autorefs autorefs-external" title="<code>pathlib.Path</code>" href="https://docs.python.org/3/library/pathlib.html#pathlib.Path">Path</a></span><span class="p">)</span> <span class="o">-></span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">]</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>Loads a .safetensor file, returning the <code class="language-python highlight"><span class="n">state_dict</span></code>.</p>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">state_dict</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">safe_load</span><span class="p">(</span><span class="s2">"test.safetensor"</span><span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">50</span>
|
||
<span class="normal">51</span>
|
||
<span class="normal">52</span>
|
||
<span class="normal">53</span>
|
||
<span class="normal">54</span>
|
||
<span class="normal">55</span>
|
||
<span class="normal">56</span>
|
||
<span class="normal">57</span>
|
||
<span class="normal">58</span>
|
||
<span class="normal">59</span>
|
||
<span class="normal">60</span>
|
||
<span class="normal">61</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">safe_load</span><span class="p">(</span><span class="n">fn</span><span class="p">:</span><span class="n">Tensor</span><span class="o">|</span><span class="nb">str</span><span class="o">|</span><span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">)</span> <span class="o">-></span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Loads a .safetensor file, returning the `state_dict`.</span>
|
||
|
||
<span class="sd"> ```python</span>
|
||
<span class="sd"> state_dict = nn.state.safe_load("test.safetensor")</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">t</span><span class="p">,</span> <span class="n">data_start</span><span class="p">,</span> <span class="n">metadata</span> <span class="o">=</span> <span class="n">safe_load_metadata</span><span class="p">(</span><span class="n">fn</span><span class="p">)</span>
|
||
<span class="n">data</span> <span class="o">=</span> <span class="n">t</span><span class="p">[</span><span class="n">data_start</span><span class="p">:]</span>
|
||
<span class="k">return</span> <span class="p">{</span> <span class="n">k</span><span class="p">:</span> <span class="n">data</span><span class="p">[</span><span class="n">v</span><span class="p">[</span><span class="s1">'data_offsets'</span><span class="p">][</span><span class="mi">0</span><span class="p">]:</span><span class="n">v</span><span class="p">[</span><span class="s1">'data_offsets'</span><span class="p">][</span><span class="mi">1</span><span class="p">]]</span><span class="o">.</span><span class="n">bitcast</span><span class="p">(</span><span class="n">safe_dtypes</span><span class="p">[</span><span class="n">v</span><span class="p">[</span><span class="s1">'dtype'</span><span class="p">]])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">v</span><span class="p">[</span><span class="s1">'shape'</span><span class="p">])</span>
|
||
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">metadata</span><span class="o">.</span><span class="n">items</span><span class="p">()</span> <span class="k">if</span> <span class="n">k</span> <span class="o">!=</span> <span class="s2">"__metadata__"</span> <span class="p">}</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.state.safe_save" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">safe_save</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.state.safe_save" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">safe_save</span><span class="p">(</span>
|
||
<span class="n">tensors</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
|
||
<span class="n">fn</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span>
|
||
<span class="n">metadata</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-external" title="<code>typing.Any</code>" href="https://docs.python.org/3/library/typing.html#typing.Any">Any</a></span><span class="p">]</span> <span class="o">|</span> <span class="kc">None</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>Saves a <code class="language-python highlight"><span class="n">state_dict</span></code> to disk in a .safetensor file with optional metadata.</p>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span>
|
||
<span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">safe_save</span><span class="p">({</span><span class="s1">'t'</span><span class="p">:</span><span class="n">t</span><span class="p">},</span> <span class="s2">"test.safetensor"</span><span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">63</span>
|
||
<span class="normal">64</span>
|
||
<span class="normal">65</span>
|
||
<span class="normal">66</span>
|
||
<span class="normal">67</span>
|
||
<span class="normal">68</span>
|
||
<span class="normal">69</span>
|
||
<span class="normal">70</span>
|
||
<span class="normal">71</span>
|
||
<span class="normal">72</span>
|
||
<span class="normal">73</span>
|
||
<span class="normal">74</span>
|
||
<span class="normal">75</span>
|
||
<span class="normal">76</span>
|
||
<span class="normal">77</span>
|
||
<span class="normal">78</span>
|
||
<span class="normal">79</span>
|
||
<span class="normal">80</span>
|
||
<span class="normal">81</span>
|
||
<span class="normal">82</span>
|
||
<span class="normal">83</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">safe_save</span><span class="p">(</span><span class="n">tensors</span><span class="p">:</span><span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">],</span> <span class="n">fn</span><span class="p">:</span><span class="nb">str</span><span class="p">,</span> <span class="n">metadata</span><span class="p">:</span><span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span><span class="o">|</span><span class="kc">None</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Saves a `state_dict` to disk in a .safetensor file with optional metadata.</span>
|
||
|
||
<span class="sd"> ```python</span>
|
||
<span class="sd"> t = Tensor([1, 2, 3])</span>
|
||
<span class="sd"> nn.state.safe_save({'t':t}, "test.safetensor")</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">headers</span><span class="p">,</span> <span class="n">offset</span> <span class="o">=</span> <span class="p">{},</span> <span class="mi">0</span>
|
||
<span class="k">if</span> <span class="n">metadata</span><span class="p">:</span> <span class="n">headers</span><span class="p">[</span><span class="s1">'__metadata__'</span><span class="p">]</span> <span class="o">=</span> <span class="n">metadata</span>
|
||
<span class="k">for</span> <span class="n">k</span><span class="p">,</span><span class="n">v</span> <span class="ow">in</span> <span class="n">tensors</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||
<span class="n">headers</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'dtype'</span><span class="p">:</span> <span class="n">inverse_safe_dtypes</span><span class="p">[</span><span class="n">v</span><span class="o">.</span><span class="n">dtype</span><span class="p">],</span> <span class="s1">'shape'</span><span class="p">:</span> <span class="nb">list</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> <span class="s1">'data_offsets'</span><span class="p">:[</span><span class="n">offset</span><span class="p">,</span> <span class="n">offset</span><span class="o">+</span><span class="n">v</span><span class="o">.</span><span class="n">nbytes</span><span class="p">()]}</span>
|
||
<span class="n">offset</span> <span class="o">+=</span> <span class="n">v</span><span class="o">.</span><span class="n">nbytes</span><span class="p">()</span>
|
||
<span class="n">j</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="n">headers</span><span class="p">,</span> <span class="n">separators</span><span class="o">=</span><span class="p">(</span><span class="s1">','</span><span class="p">,</span> <span class="s1">':'</span><span class="p">))</span>
|
||
<span class="n">j</span> <span class="o">+=</span> <span class="s2">"</span><span class="se">\x20</span><span class="s2">"</span><span class="o">*</span><span class="p">(</span><span class="n">round_up</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">j</span><span class="p">),</span><span class="mi">8</span><span class="p">)</span><span class="o">-</span><span class="nb">len</span><span class="p">(</span><span class="n">j</span><span class="p">))</span>
|
||
<span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="n">fn</span><span class="p">)</span><span class="o">.</span><span class="n">unlink</span><span class="p">(</span><span class="n">missing_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">Tensor</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="mi">8</span><span class="o">+</span><span class="nb">len</span><span class="p">(</span><span class="n">j</span><span class="p">)</span><span class="o">+</span><span class="n">offset</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtypes</span><span class="o">.</span><span class="n">uint8</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="sa">f</span><span class="s2">"disk:</span><span class="si">{</span><span class="n">fn</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="n">t</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">8</span><span class="p">]</span><span class="o">.</span><span class="n">bitcast</span><span class="p">(</span><span class="n">dtypes</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span><span class="o">.</span><span class="n">assign</span><span class="p">([</span><span class="nb">len</span><span class="p">(</span><span class="n">j</span><span class="p">)])</span>
|
||
<span class="n">t</span><span class="p">[</span><span class="mi">8</span><span class="p">:</span><span class="mi">8</span><span class="o">+</span><span class="nb">len</span><span class="p">(</span><span class="n">j</span><span class="p">)]</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">j</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="s1">'utf-8'</span><span class="p">)))</span>
|
||
<span class="k">for</span> <span class="n">k</span><span class="p">,</span><span class="n">v</span> <span class="ow">in</span> <span class="n">safe_load</span><span class="p">(</span><span class="n">t</span><span class="p">)</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> <span class="n">v</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="n">tensors</span><span class="p">[</span><span class="n">k</span><span class="p">])</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.state.get_state_dict" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">get_state_dict</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.state.get_state_dict" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">get_state_dict</span><span class="p">(</span>
|
||
<span class="n">obj</span><span class="p">,</span> <span class="n">prefix</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span> <span class="o">=</span> <span class="s2">""</span><span class="p">,</span> <span class="n">tensor_type</span><span class="o">=</span><span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span>
|
||
<span class="p">)</span> <span class="o">-></span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">]</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>Returns a <code class="language-python highlight"><span class="n">state_dict</span></code> of the object, with optional prefix.</p>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">Net</span><span class="p">:</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">l1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">l2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
|
||
|
||
<span class="n">net</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">get_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">)</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">dict_keys</span><span class="p">([</span><span class="s1">'l1.weight'</span><span class="p">,</span> <span class="s1">'l1.bias'</span><span class="p">,</span> <span class="s1">'l2.weight'</span><span class="p">,</span> <span class="s1">'l2.bias'</span><span class="p">])</span>
|
||
</code></pre></div>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal"> 87</span>
|
||
<span class="normal"> 88</span>
|
||
<span class="normal"> 89</span>
|
||
<span class="normal"> 90</span>
|
||
<span class="normal"> 91</span>
|
||
<span class="normal"> 92</span>
|
||
<span class="normal"> 93</span>
|
||
<span class="normal"> 94</span>
|
||
<span class="normal"> 95</span>
|
||
<span class="normal"> 96</span>
|
||
<span class="normal"> 97</span>
|
||
<span class="normal"> 98</span>
|
||
<span class="normal"> 99</span>
|
||
<span class="normal">100</span>
|
||
<span class="normal">101</span>
|
||
<span class="normal">102</span>
|
||
<span class="normal">103</span>
|
||
<span class="normal">104</span>
|
||
<span class="normal">105</span>
|
||
<span class="normal">106</span>
|
||
<span class="normal">107</span>
|
||
<span class="normal">108</span>
|
||
<span class="normal">109</span>
|
||
<span class="normal">110</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">get_state_dict</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">prefix</span><span class="p">:</span><span class="nb">str</span><span class="o">=</span><span class="s1">''</span><span class="p">,</span> <span class="n">tensor_type</span><span class="o">=</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Returns a `state_dict` of the object, with optional prefix.</span>
|
||
|
||
<span class="sd"> ```python exec="true" source="above" session="tensor" result="python"</span>
|
||
<span class="sd"> class Net:</span>
|
||
<span class="sd"> def __init__(self):</span>
|
||
<span class="sd"> self.l1 = nn.Linear(4, 5)</span>
|
||
<span class="sd"> self.l2 = nn.Linear(5, 6)</span>
|
||
|
||
<span class="sd"> net = Net()</span>
|
||
<span class="sd"> print(nn.state.get_state_dict(net).keys())</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">tensor_type</span><span class="p">):</span> <span class="k">return</span> <span class="p">{</span><span class="n">prefix</span><span class="o">.</span><span class="n">strip</span><span class="p">(</span><span class="s1">'.'</span><span class="p">):</span><span class="n">obj</span><span class="p">}</span>
|
||
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="s1">'_asdict'</span><span class="p">):</span> <span class="k">return</span> <span class="n">get_state_dict</span><span class="p">(</span><span class="n">obj</span><span class="o">.</span><span class="n">_asdict</span><span class="p">(),</span> <span class="n">prefix</span><span class="p">,</span> <span class="n">tensor_type</span><span class="p">)</span> <span class="c1"># namedtuple</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">OrderedDict</span><span class="p">):</span> <span class="k">return</span> <span class="n">get_state_dict</span><span class="p">(</span><span class="nb">dict</span><span class="p">(</span><span class="n">obj</span><span class="p">),</span> <span class="n">prefix</span><span class="p">,</span> <span class="n">tensor_type</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="s1">'__dict__'</span><span class="p">):</span> <span class="k">return</span> <span class="n">get_state_dict</span><span class="p">(</span><span class="n">obj</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">,</span> <span class="n">prefix</span><span class="p">,</span> <span class="n">tensor_type</span><span class="p">)</span>
|
||
<span class="n">state_dict</span> <span class="o">=</span> <span class="p">{}</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
|
||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span><span class="n">x</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">obj</span><span class="p">):</span> <span class="n">state_dict</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">get_state_dict</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">prefix</span><span class="si">}{</span><span class="nb">str</span><span class="p">(</span><span class="n">i</span><span class="p">)</span><span class="si">}</span><span class="s2">."</span><span class="p">,</span> <span class="n">tensor_type</span><span class="p">))</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="k">for</span> <span class="n">k</span><span class="p">,</span><span class="n">v</span> <span class="ow">in</span> <span class="n">obj</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> <span class="n">state_dict</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">get_state_dict</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">prefix</span><span class="si">}{</span><span class="nb">str</span><span class="p">(</span><span class="n">k</span><span class="p">)</span><span class="si">}</span><span class="s2">."</span><span class="p">,</span> <span class="n">tensor_type</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="n">state_dict</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.state.get_parameters" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">get_parameters</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.state.get_parameters" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">get_parameters</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span> <span class="o">-></span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">]</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<div class="language-python highlight"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">Net</span><span class="p">:</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">l1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">l2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
|
||
|
||
<span class="n">net</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">get_parameters</span><span class="p">(</span><span class="n">net</span><span class="p">)))</span>
|
||
</code></pre></div>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="mi">4</span>
|
||
</code></pre></div>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">112</span>
|
||
<span class="normal">113</span>
|
||
<span class="normal">114</span>
|
||
<span class="normal">115</span>
|
||
<span class="normal">116</span>
|
||
<span class="normal">117</span>
|
||
<span class="normal">118</span>
|
||
<span class="normal">119</span>
|
||
<span class="normal">120</span>
|
||
<span class="normal">121</span>
|
||
<span class="normal">122</span>
|
||
<span class="normal">123</span>
|
||
<span class="normal">124</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">get_parameters</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span> <span class="o">-></span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> ```python exec="true" source="above" session="tensor" result="python"</span>
|
||
<span class="sd"> class Net:</span>
|
||
<span class="sd"> def __init__(self):</span>
|
||
<span class="sd"> self.l1 = nn.Linear(4, 5)</span>
|
||
<span class="sd"> self.l2 = nn.Linear(5, 6)</span>
|
||
|
||
<span class="sd"> net = Net()</span>
|
||
<span class="sd"> print(len(nn.state.get_parameters(net)))</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="n">get_state_dict</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span><span class="o">.</span><span class="n">values</span><span class="p">())</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.state.load_state_dict" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">load_state_dict</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.state.load_state_dict" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">load_state_dict</span><span class="p">(</span>
|
||
<span class="n">model</span><span class="p">,</span>
|
||
<span class="n">state_dict</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">],</span>
|
||
<span class="n">strict</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">verbose</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">consume</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">realize</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="p">)</span> <span class="o">-></span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#list">list</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">]</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>Loads a <code class="language-python highlight"><span class="n">state_dict</span></code> into a model. Return the loaded Tensors.</p>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">Net</span><span class="p">:</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">l1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">l2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
|
||
|
||
<span class="n">net</span> <span class="o">=</span> <span class="n">Net</span><span class="p">()</span>
|
||
<span class="n">state_dict</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">get_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">)</span>
|
||
<span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">state_dict</span><span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">126</span>
|
||
<span class="normal">127</span>
|
||
<span class="normal">128</span>
|
||
<span class="normal">129</span>
|
||
<span class="normal">130</span>
|
||
<span class="normal">131</span>
|
||
<span class="normal">132</span>
|
||
<span class="normal">133</span>
|
||
<span class="normal">134</span>
|
||
<span class="normal">135</span>
|
||
<span class="normal">136</span>
|
||
<span class="normal">137</span>
|
||
<span class="normal">138</span>
|
||
<span class="normal">139</span>
|
||
<span class="normal">140</span>
|
||
<span class="normal">141</span>
|
||
<span class="normal">142</span>
|
||
<span class="normal">143</span>
|
||
<span class="normal">144</span>
|
||
<span class="normal">145</span>
|
||
<span class="normal">146</span>
|
||
<span class="normal">147</span>
|
||
<span class="normal">148</span>
|
||
<span class="normal">149</span>
|
||
<span class="normal">150</span>
|
||
<span class="normal">151</span>
|
||
<span class="normal">152</span>
|
||
<span class="normal">153</span>
|
||
<span class="normal">154</span>
|
||
<span class="normal">155</span>
|
||
<span class="normal">156</span>
|
||
<span class="normal">157</span>
|
||
<span class="normal">158</span>
|
||
<span class="normal">159</span>
|
||
<span class="normal">160</span>
|
||
<span class="normal">161</span>
|
||
<span class="normal">162</span>
|
||
<span class="normal">163</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">load_state_dict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">state_dict</span><span class="p">:</span><span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">],</span> <span class="n">strict</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">consume</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">realize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Loads a `state_dict` into a model. Return the loaded Tensors.</span>
|
||
|
||
<span class="sd"> ```python</span>
|
||
<span class="sd"> class Net:</span>
|
||
<span class="sd"> def __init__(self):</span>
|
||
<span class="sd"> self.l1 = nn.Linear(4, 5)</span>
|
||
<span class="sd"> self.l2 = nn.Linear(5, 6)</span>
|
||
|
||
<span class="sd"> net = Net()</span>
|
||
<span class="sd"> state_dict = nn.state.get_state_dict(net)</span>
|
||
<span class="sd"> nn.state.load_state_dict(net, state_dict)</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">start_mem_used</span> <span class="o">=</span> <span class="n">GlobalCounters</span><span class="o">.</span><span class="n">mem_used</span>
|
||
<span class="n">ret</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">with</span> <span class="n">Timing</span><span class="p">(</span><span class="s2">"loaded weights in "</span><span class="p">,</span>
|
||
<span class="k">lambda</span> <span class="n">et_ns</span><span class="p">:</span> <span class="sa">f</span><span class="s2">", </span><span class="si">{</span><span class="p">(</span><span class="n">B</span><span class="o">:=</span><span class="p">(</span><span class="n">GlobalCounters</span><span class="o">.</span><span class="n">mem_used</span><span class="o">-</span><span class="n">start_mem_used</span><span class="p">))</span><span class="o">/</span><span class="mf">1e9</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2"> GB loaded at </span><span class="si">{</span><span class="n">B</span><span class="o">/</span><span class="n">et_ns</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2"> GB/s"</span><span class="p">,</span> <span class="n">enabled</span><span class="o">=</span><span class="n">verbose</span><span class="p">):</span>
|
||
<span class="n">model_state_dict</span> <span class="o">=</span> <span class="n">get_state_dict</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">DEBUG</span> <span class="o">>=</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">state_dict</span><span class="p">)</span> <span class="o">></span> <span class="nb">len</span><span class="p">(</span><span class="n">model_state_dict</span><span class="p">):</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="s2">"WARNING: unused weights in state_dict"</span><span class="p">,</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">state_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> <span class="o">-</span> <span class="n">model_state_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">())))</span>
|
||
<span class="k">for</span> <span class="n">k</span><span class="p">,</span><span class="n">v</span> <span class="ow">in</span> <span class="p">(</span><span class="n">t</span> <span class="o">:=</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">model_state_dict</span><span class="o">.</span><span class="n">items</span><span class="p">(),</span> <span class="n">disable</span><span class="o">=</span><span class="n">CI</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">verbose</span><span class="p">)):</span>
|
||
<span class="n">t</span><span class="o">.</span><span class="n">desc</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">"ram used: </span><span class="si">{</span><span class="n">GlobalCounters</span><span class="o">.</span><span class="n">mem_used</span><span class="o">/</span><span class="mf">1e9</span><span class="si">:</span><span class="s2">5.2f</span><span class="si">}</span><span class="s2"> GB, </span><span class="si">{</span><span class="n">k</span><span class="si">:</span><span class="s2">50s</span><span class="si">}</span><span class="s2">: "</span>
|
||
<span class="k">if</span> <span class="n">k</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">state_dict</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">strict</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">DEBUG</span> <span class="o">>=</span> <span class="mi">1</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"WARNING: not loading </span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="k">continue</span>
|
||
<span class="k">if</span> <span class="n">v</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="p">{(),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,)}</span> <span class="o">==</span> <span class="p">{</span><span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">}:</span> <span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Shape mismatch in layer `</span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s1">`: Expected shape </span><span class="si">{</span><span class="n">v</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s1">, but found </span><span class="si">{</span><span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s1"> in state dict.'</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span> <span class="n">v</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">])</span>
|
||
<span class="k">else</span><span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">shard</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">v</span><span class="o">.</span><span class="n">uop</span><span class="o">.</span><span class="n">axis</span><span class="p">))</span>
|
||
<span class="k">else</span><span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">device</span><span class="p">))</span>
|
||
<span class="k">if</span> <span class="n">realize</span><span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">realize</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">consume</span><span class="p">:</span> <span class="k">del</span> <span class="n">state_dict</span><span class="p">[</span><span class="n">k</span><span class="p">]</span>
|
||
<span class="n">ret</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">v</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">ret</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.state.tar_extract" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <code class="highlight language-python"><span class="n">tar_extract</span></code>
|
||
|
||
<a href="#tinygrad.nn.state.tar_extract" class="headerlink" title="Permanent link">¤</a></h3>
|
||
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">tar_extract</span><span class="p">(</span><span class="n">fn</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">|</span> <span class="nb">str</span> <span class="o">|</span> <span class="n">Path</span><span class="p">)</span> <span class="o">-></span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]</span>
|
||
</code></pre></div>
|
||
<p>Extracts files from a tar archive and returns them as a dictionary of names (keys) and tensors (values).</p>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">tensors</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">tar_extract</span><span class="p">(</span><span class="n">Tensor</span><span class="p">(</span><span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="s2">"archive.tar"</span><span class="p">)))</span>
|
||
</code></pre></div>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">186</span>
|
||
<span class="normal">187</span>
|
||
<span class="normal">188</span>
|
||
<span class="normal">189</span>
|
||
<span class="normal">190</span>
|
||
<span class="normal">191</span>
|
||
<span class="normal">192</span>
|
||
<span class="normal">193</span>
|
||
<span class="normal">194</span>
|
||
<span class="normal">195</span>
|
||
<span class="normal">196</span>
|
||
<span class="normal">197</span>
|
||
<span class="normal">198</span>
|
||
<span class="normal">199</span>
|
||
<span class="normal">200</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="nd">@accept_filename</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">tar_extract</span><span class="p">(</span><span class="n">t</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> ```python</span>
|
||
<span class="sd"> tar_extract(fn: Tensor | str | Path) -> dict[str, Tensor]</span>
|
||
<span class="sd"> ```</span>
|
||
|
||
<span class="sd"> Extracts files from a tar archive and returns them as a dictionary of names (keys) and tensors (values).</span>
|
||
|
||
<span class="sd"> ```python</span>
|
||
<span class="sd"> tensors = nn.state.tar_extract(Tensor(pathlib.Path("archive.tar")))</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">with</span> <span class="n">tarfile</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">fileobj</span><span class="o">=</span><span class="n">TensorIO</span><span class="p">(</span><span class="n">t</span><span class="p">),</span> <span class="n">mode</span><span class="o">=</span><span class="s2">"r"</span><span class="p">)</span> <span class="k">as</span> <span class="n">tar</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="p">{</span><span class="n">member</span><span class="o">.</span><span class="n">name</span><span class="p">:</span><span class="n">t</span><span class="p">[</span><span class="n">member</span><span class="o">.</span><span class="n">offset_data</span><span class="p">:</span><span class="n">member</span><span class="o">.</span><span class="n">offset_data</span><span class="o">+</span><span class="n">member</span><span class="o">.</span><span class="n">size</span><span class="p">]</span> <span class="k">for</span> <span class="n">member</span> <span class="ow">in</span> <span class="n">tar</span> <span class="k">if</span> <span class="n">member</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">tarfile</span><span class="o">.</span><span class="n">REGTYPE</span><span class="p">}</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.state.torch_load" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <code class="highlight language-python"><span class="n">torch_load</span></code>
|
||
|
||
<a href="#tinygrad.nn.state.torch_load" class="headerlink" title="Permanent link">¤</a></h3>
|
||
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">torch_load</span><span class="p">(</span><span class="n">fn</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">|</span> <span class="nb">str</span> <span class="o">|</span> <span class="n">Path</span><span class="p">)</span> <span class="o">-></span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]</span>
|
||
</code></pre></div>
|
||
<p>Loads a torch .pth file, returning the <code class="language-python highlight"><span class="n">state_dict</span></code>.</p>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">state_dict</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">torch_load</span><span class="p">(</span><span class="s2">"test.pth"</span><span class="p">)</span>
|
||
</code></pre></div>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">205</span>
|
||
<span class="normal">206</span>
|
||
<span class="normal">207</span>
|
||
<span class="normal">208</span>
|
||
<span class="normal">209</span>
|
||
<span class="normal">210</span>
|
||
<span class="normal">211</span>
|
||
<span class="normal">212</span>
|
||
<span class="normal">213</span>
|
||
<span class="normal">214</span>
|
||
<span class="normal">215</span>
|
||
<span class="normal">216</span>
|
||
<span class="normal">217</span>
|
||
<span class="normal">218</span>
|
||
<span class="normal">219</span>
|
||
<span class="normal">220</span>
|
||
<span class="normal">221</span>
|
||
<span class="normal">222</span>
|
||
<span class="normal">223</span>
|
||
<span class="normal">224</span>
|
||
<span class="normal">225</span>
|
||
<span class="normal">226</span>
|
||
<span class="normal">227</span>
|
||
<span class="normal">228</span>
|
||
<span class="normal">229</span>
|
||
<span class="normal">230</span>
|
||
<span class="normal">231</span>
|
||
<span class="normal">232</span>
|
||
<span class="normal">233</span>
|
||
<span class="normal">234</span>
|
||
<span class="normal">235</span>
|
||
<span class="normal">236</span>
|
||
<span class="normal">237</span>
|
||
<span class="normal">238</span>
|
||
<span class="normal">239</span>
|
||
<span class="normal">240</span>
|
||
<span class="normal">241</span>
|
||
<span class="normal">242</span>
|
||
<span class="normal">243</span>
|
||
<span class="normal">244</span>
|
||
<span class="normal">245</span>
|
||
<span class="normal">246</span>
|
||
<span class="normal">247</span>
|
||
<span class="normal">248</span>
|
||
<span class="normal">249</span>
|
||
<span class="normal">250</span>
|
||
<span class="normal">251</span>
|
||
<span class="normal">252</span>
|
||
<span class="normal">253</span>
|
||
<span class="normal">254</span>
|
||
<span class="normal">255</span>
|
||
<span class="normal">256</span>
|
||
<span class="normal">257</span>
|
||
<span class="normal">258</span>
|
||
<span class="normal">259</span>
|
||
<span class="normal">260</span>
|
||
<span class="normal">261</span>
|
||
<span class="normal">262</span>
|
||
<span class="normal">263</span>
|
||
<span class="normal">264</span>
|
||
<span class="normal">265</span>
|
||
<span class="normal">266</span>
|
||
<span class="normal">267</span>
|
||
<span class="normal">268</span>
|
||
<span class="normal">269</span>
|
||
<span class="normal">270</span>
|
||
<span class="normal">271</span>
|
||
<span class="normal">272</span>
|
||
<span class="normal">273</span>
|
||
<span class="normal">274</span>
|
||
<span class="normal">275</span>
|
||
<span class="normal">276</span>
|
||
<span class="normal">277</span>
|
||
<span class="normal">278</span>
|
||
<span class="normal">279</span>
|
||
<span class="normal">280</span>
|
||
<span class="normal">281</span>
|
||
<span class="normal">282</span>
|
||
<span class="normal">283</span>
|
||
<span class="normal">284</span>
|
||
<span class="normal">285</span>
|
||
<span class="normal">286</span>
|
||
<span class="normal">287</span>
|
||
<span class="normal">288</span>
|
||
<span class="normal">289</span>
|
||
<span class="normal">290</span>
|
||
<span class="normal">291</span>
|
||
<span class="normal">292</span>
|
||
<span class="normal">293</span>
|
||
<span class="normal">294</span>
|
||
<span class="normal">295</span>
|
||
<span class="normal">296</span>
|
||
<span class="normal">297</span>
|
||
<span class="normal">298</span>
|
||
<span class="normal">299</span>
|
||
<span class="normal">300</span>
|
||
<span class="normal">301</span>
|
||
<span class="normal">302</span>
|
||
<span class="normal">303</span>
|
||
<span class="normal">304</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="nd">@accept_filename</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">torch_load</span><span class="p">(</span><span class="n">t</span><span class="p">:</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> ```python</span>
|
||
<span class="sd"> torch_load(fn: Tensor | str | Path) -> dict[str, Tensor]</span>
|
||
<span class="sd"> ```</span>
|
||
|
||
<span class="sd"> Loads a torch .pth file, returning the `state_dict`.</span>
|
||
|
||
<span class="sd"> ```python</span>
|
||
<span class="sd"> state_dict = nn.state.torch_load("test.pth")</span>
|
||
<span class="sd"> ```</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">offsets</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="o">|</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
|
||
<span class="n">lens</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="o">|</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_rebuild_tensor</span><span class="p">(</span><span class="n">storage</span><span class="p">,</span> <span class="n">storage_offset</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">stride</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">_rebuild_tensor_v2</span><span class="p">(</span><span class="n">storage</span><span class="p">,</span> <span class="n">storage_offset</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">stride</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_rebuild_tensor_v2</span><span class="p">(</span><span class="n">storage</span><span class="p">,</span> <span class="n">storage_offset</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">backward_hooks</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">metadata</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||
<span class="c1">#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)</span>
|
||
<span class="n">lens</span><span class="p">[</span><span class="n">storage</span><span class="p">[</span><span class="mi">2</span><span class="p">]]</span> <span class="o">=</span> <span class="n">storage</span><span class="p">[</span><span class="mi">4</span><span class="p">]</span> <span class="o">*</span> <span class="n">storage</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">itemsize</span>
|
||
<span class="k">if</span> <span class="n">storage</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">offsets</span><span class="p">:</span> <span class="k">return</span> <span class="kc">None</span>
|
||
<span class="n">byte_offset</span> <span class="o">=</span> <span class="n">offsets</span><span class="p">[</span><span class="n">storage</span><span class="p">[</span><span class="mi">2</span><span class="p">]]</span><span class="o">+</span><span class="n">storage_offset</span><span class="o">*</span><span class="n">storage</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">itemsize</span>
|
||
<span class="n">ret</span> <span class="o">=</span> <span class="n">t</span><span class="p">[</span><span class="n">byte_offset</span><span class="p">:</span><span class="n">byte_offset</span><span class="o">+</span><span class="n">prod</span><span class="p">(</span><span class="n">size</span><span class="p">)</span><span class="o">*</span><span class="n">storage</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">itemsize</span><span class="p">]</span><span class="o">.</span><span class="n">bitcast</span><span class="p">(</span><span class="n">storage</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
|
||
|
||
<span class="c1"># 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk</span>
|
||
<span class="n">shape_strides</span> <span class="o">=</span> <span class="p">[(</span><span class="n">s</span><span class="p">,</span> <span class="n">st</span><span class="p">)</span> <span class="k">for</span> <span class="n">s</span><span class="p">,</span><span class="n">st</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">stride</span><span class="p">)</span> <span class="k">if</span> <span class="n">s</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">]</span>
|
||
<span class="n">permute_indexes</span> <span class="o">=</span> <span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">shape_strides</span><span class="p">)</span><span class="o">-</span><span class="mi">1</span><span class="o">-</span><span class="n">y</span> <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">argsort</span><span class="p">([</span><span class="n">x</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">shape_strides</span><span class="p">])]</span>
|
||
<span class="k">if</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">permute_indexes</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">tuple</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">permute_indexes</span><span class="p">))):</span>
|
||
<span class="n">intermediate_shape</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="n">shape_strides</span><span class="p">[</span><span class="n">x</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">argsort</span><span class="p">(</span><span class="n">permute_indexes</span><span class="p">)])</span>
|
||
<span class="k">assert</span> <span class="nb">tuple</span><span class="p">([</span><span class="n">shape_strides</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">argsort</span><span class="p">(</span><span class="n">permute_indexes</span><span class="p">)])</span> <span class="o">==</span> <span class="n">strides_for_shape</span><span class="p">(</span><span class="n">intermediate_shape</span><span class="p">),</span> <span class="s2">"nonpermutable strides"</span>
|
||
<span class="k">if</span> <span class="n">DEBUG</span> <span class="o">>=</span> <span class="mi">3</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"WARNING: this torch load is slow. to permute </span><span class="si">{</span><span class="n">intermediate_shape</span><span class="si">}</span><span class="s2"> with </span><span class="si">{</span><span class="n">permute_indexes</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">storage</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">!=</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">,</span> <span class="s2">"can't permute BF16"</span>
|
||
<span class="c1"># TODO: find a nice way to support all movement ops on disktensors</span>
|
||
<span class="n">ret</span> <span class="o">=</span> <span class="n">ret</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">intermediate_shape</span><span class="p">)</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="n">permute_indexes</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">ret</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">size</span><span class="p">)</span>
|
||
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">Parameter</span><span class="p">:</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">__setstate__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">tensor</span> <span class="o">=</span> <span class="n">state</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||
|
||
<span class="n">deserialized_objects</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
|
||
<span class="n">intercept</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"HalfStorage"</span><span class="p">:</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="s2">"FloatStorage"</span><span class="p">:</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="s2">"BFloat16Storage"</span><span class="p">:</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">,</span>
|
||
<span class="s2">"IntStorage"</span><span class="p">:</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="s2">"BoolStorage"</span><span class="p">:</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span>
|
||
<span class="s2">"LongStorage"</span><span class="p">:</span> <span class="n">dtypes</span><span class="o">.</span><span class="n">int64</span><span class="p">,</span> <span class="s2">"_rebuild_tensor"</span><span class="p">:</span> <span class="n">_rebuild_tensor</span><span class="p">,</span> <span class="s2">"_rebuild_tensor_v2"</span><span class="p">:</span> <span class="n">_rebuild_tensor_v2</span><span class="p">,</span>
|
||
<span class="s2">"FloatTensor"</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"Parameter"</span><span class="p">:</span> <span class="n">Parameter</span><span class="p">}</span>
|
||
<span class="n">whitelist</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"torch"</span><span class="p">,</span> <span class="s2">"collections"</span><span class="p">,</span> <span class="s2">"numpy"</span><span class="p">,</span> <span class="s2">"_codecs"</span><span class="p">}</span> <span class="c1"># NOTE: this is not for security, only speed</span>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">Dummy</span><span class="p">:</span> <span class="k">pass</span>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">TorchPickle</span><span class="p">(</span><span class="n">pickle</span><span class="o">.</span><span class="n">Unpickler</span><span class="p">):</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">find_class</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">module</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
||
<span class="n">module_root</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"."</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">module_root</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">whitelist</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">DEBUG</span> <span class="o">>=</span> <span class="mi">2</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"WARNING: returning Dummy for </span><span class="si">{</span><span class="n">module</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">Dummy</span>
|
||
<span class="k">return</span> <span class="n">intercept</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="k">if</span> <span class="n">module_root</span> <span class="o">==</span> <span class="s2">"torch"</span> <span class="k">else</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">find_class</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="n">name</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">persistent_load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pid</span><span class="p">):</span> <span class="k">return</span> <span class="n">deserialized_objects</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">pid</span><span class="p">,</span> <span class="n">pid</span><span class="p">)</span>
|
||
|
||
<span class="n">fobj</span> <span class="o">=</span> <span class="n">io</span><span class="o">.</span><span class="n">BufferedReader</span><span class="p">(</span><span class="n">TensorIO</span><span class="p">(</span><span class="n">t</span><span class="p">))</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">passthrough_reset</span><span class="p">(</span><span class="n">v</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span> <span class="k">return</span> <span class="n">fobj</span><span class="o">.</span><span class="n">seek</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="ow">or</span> <span class="n">v</span>
|
||
|
||
<span class="k">if</span> <span class="n">passthrough_reset</span><span class="p">(</span><span class="n">zipfile</span><span class="o">.</span><span class="n">is_zipfile</span><span class="p">(</span><span class="n">fobj</span><span class="p">)):</span> <span class="c1"># NOTE: passthrough_reset required to support python < 3.14</span>
|
||
<span class="n">myzip</span> <span class="o">=</span> <span class="n">zipfile</span><span class="o">.</span><span class="n">ZipFile</span><span class="p">(</span><span class="n">fobj</span><span class="p">,</span> <span class="s1">'r'</span><span class="p">)</span>
|
||
<span class="n">base_name</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">header_offsets</span> <span class="o">=</span> <span class="p">{}</span>
|
||
<span class="k">for</span> <span class="n">zi</span> <span class="ow">in</span> <span class="n">myzip</span><span class="o">.</span><span class="n">filelist</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">base_name</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> <span class="n">base_name</span> <span class="o">=</span> <span class="n">zi</span><span class="o">.</span><span class="n">filename</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">'/'</span><span class="p">,</span> <span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">zi</span><span class="o">.</span><span class="n">filename</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">base_name</span><span class="si">}</span><span class="s1">/data/'</span><span class="p">):</span> <span class="n">header_offsets</span><span class="p">[</span><span class="n">zi</span><span class="o">.</span><span class="n">filename</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"/"</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]]</span> <span class="o">=</span> <span class="n">zi</span><span class="o">.</span><span class="n">header_offset</span>
|
||
<span class="c1"># sadly there's no way to get the start of the file in the zip without reading the header</span>
|
||
<span class="c1"># at least here we read them in parallel</span>
|
||
<span class="n">header_contents</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span><span class="p">[</span><span class="n">v</span><span class="o">+</span><span class="mi">26</span><span class="p">:</span><span class="n">v</span><span class="o">+</span><span class="mi">30</span><span class="p">]</span><span class="o">.</span><span class="n">bitcast</span><span class="p">(</span><span class="n">dtypes</span><span class="o">.</span><span class="n">uint16</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">'CPU'</span><span class="p">)</span> <span class="k">for</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">header_offsets</span><span class="o">.</span><span class="n">values</span><span class="p">()]</span>
|
||
<span class="n">Tensor</span><span class="o">.</span><span class="n">realize</span><span class="p">(</span><span class="o">*</span><span class="n">header_contents</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="p">(</span><span class="n">n</span><span class="p">,</span><span class="n">o</span><span class="p">),</span><span class="n">c</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">header_offsets</span><span class="o">.</span><span class="n">items</span><span class="p">(),</span> <span class="n">header_contents</span><span class="p">):</span>
|
||
<span class="c1"># header_offset + sizeFileHeader + File name length + Extra field length : https://en.wikipedia.org/wiki/ZIP_(file_format)</span>
|
||
<span class="n">offsets</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="n">o</span><span class="o">+</span><span class="mi">30</span><span class="o">+</span><span class="nb">sum</span><span class="p">(</span><span class="n">cast</span><span class="p">(</span><span class="nb">list</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">c</span><span class="o">.</span><span class="n">tolist</span><span class="p">()))</span>
|
||
<span class="k">with</span> <span class="n">myzip</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">base_name</span><span class="si">}</span><span class="s1">/data.pkl'</span><span class="p">)</span> <span class="k">as</span> <span class="n">myfile</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">TorchPickle</span><span class="p">(</span><span class="n">myfile</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">()</span>
|
||
<span class="k">elif</span> <span class="n">passthrough_reset</span><span class="p">(</span><span class="n">tarfile</span><span class="o">.</span><span class="n">is_tarfile</span><span class="p">(</span><span class="n">fobj</span><span class="p">)):</span> <span class="c1"># NOTE: passthrough_reset required to support python < 3.11</span>
|
||
<span class="k">with</span> <span class="n">tarfile</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">fileobj</span><span class="o">=</span><span class="n">fobj</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">"r"</span><span class="p">)</span> <span class="k">as</span> <span class="n">tar</span><span class="p">:</span>
|
||
<span class="n">storages_offset</span> <span class="o">=</span> <span class="n">tar</span><span class="o">.</span><span class="n">getmember</span><span class="p">(</span><span class="s1">'storages'</span><span class="p">)</span><span class="o">.</span><span class="n">offset_data</span>
|
||
<span class="n">f</span> <span class="o">=</span> <span class="n">unwrap</span><span class="p">(</span><span class="n">tar</span><span class="o">.</span><span class="n">extractfile</span><span class="p">(</span><span class="s1">'storages'</span><span class="p">))</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">TorchPickle</span><span class="p">(</span><span class="n">f</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">()):</span> <span class="c1"># num_storages</span>
|
||
<span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">storage_type</span><span class="p">),</span> <span class="n">sz</span> <span class="o">=</span> <span class="n">TorchPickle</span><span class="p">(</span><span class="n">f</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="s1">'<q'</span><span class="p">,</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">8</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="n">offsets</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">storages_offset</span> <span class="o">+</span> <span class="n">f</span><span class="o">.</span><span class="n">tell</span><span class="p">()</span>
|
||
<span class="n">f</span><span class="o">.</span><span class="n">seek</span><span class="p">(</span><span class="n">sz</span><span class="o">*</span><span class="n">storage_type</span><span class="o">.</span><span class="n">itemsize</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">f</span> <span class="o">=</span> <span class="n">unwrap</span><span class="p">(</span><span class="n">tar</span><span class="o">.</span><span class="n">extractfile</span><span class="p">(</span><span class="s1">'tensors'</span><span class="p">))</span>
|
||
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">TorchPickle</span><span class="p">(</span><span class="n">f</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">()):</span> <span class="c1"># num_tensors</span>
|
||
<span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">storage_id</span><span class="p">,</span> <span class="n">_</span><span class="p">),</span> <span class="n">ndim</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">TorchPickle</span><span class="p">(</span><span class="n">f</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="s1">'<i'</span><span class="p">,</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">4</span><span class="p">))[</span><span class="mi">0</span><span class="p">],</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">4</span><span class="p">)</span>
|
||
<span class="n">size</span><span class="p">,</span> <span class="n">stride</span> <span class="o">=</span> <span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="sa">f</span><span class="s1">'<</span><span class="si">{</span><span class="n">ndim</span><span class="si">}</span><span class="s1">q'</span><span class="p">,</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">8</span> <span class="o">*</span> <span class="n">ndim</span><span class="p">)),</span> <span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="sa">f</span><span class="s1">'<</span><span class="si">{</span><span class="n">ndim</span><span class="si">}</span><span class="s1">q'</span><span class="p">,</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">8</span> <span class="o">*</span> <span class="n">ndim</span><span class="p">))</span>
|
||
<span class="n">storage_offset</span> <span class="o">=</span> <span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="s1">'<q'</span><span class="p">,</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">8</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="n">deserialized_objects</span><span class="p">[</span><span class="nb">str</span><span class="p">(</span><span class="n">key</span><span class="p">)]</span> <span class="o">=</span> <span class="n">_rebuild_tensor_v2</span><span class="p">((</span><span class="kc">None</span><span class="p">,</span> <span class="n">storage_type</span><span class="p">,</span> <span class="n">storage_id</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">storage_offset</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">stride</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span><span class="n">v</span><span class="o">.</span><span class="n">tensor</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">Parameter</span><span class="p">)</span> <span class="k">else</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span><span class="n">v</span> <span class="ow">in</span> <span class="n">TorchPickle</span><span class="p">(</span><span class="n">unwrap</span><span class="p">(</span><span class="n">tar</span><span class="o">.</span><span class="n">extractfile</span><span class="p">(</span><span class="s1">'pickle'</span><span class="p">)))</span><span class="o">.</span><span class="n">load</span><span class="p">()</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">pkl</span> <span class="o">=</span> <span class="n">TorchPickle</span><span class="p">(</span><span class="n">fobj</span><span class="p">)</span>
|
||
<span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">rwd</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">ids</span><span class="p">,</span> <span class="n">base_offset</span> <span class="o">=</span> <span class="n">pkl</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">pkl</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">pkl</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">fobj</span><span class="o">.</span><span class="n">tell</span><span class="p">(),</span> <span class="n">pkl</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">pkl</span><span class="o">.</span><span class="n">load</span><span class="p">(),</span> <span class="n">fobj</span><span class="o">.</span><span class="n">tell</span><span class="p">()</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">ids</span><span class="p">:</span>
|
||
<span class="n">offsets</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">base_offset</span> <span class="o">+</span> <span class="mi">8</span>
|
||
<span class="n">base_offset</span> <span class="o">+=</span> <span class="mi">8</span> <span class="o">+</span> <span class="n">lens</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||
<span class="n">fobj</span><span class="o">.</span><span class="n">seek</span><span class="p">(</span><span class="n">rwd</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">TorchPickle</span><span class="p">(</span><span class="n">fobj</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">()</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="doc doc-object doc-function">
|
||
|
||
|
||
<h3 id="tinygrad.nn.state.gguf_load" class="doc doc-heading">
|
||
<code class="doc-symbol doc-symbol-heading doc-symbol-function"></code> <span class="doc doc-object-name doc-function-name">gguf_load</span>
|
||
|
||
|
||
<a href="#tinygrad.nn.state.gguf_load" class="headerlink" title="Permanent link">¤</a></h3>
|
||
<div class="language-python doc-signature highlight"><pre><span></span><code><span class="nf">gguf_load</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">)</span> <span class="o">-></span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#tuple">tuple</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#dict">dict</a></span><span class="p">[</span><span class="n"><a class="autorefs autorefs-external" href="https://docs.python.org/3/library/stdtypes.html#str">str</a></span><span class="p">,</span> <span class="n"><a class="autorefs autorefs-internal" title="<code class="doc-symbol doc-symbol-heading doc-symbol-class"></code> <span class="doc doc-object-name doc-class-name">Tensor</span> (<code>tinygrad.tensor.Tensor</code>)" href="../tensor/#tinygrad.Tensor">Tensor</a></span><span class="p">]]</span>
|
||
</code></pre></div>
|
||
|
||
<div class="doc doc-contents first">
|
||
|
||
<p>Loads a .gguf file, returning the <code class="language-python highlight"><span class="n">kv_data</span></code> and <code class="language-python highlight"><span class="n">state_dict</span></code>.</p>
|
||
<div class="language-python highlight"><pre><span></span><code><span class="n">gguf_tensor</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="s2">"Meta-Llama-3-8B-Instruct.Q4_0.gguf"</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">Device</span><span class="o">.</span><span class="n">DEFAULT</span><span class="p">)</span>
|
||
<span class="n">kv_data</span><span class="p">,</span> <span class="n">state_dict</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">gguf_load</span><span class="p">(</span><span class="n">gguf_tensor</span><span class="p">)</span>
|
||
</code></pre></div>
|
||
<div class="admonition note">
|
||
<p class="admonition-title">Note</p>
|
||
<p>The provided tensor must be on a device that supports execution.</p>
|
||
</div>
|
||
|
||
|
||
<details class="mkdocstrings-source">
|
||
<summary>Source code in <code>tinygrad/nn/state.py</code></summary>
|
||
<div class="language-python highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">357</span>
|
||
<span class="normal">358</span>
|
||
<span class="normal">359</span>
|
||
<span class="normal">360</span>
|
||
<span class="normal">361</span>
|
||
<span class="normal">362</span>
|
||
<span class="normal">363</span>
|
||
<span class="normal">364</span>
|
||
<span class="normal">365</span>
|
||
<span class="normal">366</span>
|
||
<span class="normal">367</span>
|
||
<span class="normal">368</span>
|
||
<span class="normal">369</span>
|
||
<span class="normal">370</span>
|
||
<span class="normal">371</span>
|
||
<span class="normal">372</span>
|
||
<span class="normal">373</span>
|
||
<span class="normal">374</span>
|
||
<span class="normal">375</span>
|
||
<span class="normal">376</span>
|
||
<span class="normal">377</span>
|
||
<span class="normal">378</span>
|
||
<span class="normal">379</span>
|
||
<span class="normal">380</span>
|
||
<span class="normal">381</span>
|
||
<span class="normal">382</span>
|
||
<span class="normal">383</span>
|
||
<span class="normal">384</span>
|
||
<span class="normal">385</span>
|
||
<span class="normal">386</span>
|
||
<span class="normal">387</span>
|
||
<span class="normal">388</span>
|
||
<span class="normal">389</span>
|
||
<span class="normal">390</span>
|
||
<span class="normal">391</span>
|
||
<span class="normal">392</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="nd">@accept_filename</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">gguf_load</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="nb">tuple</span><span class="p">[</span><span class="nb">dict</span><span class="p">,</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]]:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Loads a .gguf file, returning the `kv_data` and `state_dict`.</span>
|
||
|
||
<span class="sd"> ```python</span>
|
||
<span class="sd"> gguf_tensor = Tensor(pathlib.Path("Meta-Llama-3-8B-Instruct.Q4_0.gguf")).to(Device.DEFAULT)</span>
|
||
<span class="sd"> kv_data, state_dict = nn.state.gguf_load(gguf_tensor)</span>
|
||
<span class="sd"> ```</span>
|
||
|
||
<span class="sd"> NOTE: The provided tensor must be on a device that supports execution.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">reader</span><span class="p">,</span> <span class="n">kv_data</span><span class="p">,</span> <span class="n">state_dict</span> <span class="o">=</span> <span class="n">io</span><span class="o">.</span><span class="n">BufferedReader</span><span class="p">(</span><span class="n">TensorIO</span><span class="p">(</span><span class="n">tensor</span><span class="p">),</span> <span class="mi">1_000_000</span><span class="p">),</span> <span class="p">{},</span> <span class="p">{}</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">read_unpack</span><span class="p">(</span><span class="n">fmt</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">n</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span> <span class="k">return</span> <span class="n">struct</span><span class="o">.</span><span class="n">unpack</span><span class="p">(</span><span class="n">fmt</span><span class="p">,</span> <span class="n">reader</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="n">n</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">read_str</span><span class="p">():</span> <span class="k">return</span> <span class="nb">str</span><span class="p">(</span><span class="n">reader</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="n">read_uint64</span><span class="p">()),</span> <span class="s2">"utf-8"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">read_arr</span><span class="p">():</span>
|
||
<span class="n">reader</span><span class="p">,</span> <span class="n">n</span> <span class="o">=</span> <span class="n">readers</span><span class="p">[</span><span class="n">read_int32</span><span class="p">()],</span> <span class="n">read_uint64</span><span class="p">()</span>
|
||
<span class="k">return</span> <span class="p">[</span> <span class="n">reader</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="p">)</span> <span class="p">]</span>
|
||
|
||
<span class="n">readers</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Callable</span><span class="p">[[],</span> <span class="n">Any</span><span class="p">]]</span> <span class="o">=</span> <span class="p">{</span> <span class="mi">8</span><span class="p">:</span> <span class="n">read_str</span><span class="p">,</span> <span class="mi">9</span><span class="p">:</span> <span class="n">read_arr</span><span class="p">,</span> <span class="o">**</span><span class="p">{</span> <span class="n">t</span><span class="p">:</span> <span class="n">functools</span><span class="o">.</span><span class="n">partial</span><span class="p">(</span><span class="n">read_unpack</span><span class="p">,</span> <span class="s2">"<"</span><span class="o">+</span><span class="n">f</span><span class="p">,</span> <span class="n">nb</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span><span class="p">,</span><span class="n">f</span><span class="p">,</span><span class="n">nb</span> <span class="ow">in</span> \
|
||
<span class="p">[</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span><span class="s2">"c"</span><span class="p">,</span><span class="mi">1</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="s2">"b"</span><span class="p">,</span><span class="mi">1</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="s2">"H"</span><span class="p">,</span><span class="mi">2</span><span class="p">),</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span><span class="s2">"h"</span><span class="p">,</span><span class="mi">2</span><span class="p">),</span> <span class="p">(</span><span class="mi">4</span><span class="p">,</span><span class="s2">"I"</span><span class="p">,</span><span class="mi">4</span><span class="p">),</span> <span class="p">(</span><span class="mi">5</span><span class="p">,</span><span class="s2">"i"</span><span class="p">,</span><span class="mi">4</span><span class="p">),</span> <span class="p">(</span><span class="mi">6</span><span class="p">,</span><span class="s2">"f"</span><span class="p">,</span><span class="mi">4</span><span class="p">),</span> <span class="p">(</span><span class="mi">7</span><span class="p">,</span><span class="s2">"?"</span><span class="p">,</span><span class="mi">1</span><span class="p">),</span> <span class="p">(</span><span class="mi">10</span><span class="p">,</span><span class="s2">"Q"</span><span class="p">,</span><span class="mi">8</span><span class="p">),</span> <span class="p">(</span><span class="mi">11</span><span class="p">,</span><span class="s2">"q"</span><span class="p">,</span><span class="mi">8</span><span class="p">),</span> <span class="p">(</span><span class="mi">12</span><span class="p">,</span><span class="s2">"d"</span><span class="p">,</span><span class="mi">8</span><span class="p">)</span> <span class="p">]</span> <span class="p">}</span> <span class="p">}</span>
|
||
<span class="n">read_uint32</span><span class="p">,</span> <span class="n">read_int32</span><span class="p">,</span> <span class="n">read_uint64</span><span class="p">,</span> <span class="n">read_int64</span> <span class="o">=</span> <span class="n">readers</span><span class="p">[</span><span class="mi">4</span><span class="p">],</span> <span class="n">readers</span><span class="p">[</span><span class="mi">5</span><span class="p">],</span> <span class="n">readers</span><span class="p">[</span><span class="mi">10</span><span class="p">],</span> <span class="n">readers</span><span class="p">[</span><span class="mi">11</span><span class="p">]</span>
|
||
|
||
<span class="n">magic</span><span class="p">,</span> <span class="n">version</span><span class="p">,</span> <span class="n">n_tensors</span><span class="p">,</span> <span class="n">n_kv</span> <span class="o">=</span> <span class="n">reader</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">4</span><span class="p">),</span> <span class="n">read_int32</span><span class="p">(),</span> <span class="n">read_int64</span><span class="p">(),</span> <span class="n">read_int64</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">magic</span> <span class="o">!=</span> <span class="sa">b</span><span class="s2">"GGUF"</span> <span class="ow">or</span> <span class="n">version</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">]:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Invalid GGUF format!"</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_kv</span><span class="p">):</span>
|
||
<span class="n">k</span><span class="p">,</span> <span class="n">typ</span> <span class="o">=</span> <span class="n">read_str</span><span class="p">(),</span> <span class="n">read_int32</span><span class="p">()</span>
|
||
<span class="n">kv_data</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">readers</span><span class="p">[</span><span class="n">typ</span><span class="p">]()</span>
|
||
|
||
<span class="n">t_infos</span> <span class="o">=</span> <span class="p">[</span> <span class="p">(</span><span class="n">read_str</span><span class="p">(),</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">read_uint64</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">read_uint32</span><span class="p">())),</span> <span class="n">read_int32</span><span class="p">(),</span> <span class="n">read_uint64</span><span class="p">())</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_tensors</span><span class="p">)</span> <span class="p">]</span>
|
||
<span class="n">alignment</span><span class="p">,</span> <span class="n">pos</span> <span class="o">=</span> <span class="n">kv_data</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">"general.alignment"</span><span class="p">,</span> <span class="mi">32</span><span class="p">),</span> <span class="n">reader</span><span class="o">.</span><span class="n">tell</span><span class="p">()</span>
|
||
<span class="n">data_start</span> <span class="o">=</span> <span class="n">round_up</span><span class="p">(</span><span class="n">pos</span><span class="p">,</span> <span class="n">alignment</span><span class="p">)</span>
|
||
|
||
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="n">typ</span><span class="p">,</span> <span class="n">off</span> <span class="ow">in</span> <span class="n">t_infos</span><span class="p">:</span> <span class="n">state_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">ggml_data_to_tensor</span><span class="p">(</span><span class="n">tensor</span><span class="p">[</span><span class="n">data_start</span> <span class="o">+</span> <span class="n">off</span><span class="p">:],</span> <span class="n">prod</span><span class="p">(</span><span class="n">dims</span><span class="p">),</span> <span class="n">typ</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">*</span><span class="nb">reversed</span><span class="p">(</span><span class="n">dims</span><span class="p">))</span>
|
||
|
||
<span class="k">return</span> <span class="n">kv_data</span><span class="p">,</span> <span class="n">state_dict</span>
|
||
</code></pre></div></td></tr></table></div>
|
||
</details>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</article>
|
||
</div>
|
||
|
||
|
||
<script>var target=document.getElementById(location.hash.slice(1));target&&target.name&&(target.checked=target.name.startsWith("__tabbed_"))</script>
|
||
</div>
|
||
|
||
<button type="button" class="md-top md-icon" data-md-component="top" hidden>
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8z"/></svg>
|
||
Back to top
|
||
</button>
|
||
|
||
</main>
|
||
|
||
<footer class="md-footer">
|
||
|
||
|
||
|
||
<nav class="md-footer__inner md-grid" aria-label="Footer" >
|
||
|
||
|
||
<a href="../dtypes/" class="md-footer__link md-footer__link--prev" aria-label="Previous: dtypes">
|
||
<div class="md-footer__button md-icon">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
|
||
</div>
|
||
<div class="md-footer__title">
|
||
<span class="md-footer__direction">
|
||
Previous
|
||
</span>
|
||
<div class="md-ellipsis">
|
||
dtypes
|
||
</div>
|
||
</div>
|
||
</a>
|
||
|
||
|
||
|
||
<a href="../env_vars/" class="md-footer__link md-footer__link--next" aria-label="Next: Environment Variables">
|
||
<div class="md-footer__title">
|
||
<span class="md-footer__direction">
|
||
Next
|
||
</span>
|
||
<div class="md-ellipsis">
|
||
Environment Variables
|
||
</div>
|
||
</div>
|
||
<div class="md-footer__button md-icon">
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M4 11v2h12l-5.5 5.5 1.42 1.42L19.84 12l-7.92-7.92L10.5 5.5 16 11z"/></svg>
|
||
</div>
|
||
</a>
|
||
|
||
</nav>
|
||
|
||
|
||
<div class="md-footer-meta md-typeset">
|
||
<div class="md-footer-meta__inner md-grid">
|
||
<div class="md-copyright">
|
||
|
||
|
||
Made with
|
||
<a href="https://squidfunk.github.io/mkdocs-material/" target="_blank" rel="noopener">
|
||
Material for MkDocs
|
||
</a>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
</div>
|
||
</footer>
|
||
|
||
</div>
|
||
<div class="md-dialog" data-md-component="dialog">
|
||
<div class="md-dialog__inner md-typeset"></div>
|
||
</div>
|
||
|
||
|
||
|
||
|
||
|
||
<script id="__config" type="application/json">{"annotate": null, "base": "..", "features": ["announce.dismiss", "content.action.edit", "content.action.view", "content.code.annotate", "content.code.copy", "content.tooltips", "navigation.footer", "navigation.indexes", "navigation.sections", "navigation.expand", "navigation.top", "navigation.path", "search.highlight", "search.suggest", "toc.follow", "toc.integrate"], "search": "../assets/javascripts/workers/search.2c215733.min.js", "tags": null, "translations": {"clipboard.copied": "Copied to clipboard", "clipboard.copy": "Copy to clipboard", "search.result.more.one": "1 more on this page", "search.result.more.other": "# more on this page", "search.result.none": "No matching documents", "search.result.one": "1 matching document", "search.result.other": "# matching documents", "search.result.placeholder": "Type to start searching", "search.result.term.missing": "Missing", "select.version": "Select version"}, "version": null}</script>
|
||
|
||
|
||
<script src="../assets/javascripts/bundle.79ae519e.min.js"></script>
|
||
|
||
<script src="../assets/_markdown_exec_pyodide.js"></script>
|
||
|
||
|
||
</body>
|
||
</html> |