FROM -> COPY, move vars_from_ast (#2675)

This commit is contained in:
George Hotz
2023-12-07 16:32:30 -08:00
committed by GitHub
parent 51af99367f
commit 00d9eda961
13 changed files with 22 additions and 26 deletions

View File

@@ -50,7 +50,7 @@ jobs:
- name: Test Docs
run: |
python docs/abstractions.py
python docs/beautiful.py
python docs/abstractions2.py
- name: Test Quickstart
run: awk '/```python/{flag=1;next}/```/{flag=0}flag' docs/quickstart.md > quickstart.py && PYTHONPATH=. python quickstart.py
- name: Fuzz Test symbolic

View File

@@ -23,7 +23,7 @@ repos:
name: docs
entry: |
python3 docs/abstractions.py
python3 docs/beautiful.py
python3 docs/abstractions2.py
language: system
always_run: true
pass_filenames: false

View File

@@ -104,7 +104,7 @@ class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto
class ReduceOps(Enum): SUM = auto(); MAX = auto()
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto()
class TernaryOps(Enum): MULACC = auto(); WHERE = auto()
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto()
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto()
# NOTE: if you have a CompiledBuffer(DeviceBuffer)
# you do not need to implement the MovementOps
# as they are handled by the ShapeTracker(in tinygrad/shape/shapetracker.py, code 7/10)
@@ -133,9 +133,9 @@ assert len(lazyop.src) == 2
# the first source is the 2, it comes from the CPU
# the source is a LazyBuffer that is a "CPU" Tensor
# again, a LazyOp AST is like a GPU kernel. you have to copy the data on the device first
assert lazyop.src[0].op.op == LoadOps.FROM
assert lazyop.src[0].op.op == LoadOps.COPY
assert lazyop.src[0].op.src[0].device == "CPU"
assert lazyop.src[0].op.src[0].op.src[0].realized._buf[0] == 2, "the src of the FROM LazyOP is a LazyBuffer on the CPU holding [2.]"
assert lazyop.src[0].op.src[0].op.src[0].realized._buf[0] == 2, "the src of the COPY LazyOP is a LazyBuffer on the CPU holding [2.]"
assert result.lazydata.realized is None, "the LazyBuffer is not realized yet"
# now we realize the LazyBuffer

View File

@@ -1,15 +1,13 @@
from typing import List
from extra.models.resnet import ResNet50
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps
from tinygrad.ops import LoadOps, vars_from_ast
from tinygrad.device import Device, Compiled
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.search import time_linearizer, beam_search, bufs_from_lin
from tinygrad.helpers import ansilen, DEBUG, getenv
from tinygrad.lazy import vars_from_ast
from tinygrad.shape.symbolic import sym_infer
if __name__ == "__main__":
mdl = ResNet50()
seen = set()

View File

@@ -7,7 +7,7 @@ from tinygrad.features.search import get_linearizer_actions, bufs_from_lin, tupl
from tinygrad.graph import print_tree
from tinygrad.helpers import getenv
from tinygrad.device import Device, Compiled, Interpreted
from tinygrad.lazy import vars_from_ast
from tinygrad.ops import vars_from_ast
device = Device[Device.DEFAULT]

View File

@@ -1,8 +1,7 @@
from __future__ import annotations
import os, math, itertools
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
from tinygrad.lazy import vars_from_ast
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, vars_from_ast
from tinygrad.device import Device, Compiled
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, ansilen, getenv, prod, DEBUG, round_up
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction

View File

@@ -6,11 +6,10 @@ from enum import Enum, auto
from dataclasses import dataclass
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, getenv, all_same, to_function_name, flatten
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, vars_from_ast
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
from tinygrad.codegen.kernel import LocalBuffer, Kernel
from tinygrad.lazy import vars_from_ast
from tinygrad.features.image import to_image_idx
# bottom ones are asm only

View File

@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable
import importlib, inspect, functools, pathlib, time, re, ctypes
from tinygrad.helpers import ansilen, DEBUG, getenv, GlobalCounters, colored, BEAM, NOOPT, all_int, to_function_name, DType, from_mv, dtypes, flat_mv, ImageDType
from tinygrad.shape.symbolic import Variable, sym_infer, sint
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, vars_from_ast
if TYPE_CHECKING:
from tinygrad.codegen.linearizer import Linearizer
@@ -233,7 +233,6 @@ class CompiledASTRunner(JITRunner):
if ast:
info = get_lazyop_info(ast)
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
from tinygrad.lazy import vars_from_ast
self.vars = vars_from_ast(ast)
assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}"

View File

@@ -1,8 +1,7 @@
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
import itertools, random, math, time, multiprocessing, traceback, signal
from tinygrad.lazy import vars_from_ast
from tinygrad.device import Device, Compiled, Buffer
from tinygrad.ops import MemBuffer
from tinygrad.ops import MemBuffer, vars_from_ast
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
from tinygrad.codegen.linearizer import Linearizer, UOp
from tinygrad.shape.symbolic import sym_infer

View File

@@ -5,9 +5,9 @@ from weakref import ref, WeakSet, WeakValueDictionary
import numpy as np
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, dedup, merge_dicts, all_int, ImageDType, DEBUG
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps, get_lazyop_info
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps, get_lazyop_info, vars_from_ast
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.symbolic import sint
from tinygrad.device import Buffer
# lazy can recurse a lot
@@ -78,7 +78,7 @@ def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: ret
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
# NOTE: this is the canonical order
def vars_from_ast(ast:LazyOp) -> List[Variable]: return sorted(set.union(*[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set()), key=lambda x: str(x.expr))
lazycache: WeakValueDictionary = WeakValueDictionary()
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, base:Optional[LazyBuffer]=None):
@@ -206,9 +206,9 @@ class LazyBuffer:
return LazyBuffer.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
def copy_to_device(self, device:str) -> LazyBuffer:
# back off a FROM if it's a double FROM
if not self.realized and self.op.op == LoadOps.FROM and cast(LazyBuffer, self.op.src[0]).device == device: return cast(LazyBuffer, self.op.src[0])
return LazyBuffer.loadop(LoadOps.FROM, self.shape, self.dtype, device, src=self.contiguous())
# back off a COPY if it's a double COPY
if not self.realized and self.op.op == LoadOps.COPY and cast(LazyBuffer, self.op.src[0]).device == device: return cast(LazyBuffer, self.op.src[0])
return LazyBuffer.loadop(LoadOps.COPY, self.shape, self.dtype, device, src=self.contiguous())
def contiguous(self:LazyBuffer) -> LazyBuffer:
if not self.realized and self.op.op in LoadOps and self.op.op != LoadOps.CONST: return self # all LoadOps are already contiguous (except CONST)

View File

@@ -17,7 +17,7 @@ class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
# Ops below this line are not allowed in ASTs
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, BufferOps]
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]
@@ -80,6 +80,8 @@ class LazyOp:
def shrink(self, _): raise NotImplementedError
def stride(self, _): raise NotImplementedError
def vars_from_ast(ast:LazyOp) -> List[Variable]: return sorted(set.union(*[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set()), key=lambda x: str(x.expr))
# **************** independent FlopCounter ****************
@dataclass

View File

@@ -12,9 +12,9 @@ class CustomOp(JITRunner):
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs)
def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.FROM, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}"
assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.COPY, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}"
if si.ast.op is LoadOps.EMPTY: return None
if si.ast.op is LoadOps.FROM: return BufferCopy
if si.ast.op is LoadOps.COPY: return BufferCopy
if si.ast.op is LoadOps.CUSTOM: return CustomOp(si.ast.arg)
return Device[si.out.device].get_runner(si.ast)