mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
remove becomes map (#10533)
* remove becomes map * add comment and delete dead code * multi is a view
This commit is contained in:
@@ -3,7 +3,6 @@ from tinygrad.helpers import GlobalCounters
|
||||
from tinygrad.nn.datasets import mnist
|
||||
|
||||
class TestDataset(unittest.TestCase):
|
||||
@unittest.expectedFailure
|
||||
def test_dataset_is_realized(self):
|
||||
X_train, _, _, _ = mnist()
|
||||
X_train[0].contiguous().realize()
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewr
|
||||
from tinygrad.uop.ops import can_pad, sint, track_rewrites, _substitute
|
||||
from tinygrad.codegen.lowerer import get_contraction_with_reduce, get_contraction
|
||||
from tinygrad.codegen.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, ContextVar, Context, diskcache_put, flatten
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, ContextVar, Context, diskcache_put
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP, CAPTURE_PROCESS_REPLAY
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.engine.multi import multi_pm, replace_allreduce
|
||||
@@ -67,9 +67,6 @@ sym = symbolic_simple+PatternMatcher([
|
||||
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
|
||||
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
# MULTI in SINK just flattens srcs
|
||||
(UPat(Ops.SINK, name="x"),
|
||||
lambda x: UOp.sink(*new_src) if (new_src:=tuple(flatten([s.src if s.op is Ops.MULTI else [s] for s in x.src]))) != x.src else None),
|
||||
# reduce of size 0 is the identity element
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
@@ -241,7 +238,7 @@ def create_kernel(x:UOp, b:UOp|None=None):
|
||||
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
|
||||
return buffer.assign(kernel).reshape(x.shape)
|
||||
|
||||
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT}
|
||||
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MULTI}
|
||||
def append_to_kernel(x:UOp):
|
||||
new_srcs: list[UOp] = []
|
||||
metadata = x.arg.metadata
|
||||
@@ -582,14 +579,4 @@ def get_kernelize_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
kernel_asts = [u.arg.ast for u in sched_sink.toposort() if u.op is Ops.KERNEL]
|
||||
PROCESS_REPLAY_CAPTURE[id(big_sink)] = pickle.dumps((big_sink, ContextVar._cache, kernel_asts))
|
||||
|
||||
# map tensors to buffer/assign/const
|
||||
# TODO: this is not right, and causes TestDataset.test_dataset_is_realized to fail unless I unprincipledly add Ops.COPY, which breaks others
|
||||
becomes_map: dict[UOp, UOp] = {}
|
||||
for k,v in tensor_map.items():
|
||||
if k is v: continue
|
||||
op = v.base.op
|
||||
if op in {Ops.BUFFER, Ops.ASSIGN}: becomes_map[k] = v
|
||||
if op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
|
||||
if op is Ops.MULTI and all(x.base in becomes_map for x in v.base.src): becomes_map[k] = v
|
||||
|
||||
return becomes_map
|
||||
return tensor_map
|
||||
|
||||
@@ -501,6 +501,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@property
|
||||
def base(self) -> UOp:
|
||||
if (self.op is Ops.VIEW and len(self.src) != 0) or self.op in GroupOp.Movement: return self.src[0].base
|
||||
if self.op is Ops.MULTI: return self.src[0].base # MULTI is really a VIEW
|
||||
return self
|
||||
def view(self, new_st:ShapeTracker) -> UOp: return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
|
||||
|
||||
|
||||
@@ -39,6 +39,8 @@ buffer_spec = PatternMatcher([
|
||||
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
|
||||
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)),
|
||||
# allow VIEW here. TODO: what views specifically are allowed? does this mess with gradient?
|
||||
(UPat(Ops.VIEW), lambda: True),
|
||||
])
|
||||
|
||||
def validate_kernel(k:UOp):
|
||||
|
||||
Reference in New Issue
Block a user