From 00d9eda961df53026e6b89dcf86306f5e506b0d2 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 7 Dec 2023 16:32:30 -0800 Subject: [PATCH] FROM -> COPY, move vars_from_ast (#2675) --- .github/workflows/test.yml | 2 +- .pre-commit-config.yaml | 2 +- docs/abstractions.py | 6 +++--- docs/{beautiful.py => abstractions2.py} | 0 examples/handcode_resnet50_opt.py | 4 +--- test/external/fuzz_linearizer.py | 2 +- tinygrad/codegen/kernel.py | 3 +-- tinygrad/codegen/linearizer.py | 3 +-- tinygrad/device.py | 3 +-- tinygrad/features/search.py | 3 +-- tinygrad/lazy.py | 12 ++++++------ tinygrad/ops.py | 4 +++- tinygrad/realize.py | 4 ++-- 13 files changed, 22 insertions(+), 26 deletions(-) rename docs/{beautiful.py => abstractions2.py} (100%) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6eb14c7c7a..7345b9161f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -50,7 +50,7 @@ jobs: - name: Test Docs run: | python docs/abstractions.py - python docs/beautiful.py + python docs/abstractions2.py - name: Test Quickstart run: awk '/```python/{flag=1;next}/```/{flag=0}flag' docs/quickstart.md > quickstart.py && PYTHONPATH=. python quickstart.py - name: Fuzz Test symbolic diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bf8e2d8791..b377e14ed3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: name: docs entry: | python3 docs/abstractions.py - python3 docs/beautiful.py + python3 docs/abstractions2.py language: system always_run: true pass_filenames: false diff --git a/docs/abstractions.py b/docs/abstractions.py index 6971389de2..2608e1ef79 100644 --- a/docs/abstractions.py +++ b/docs/abstractions.py @@ -104,7 +104,7 @@ class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto class ReduceOps(Enum): SUM = auto(); MAX = auto() class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() class TernaryOps(Enum): MULACC = auto(); WHERE = auto() -class LoadOps(Enum): EMPTY = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() +class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # NOTE: if you have a CompiledBuffer(DeviceBuffer) # you do not need to implement the MovementOps # as they are handled by the ShapeTracker(in tinygrad/shape/shapetracker.py, code 7/10) @@ -133,9 +133,9 @@ assert len(lazyop.src) == 2 # the first source is the 2, it comes from the CPU # the source is a LazyBuffer that is a "CPU" Tensor # again, a LazyOp AST is like a GPU kernel. you have to copy the data on the device first -assert lazyop.src[0].op.op == LoadOps.FROM +assert lazyop.src[0].op.op == LoadOps.COPY assert lazyop.src[0].op.src[0].device == "CPU" -assert lazyop.src[0].op.src[0].op.src[0].realized._buf[0] == 2, "the src of the FROM LazyOP is a LazyBuffer on the CPU holding [2.]" +assert lazyop.src[0].op.src[0].op.src[0].realized._buf[0] == 2, "the src of the COPY LazyOP is a LazyBuffer on the CPU holding [2.]" assert result.lazydata.realized is None, "the LazyBuffer is not realized yet" # now we realize the LazyBuffer diff --git a/docs/beautiful.py b/docs/abstractions2.py similarity index 100% rename from docs/beautiful.py rename to docs/abstractions2.py diff --git a/examples/handcode_resnet50_opt.py b/examples/handcode_resnet50_opt.py index 720b9b0899..31c58bafd1 100644 --- a/examples/handcode_resnet50_opt.py +++ b/examples/handcode_resnet50_opt.py @@ -1,15 +1,13 @@ from typing import List from extra.models.resnet import ResNet50 from tinygrad.tensor import Tensor -from tinygrad.ops import LoadOps +from tinygrad.ops import LoadOps, vars_from_ast from tinygrad.device import Device, Compiled from tinygrad.codegen.linearizer import Linearizer from tinygrad.features.search import time_linearizer, beam_search, bufs_from_lin from tinygrad.helpers import ansilen, DEBUG, getenv -from tinygrad.lazy import vars_from_ast from tinygrad.shape.symbolic import sym_infer - if __name__ == "__main__": mdl = ResNet50() seen = set() diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 133bc03642..f515fb01a1 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -7,7 +7,7 @@ from tinygrad.features.search import get_linearizer_actions, bufs_from_lin, tupl from tinygrad.graph import print_tree from tinygrad.helpers import getenv from tinygrad.device import Device, Compiled, Interpreted -from tinygrad.lazy import vars_from_ast +from tinygrad.ops import vars_from_ast device = Device[Device.DEFAULT] diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 6d39053da0..f3be3bda08 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -1,8 +1,7 @@ from __future__ import annotations import os, math, itertools from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union -from tinygrad.lazy import vars_from_ast -from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps +from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, vars_from_ast from tinygrad.device import Device, Compiled from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, ansilen, getenv, prod, DEBUG, round_up from tinygrad.shape.shapetracker import ShapeTracker, get_contraction diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 901e9799ce..93821568a7 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -6,11 +6,10 @@ from enum import Enum, auto from dataclasses import dataclass from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, getenv, all_same, to_function_name, flatten -from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps +from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, vars_from_ast from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode from tinygrad.codegen.kernel import LocalBuffer, Kernel -from tinygrad.lazy import vars_from_ast from tinygrad.features.image import to_image_idx # bottom ones are asm only diff --git a/tinygrad/device.py b/tinygrad/device.py index 3600d7ac90..b6b40d78c9 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable import importlib, inspect, functools, pathlib, time, re, ctypes from tinygrad.helpers import ansilen, DEBUG, getenv, GlobalCounters, colored, BEAM, NOOPT, all_int, to_function_name, DType, from_mv, dtypes, flat_mv, ImageDType from tinygrad.shape.symbolic import Variable, sym_infer, sint -from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op +from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, vars_from_ast if TYPE_CHECKING: from tinygrad.codegen.linearizer import Linearizer @@ -233,7 +233,6 @@ class CompiledASTRunner(JITRunner): if ast: info = get_lazyop_info(ast) self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate - from tinygrad.lazy import vars_from_ast self.vars = vars_from_ast(ast) assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}" diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index 6aad28af93..21da35118d 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -1,8 +1,7 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable import itertools, random, math, time, multiprocessing, traceback, signal -from tinygrad.lazy import vars_from_ast from tinygrad.device import Device, Compiled, Buffer -from tinygrad.ops import MemBuffer +from tinygrad.ops import MemBuffer, vars_from_ast from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name from tinygrad.codegen.linearizer import Linearizer, UOp from tinygrad.shape.symbolic import sym_infer diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index d7e54db32e..44a7ea778e 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -5,9 +5,9 @@ from weakref import ref, WeakSet, WeakValueDictionary import numpy as np from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, dedup, merge_dicts, all_int, ImageDType, DEBUG -from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps, get_lazyop_info +from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps, get_lazyop_info, vars_from_ast from tinygrad.shape.shapetracker import ShapeTracker, get_contraction -from tinygrad.shape.symbolic import Variable, sint +from tinygrad.shape.symbolic import sint from tinygrad.device import Buffer # lazy can recurse a lot @@ -78,7 +78,7 @@ def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: ret def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x) # NOTE: this is the canonical order -def vars_from_ast(ast:LazyOp) -> List[Variable]: return sorted(set.union(*[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set()), key=lambda x: str(x.expr)) + lazycache: WeakValueDictionary = WeakValueDictionary() def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, base:Optional[LazyBuffer]=None): @@ -206,9 +206,9 @@ class LazyBuffer: return LazyBuffer.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape) def copy_to_device(self, device:str) -> LazyBuffer: - # back off a FROM if it's a double FROM - if not self.realized and self.op.op == LoadOps.FROM and cast(LazyBuffer, self.op.src[0]).device == device: return cast(LazyBuffer, self.op.src[0]) - return LazyBuffer.loadop(LoadOps.FROM, self.shape, self.dtype, device, src=self.contiguous()) + # back off a COPY if it's a double COPY + if not self.realized and self.op.op == LoadOps.COPY and cast(LazyBuffer, self.op.src[0]).device == device: return cast(LazyBuffer, self.op.src[0]) + return LazyBuffer.loadop(LoadOps.COPY, self.shape, self.dtype, device, src=self.contiguous()) def contiguous(self:LazyBuffer) -> LazyBuffer: if not self.realized and self.op.op in LoadOps and self.op.op != LoadOps.CONST: return self # all LoadOps are already contiguous (except CONST) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index ec0e3775bd..8173090982 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -17,7 +17,7 @@ class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702 # Ops below this line are not allowed in ASTs class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702 -class LoadOps(Enum): EMPTY = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702 +class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702 Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, BufferOps] OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]] @@ -80,6 +80,8 @@ class LazyOp: def shrink(self, _): raise NotImplementedError def stride(self, _): raise NotImplementedError +def vars_from_ast(ast:LazyOp) -> List[Variable]: return sorted(set.union(*[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set()), key=lambda x: str(x.expr)) + # **************** independent FlopCounter **************** @dataclass diff --git a/tinygrad/realize.py b/tinygrad/realize.py index e43ac139a8..3fdb5da107 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -12,9 +12,9 @@ class CustomOp(JITRunner): def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs) def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]: - assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.FROM, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}" + assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.COPY, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}" if si.ast.op is LoadOps.EMPTY: return None - if si.ast.op is LoadOps.FROM: return BufferCopy + if si.ast.op is LoadOps.COPY: return BufferCopy if si.ast.op is LoadOps.CUSTOM: return CustomOp(si.ast.arg) return Device[si.out.device].get_runner(si.ast)