mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
tinygrad.nn (#367)
* tinygrad.nn * flake8 * working on pylint * more pylint * more pylint * pylint passes * networkx * mypy can't infer that type * junk
This commit is contained in:
13
.github/workflows/test.yml
vendored
13
.github/workflows/test.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
run: sloccount tinygrad test examples extra; if [ $(sloccount tinygrad | sed -n 's/.*Total Physical Source Lines of Code (SLOC)[ ]*= \([^ ]*\).*/\1/p' | tr -d ',') -gt 1000 ]; then exit 1; fi
|
||||
|
||||
linter:
|
||||
name: Indentation Linter
|
||||
name: Linter
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
@@ -29,11 +29,14 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pylint
|
||||
# pip install -r requirements.txt
|
||||
python -m pip install pylint flake8
|
||||
pip install -e .
|
||||
- name: Lint with pylint
|
||||
run: |
|
||||
python -m pylint --disable=all -e W0311 --jobs=0 --indent-string=' ' **/*.py
|
||||
run: python -m pylint --disable=all -e W0311 --jobs=0 --indent-string=' ' **/*.py
|
||||
- name: Lint with flake8
|
||||
run: flake8 tinygrad/ --indent-size=2 --select=F,E112,E113,E304,E502,E702,E703,E71,E72,E731,W191,W6 --statistics -j4
|
||||
- name: Lint tinygrad with pylint
|
||||
run: pylint tinygrad/
|
||||
|
||||
testcpu:
|
||||
name: CPU Tests
|
||||
|
||||
470
.pylintrc
Normal file
470
.pylintrc
Normal file
@@ -0,0 +1,470 @@
|
||||
[MASTER]
|
||||
|
||||
# A comma-separated list of package or module names from where C extensions may
|
||||
# be loaded. Extensions are loading into the active Python interpreter and may
|
||||
# run arbitrary code
|
||||
extension-pkg-whitelist=scipy,cereal.messaging.messaging_pyx,PyQt5,av
|
||||
|
||||
# Add files or directories to the blacklist. They should be base names, not
|
||||
# paths.
|
||||
ignore=CVS
|
||||
|
||||
# Add files or directories matching the regex patterns to the blacklist. The
|
||||
# regex matches against base names, not paths.
|
||||
ignore-patterns=
|
||||
|
||||
# Python code to execute, usually for sys.path manipulation such as
|
||||
# pygtk.require().
|
||||
#init-hook=
|
||||
|
||||
# Use multiple processes to speed up Pylint.
|
||||
jobs=4
|
||||
|
||||
# List of plugins (as comma separated values of python modules names) to load,
|
||||
# usually to register additional checkers.
|
||||
load-plugins=
|
||||
|
||||
# Pickle collected data for later comparisons.
|
||||
persistent=yes
|
||||
|
||||
# Specify a configuration file.
|
||||
#rcfile=
|
||||
|
||||
# When enabled, pylint would attempt to guess common misconfiguration and emit
|
||||
# user-friendly hints instead of false-positive error messages
|
||||
suggestion-mode=yes
|
||||
|
||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||
# active Python interpreter and may run arbitrary code.
|
||||
unsafe-load-any-extension=no
|
||||
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
|
||||
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
|
||||
confidence=
|
||||
|
||||
# Disable the message, report, category or checker with the given id(s). You
|
||||
# can either give multiple identifiers separated by comma (,) or put this
|
||||
# option multiple times (only on the command line, not in the configuration
|
||||
# file where it should appear only once).You can also use "--disable=all" to
|
||||
# disable everything first and then reenable specific checks. For example, if
|
||||
# you want to run only the similarities checker, you can use "--disable=all
|
||||
# --enable=similarities". If you want to run only the classes checker, but have
|
||||
# no Warning level messages displayed, use"--disable=all --enable=classes
|
||||
# --disable=W"
|
||||
disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1514,E1101,W0221,W0632
|
||||
# E1101 for mlops binding
|
||||
# W0221,W0632 for Function class
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
enable=c-extension-no-member
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
||||
# Python expression which should return a note less than 10 (10 is the highest
|
||||
# note). You have access to the variables errors warning, statement which
|
||||
# respectively contain the number of errors / warnings messages and the total
|
||||
# number of statements analyzed. This is used by the global evaluation report
|
||||
# (RP0004).
|
||||
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
||||
|
||||
# Template used to display messages. This is a python new-style format string
|
||||
# used to format the message information. See doc for all details
|
||||
#msg-template=
|
||||
|
||||
# Set the output format. Available formats are text, parseable, colorized, json
|
||||
# and msvs (visual studio).You can also give a reporter class, eg
|
||||
# mypackage.mymodule.MyReporterClass.
|
||||
output-format=text
|
||||
|
||||
# Tells whether to display a full report or only the messages
|
||||
reports=no
|
||||
|
||||
# Activate the evaluation score.
|
||||
score=yes
|
||||
|
||||
|
||||
[REFACTORING]
|
||||
|
||||
# Maximum number of nested blocks for function / method body
|
||||
max-nested-blocks=5
|
||||
|
||||
# Complete name of functions that never returns. When checking for
|
||||
# inconsistent-return-statements if a never returning function is called then
|
||||
# it will be considered as an explicit return statement and no message will be
|
||||
# printed.
|
||||
never-returning-functions=optparse.Values,sys.exit
|
||||
|
||||
|
||||
[LOGGING]
|
||||
|
||||
# Logging modules to check that the string format arguments are in logging
|
||||
# function parameter format
|
||||
logging-modules=logging
|
||||
|
||||
|
||||
[SPELLING]
|
||||
|
||||
# Limits count of emitted suggestions for spelling mistakes
|
||||
max-spelling-suggestions=4
|
||||
|
||||
# Spelling dictionary name. Available dictionaries: none. To make it working
|
||||
# install python-enchant package.
|
||||
spelling-dict=
|
||||
|
||||
# List of comma separated words that should not be checked.
|
||||
spelling-ignore-words=
|
||||
|
||||
# A path to a file that contains private dictionary; one word per line.
|
||||
spelling-private-dict-file=
|
||||
|
||||
# Tells whether to store unknown words to indicated private dictionary in
|
||||
# --spelling-private-dict-file option instead of raising a message.
|
||||
spelling-store-unknown-words=no
|
||||
|
||||
|
||||
[MISCELLANEOUS]
|
||||
|
||||
# List of note tags to take in consideration, separated by a comma.
|
||||
notes=FIXME,
|
||||
XXX,
|
||||
TODO
|
||||
|
||||
|
||||
[SIMILARITIES]
|
||||
|
||||
# Ignore comments when computing similarities.
|
||||
ignore-comments=yes
|
||||
|
||||
# Ignore docstrings when computing similarities.
|
||||
ignore-docstrings=yes
|
||||
|
||||
# Ignore imports when computing similarities.
|
||||
ignore-imports=no
|
||||
|
||||
# Minimum lines number of a similarity.
|
||||
min-similarity-lines=4
|
||||
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# List of decorators that produce context managers, such as
|
||||
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||
# produce valid context managers.
|
||||
contextmanager-decorators=contextlib.contextmanager
|
||||
|
||||
# List of members which are set dynamically and missed by pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||
# expressions are accepted.
|
||||
generated-members=capnp.* cereal.* pygame.* zmq.* setproctitle.* smbus2.* usb1.* serial.* cv2.* ft4222.* carla.*
|
||||
|
||||
# Tells whether missing members accessed in mixin class should be ignored. A
|
||||
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
||||
ignore-mixin-members=yes
|
||||
|
||||
# This flag controls whether pylint should warn about no-member and similar
|
||||
# checks whenever an opaque object is returned when inferring. The inference
|
||||
# can return multiple potential results while evaluating a Python object, but
|
||||
# some branches might not be evaluated, which results in partial inference. In
|
||||
# that case, it might be useful to still emit no-member and other checks for
|
||||
# the rest of the inferred objects.
|
||||
ignore-on-opaque-inference=yes
|
||||
|
||||
# List of class names for which member attributes should not be checked (useful
|
||||
# for classes with dynamically set attributes). This supports the use of
|
||||
# qualified names.
|
||||
ignored-classes=optparse.Values,thread._local,_thread._local
|
||||
|
||||
# List of module names for which member attributes should not be checked
|
||||
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||
# and thus existing member attributes cannot be deduced by static analysis. It
|
||||
# supports qualified module names, as well as Unix pattern matching.
|
||||
ignored-modules=flask setproctitle usb1 flask.ext.socketio smbus2 usb1.*
|
||||
|
||||
# Show a hint with possible names when a member name was not found. The aspect
|
||||
# of finding the hint is based on edit distance.
|
||||
missing-member-hint=yes
|
||||
|
||||
# The minimum edit distance a name should have in order to be considered a
|
||||
# similar match for a missing member name.
|
||||
missing-member-hint-distance=1
|
||||
|
||||
# The total number of similar names that should be taken in consideration when
|
||||
# showing a hint for a missing member.
|
||||
missing-member-max-choices=1
|
||||
|
||||
|
||||
[VARIABLES]
|
||||
|
||||
# List of additional names supposed to be defined in builtins. Remember that
|
||||
# you should avoid to define new builtins when possible.
|
||||
additional-builtins=
|
||||
|
||||
# Tells whether unused global variables should be treated as a violation.
|
||||
allow-global-unused-variables=yes
|
||||
|
||||
# List of strings which can identify a callback function by name. A callback
|
||||
# name must start or end with one of those strings.
|
||||
callbacks=cb_,
|
||||
_cb
|
||||
|
||||
# A regular expression matching the name of dummy variables (i.e. expectedly
|
||||
# not used).
|
||||
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
|
||||
|
||||
# Argument names that match this expression will be ignored. Default to name
|
||||
# with leading underscore
|
||||
ignored-argument-names=_.*|^ignored_|^unused_
|
||||
|
||||
# Tells whether we should check for unused import in __init__ files.
|
||||
init-import=no
|
||||
|
||||
# List of qualified module names which can have objects that can redefine
|
||||
# builtins.
|
||||
redefining-builtins-modules=six.moves,past.builtins,future.builtins
|
||||
|
||||
|
||||
[FORMAT]
|
||||
|
||||
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||
expected-line-ending-format=
|
||||
|
||||
# Regexp for a line that is allowed to be longer than the limit.
|
||||
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
|
||||
|
||||
# Number of spaces of indent required inside a hanging or continued line.
|
||||
indent-after-paren=4
|
||||
|
||||
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
|
||||
# tab).
|
||||
indent-string=' '
|
||||
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=100
|
||||
|
||||
# Maximum number of lines in a module
|
||||
max-module-lines=1000
|
||||
|
||||
# Allow the body of a class to be on the same line as the declaration if body
|
||||
# contains single statement.
|
||||
single-line-class-stmt=no
|
||||
|
||||
# Allow the body of an if to be on the same line as the test if there is no
|
||||
# else.
|
||||
single-line-if-stmt=no
|
||||
|
||||
|
||||
[BASIC]
|
||||
|
||||
# Naming style matching correct argument names
|
||||
argument-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct argument names. Overrides argument-
|
||||
# naming-style
|
||||
#argument-rgx=
|
||||
|
||||
# Naming style matching correct attribute names
|
||||
attr-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct attribute names. Overrides attr-naming-
|
||||
# style
|
||||
#attr-rgx=
|
||||
|
||||
# Bad variable names which should always be refused, separated by a comma
|
||||
bad-names=foo,
|
||||
bar,
|
||||
baz,
|
||||
toto,
|
||||
tutu,
|
||||
tata
|
||||
|
||||
# Naming style matching correct class attribute names
|
||||
class-attribute-naming-style=any
|
||||
|
||||
# Regular expression matching correct class attribute names. Overrides class-
|
||||
# attribute-naming-style
|
||||
#class-attribute-rgx=
|
||||
|
||||
# Naming style matching correct class names
|
||||
class-naming-style=PascalCase
|
||||
|
||||
# Regular expression matching correct class names. Overrides class-naming-style
|
||||
#class-rgx=
|
||||
|
||||
# Naming style matching correct constant names
|
||||
const-naming-style=UPPER_CASE
|
||||
|
||||
# Regular expression matching correct constant names. Overrides const-naming-
|
||||
# style
|
||||
#const-rgx=
|
||||
|
||||
# Minimum line length for functions/classes that require docstrings, shorter
|
||||
# ones are exempt.
|
||||
docstring-min-length=-1
|
||||
|
||||
# Naming style matching correct function names
|
||||
function-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct function names. Overrides function-
|
||||
# naming-style
|
||||
#function-rgx=
|
||||
|
||||
# Good variable names which should always be accepted, separated by a comma
|
||||
good-names=i,
|
||||
j,
|
||||
k,
|
||||
ex,
|
||||
Run,
|
||||
_
|
||||
|
||||
# Include a hint for the correct naming format with invalid-name
|
||||
include-naming-hint=no
|
||||
|
||||
# Naming style matching correct inline iteration names
|
||||
inlinevar-naming-style=any
|
||||
|
||||
# Regular expression matching correct inline iteration names. Overrides
|
||||
# inlinevar-naming-style
|
||||
#inlinevar-rgx=
|
||||
|
||||
# Naming style matching correct method names
|
||||
method-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct method names. Overrides method-naming-
|
||||
# style
|
||||
#method-rgx=
|
||||
|
||||
# Naming style matching correct module names
|
||||
module-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct module names. Overrides module-naming-
|
||||
# style
|
||||
#module-rgx=
|
||||
|
||||
# Colon-delimited sets of names that determine each other's naming style when
|
||||
# the name regexes allow several styles.
|
||||
name-group=
|
||||
|
||||
# Regular expression which should only match function or class names that do
|
||||
# not require a docstring.
|
||||
no-docstring-rgx=^_
|
||||
|
||||
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||
# to this list to register other decorators that produce valid properties.
|
||||
property-classes=abc.abstractproperty
|
||||
|
||||
# Naming style matching correct variable names
|
||||
variable-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct variable names. Overrides variable-
|
||||
# naming-style
|
||||
#variable-rgx=
|
||||
|
||||
|
||||
[DESIGN]
|
||||
|
||||
# Maximum number of arguments for function / method
|
||||
max-args=5
|
||||
|
||||
# Maximum number of attributes for a class (see R0902).
|
||||
max-attributes=7
|
||||
|
||||
# Maximum number of boolean expressions in a if statement
|
||||
max-bool-expr=5
|
||||
|
||||
# Maximum number of branch for function / method body
|
||||
max-branches=12
|
||||
|
||||
# Maximum number of locals for function / method body
|
||||
max-locals=15
|
||||
|
||||
# Maximum number of parents for a class (see R0901).
|
||||
max-parents=7
|
||||
|
||||
# Maximum number of public methods for a class (see R0904).
|
||||
max-public-methods=20
|
||||
|
||||
# Maximum number of return / yield for function / method body
|
||||
max-returns=6
|
||||
|
||||
# Maximum number of statements in function / method body
|
||||
max-statements=50
|
||||
|
||||
# Minimum number of public methods for a class (see R0903).
|
||||
min-public-methods=2
|
||||
|
||||
|
||||
[CLASSES]
|
||||
|
||||
# List of method names used to declare (i.e. assign) instance attributes.
|
||||
defining-attr-methods=__init__,
|
||||
__new__,
|
||||
setUp
|
||||
|
||||
# List of member names, which should be excluded from the protected access
|
||||
# warning.
|
||||
exclude-protected=_asdict,
|
||||
_fields,
|
||||
_replace,
|
||||
_source,
|
||||
_make
|
||||
|
||||
# List of valid names for the first argument in a class method.
|
||||
valid-classmethod-first-arg=cls
|
||||
|
||||
# List of valid names for the first argument in a metaclass class method.
|
||||
valid-metaclass-classmethod-first-arg=mcs
|
||||
|
||||
|
||||
[IMPORTS]
|
||||
|
||||
# Allow wildcard imports from modules that define __all__.
|
||||
allow-wildcard-with-all=no
|
||||
|
||||
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||
# 3 compatible code, which means that the block might have code that exists
|
||||
# only in one or another interpreter, leading to false positives when analysed.
|
||||
analyse-fallback-blocks=no
|
||||
|
||||
# Deprecated modules which should not be used, separated by a comma
|
||||
deprecated-modules=regsub,
|
||||
TERMIOS,
|
||||
Bastion,
|
||||
rexec
|
||||
|
||||
# Create a graph of external dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
ext-import-graph=
|
||||
|
||||
# Create a graph of every (i.e. internal and external) dependencies in the
|
||||
# given file (report RP0402 must not be disabled)
|
||||
import-graph=
|
||||
|
||||
# Create a graph of internal dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
int-import-graph=
|
||||
|
||||
# Force import order to recognize a module as part of the standard
|
||||
# compatibility libraries.
|
||||
known-standard-library=
|
||||
|
||||
# Force import order to recognize a module as part of a third party library.
|
||||
known-third-party=enchant
|
||||
|
||||
[STRING]
|
||||
|
||||
# This flag controls whether the implicit-str-concat should generate a warning
|
||||
# on implicit string concatenation in sequences defined over several lines.
|
||||
check-str-concat-over-line-jumps=yes
|
||||
|
||||
[EXCEPTIONS]
|
||||
|
||||
# Exceptions that will emit a warning when being caught. Defaults to
|
||||
# "Exception"
|
||||
overgeneral-exceptions=Exception
|
||||
@@ -54,13 +54,13 @@ print(y.grad) # dz/dy
|
||||
|
||||
## Neural networks?
|
||||
|
||||
It turns out, a decent autograd tensor library is 90% of what you need for neural networks. Add an optimizer (SGD, RMSprop, and Adam implemented) from tinygrad.optim, write some boilerplate minibatching code, and you have all you need.
|
||||
It turns out, a decent autograd tensor library is 90% of what you need for neural networks. Add an optimizer (SGD, RMSprop, and Adam implemented) from tinygrad.nn.optim, write some boilerplate minibatching code, and you have all you need.
|
||||
|
||||
### Neural network example (from test/test_mnist.py)
|
||||
|
||||
```python
|
||||
from tinygrad.tensor import Tensor
|
||||
import tinygrad.optim as optim
|
||||
import tinygrad.nn.optim as optim
|
||||
|
||||
class TinyBobNet:
|
||||
def __init__(self):
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
from tinygrad.llops.ops_gpu import GPUBuffer, CL, CLProgram, CLBuffer
|
||||
from tinygrad.ops import ProcessingOps
|
||||
from tinygrad.helpers import prod, ConvArgs
|
||||
from typing import List, Tuple, Optional, Dict
|
||||
from typing import List, Tuple, Optional, Dict, Set
|
||||
import numpy as np
|
||||
import pyopencl as cl
|
||||
|
||||
@@ -40,7 +40,7 @@ def get_replacements(prg_src:str, opencl_type:List[str]) -> Dict[str, str]:
|
||||
args = [f"(outputRow * get_image_width(output) + outputLocation.x)*4+{i}", acc]+args
|
||||
middle_code.append(f"{acc} = _ewop("+', '.join(args)+");\n")
|
||||
"""
|
||||
acc = f"outputValues[i]"
|
||||
acc = "outputValues[i]"
|
||||
args = [x.split(" ")[-1].replace("*", "") for x in opencl_type]
|
||||
args = ["smp", "outputLocation", "(outputLocation.y * get_image_width(output) + outputLocation.x)*4", acc]+args
|
||||
middle_code.append(f"{acc} = _ewop("+', '.join(args)+");\n")
|
||||
@@ -155,8 +155,7 @@ class OpenCLBuffer(GPUBuffer):
|
||||
elif buf.st.contiguous:
|
||||
# use float4
|
||||
ewtypes.append(f"__global const float4 *{name}_g")
|
||||
getters.append(f"inline float4 get4_{name}(__global const float4 *x, const sampler_t smp, int2 loc, int gid) {{"+
|
||||
f"return x[gid/4]; }}")
|
||||
getters.append(f"inline float4 get4_{name}(__global const float4 *x, const sampler_t smp, int2 loc, int gid) {{ return x[gid/4]; }}")
|
||||
elif UNSAFE_FLOAT4:
|
||||
# aggressive constant folding
|
||||
fakebufs.append(name)
|
||||
|
||||
@@ -4,6 +4,6 @@ import sys
|
||||
# only pyximport this
|
||||
import pyximport
|
||||
py_importer, pyx_importer = pyximport.install()
|
||||
from accel.rawcpu.buffer import RawCPUBuffer
|
||||
from accel.rawcpu.buffer import RawCPUBuffer # noqa: F401
|
||||
sys.meta_path.remove(pyx_importer)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import time
|
||||
from tqdm import trange
|
||||
from extra.utils import get_parameters
|
||||
from models.efficientnet import EfficientNet
|
||||
import tinygrad.optim as optim
|
||||
import tinygrad.nn.optim as optim
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.llops.ops_gpu import CL
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ sys.path.append(os.path.join(os.getcwd(), 'test'))
|
||||
|
||||
from tinygrad.tensor import Tensor, Function, register
|
||||
from extra.utils import get_parameters
|
||||
import tinygrad.optim as optim
|
||||
import tinygrad.nn.optim as optim
|
||||
from test_mnist import X_train
|
||||
from torchvision.utils import make_grid, save_image
|
||||
import torch
|
||||
|
||||
@@ -11,7 +11,7 @@ from tinygrad.nn import BatchNorm2D
|
||||
from extra.utils import get_parameters
|
||||
from datasets import fetch_mnist
|
||||
from extra.training import train, evaluate, sparse_categorical_crossentropy
|
||||
import tinygrad.optim as optim
|
||||
import tinygrad.nn.optim as optim
|
||||
from extra.augment import augment_img
|
||||
GPU = os.getenv("GPU", None) is not None
|
||||
QUICK = os.getenv("QUICK", None) is not None
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.tensor import Tensor
|
||||
from extra.utils import get_parameters
|
||||
from tqdm import trange
|
||||
from tinygrad.nn import BatchNorm2D
|
||||
import tinygrad.optim as optim
|
||||
import tinygrad.nn.optim as optim
|
||||
from datasets import fetch_cifar
|
||||
|
||||
class TinyConvNet:
|
||||
|
||||
@@ -8,10 +8,10 @@ from tinygrad.tensor import Device
|
||||
from extra.utils import get_parameters
|
||||
from extra.training import train, evaluate
|
||||
from models.resnet import ResNet
|
||||
from tinygrad.optim import Adam
|
||||
from tinygrad.nn.optim import Adam
|
||||
from datasets import fetch_mnist
|
||||
|
||||
from tinygrad.optim import Adam
|
||||
from tinygrad.nn.optim import Adam
|
||||
|
||||
class ComposeTransforms:
|
||||
def __init__(self, trans):
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.tensor import Device
|
||||
from extra.utils import get_parameters
|
||||
from extra.training import train, evaluate
|
||||
from models.transformer import Transformer
|
||||
from tinygrad.optim import Adam
|
||||
from tinygrad.nn.optim import Adam
|
||||
|
||||
# dataset idea from https://github.com/karpathy/minGPT/blob/master/play_math.ipynb
|
||||
def make_dataset():
|
||||
@@ -25,7 +25,7 @@ def make_dataset():
|
||||
|
||||
return ds_X_train, ds_Y_train, ds_X_test, ds_Y_test
|
||||
|
||||
from tinygrad.optim import Adam
|
||||
from tinygrad.nn.optim import Adam
|
||||
if __name__ == "__main__":
|
||||
model = Transformer(10, 6, 2, 128, 4, 32)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from PIL import Image
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.optim import SGD
|
||||
from tinygrad.nn.optim import SGD
|
||||
import examples.yolo.waifu2x
|
||||
from examples.yolo.kinne import KinneDir
|
||||
import sys
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
import tinygrad.nn as nn
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@@ -19,7 +19,7 @@ setup(name='tinygrad',
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License"
|
||||
],
|
||||
install_requires=['numpy', 'requests', 'pillow'],
|
||||
install_requires=['numpy', 'requests', 'pillow', 'networkx'],
|
||||
python_requires='>=3.8',
|
||||
extras_require={
|
||||
'gpu': ["pyopencl", "six"],
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
import tinygrad.optim as optim
|
||||
import tinygrad.nn.optim as optim
|
||||
from extra.training import train, evaluate
|
||||
from extra.utils import get_parameters
|
||||
from datasets import fetch_mnist
|
||||
|
||||
@@ -2,7 +2,7 @@ import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.optim import Adam, SGD, RMSprop
|
||||
from tinygrad.nn.optim import Adam, SGD, RMSprop
|
||||
from extra.utils import get_parameters
|
||||
|
||||
x_init = np.random.randn(1,3).astype(np.float32)
|
||||
|
||||
@@ -128,6 +128,11 @@ class TestSingleShapeTracker(unittest.TestCase):
|
||||
self.st.permute(1,0)
|
||||
assert not self.st.contiguous
|
||||
|
||||
def shapetracker_getitem(st, val):
|
||||
locals = {"idx": val, "valid": 1}
|
||||
exec(st.expr(), None, locals)
|
||||
return locals["idx"] if locals["valid"] else -1
|
||||
|
||||
class TestShapeTracker(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.st = ShapeTracker((7,4))
|
||||
@@ -135,7 +140,7 @@ class TestShapeTracker(unittest.TestCase):
|
||||
self.apply = lambda fxn: [fxn(x) for x in [self.st, self.dt]]
|
||||
|
||||
def tearDown(self):
|
||||
x = [self.st[i] for i in range(prod(self.st.shape))]
|
||||
x = [shapetracker_getitem(self.st, i) for i in range(prod(self.st.shape))]
|
||||
y = [self.dt[i] for i in range(prod(self.dt.shape))]
|
||||
print(x,y, self.st.shape, self.dt.shape, self.st.expr())
|
||||
assert self.st.shape == self.dt.shape
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import unittest
|
||||
import time
|
||||
import tinygrad.optim as optim
|
||||
import tinygrad.nn.optim as optim
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Device
|
||||
from extra.training import train
|
||||
|
||||
@@ -1 +1 @@
|
||||
from tinygrad import optim, tensor, nn
|
||||
from tinygrad import tensor, nn # noqa: F401
|
||||
|
||||
@@ -11,7 +11,9 @@ def get_conv_args(x_shape, w_shape, stride=1, groups=1, padding=0, dilation=1, o
|
||||
cout,cin,H,W = w_shape
|
||||
sy,sx = (stride, stride) if isinstance(stride, int) else stride
|
||||
if not isinstance(padding, int) and len(padding) == 4: px,px_,py,py_ = padding
|
||||
else: py,px = (padding, padding) if isinstance(padding, int) else padding; py_, px_ = py, px
|
||||
else:
|
||||
py,px = (padding, padding) if isinstance(padding, int) else padding
|
||||
py_, px_ = py, px
|
||||
dy,dx = (dilation, dilation) if isinstance(dilation, int) else dilation
|
||||
bs,cin_,iy,ix = x_shape
|
||||
|
||||
|
||||
@@ -3,9 +3,9 @@ import os, functools
|
||||
import numpy as np
|
||||
import pyopencl as cl # type: ignore
|
||||
from collections import defaultdict
|
||||
from typing import List, Tuple, Optional, Dict, Union, Set, Tuple
|
||||
from typing import List, Tuple, Optional, Dict, Union, Set
|
||||
from tinygrad.helpers import prod, ConvArgs
|
||||
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
|
||||
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps
|
||||
from tinygrad.shapetracker import ShapeTracker, View, strides_for_shape
|
||||
|
||||
CLCACHE = int(os.getenv("CLCACHE", "1"))
|
||||
@@ -58,7 +58,7 @@ class CLProgram:
|
||||
if DEBUG >= 1:
|
||||
CL.time_sum += 0 if DEBUG <= 1 or CL.CACHE is not None else (e.profile.end - e.profile.start)
|
||||
CL.ops_sum += op_estimate
|
||||
print(f"**CL** {CL.kernel_count:6d} {self.name:20s} args {len(args[2:]):5d} kernels {str(args[0]):18s} {str(args[1]):12s} OPs {op_estimate/1e6:6.1f}M/{CL.ops_sum/1e9:7.2f}G " + \
|
||||
print(f"**CL** {CL.kernel_count:6d} {self.name:20s} args {len(args[2:]):5d} kernels {str(args[0]):18s} {str(args[1]):12s} OPs {op_estimate/1e6:6.1f}M/{CL.ops_sum/1e9:7.2f}G " +
|
||||
("" if DEBUG <= 1 or CL.CACHE is not None else f"tm {(e.profile.end - e.profile.start)/1e3:9.2f}us/{CL.time_sum/1e6:9.2f}ms ({op_estimate/(e.profile.end - e.profile.start):8.2f} GFLOPS)"))
|
||||
if DEBUG >= 4: print(self.prg)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from tinygrad.llops.ops_cpu import CPUBuffer # type: ignore
|
||||
from tinygrad.ops import MovementOps, ProcessingOps
|
||||
from tinygrad.ops import ProcessingOps
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
class TorchBuffer(torch.Tensor):
|
||||
|
||||
@@ -5,106 +5,106 @@ from tinygrad.tensor import Function
|
||||
# ************* unary ops *************
|
||||
|
||||
class ReLU(Function):
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return input.unary_op(UnaryOps.RELU)
|
||||
def forward(self, x):
|
||||
self.save_for_backward(x)
|
||||
return x.unary_op(UnaryOps.RELU)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
return ctx.saved_tensors[0].unary_op(UnaryOps.SIGN).unary_op(UnaryOps.RELU).binary_op(BinaryOps.MUL, grad_output)
|
||||
def backward(self, grad_output):
|
||||
return self.saved_tensors[0].unary_op(UnaryOps.SIGN).unary_op(UnaryOps.RELU).binary_op(BinaryOps.MUL, grad_output)
|
||||
|
||||
class Log(Function):
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return input.unary_op(UnaryOps.LOG)
|
||||
def forward(self, x):
|
||||
self.save_for_backward(x)
|
||||
return x.unary_op(UnaryOps.LOG)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.binary_op(BinaryOps.DIV, ctx.saved_tensors[0])
|
||||
def backward(self, grad_output):
|
||||
return grad_output.binary_op(BinaryOps.DIV, self.saved_tensors[0])
|
||||
|
||||
class Exp(Function):
|
||||
def forward(ctx, input):
|
||||
ret = input.unary_op(UnaryOps.EXP)
|
||||
ctx.save_for_backward(ret) # we save the output here, not the input
|
||||
def forward(self, x):
|
||||
ret = x.unary_op(UnaryOps.EXP)
|
||||
self.save_for_backward(ret) # we save the output here, not the input
|
||||
return ret
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
return ctx.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output)
|
||||
def backward(self, grad_output):
|
||||
return self.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output)
|
||||
|
||||
# TODO: add Neg? confirm the optimizer on Sub good enough
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
ctx.input_shape = input.shape
|
||||
return input.reduce_op(ReduceOps.SUM, reduce_shape(input.shape, axis))
|
||||
def forward(self, x, axis=None):
|
||||
self.input_shape = x.shape
|
||||
return x.reduce_op(ReduceOps.SUM, reduce_shape(x.shape, axis))
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.movement_op(MovementOps.EXPAND, ctx.input_shape)
|
||||
def backward(self, grad_output):
|
||||
return grad_output.movement_op(MovementOps.EXPAND, self.input_shape)
|
||||
|
||||
class Max(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
ret = input.reduce_op(ReduceOps.MAX, reduce_shape(input.shape, axis))
|
||||
ctx.save_for_backward(input, ret)
|
||||
def forward(self, x, axis=None):
|
||||
ret = x.reduce_op(ReduceOps.MAX, reduce_shape(x.shape, axis))
|
||||
self.save_for_backward(x, ret)
|
||||
return ret
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, ret = ctx.saved_tensors
|
||||
def backward(self, grad_output):
|
||||
x, ret = self.saved_tensors
|
||||
|
||||
# 1s in locations where the max was chosen (can be two locations)
|
||||
max_is_1s = input.binary_op(BinaryOps.CMPEQ, ret.movement_op(MovementOps.EXPAND, input.shape))
|
||||
max_is_1s = x.binary_op(BinaryOps.CMPEQ, ret.movement_op(MovementOps.EXPAND, x.shape))
|
||||
|
||||
# sum of locations, averaged
|
||||
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape)
|
||||
div = div.movement_op(MovementOps.EXPAND, input.shape)
|
||||
div = div.movement_op(MovementOps.EXPAND, x.shape)
|
||||
max_is_amount = max_is_1s.binary_op(BinaryOps.DIV, div)
|
||||
|
||||
grad_output_expanded = grad_output.movement_op(MovementOps.EXPAND, input.shape)
|
||||
grad_output_expanded = grad_output.movement_op(MovementOps.EXPAND, x.shape)
|
||||
return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded)
|
||||
|
||||
# ************* binary ops *************
|
||||
|
||||
class Add(Function):
|
||||
def forward(ctx, x, y):
|
||||
def forward(self, x, y):
|
||||
return x.binary_op(BinaryOps.ADD, y)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output if ctx.needs_input_grad[0] else None, \
|
||||
grad_output if ctx.needs_input_grad[1] else None
|
||||
def backward(self, grad_output):
|
||||
return grad_output if self.needs_input_grad[0] else None, \
|
||||
grad_output if self.needs_input_grad[1] else None
|
||||
|
||||
class Sub(Function):
|
||||
def forward(ctx, x, y):
|
||||
def forward(self, x, y):
|
||||
return x.binary_op(BinaryOps.SUB, y)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output if ctx.needs_input_grad[0] else None, \
|
||||
grad_output.unary_op(UnaryOps.NEG) if ctx.needs_input_grad[1] else None
|
||||
def backward(self, grad_output):
|
||||
return grad_output if self.needs_input_grad[0] else None, \
|
||||
grad_output.unary_op(UnaryOps.NEG) if self.needs_input_grad[1] else None
|
||||
|
||||
class Mul(Function):
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
def forward(self, x, y):
|
||||
self.save_for_backward(x, y)
|
||||
return x.binary_op(BinaryOps.MUL, y)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
grad_x = ctx.saved_tensors[1].binary_op(BinaryOps.MUL, grad_output) if ctx.needs_input_grad[0] else None
|
||||
grad_y = ctx.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output) if ctx.needs_input_grad[1] else None
|
||||
def backward(self, grad_output):
|
||||
grad_x = self.saved_tensors[1].binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None
|
||||
grad_y = self.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
|
||||
return grad_x, grad_y
|
||||
|
||||
# TODO: add Div? is the optimizer on Pow good enough?
|
||||
# nope, we def need div, can't optimize that
|
||||
|
||||
class Pow(Function):
|
||||
def forward(ctx, x, y):
|
||||
def forward(self, x, y):
|
||||
ret = x.binary_op(BinaryOps.POW, y)
|
||||
ctx.save_for_backward(x, y, ret)
|
||||
self.save_for_backward(x, y, ret)
|
||||
return ret
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x,y,powxy = ctx.saved_tensors
|
||||
def backward(self, grad_output):
|
||||
x,y,powxy = self.saved_tensors
|
||||
grad_x, grad_y = None, None
|
||||
if ctx.needs_input_grad[0]:
|
||||
if self.needs_input_grad[0]:
|
||||
tmp = y.binary_op(BinaryOps.MUL, powxy.binary_op(BinaryOps.DIV, x)) # y * (pow(x,y)/x)
|
||||
grad_x = grad_output.binary_op(BinaryOps.MUL, tmp)
|
||||
if ctx.needs_input_grad[1]:
|
||||
if self.needs_input_grad[1]:
|
||||
tmp = x.unary_op(UnaryOps.LOG).binary_op(BinaryOps.MUL, powxy) # log(x) * pow(x,y)
|
||||
grad_y = grad_output.binary_op(BinaryOps.MUL, tmp)
|
||||
return grad_x, grad_y
|
||||
@@ -113,61 +113,61 @@ class Pow(Function):
|
||||
|
||||
# NOTE: this is sum in reverse
|
||||
class Expand(Function):
|
||||
def forward(ctx, x, shape):
|
||||
ctx.input_shape = x.shape
|
||||
def forward(self, x, shape):
|
||||
self.input_shape = x.shape
|
||||
return x.movement_op(MovementOps.EXPAND, shape)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.reduce_op(ReduceOps.SUM, ctx.input_shape)
|
||||
def backward(self, grad_output):
|
||||
return grad_output.reduce_op(ReduceOps.SUM, self.input_shape)
|
||||
|
||||
class Reshape(Function):
|
||||
def forward(ctx, x, shape):
|
||||
ctx.input_shape = x.shape
|
||||
def forward(self, x, shape):
|
||||
self.input_shape = x.shape
|
||||
shape = tuple(-prod(x.shape) // prod(shape) if s == -1 else s for s in shape)
|
||||
return x.movement_op(MovementOps.RESHAPE, shape)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.movement_op(MovementOps.RESHAPE, ctx.input_shape)
|
||||
def backward(self, grad_output):
|
||||
return grad_output.movement_op(MovementOps.RESHAPE, self.input_shape)
|
||||
|
||||
class Permute(Function):
|
||||
def forward(ctx, x, order=(1,0)):
|
||||
ctx.input_order = order
|
||||
def forward(self, x, order=(1,0)):
|
||||
self.input_order = order
|
||||
return x.movement_op(MovementOps.PERMUTE, order)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.movement_op(MovementOps.PERMUTE, tuple(argsort(ctx.input_order)))
|
||||
def backward(self, grad_output):
|
||||
return grad_output.movement_op(MovementOps.PERMUTE, tuple(argsort(self.input_order)))
|
||||
|
||||
# TODO: merge Slice and Flip into Stride with the 3 arguments
|
||||
class Slice(Function):
|
||||
def forward(ctx, x, arg=None):
|
||||
ctx.narg = tuple((0-p[0], x.shape[i]-p[0]) for i,p in enumerate(arg))
|
||||
def forward(self, x, arg=None):
|
||||
self.narg = tuple((0-p[0], x.shape[i]-p[0]) for i,p in enumerate(arg))
|
||||
return x.slice(tuple(arg))
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.slice(ctx.narg)
|
||||
def backward(self, grad_output):
|
||||
return grad_output.slice(self.narg)
|
||||
|
||||
class Flip(Function):
|
||||
def forward(ctx, x, axis):
|
||||
ctx.axis = axis
|
||||
def forward(self, x, axis):
|
||||
self.axis = axis
|
||||
return x.movement_op(MovementOps.FLIP, axis)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.movement_op(MovementOps.FLIP, ctx.axis)
|
||||
def backward(self, grad_output):
|
||||
return grad_output.movement_op(MovementOps.FLIP, self.axis)
|
||||
|
||||
# ************* processing ops *************
|
||||
|
||||
class Conv2D(Function):
|
||||
def forward(ctx, x, w, stride=1, groups=1, dilation=1, padding=0):
|
||||
ctx.C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation, padding=padding)
|
||||
ctx.save_for_backward(x,w)
|
||||
return x.processing_op(ProcessingOps.CONV, w, ctx.C)
|
||||
def forward(self, x, w, stride=1, groups=1, dilation=1, padding=0):
|
||||
self.C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation, padding=padding)
|
||||
self.save_for_backward(x,w)
|
||||
return x.processing_op(ProcessingOps.CONV, w, self.C)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x, w = ctx.saved_tensors
|
||||
C = ctx.C # conv args from the context
|
||||
def backward(self, grad_output):
|
||||
x, w = self.saved_tensors
|
||||
C = self.C # conv args from the context
|
||||
dx, dw = None, None
|
||||
|
||||
if ctx.needs_input_grad[0]: # compute derivative of inputs using ProcessingOps.CONV (this is a transposed conv)
|
||||
if self.needs_input_grad[0]: # compute derivative of inputs using ProcessingOps.CONV (this is a transposed conv)
|
||||
xt = grad_output
|
||||
if C.sx > 1 or C.sy > 1: # unstride. NOTE: this is really memory intensive for big strides. (but only when we contiguous it)
|
||||
xt = xt.movement_op(MovementOps.RESHAPE, (grad_output.shape[0], grad_output.shape[1], grad_output.shape[2], 1, grad_output.shape[3], 1))
|
||||
@@ -179,7 +179,7 @@ class Conv2D(Function):
|
||||
Cdx = get_conv_args(xt.shape, wt.shape, out_shape=x.shape, dilation=(C.dy, C.dx), padding=(py, px), groups=C.groups)
|
||||
dx = xt.processing_op(ProcessingOps.CONV, wt, Cdx)
|
||||
|
||||
if ctx.needs_input_grad[1]: # compute derivative of weights using ProcessingOps.CONV
|
||||
if self.needs_input_grad[1]: # compute derivative of weights using ProcessingOps.CONV
|
||||
xdw = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix)).movement_op(MovementOps.PERMUTE, (2, 1, 0, 3, 4))
|
||||
xdw = xdw.movement_op(MovementOps.RESHAPE, (C.cin, C.groups*C.bs, C.iy, C.ix))
|
||||
grad_output_dw = grad_output.movement_op(MovementOps.PERMUTE, (1,0,2,3))
|
||||
|
||||
@@ -6,7 +6,7 @@ def batch_normalize(x, weight, bias, mean, var, eps):
|
||||
|
||||
class BatchNorm2D:
|
||||
def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
|
||||
assert affine == True, "BatchNorm2D is only supported with affine"
|
||||
assert affine, "BatchNorm2D is only supported with affine"
|
||||
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
|
||||
|
||||
self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)
|
||||
@@ -9,9 +9,9 @@ class Optimizer:
|
||||
for param in self.params:
|
||||
param.grad = None
|
||||
|
||||
def realize(self, extra=[]):
|
||||
def realize(self, extra=None):
|
||||
# TODO: corealize
|
||||
for p in self.params + extra: p.realize()
|
||||
for p in self.params + extra if extra is not None else self.params: p.realize()
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(self, params, lr=0.001):
|
||||
@@ -51,7 +51,7 @@ if GRAPH:
|
||||
G = nx.DiGraph()
|
||||
def save_graph_exit():
|
||||
for k,v in cnts.items(): print(k, v)
|
||||
if int(os.getenv("PRUNEGRAPH", 0)):
|
||||
if int(os.getenv("PRUNEGRAPH", "0")):
|
||||
dead_nodes = []
|
||||
for n in G.nodes:
|
||||
# prune movementops and loadops
|
||||
@@ -123,7 +123,7 @@ def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
|
||||
real_srcs : Dict[LazyBuffer, DeviceBuffer] = {x:x.realize(self.device) for x in get_lazybuffers(src.op)}
|
||||
buf_names : Dict[LazyBuffer, str] = {x:f"arg_{i}" for i,x in enumerate(real_srcs.keys())}
|
||||
|
||||
return self.dbuffer(self.shape)._processing_op([(buf_names[lb], db) for lb,db in real_srcs.items()], \
|
||||
return self.dbuffer(self.shape)._processing_op([(buf_names[lb], db) for lb,db in real_srcs.items()],
|
||||
earlycode=_ast(LazyOp(self.op.op, (src.op,), self.op.arg), buf_names, self.dbuffer.code_for_op), earlybufs=buf_names.values(), start=self.dbuffer.start_for_op[self.op.op]), \
|
||||
list(real_srcs.values()), ReduceOps
|
||||
else:
|
||||
@@ -167,7 +167,7 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
|
||||
|
||||
for x in real_srcs.keys(): real_srcs[x] = x.realize(self.device)
|
||||
# fast path, no middle buffers
|
||||
return self.dbuffer(self.shape)._processing_op([(buf_names[lb], db) for lb,db in real_srcs.items()], \
|
||||
return self.dbuffer(self.shape)._processing_op([(buf_names[lb], db) for lb,db in real_srcs.items()],
|
||||
_ast(self.op, buf_names, self.dbuffer.code_for_op), earlycode=earlycode, earlybufs=set(x for x in buf_names.values() if x.startswith("earlyarg_")),
|
||||
C=conv_args, reduce_shape=reduce_shape), \
|
||||
list(real_srcs.values()), ProcessingOps if conv_args is not None else (ReduceOps if reduce_shape[0] != reduce_shape[1] else BinaryOps)
|
||||
@@ -176,8 +176,8 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
|
||||
# slow path, creates middle buffers
|
||||
def ast_eval(x: Union[LazyBuffer, LazyOp]) -> DeviceBuffer:
|
||||
if isinstance(x, LazyBuffer): return real_srcs[x]
|
||||
if isinstance(x.op, UnaryOps): return ast_eval(x.src[0]).unary_op(x.op)
|
||||
if isinstance(x.op, BinaryOps): return ast_eval(x.src[0]).binary_op(x.op, ast_eval(x.src[1]))
|
||||
if x.op in UnaryOps: return ast_eval(x.src[0]).unary_op(x.op)
|
||||
if x.op in BinaryOps: return ast_eval(x.src[0]).binary_op(x.op, ast_eval(x.src[1]))
|
||||
return ast_eval(self.op), list(real_srcs.values()), BinaryOps
|
||||
|
||||
_realize = {LoadOps:_realize_loadops, ReduceOps:_realize_reduceops, MovementOps:_realize_movementops, BinaryOps:_realize_binaryops, ProcessingOps:_realize_processingops}
|
||||
@@ -205,14 +205,13 @@ class LazyBuffer:
|
||||
if optype == LoadOps: return super().__new__(cls)
|
||||
wop = (device, optype, get_weakop(op)) # NOTE: shape should be deterministic. annoying to cache with the ShapeTracker
|
||||
# NOTE: we need "ret" to prevent the new buffer from being immediately deleted
|
||||
if wop not in LazyBuffer.lazycache: LazyBuffer.lazycache[wop] = ret = super().__new__(cls)
|
||||
if wop not in LazyBuffer.lazycache: LazyBuffer.lazycache[wop] = ret = super().__new__(cls) # noqa: F841, pylint: disable=W0612
|
||||
return LazyBuffer.lazycache[wop]
|
||||
|
||||
def __init__(self, device, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp):
|
||||
if getattr(self, 'device', None) is not None: return # cache hit, we return and don't reinit
|
||||
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
|
||||
self.shape = self.st.shape
|
||||
self.optype, self.op = optype, op
|
||||
self.shape, self.optype, self.op = self.st.shape, optype, op
|
||||
self.realized : Optional[DeviceBuffer] = None
|
||||
self.device, self.dbuffer = device, Device._buffers[device]
|
||||
self.children : weakref.WeakSet[LazyBuffer] = weakref.WeakSet()
|
||||
@@ -239,69 +238,70 @@ class LazyBuffer:
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x, device): return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x.copy()))
|
||||
def toCPU(x): return x.realize().toCPU()
|
||||
def toCPU(self): return self.realize().toCPU()
|
||||
|
||||
def unary_op(x:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, x)
|
||||
def binary_op(x:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, x, y)
|
||||
def contiguous_op(x:LazyBuffer) -> LazyBuffer: return x if x.st.contiguous else x.unary_op(UnaryOps.NOOP)
|
||||
def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self)
|
||||
def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y)
|
||||
def contiguous_op(self:LazyBuffer) -> LazyBuffer: return self if self.st.contiguous else self.unary_op(UnaryOps.NOOP)
|
||||
|
||||
# TODO: permute to put all the reduce axis at the end
|
||||
def reduce_op(x:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
if x.shape == tuple(new_shape): return x
|
||||
if getattr(x.dbuffer, "REQUIRES_SIMPLE_REDUCE", False) and (len(new_shape) != 2 or new_shape[1] != 1):
|
||||
num, red = prod([s for s,n in zip(x.shape, new_shape) if n != 1]), prod([s for s,n in zip(x.shape, new_shape) if n == 1])
|
||||
x = x.movement_op(MovementOps.PERMUTE, [i for i,n in enumerate(new_shape) if n != 1] + [i for i,n in enumerate(new_shape) if n == 1])
|
||||
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
if self.shape == tuple(new_shape): return self
|
||||
if getattr(self.dbuffer, "REQUIRES_SIMPLE_REDUCE", False) and (len(new_shape) != 2 or new_shape[1] != 1):
|
||||
num, red = prod([s for s,n in zip(self.shape, new_shape) if n != 1]), prod([s for s,n in zip(self.shape, new_shape) if n == 1])
|
||||
x = self.movement_op(MovementOps.PERMUTE, [i for i,n in enumerate(new_shape) if n != 1] + [i for i,n in enumerate(new_shape) if n == 1])
|
||||
x = x.movement_op(MovementOps.RESHAPE, (num, red)) # remove this reshape, at the end is enough
|
||||
return x.reduce_op(op, (num, 1)).movement_op(MovementOps.RESHAPE, new_shape)
|
||||
else:
|
||||
return LazyBuffer(x.device, tuple(new_shape), ReduceOps, LazyOp(op, (x,), tuple(new_shape)))
|
||||
return LazyBuffer(self.device, tuple(new_shape), ReduceOps, LazyOp(op, (self,), tuple(new_shape)))
|
||||
|
||||
# syntactic sugar around PAD and SHRINK
|
||||
# TODO: turn RESHAPE into EXPAND and CONTRACT (current EXPAND should be REPEAT)
|
||||
def slice(x:LazyBuffer, arg):
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
|
||||
return x.movement_op(MovementOps.PAD, padding).movement_op(MovementOps.SHRINK, tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)))
|
||||
def slice(self:LazyBuffer, arg):
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg)]
|
||||
return self.movement_op(MovementOps.PAD, padding).movement_op(MovementOps.SHRINK, tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)))
|
||||
|
||||
def movement_op(x:LazyBuffer, op:MovementOps, arg) -> LazyBuffer:
|
||||
def movement_op(self:LazyBuffer, op:MovementOps, arg) -> LazyBuffer:
|
||||
# TODO: look into why that copy is needed
|
||||
arg = tuple(copy(arg))
|
||||
|
||||
# instant nops
|
||||
if op in [MovementOps.RESHAPE, MovementOps.EXPAND] and arg == x.shape: return x
|
||||
if op == MovementOps.PERMUTE and arg == tuple(range(len(x.shape))): return x
|
||||
if op == MovementOps.SHRINK and arg == tuple((0,i) for i in x.shape): return x
|
||||
if op == MovementOps.PAD and arg == tuple((0,0) for _ in x.shape): return x
|
||||
if op == MovementOps.FLIP and all(s == 1 or i not in arg for i,s in enumerate(x.shape)): return x
|
||||
if op in [MovementOps.RESHAPE, MovementOps.EXPAND] and arg == self.shape: return self
|
||||
if op == MovementOps.PERMUTE and arg == tuple(range(len(self.shape))): return self
|
||||
if op == MovementOps.SHRINK and arg == tuple((0,i) for i in self.shape): return self
|
||||
if op == MovementOps.PAD and arg == tuple((0,0) for _ in self.shape): return self
|
||||
if op == MovementOps.FLIP and all(s == 1 or i not in arg for i,s in enumerate(self.shape)): return self
|
||||
|
||||
# two ops in a row is one op
|
||||
if op in [MovementOps.RESHAPE, MovementOps.EXPAND, MovementOps.SHRINK] and x.realized is None and x.op.op == op: return x.op.src[0].movement_op(op, arg)
|
||||
if op == MovementOps.PERMUTE and x.realized is None and x.op.op == op: return x.op.src[0].movement_op(op, tuple(x.op.arg[i] for i in arg))
|
||||
if op == MovementOps.PAD and x.realized is None and x.op.op == op: return x.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(x.op.arg, arg)))
|
||||
if op in [MovementOps.RESHAPE, MovementOps.EXPAND, MovementOps.SHRINK] and self.realized is None and self.op.op == op: return self.op.src[0].movement_op(op, arg)
|
||||
if op == MovementOps.PERMUTE and self.realized is None and self.op.op == op: return self.op.src[0].movement_op(op, tuple(self.op.arg[i] for i in arg))
|
||||
if op == MovementOps.PAD and self.realized is None and self.op.op == op: return self.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)))
|
||||
|
||||
# some permutes are actually just reshapes
|
||||
if op == MovementOps.PERMUTE and ShapeTracker(x.shape).movement_op(op, arg).contiguous: return x.movement_op(MovementOps.RESHAPE, tuple(x.shape[i] for i in arg))
|
||||
if op == MovementOps.PERMUTE and ShapeTracker(self.shape).movement_op(op, arg).contiguous: return self.movement_op(MovementOps.RESHAPE, tuple(self.shape[i] for i in arg))
|
||||
|
||||
if SHUFFLE_MOVEMENT_OPS and x.optype == BinaryOps and x.realized is None and len(x.children) == 0 and (SHUFFLE_PAD_OPS or op != MovementOps.PAD) and op not in [MovementOps.EXPAND, MovementOps.STRIDED]:
|
||||
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and len(self.children) == 0 and (SHUFFLE_PAD_OPS or op != MovementOps.PAD) and op not in [MovementOps.EXPAND, MovementOps.STRIDED]:
|
||||
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead
|
||||
def replace_with_movement_op(y:Union[LazyOp, LazyBuffer]) -> LazyBuffer:
|
||||
if isinstance(y, LazyBuffer): return y.movement_op(op, arg)
|
||||
assert isinstance(y.op, BinaryOps) or isinstance(y.op, UnaryOps)
|
||||
return elementwise_op(y.op, *[replace_with_movement_op(z) for z in y.src])
|
||||
return replace_with_movement_op(x.op)
|
||||
assert y.op in BinaryOps or y.op in UnaryOps
|
||||
return elementwise_op(y.op, *[replace_with_movement_op(z) for z in y.src]) # type: ignore
|
||||
return replace_with_movement_op(self.op)
|
||||
|
||||
# create the buffer
|
||||
ret = LazyBuffer(x.device, ShapeTracker(x.st).movement_op(op, arg), MovementOps, LazyOp(op, (x,), arg))
|
||||
ret = LazyBuffer(self.device, ShapeTracker(self.st).movement_op(op, arg), MovementOps, LazyOp(op, (self,), arg))
|
||||
|
||||
# NOTE: if ret is in the cache, it can already be realized
|
||||
if REMOVE_MOVEMENT_NOPS and ret.realized is None and x.realized is None and ret.st.contiguous:
|
||||
if REMOVE_MOVEMENT_NOPS and ret.realized is None and self.realized is None and ret.st.contiguous:
|
||||
# MovementOps aren't stacked any more, they each have one parent, find the root
|
||||
root = get_movementroot(x)
|
||||
if root.st.contiguous and root != x and prod(ret.st.shape) == prod(root.shape):
|
||||
root = get_movementroot(self)
|
||||
if root.st.contiguous and root != self and prod(ret.st.shape) == prod(root.shape):
|
||||
return root.movement_op(MovementOps.RESHAPE, ret.st.shape) if ret.st.shape != root.shape else root
|
||||
|
||||
return ret
|
||||
|
||||
def processing_op(x:LazyBuffer, op:ProcessingOps, w:LazyBuffer, C:ConvArgs) -> LazyBuffer:
|
||||
def processing_op(self:LazyBuffer, op:ProcessingOps, w:LazyBuffer, C:ConvArgs) -> LazyBuffer:
|
||||
x = self
|
||||
# TODO: fixup C?
|
||||
if NOCONV or not getattr(x.dbuffer, "SUPPORTS_PADDING", False): x = x.slice(((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
|
||||
|
||||
|
||||
@@ -83,15 +83,11 @@ class ShapeTracker:
|
||||
def offset(self): return self.views[-1].offset
|
||||
|
||||
def expr(self): return ';'.join([v.expr for v in self.views[::-1] if v.expr != 'idx=idx' and v.expr != 'valid=valid'])
|
||||
def movement_op(self, op, arg): getattr(self, str(op).split(".")[1].lower())(*arg); return self
|
||||
def movement_op(self, op, arg):
|
||||
getattr(self, str(op).split(".")[1].lower())(*arg)
|
||||
return self
|
||||
def needs_valid(self): return any(isinstance(v, ZeroView) for v in self.views)
|
||||
|
||||
# TODO: this is not really needed, only for testing
|
||||
def __getitem__(self, val):
|
||||
locals = {"idx": val, "valid": 1}
|
||||
exec(self.expr(), None, locals)
|
||||
return locals["idx"] if locals["valid"] else -1
|
||||
|
||||
# TODO: do we really need this for conv?
|
||||
# if we replace, confirm the ops taken fold into one view
|
||||
def strided(self, *arg):
|
||||
|
||||
@@ -146,7 +146,7 @@ class Tensor:
|
||||
s = [[] for _ in range(len(args))]
|
||||
for i in range(len(self.shape)):
|
||||
if i != dim:
|
||||
assert self.shape[i] == y.shape[i]
|
||||
for y in args: assert self.shape[i] == y.shape[i]
|
||||
for j in range(len(args)):
|
||||
s[j].append((0, self.shape[i]))
|
||||
else:
|
||||
@@ -161,18 +161,18 @@ class Tensor:
|
||||
ret += y.slice(arg=ts)
|
||||
return ret
|
||||
|
||||
def matmul(x:Tensor, w:Tensor):
|
||||
def matmul(self:Tensor, w:Tensor):
|
||||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
bs, groups = prod(x.shape[0:-2]), prod(w.shape[0:-2])
|
||||
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
|
||||
cin, cout = w.shape[-2], w.shape[-1]
|
||||
out_shape_t = tuple(list(x.shape[0:-2])+[cout,-1])
|
||||
if len(x.shape) > 1: order = tuple(list(range(len(x.shape)-2))+[len(x.shape)-1, len(x.shape)-2])
|
||||
out_shape_t = tuple(list(self.shape[0:-2])+[cout,-1])
|
||||
if len(self.shape) > 1: order = tuple(list(range(len(self.shape)-2))+[len(self.shape)-1, len(self.shape)-2])
|
||||
else: order, out_shape_t = (0,), (cout, )
|
||||
worder = tuple(list(range(len(w.shape)-2))+[len(w.shape)-1, len(w.shape)-2])
|
||||
|
||||
# NOTE: with NHWC we can remove the transposes
|
||||
# bs x groups*cin x H x W
|
||||
cx = x.transpose(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
|
||||
cx = self.transpose(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
|
||||
# groups*cout x cin x H, W
|
||||
cw = w.transpose(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
|
||||
return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).transpose(order=order)
|
||||
@@ -234,7 +234,7 @@ class Tensor:
|
||||
|
||||
def __neg__(self): return 0.0-self
|
||||
def sqrt(self): return self.pow(0.5)
|
||||
def clip(self, min, max): return ((self-min).relu()+min) - (self-max).relu()
|
||||
def clip(self, min_, max_): return ((self-min_).relu()+min_) - (self-max_).relu()
|
||||
def abs(self): return self.relu() + (-self).relu()
|
||||
def sign(self): return self / (self.abs() + 1e-10)
|
||||
|
||||
@@ -247,7 +247,7 @@ class Tensor:
|
||||
def relu6(self): return self.relu() - (self-6).relu()
|
||||
def hardswish(self): return self * (self+3).relu6() * (1/6)
|
||||
def tanh(self): return 2.0 * ((2.0 * self).sigmoid()) - 1.0
|
||||
def gelu(x): return 0.5 * x * (1 + (x * 0.7978845608 * (1 + 0.044715 * x * x)).tanh())
|
||||
def gelu(self): return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
|
||||
def leakyrelu(self, neg_slope=0.01): return self.relu() - (-neg_slope*self).relu()
|
||||
def mish(self): return self * self.softplus().tanh()
|
||||
def softplus(self, limit=20, beta=1): return (1/beta) * (1 + (self*beta).exp()).log()
|
||||
@@ -285,8 +285,8 @@ class Tensor:
|
||||
|
||||
def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return functools.reduce(lambda x,f: f(x), ll, self)
|
||||
|
||||
def layernorm(x, eps=1e-5):
|
||||
y = (x - x.mean(axis=-1, keepdim=True))
|
||||
def layernorm(self, eps=1e-5):
|
||||
y = (self - self.mean(axis=-1, keepdim=True))
|
||||
return y.div((y*y).mean(axis=-1, keepdim=True).add(eps).sqrt())
|
||||
|
||||
# An instantiation of the Function is the Context
|
||||
|
||||
Reference in New Issue
Block a user