mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
uops can have multiple metadata (#10479)
* uops can have multiple metadata * fixups
This commit is contained in:
@@ -775,9 +775,10 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
def setUp(self) -> None: _METADATA.set(None)
|
||||
|
||||
# NOOPs are not included in kernel metadata
|
||||
@unittest.skip("why would this be true?")
|
||||
def test_exclude_noop_metadata(self):
|
||||
a = Tensor.rand(4, 4)*1
|
||||
self.assertEqual(a.lazydata.metadata.name, "__mul__")
|
||||
self.assertEqual(a.lazydata.metadata[0].name, "__mul__")
|
||||
k = a.schedule()[-1]
|
||||
self.assertEqual([m.name for m in k.metadata], ["rand"])
|
||||
|
||||
@@ -794,7 +795,7 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
W = Tensor.rand(3, 3, requires_grad=True)
|
||||
out = x.matmul(W)
|
||||
self.assertEqual(out.lazydata.metadata.name, "matmul")
|
||||
self.assertEqual(out.lazydata.metadata[0].name, "matmul")
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "matmul")
|
||||
@@ -802,7 +803,7 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
def test_relu(self):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
out = x.relu()
|
||||
self.assertEqual(out.lazydata.metadata.name, "relu")
|
||||
self.assertEqual(out.lazydata.metadata[0].name, "relu")
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "relu")
|
||||
@@ -811,9 +812,9 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
y = Tensor.rand(3, requires_grad=True)
|
||||
out = x.relu() * y.sigmoid()
|
||||
self.assertEqual(out.lazydata.metadata.name, "__mul__")
|
||||
self.assertEqual(out.lazydata.src[0].metadata.name, "relu")
|
||||
self.assertEqual(out.lazydata.src[1].metadata.name, "sigmoid")
|
||||
self.assertEqual(out.lazydata.metadata[0].name, "__mul__")
|
||||
self.assertEqual(out.lazydata.src[0].metadata[0].name, "relu")
|
||||
self.assertEqual(out.lazydata.src[1].metadata[0].name, "sigmoid")
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 3)
|
||||
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
|
||||
@@ -822,17 +823,17 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
x = Tensor.rand(3, requires_grad=True).realize()
|
||||
y = Tensor.rand(3, requires_grad=True).realize()
|
||||
out = (x.relu() * y.sigmoid()).sum()
|
||||
self.assertEqual(out.lazydata.metadata.name, "sum")
|
||||
self.assertEqual(out.lazydata.metadata[0].name, "sum")
|
||||
out.backward()
|
||||
self.assertEqual(x.grad.lazydata.metadata.name, "relu")
|
||||
self.assertTrue(x.grad.lazydata.metadata.backward)
|
||||
self.assertEqual(y.grad.lazydata.metadata.name, "sigmoid")
|
||||
self.assertTrue(y.grad.lazydata.metadata.backward)
|
||||
self.assertEqual(x.grad.lazydata.metadata[0].name, "relu")
|
||||
self.assertTrue(x.grad.lazydata.metadata[0].backward)
|
||||
self.assertEqual(y.grad.lazydata.metadata[0].name, "sigmoid")
|
||||
self.assertTrue(y.grad.lazydata.metadata[0].backward)
|
||||
si = Tensor.schedule(out, x.grad, y.grad)[-1]
|
||||
self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
|
||||
self.assertEqual(set(m.name for m in si.metadata), {"sigmoid", "sigmoid", "relu"})
|
||||
self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}")
|
||||
self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "__mul__", "relu"})
|
||||
bw = [m for m in si.metadata if m.backward]
|
||||
self.assertEqual(len(bw), 1)
|
||||
self.assertEqual(len(bw), 2)
|
||||
self.assertEqual(bw[0].name, "sigmoid")
|
||||
|
||||
class TestIdxUpcast(unittest.TestCase):
|
||||
|
||||
@@ -243,20 +243,20 @@ class Kernel:
|
||||
|
||||
def create_kernel(x:UOp, b:UOp|None=None):
|
||||
if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype)
|
||||
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), (m,) if (m:=x.metadata) else ()))
|
||||
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), m if (m:=x.metadata) else ()))
|
||||
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
|
||||
return buffer.assign(kernel).reshape(x.shape)
|
||||
|
||||
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER}
|
||||
def append_to_kernel(x:UOp):
|
||||
new_srcs: list[UOp] = []
|
||||
metadata = dict.fromkeys(x.arg.metadata)
|
||||
metadata = x.arg.metadata
|
||||
for s in x.src:
|
||||
if s.op in DONT_PLACE_IN_KERNEL or s.op is Ops.GBARRIER: new_srcs.append(s)
|
||||
else:
|
||||
new_srcs.extend(s.src)
|
||||
if s.base.op not in {Ops.CONST, Ops.DEVICE} and (m:=s.metadata): metadata[m] = None
|
||||
if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(metadata)))
|
||||
if s.base.op not in {Ops.CONST, Ops.DEVICE} and (m:=s.metadata): metadata += m
|
||||
if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(dedup(metadata))))
|
||||
|
||||
create_kernels = PatternMatcher([
|
||||
# always give assign/contiguous a kernel
|
||||
|
||||
@@ -63,5 +63,5 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
|
||||
if v is None: continue
|
||||
if k in grads: grads[k] = grads[k] + v
|
||||
else: grads[k] = v
|
||||
if (forward_metadata:=all_metadata.get(t0)) is not None: all_metadata[v] = dataclasses.replace(forward_metadata, backward=True)
|
||||
if len(forward_metadata:=all_metadata.get(t0, ())): all_metadata[v] = tuple(dataclasses.replace(x, backward=True) for x in forward_metadata)
|
||||
return grads
|
||||
|
||||
@@ -177,7 +177,7 @@ class Tensor(MathTrait):
|
||||
|
||||
def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
|
||||
new_uop: UOp = fxn(*[t.lazydata for t in (self,)+x], **kwargs)
|
||||
if (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = metadata
|
||||
if (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,)
|
||||
needs_input_grad = [t.requires_grad for t in (self,)+x]
|
||||
return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False)
|
||||
|
||||
|
||||
@@ -221,7 +221,7 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
|
||||
class UOpMetaClass(type):
|
||||
ucache:dict[tuple, weakref.ReferenceType[UOp]] = {}
|
||||
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, tag:Any=None,
|
||||
metadata:Metadata|None=None, _buffer:Buffer|None=None):
|
||||
metadata:tuple[Metadata,...]|None=None, _buffer:Buffer|None=None):
|
||||
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg, tag), None)) is not None and (ret:=wret()) is not None: return ret
|
||||
UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key))
|
||||
for s in src: s.children.add(ref)
|
||||
@@ -234,7 +234,7 @@ class UOpMetaClass(type):
|
||||
|
||||
# some uops map to other stuff
|
||||
buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers
|
||||
all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary() # TODO: should this be here?
|
||||
all_metadata:weakref.WeakKeyDictionary[UOp, tuple[Metadata, ...]] = weakref.WeakKeyDictionary() # TODO: should this be here?
|
||||
|
||||
# NOTE: this should be frozen, but frozen is slower
|
||||
@dataclass(eq=False, slots=True)
|
||||
@@ -251,8 +251,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
for s in self.src: s.children.discard(ref)
|
||||
del UOpMetaClass.ucache[k]
|
||||
def __reduce__(self):
|
||||
args = [self.op, self.dtype, self.src, self.arg, self.tag]
|
||||
args.append(self.metadata)
|
||||
args = [self.op, self.dtype, self.src, self.arg, self.tag, self.metadata]
|
||||
if self.op is Ops.BUFFER and self.realized is not None and PICKLE_BUFFERS: args.append(self.realized)
|
||||
return UOp, tuple(args)
|
||||
def replace(self, **kwargs) -> UOp:
|
||||
@@ -494,7 +493,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return UOp(Ops.COPY, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), arg)
|
||||
def clone(self) -> UOp: return self.copy_to_device(self.device)
|
||||
@property
|
||||
def metadata(self) -> Metadata|None: return all_metadata.get(self, None)
|
||||
def metadata(self) -> tuple[Metadata, ...]|None: return all_metadata.get(self, None)
|
||||
|
||||
# *** uop movement ops ***
|
||||
|
||||
@@ -990,7 +989,7 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=N
|
||||
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, input_map:dict[UOp, UOp]|None=None) -> dict[UOp, UOp]:
|
||||
rewrite_ctx = RewriteContext(pm, ctx)
|
||||
new_map = {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in sink.toposort()}
|
||||
all_metadata.update((v, k.metadata) for k,v in reversed(new_map.items()) if k.metadata is not None)
|
||||
all_metadata.update((v, tuple(dedup(all_metadata.get(v, ())+k.metadata))) for k,v in new_map.items() if k.metadata is not None)
|
||||
if input_map is not None:
|
||||
for k,v in input_map.items(): new_map[k] = new_map.get(v,v)
|
||||
return new_map
|
||||
|
||||
Reference in New Issue
Block a user