remove becomes map (#10533)

* remove becomes map

* add comment and delete dead code

* multi is a view
This commit is contained in:
George Hotz
2025-05-27 11:47:11 -07:00
committed by GitHub
parent 271110bb5a
commit 5b268121d4
4 changed files with 6 additions and 17 deletions

View File

@@ -3,7 +3,6 @@ from tinygrad.helpers import GlobalCounters
from tinygrad.nn.datasets import mnist
class TestDataset(unittest.TestCase):
@unittest.expectedFailure
def test_dataset_is_realized(self):
X_train, _, _, _ = mnist()
X_train[0].contiguous().realize()

View File

@@ -4,7 +4,7 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewr
from tinygrad.uop.ops import can_pad, sint, track_rewrites, _substitute
from tinygrad.codegen.lowerer import get_contraction_with_reduce, get_contraction
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, ContextVar, Context, diskcache_put, flatten
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, ContextVar, Context, diskcache_put
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP, CAPTURE_PROCESS_REPLAY
from tinygrad.dtype import ImageDType
from tinygrad.engine.multi import multi_pm, replace_allreduce
@@ -67,9 +67,6 @@ sym = symbolic_simple+PatternMatcher([
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# MULTI in SINK just flattens srcs
(UPat(Ops.SINK, name="x"),
lambda x: UOp.sink(*new_src) if (new_src:=tuple(flatten([s.src if s.op is Ops.MULTI else [s] for s in x.src]))) != x.src else None),
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
@@ -241,7 +238,7 @@ def create_kernel(x:UOp, b:UOp|None=None):
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
return buffer.assign(kernel).reshape(x.shape)
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT}
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MULTI}
def append_to_kernel(x:UOp):
new_srcs: list[UOp] = []
metadata = x.arg.metadata
@@ -582,14 +579,4 @@ def get_kernelize_map(big_sink:UOp) -> dict[UOp, UOp]:
kernel_asts = [u.arg.ast for u in sched_sink.toposort() if u.op is Ops.KERNEL]
PROCESS_REPLAY_CAPTURE[id(big_sink)] = pickle.dumps((big_sink, ContextVar._cache, kernel_asts))
# map tensors to buffer/assign/const
# TODO: this is not right, and causes TestDataset.test_dataset_is_realized to fail unless I unprincipledly add Ops.COPY, which breaks others
becomes_map: dict[UOp, UOp] = {}
for k,v in tensor_map.items():
if k is v: continue
op = v.base.op
if op in {Ops.BUFFER, Ops.ASSIGN}: becomes_map[k] = v
if op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
if op is Ops.MULTI and all(x.base in becomes_map for x in v.base.src): becomes_map[k] = v
return becomes_map
return tensor_map

View File

@@ -501,6 +501,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def base(self) -> UOp:
if (self.op is Ops.VIEW and len(self.src) != 0) or self.op in GroupOp.Movement: return self.src[0].base
if self.op is Ops.MULTI: return self.src[0].base # MULTI is really a VIEW
return self
def view(self, new_st:ShapeTracker) -> UOp: return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)

View File

@@ -39,6 +39,8 @@ buffer_spec = PatternMatcher([
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)),
# allow VIEW here. TODO: what views specifically are allowed? does this mess with gradient?
(UPat(Ops.VIEW), lambda: True),
])
def validate_kernel(k:UOp):