From 83cdc85790488960a87f8c967f48e943f2cdc6fd Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 29 Feb 2024 15:22:26 -0800 Subject: [PATCH] add index to DEFINE_GLOBAL (#3542) * remove DEFINE_GLOBAL from uops with side effects * add index to DEFINE_GLOBAL * bugfix * better var name --- test/test_uops.py | 6 +++--- tinygrad/codegen/linearizer.py | 13 ++++++++++--- tinygrad/codegen/uops.py | 2 +- tinygrad/renderer/cstyle.py | 8 +++++--- tinygrad/runtime/ops_python.py | 2 +- 5 files changed, 20 insertions(+), 11 deletions(-) diff --git a/test/test_uops.py b/test/test_uops.py index 48703932a3..0794e59f52 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -22,8 +22,8 @@ def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], ar def _test_single_value(vals, op, dts): uops = [] output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0] - buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 'data0') - buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), f'data{i+1}') for i,dtype in enumerate(dts)] + buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0')) + buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (i+1, f'data{i+1}')) for i,dtype in enumerate(dts)] loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts)) alu = uop(uops, UOps.ALU, output_dtype, loads, op) uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) @@ -38,7 +38,7 @@ def _test_single_value(vals, op, dts): def _test_single_value_const(vals, op, dts): uops = [] output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0] - buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 'data0') + buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0')) loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts)) alu = uop(uops, UOps.ALU, output_dtype, loads, op) uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 23a5f30c66..78edf9982d 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -190,13 +190,20 @@ class Linearizer(Kernel): self.loop_uops: Dict[str, UOp] = {} # add global buffers + buf_count = 0 + buf_index = {} for i,buf in enumerate(self.bufs): if isinstance(buf, MemBuffer): - self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), f"data{buf.idx}") + if buf.idx not in buf_index: + buf_index[buf.idx] = buf_count + buf_count += 1 + self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, + buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), + (buf_index[buf.idx], f"data{buf.idx}")) # add var vals - for var in self.ast.vars(): + for i,var in enumerate(self.ast.vars()): assert var.expr is not None - self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), var.expr) + self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (len(buf_index)+i, var.expr)) # define local buffers for lb in self.local_alias.values(): self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size)) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index b004ab8d9d..5513bd774f 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -85,7 +85,7 @@ def uops_type_verify(uops:List[UOp]): def uops_alu_resolve(u:UOp, vars:Dict[str, Variable]) -> sint: if u.uop == UOps.CONST: return u.arg - elif u.uop == UOps.DEFINE_GLOBAL: return vars[u.arg] + elif u.uop == UOps.DEFINE_GLOBAL: return vars[u.arg[1]] elif u.uop == UOps.ALU and u.arg == BinaryOps.MUL: return uops_alu_resolve(u.vin[0], vars) * uops_alu_resolve(u.vin[1], vars) elif u.uop == UOps.ALU and u.arg == BinaryOps.ADD: diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 1652b50c61..cc4bd3331f 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -89,7 +89,8 @@ class CStyleLanguage(NamedTuple): def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> str: local_size: List[int] = [] - kernel,bufs = [],[] + kernel = [] + bufs: List[Tuple[str, DType]] = [] #pend_close = None depth = 1 def kk(s): kernel.append(" "*depth+s) @@ -161,8 +162,9 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st kk(lang.render_local(args[0], dtype, args[1])) r[u] = args[0] elif uop is UOps.DEFINE_GLOBAL: - bufs.append((args, dtype)) - r[u] = args + assert len(bufs) == args[0], f"missed a global buffer {len(bufs)} {args}" + bufs.append((args[1], dtype)) + r[u] = args[1] elif uop is UOps.WMMA: kk(f"{dtype.name} {ssa(u, 'wmma')} = {args}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});") elif uop is UOps.DEFINE_ACC: kk(f"{dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};") elif uop is UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})" diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index b80bc7aad0..ef8f3d03de 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -92,7 +92,7 @@ class PythonProgram: dl[i] = dtype if uop is UOps.DEFINE_GLOBAL: assert dtype.fmt is not None - ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size + ul[i] = [pbufs[arg[0]].cast(dtype.fmt)] * warp_size elif uop is UOps.DEFINE_LOCAL: assert dtype.fmt is not None lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))