diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index abeb26d9a7..263465a477 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -131,17 +131,17 @@ class TestGraphRewrite(unittest.TestCase): self.assertEqual(nout.src[1].arg, 3.0) def test_consts_go_last(self): - a = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('a', 0, 1)) - b = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('b', 0, 1)) - c = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('c', 0, 1)) - d = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('d', 0, 1)) + a = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('a', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))) + b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('b', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))) + c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('c', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))) + d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('d', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))) outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b] for out in outs: sink = graph_rewrite(out, constant_folder) print(sink) self.assertEqual(sink.op, UOps.ALU) self.assertEqual(sink.src[1].op, UOps.CONST) - self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 3) + self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 1) class TestUOpGraph(unittest.TestCase): def test_add_constant_fold(self): @@ -155,7 +155,7 @@ class TestUOpGraph(unittest.TestCase): self.assertEqual(out.arg, 3.0) def test_where_same_fold(self): - v = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp(UOps.CONST, dtypes.int, (), 0), UOp(UOps.CONST, dtypes.int, (), 1)), arg=Variable('tmp', 0, 1)) + v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('tmp', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))) c0 = UOp(UOps.CONST, dtypes.int, arg=0) vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) @@ -249,7 +249,7 @@ class TestUOpGraph(unittest.TestCase): for i in [2, 4, 8]: vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i))) var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i)) - acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) + acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) uops = to_uops_list([wmma]) assert_equiv_uops(uops[0], acc) @@ -258,7 +258,7 @@ class TestUOpGraph(unittest.TestCase): for i in [2, 4, 8]: var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i)) vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i))) - acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) + acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) uops = to_uops_list([wmma]) assert_equiv_uops(uops[0], acc) @@ -268,19 +268,19 @@ class TestUOpGraph(unittest.TestCase): for i in [4, 8]: vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) + - tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2))) - var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0)) - acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) + tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(Variable(f'tmp{j}', 0.0, 1.0),)) for j in range(i//2))) + var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),)) + acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) uops = to_uops_list([wmma]) assert_equiv_uops(uops[-1], wmma) for i in [4, 8]: - var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0)) + var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),)) vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) + - tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2))) - acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) + tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(Variable(f'tmp{j}', 0.0, 1.0),)) for j in range(i//2))) + acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) uops = to_uops_list([wmma]) assert_equiv_uops(uops[-1], wmma) @@ -288,17 +288,17 @@ class TestUOpGraph(unittest.TestCase): for i in [2, 4, 8]: vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i))) - var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0)) - acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) + var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),)) + acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) uops = to_uops_list([wmma]) assert_equiv_uops(uops[-1], wmma) for i in [2, 4, 8]: - var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0)) + var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),)) vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i))) - acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) + acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) uops = to_uops_list([wmma]) assert_equiv_uops(uops[-1], wmma) @@ -324,13 +324,13 @@ class TestUOpGraph(unittest.TestCase): self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1) def test_depth_2_const_fold(self): - v = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('tmp', 0, 1)) + v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('tmp', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))) c2 = UOp(UOps.CONST, dtypes.int, arg=2) c4 = UOp(UOps.CONST, dtypes.int, arg=4) vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD) out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD) uops = to_uops_list([out]) - self.assertEqual(len(uops), 5) + self.assertEqual(len(uops), 3) out = uops[-1] self.assertEqual(out.op, UOps.ALU) self.assertEqual(out.arg, BinaryOps.ADD) @@ -575,8 +575,8 @@ class TestLoadStoreFolder(unittest.TestCase): def test_simple_load_dont_fold_different_gated(self): buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) - gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1") - gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g2") + gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True))) + gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g2", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True))) load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)] sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),)) sink = float4_rewrite(sink) @@ -591,7 +591,7 @@ class TestLoadStoreFolder(unittest.TestCase): def test_simple_store_fold_gate(self): buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) - gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1") + gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True))) load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)] sink = UOp(UOps.SINK, None, tuple(load)) sink = float4_rewrite(sink) @@ -602,8 +602,8 @@ class TestLoadStoreFolder(unittest.TestCase): def test_simple_store_dont_fold(self): buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) - gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1") - gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g2") + gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True))) + gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g2", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True))) load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)] sink = UOp(UOps.SINK, None, tuple(load)) sink = float4_rewrite(sink) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 6f35e57efd..35ac40e278 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -30,7 +30,7 @@ def Variable(expr, nmin, nmax): # TODO: fix DEFINE_VAR to not need this class TempVar: def __init__(self, x): self.expr = x - return UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)), TempVar(expr)) + return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(TempVar(expr), UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax))) class Node: @staticmethod def sum(ops): return functools.reduce(lambda x,y: x+y, ops) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index e56d69b77d..31fd919409 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -333,7 +333,7 @@ class UOp(MathTrait): @functools.cached_property def cmp_tuple(self) -> Tuple[int, Any, Optional[DType], Tuple[UOp, ...]]: # NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX - return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \ + return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg[0].expr) if self.op is not UOps.ALU else \ self.arg.value, self.dtype, self.src) def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple @functools.cached_property @@ -366,7 +366,7 @@ class UOp(MathTrait): @classmethod def _const(cls, dtype:Optional[DType], b:ConstType|Variable): # TODO: fix dtype of b.max after Variable is just an UOp - if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, (cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max))), b) + if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max)))) if dtype is not None and dtype != (sdtype := dtype.scalar()): return cls(UOps.VECTORIZE, dtype, src=tuple(cls(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count))) return cls(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) @@ -388,7 +388,7 @@ class UOp(MathTrait): def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR]) def variables(self) -> List[Variable]: st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS] - return sorted(set.union(*st_vars, set([x.arg for x in self.sparents if x.op is UOps.DEFINE_VAR])), key=lambda v: v.expr) + return sorted(set.union(*st_vars, set([x.arg[0] for x in self.sparents if x.op is UOps.DEFINE_VAR])), key=lambda v: v.expr) def const_factor(self) -> int: """largest known int that divides self""" if self.op is UOps.CONST: return self.arg @@ -414,7 +414,8 @@ class UOp(MathTrait): @functools.cached_property def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]: # NOTE: returned UOp is assumed to be CONST - if self.op is UOps.DEFINE_VAR and self.src: return self.src[0], self.src[1] if isinstance(self.src[1].arg, int) else None + # TODO: fix DEFINE_VAR arg in tests and remove checking len(self.arg) + if self.op is UOps.DEFINE_VAR and self.arg and len(self.arg) > 1: return self.arg[1], self.arg[2] if isinstance(self.arg[2].arg, int) else None if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax # TODO: UOps.SPECIAL is UOps.DEFINE_VAR if self.op is UOps.SPECIAL: return self.const_like(0), self.const_like(self.arg[1]-1) if isinstance(self.arg[1], int) else None @@ -475,7 +476,8 @@ truncate: Dict[DType, Callable] = {dtypes.bool: bool, def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands)) def uop_alu_resolve(u:UOp) -> sint: - if u.op in {UOps.CONST, UOps.DEFINE_VAR}: return u.arg + if u.op is UOps.CONST: return u.arg + if u.op is UOps.DEFINE_VAR: return u.arg[0] if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src))) raise RuntimeError(f"ALU resolve fail @ {u.op}") diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 87c856987a..4b9e064df3 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -34,7 +34,7 @@ class Program: if not self._ran_post_init and self.uops is not None: # single pass through the uops for u in self.uops: - if u.op is UOps.DEFINE_VAR: self.vars.append(u.arg) + if u.op is UOps.DEFINE_VAR: self.vars.append(u.arg[0]) if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg) if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL]) if u.op is UOps.SPECIAL: diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index f54fa19896..6d3addb7f0 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -212,9 +212,9 @@ class PTXRenderer(Renderer): r[u] = "%" + args[0] kernel = [f".reg .u32 %{args[0]};"] + kernel elif uop is UOps.DEFINE_VAR: - bufs.append((args.expr, dtype)) - r[u] = f"%{args.expr}" - kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param")) + bufs.append((args[0].expr, dtype)) + r[u] = f"%{args[0].expr}" + kk(*self.render_load(args[0].expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param")) elif uop is UOps.CONST: r[u] = const(args, dtype, mov=True) elif uop is UOps.GEP: r[u] = r[src[0]][u.arg] elif uop is UOps.LOAD: diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 4e184dc788..30b407e702 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -147,10 +147,10 @@ class CStyleLanguage(Renderer): kk(f"int {args[0]} = {self.code_for_workitem[args[0][0]](args[0][-1])}; /* {args[1]} */") r[u] = args[0] elif uop is UOps.DEFINE_VAR: - assert args.expr not in seen_vars, f"duplicate variable {args.expr}" - seen_vars.add(args.expr) - bufs[u] = (args.expr, (dtype,False)) - r[u] = args.expr + assert args[0].expr not in seen_vars, f"duplicate variable {args[0].expr}" + seen_vars.add(args[0].expr) + bufs[u] = (args[0].expr, (dtype,False)) + r[u] = args[0].expr elif uop is UOps.LOAD: val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL) # NOTE: this relies on the load not happening if it's in the unselected branch diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 7607e82654..9fbe6c1191 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -19,7 +19,7 @@ render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.pyint, self ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx), LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)), Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else \ - UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, self.min), UOp.const(dtypes.int, self.max)), self), + UOp(UOps.DEFINE_VAR, dtypes.int, arg=(self, UOp.const(dtypes.int, self.min), UOp.const(dtypes.int, self.max))), SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)), AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }