make uops.add private (#4950)

* make uops.add private

* modernize all tests
This commit is contained in:
George Hotz
2024-06-14 03:23:25 -07:00
committed by GitHub
parent dc9e9e4363
commit 9823752397
6 changed files with 94 additions and 106 deletions

View File

@@ -15,8 +15,8 @@ def uops_to_rdna(function_name:str, uops:UOpGraph) -> str:
u.vin = tuple(n if x == o else x for x in u.vin)
# pointer indexing
if u.uop in {UOps.LOAD, UOps.STORE} and u.vin[0].dtype.itemsize > 1:
val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=u.vin[0].dtype.itemsize, insert_before=uops.uops.index(u))
ptr = uops.add(UOps.ALU, dtypes.int, (u.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(u))
val = UOp(UOps.CONST, dtypes.int, tuple(), arg=u.vin[0].dtype.itemsize, insert_before=uops.uops.index(u))
ptr = UOp(UOps.ALU, dtypes.int, (u.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(u))
u.vin = (u.vin[0], ptr) + u.vin[2:]
#uops.print()

View File

@@ -141,7 +141,7 @@ class TestPatternMatcher(unittest.TestCase):
@unittest.skip("no longer supported")
def test_rewrite_graph_folds(self):
uops = UOpGraph()
uops.add(UOps.CONST, dtypes.float, arg=2.0, simplify=False)
UOp(UOps.CONST, dtypes.float, arg=2.0, simplify=False)
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float),
lambda x: UOp(UOps.CAST, dtypes.int, (UOp(UOps.ALU, x.dtype, (x, x), BinaryOps.ADD),)))])
matcher.rewrite_graph(uops)
@@ -153,7 +153,7 @@ class TestPatternMatcher(unittest.TestCase):
@unittest.skip("no longer supported")
def test_rewrite_graph_adds(self):
uops = UOpGraph()
uops.add(UOps.CONST, dtypes.int, arg=2, simplify=False)
UOp(UOps.CONST, dtypes.int, arg=2, simplify=False)
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.int),
lambda x: UOp(UOps.STORE, x.dtype, (UOp(UOps.DEFINE_GLOBAL, x.dtype, tuple(), None), x)))])
matcher.rewrite_graph(uops)

View File

@@ -2,75 +2,69 @@ import unittest
from tinygrad import dtypes, Variable
from tinygrad.dtype import PtrDType
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps
from tinygrad.codegen.uops import UOpGraph, UOps
from tinygrad.codegen.uops import UOpGraph, UOps, UOp
class TestUOpGraph(unittest.TestCase):
def test_add_constant_fold(self):
g = UOpGraph()
c1 = g.add(UOps.CONST, dtypes.float, arg=1.0)
c2 = g.add(UOps.CONST, dtypes.float, arg=2.0)
out = g.add(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
g.add(UOps.SINK, None, (out,))
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
out = UOp(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
self.assertEqual(len(g.uops), 1)
out = g.uops[-1]
self.assertEqual(out.uop, UOps.CONST)
self.assertEqual(out.arg, 3.0)
def test_where_same_fold(self):
g = UOpGraph()
v = g.add(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
c0 = g.add(UOps.CONST, dtypes.int, arg=0)
vc = g.add(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
c1 = g.add(UOps.CONST, dtypes.float, arg=1.0)
out = g.add(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
g.add(UOps.SINK, None, (out,))
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
c0 = UOp(UOps.CONST, dtypes.int, arg=0)
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
out = UOp(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
self.assertEqual(len(g.uops), 1)
out = g.uops[-1]
self.assertEqual(out.uop, UOps.CONST)
self.assertEqual(out.arg, 1.0)
def test_where_const_fold(self):
g = UOpGraph()
bf = g.add(UOps.CONST, dtypes.bool, arg=False)
c1 = g.add(UOps.CONST, dtypes.float, arg=1.0)
c2 = g.add(UOps.CONST, dtypes.float, arg=2.0)
out = g.add(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
g.add(UOps.SINK, None, (out,))
bf = UOp(UOps.CONST, dtypes.bool, arg=False)
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
out = UOp(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
self.assertEqual(len(g.uops), 1)
out = g.uops[-1]
self.assertEqual(out.uop, UOps.CONST)
self.assertEqual(out.arg, 2.0)
def test_const_cast(self):
g = UOpGraph()
bf = g.add(UOps.CONST, dtypes.bool, arg=False)
out = g.add(UOps.CAST, dtypes.int, (bf,))
g.add(UOps.SINK, None, (out,))
bf = UOp(UOps.CONST, dtypes.bool, arg=False)
out = UOp(UOps.CAST, dtypes.int, (bf,))
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
self.assertEqual(len(g.uops), 1)
out = g.uops[-1]
self.assertEqual(out.uop, UOps.CONST)
self.assertEqual(out.arg, 0)
def test_cast_vectorized_fold(self):
g = UOpGraph()
d0 = g.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=(0, True))
idx = g.add(UOps.CONST, dtypes.int, arg=0)
ld = g.add(UOps.LOAD, dtypes.float.vec(2), (d0, idx))
cast = g.add(UOps.CAST, dtypes.float.vec(2), (ld,))
x = g.add(UOps.GEP, dtypes.float, (cast, ), arg=0)
alu = g.add(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
out = g.add(UOps.STORE, dtypes.float, (d0, idx, alu))
g.add(UOps.SINK, None, (out,))
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=(0, True))
idx = UOp(UOps.CONST, dtypes.int, arg=0)
ld = UOp(UOps.LOAD, dtypes.float.vec(2), (d0, idx))
cast = UOp(UOps.CAST, dtypes.float.vec(2), (ld,))
x = UOp(UOps.GEP, dtypes.float, (cast, ), arg=0)
alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
out = UOp(UOps.STORE, dtypes.float, (d0, idx, alu))
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
self.assertEqual(len([x for x in g.uops if x.uop is UOps.CAST]), 0)
def test_depth_2_const_fold(self):
g = UOpGraph()
v = g.add(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
c2 = g.add(UOps.CONST, dtypes.int, arg=2)
c4 = g.add(UOps.CONST, dtypes.int, arg=4)
vc = g.add(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
out = g.add(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
g.add(UOps.SINK, None, (out,))
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
self.assertEqual(len(g.uops), 3)
out = g.uops[-1]
self.assertEqual(out.uop, UOps.ALU)

View File

@@ -14,8 +14,7 @@ from tinygrad.codegen.uops import UOpGraph
from test.helpers import is_dtype_supported
def _uops_to_prg(uops_list, print=False):
uops = UOpGraph()
for l in uops_list: uops.add(l.uop, l.dtype, l.vin, l.arg)
uops = UOpGraph(uops_list)
src = Device[Device.DEFAULT].renderer.render("test", uops)
if print: uops.print()
has_local = Device[Device.DEFAULT].renderer.has_local
@@ -32,10 +31,10 @@ def _test_single_value(vals, op, dts):
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (i+1, False)) for i,dtype in enumerate(dts)]
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts))
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
out = uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=dtype.np).data) for a,dtype in zip(vals, dts)]
prg = _uops_to_prg(uops)
prg = _uops_to_prg([out])
prg.exec([buf]+buf2)
ret = np.empty(1, output_dtype.np)
buf.copyout(ret.data)
@@ -47,9 +46,9 @@ def _test_single_value_const(vals, op, dts):
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, True))
loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
out = uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
prg = _uops_to_prg(uops)
prg = _uops_to_prg([out])
prg.exec([buf])
ret = np.empty(1, output_dtype.np)
buf.copyout(ret.data)
@@ -59,9 +58,9 @@ def _test_uops_result(output_dtype, uops, res):
# uops = []
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, True))
# res = output_fn(uops)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res))
out = uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res))
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
prg = _uops_to_prg(uops, print=True)
prg = _uops_to_prg([out], print=True)
prg.exec([buf])
ret = np.empty(1, output_dtype.np)
buf.copyout(ret.data)
@@ -235,14 +234,13 @@ class TestGatedStoreRewrite(unittest.TestCase):
@unittest.skip("not yet implemented")
def test_wrap_store_parents(self):
# wraps all store parents in the valid branch
uops = UOpGraph()
gmem = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
gidx0 = uops.add(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
gmem = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
idx = gidx0 * UOp.const(dtypes.int, 2)
value = uops.add(UOps.CONST, dtypes.float, (), 42.0)
value = UOp(UOps.CONST, dtypes.float, (), 42.0)
gate = uops.add(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT)
uops.add(UOps.STORE, None, (gmem, idx, value, gate))
gate = UOp(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT)
uops = UOpGraph([UOp(UOps.STORE, None, (gmem, idx, value, gate))])
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
if_uop = next(u for u in uops if u.uop is UOps.IF)
endif = next(u for u in uops if u.uop is UOps.ENDIF)
@@ -253,17 +251,17 @@ class TestGatedStoreRewrite(unittest.TestCase):
@unittest.skip("not yet implemented")
def test_wrap_some_parents(self):
# some parents are used outside the branch
uops = UOpGraph()
gmem0 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
gmem1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (1, True))
gidx0 = uops.add(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
gmem0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
gmem1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (1, True))
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
idx = gidx0 * UOp.const(dtypes.int, 2)
value0 = uops.add(UOps.CONST, dtypes.float, (), 42.0)
value1 = uops.add(UOps.CONST, dtypes.float, (), 43.0)
value0 = UOp(UOps.CONST, dtypes.float, (), 42.0)
value1 = UOp(UOps.CONST, dtypes.float, (), 43.0)
gate = uops.add(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT)
uops.add(UOps.STORE, None, (gmem0, idx, value0, gate))
uops.add(UOps.STORE, None, (gmem1, idx, value1))
gate = UOp(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT)
outs = [UOp(UOps.STORE, None, (gmem0, idx, value0, gate))]
outs.append(UOp(UOps.STORE, None, (gmem1, idx, value1)))
uops = UOpGraph(outs)
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
if_uop = next(u for u in uops if u.uop is UOps.IF)
endif = next(u for u in uops if u.uop is UOps.ENDIF)
@@ -297,27 +295,25 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT in {"CUDA"} and getenv("PTX"), "This only tests assembly backends")
class TestAssembly(unittest.TestCase):
def test_bitshift_left(self):
uops = UOpGraph()
g1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), (0, True))
c1 = uops.add(UOps.CONST, dtypes.int, (), 2)
c2 = uops.add(UOps.CONST, dtypes.int, (), 3)
l1 = uops.add(UOps.LOAD, dtypes.int, (g1, c1))
a1 = uops.add(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL)
a2 = uops.add(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
uops.add(UOps.SINK, None, (a1,a2))
g1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), (0, True))
c1 = UOp(UOps.CONST, dtypes.int, (), 2)
c2 = UOp(UOps.CONST, dtypes.int, (), 3)
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL)
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
uops = UOpGraph([UOp(UOps.SINK, None, (a1,a2))])
Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops.uops[-1].arg, BinaryOps.MUL)
self.assertEqual(uops.uops[-2].arg, BinaryOps.SHL)
def test_bitshift_right(self):
uops = UOpGraph()
g1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), (0, True))
c1 = uops.add(UOps.CONST, dtypes.int, (), 2)
c2 = uops.add(UOps.CONST, dtypes.int, (), 3)
l1 = uops.add(UOps.LOAD, dtypes.int, (g1, c1))
a1 = uops.add(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
a2 = uops.add(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
uops.add(UOps.SINK, None, (a1,a2))
g1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), (0, True))
c1 = UOp(UOps.CONST, dtypes.int, (), 2)
c2 = UOp(UOps.CONST, dtypes.int, (), 3)
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
uops = UOpGraph([UOp(UOps.SINK, None, (a1,a2))])
Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops.uops[-1].arg, BinaryOps.IDIV)
self.assertEqual(uops.uops[-2].arg, BinaryOps.SHR)

View File

@@ -2,7 +2,7 @@ import unittest
from tinygrad import Tensor
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import lower_schedule_item
from tinygrad.codegen.uops import UOpGraph, UOps
from tinygrad.codegen.uops import UOpGraph, UOps, UOp
from tinygrad.ops import BinaryOps, TernaryOps
from tinygrad.dtype import dtypes
@@ -55,26 +55,24 @@ class TestUOpsStats(unittest.TestCase):
#MULACC should have the same stats as MUL + ADD
def test_mulacc(self):
uops = UOpGraph()
globl = uops.add(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
o1 = uops.add(UOps.CONST, dtypes.int, tuple(), 1)
o2 = uops.add(UOps.CONST, dtypes.int, tuple(), 2)
u1 = uops.add(UOps.LOAD, dtypes.int, (globl, o1))
u2 = uops.add(UOps.LOAD, dtypes.int, (globl, o2))
u3 = uops.add(UOps.CONST, dtypes.int, tuple(), 3)
u4 = uops.add(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
u5 = uops.add(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
uops.add(UOps.SINK, None, (u5,))
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1))
u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
uops = UOpGraph([UOp(UOps.SINK, None, (u5,))])
uops_fma = UOpGraph()
globl = uops_fma.add(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
o1 = uops_fma.add(UOps.CONST, dtypes.int, tuple(), 1)
o2 = uops_fma.add(UOps.CONST, dtypes.int, tuple(), 2)
u1 = uops_fma.add(UOps.LOAD, dtypes.int, (globl, o1))
u2 = uops_fma.add(UOps.LOAD, dtypes.int, (globl, o2))
u3 = uops_fma.add(UOps.CONST, dtypes.int, tuple(), 3)
u4 = uops_fma.add(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
uops_fma.add(UOps.SINK, None, (u4,))
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1))
u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
uops_fma = UOpGraph([UOp(UOps.SINK, None, (u4,))])
self.assertEqual(uops.flops_mem(), uops_fma.flops_mem())

View File

@@ -254,12 +254,12 @@ class UOpGraph:
def __init__(self, add_nodes:Optional[List[UOp]]=None):
self.nodes: Dict[Tuple, UOp] = {}
self._uops: Optional[List[UOp]] = None
if add_nodes is not None: self.multiadd(add_nodes)
if add_nodes is not None: self._multiadd(add_nodes)
def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
def __getitem__(self, index) -> UOp: return self.uops[index]
def multiadd(self, unprocessed_nodes:List[UOp]):
def _multiadd(self, unprocessed_nodes:List[UOp]):
# add nodes to graph in reverse BFS order
# TODO: i feel like this is written in a few places, possible to library it?
in_degree: DefaultDict[UOp, int] = defaultdict(int)
@@ -278,7 +278,7 @@ class UOpGraph:
while len(queue):
n = queue.pop(0)
if n in replace_nodes: continue
replace_nodes[n] = self.add(n.uop, n.dtype, tuple(replace_nodes.get(x, x) for x in n.vin), n.arg)
replace_nodes[n] = self._add(n.uop, n.dtype, tuple(replace_nodes.get(x, x) for x in n.vin), n.arg)
for x in children[n]:
in_degree[x] -= 1
if in_degree[x] == 0:
@@ -411,7 +411,7 @@ class UOpGraph:
if type_verify: self.type_verify()
def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None) -> UOp:
def _add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None) -> UOp:
if found:=self.nodes.get(key:=(uop, dtype, vin, arg)): return found
self.nodes[key] = ret = UOp(*key)
return ret