diff --git a/test/test_ops.py b/test/test_ops.py index 742d40e058..d9713781f3 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -889,6 +889,8 @@ class TestOps(unittest.TestCase): def test_sum_simple(self): helper_test_op(None, lambda x: x.sum(), vals=[[1.,1.]]) + # NOTE: simple test for locals + # FORWARD_ONLY=1 DEBUG=4 python3 test/test_ops.py TestOps.test_sum_full def test_sum_full(self): helper_test_op([(16384)], lambda x: x.sum()) def test_sum_relu(self): diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index fd72d24f1e..51570b21a8 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -385,7 +385,7 @@ class TestUOpGraph(unittest.TestCase): def test_fold_gated_load_local(self): glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) - smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int), (), ("temp", 1)) + smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int, local=True), (), ("temp", 1)) lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16)) st = UOp(UOps.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int))) barrier = UOp(UOps.BARRIER, dtypes.void, (st, )) @@ -650,7 +650,7 @@ def gate_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + class TestIFUOps(unittest.TestCase): def test_create_ifs(self): gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) - sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 4)) + sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float, local=True), (), ("smem", 4)) valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5) lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4)) gate = valid&(lidx.ne(2)) @@ -669,7 +669,7 @@ class TestIFUOps(unittest.TestCase): def test_expand_ifs_one_gate(self): gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) - sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 16)) + sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float, local=True), (), ("smem", 16)) valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1) lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16)) gate = valid&(lidx.ne(2)) diff --git a/test/test_uops.py b/test/test_uops.py index c3c547ba8e..b191fac730 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -305,7 +305,7 @@ class TestLocalAccess(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") def test_local_basic(self): uops = [] - smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ('smem', 16)) + smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.float32, local=True), (), ('smem', 16)) st = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), uop(uops, UOps.CONST, dtypes.float32, (), 42.0))) barr = uop(uops, UOps.BARRIER, dtypes.void, (st,)) sres = uop(uops, UOps.LOAD, dtypes.float32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), barr)) @@ -314,7 +314,7 @@ class TestLocalAccess(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") def test_local_indirect(self): uops = [] - smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.int32), (), ('smem', 16)) + smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.int32, local=True), (), ('smem', 16)) st1 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), uop(uops, UOps.CONST, dtypes.int32, (), 2))) st2 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 2), uop(uops, UOps.CONST, dtypes.int32, (), 42))) barr = uop(uops, UOps.BARRIER, dtypes.void, (st1,st2)) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 9b4f06f2fe..1325556bc9 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -702,7 +702,7 @@ class Kernel: st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD] local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape)) st_uop = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uop() - membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in), (), (f"temp{-(-1-i)}", st_uop.arg.real_size())) + membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in, True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size())) local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn) srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, st_uop, local_store))) else: @@ -732,7 +732,7 @@ class Kernel: for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \ (1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)]) st_uop = ShapeTracker.from_shape(local_shape).to_uop() - local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(op.dtype), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size())) + local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(op.dtype, True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size())) local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start))) grouped_reduce = UOp(UOps.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis)) if op is self.reduceops[-1]: return grouped_reduce diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index ca92e6adee..3440b6c7a4 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -25,17 +25,20 @@ class DType: class ImageDType(DType): shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape base: DType + local: bool = False # images are never local def scalar(self): return self.base def vec(self, sz:int): return self.base.vec(sz) def __repr__(self): return f"dtypes.{self.name}({self.shape})" # @dataclass(frozen=True, init=False, repr=False, eq=False) class PtrDType(DType): - def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count) + def __init__(self, dt:DType, local=False): + self.base, self.local = dt, local + super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count) def __hash__(self): return super().__hash__() def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count def __ne__(self, dt): return not (self == dt) - def __repr__(self): return f"PtrDType({super().__repr__()})" + def __repr__(self): return f"PtrDType({super().__repr__()}, local=True)" if self.local else f"PtrDType({super().__repr__()})" class dtypes: @staticmethod diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 308e6dd680..9e2d6cdfd1 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -572,8 +572,8 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp: # this is the matcher for the final rendered UOps # matcher functions returns True or False (or None to not match) spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.bool, r) if (r:=fxn(**kw)) is not None else None, y)) for (x,y) in [ - (UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType))), - (UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType)), + (UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local), + (UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local), (UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST, name="c"),), name="x", allow_any_len=True), lambda x,c: all(y.op is UOps.RANGE for y in x.src[1:]) and c.dtype == x.dtype), (UPat(UOps.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)), diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index ddf3d6643c..f746d032a0 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -56,14 +56,14 @@ class CStyleLanguage(Renderer): return (self.render_cast(val, dtype) if dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val) # returns a str expression of the loaded value with the output type - def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str: + def render_load(self, output_dtype, buf_name, buf_dtype, idx) -> str: if isinstance(buf_dtype, ImageDType): assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}" return f"read_imagef({buf_name}, smp, {idx})" if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16: return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})" if output_dtype.count > 1: - return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(output_dtype)}*)({buf_name}+{idx}))" # noqa: E501 + return f"*(({self.smem_prefix if buf_dtype.local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(output_dtype)}*)({buf_name}+{idx}))" # noqa: E501 return f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]" def get_kernel_modifier(self, uops:List[UOp]) -> str: return "" @@ -78,14 +78,14 @@ class CStyleLanguage(Renderer): return prg if prefix is None else "\n".join(prefix)+f"\n{prg}" # returns a str statement that does the store - def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str: + def render_store(self, buf_name:str, buf_dtype:Union[ImageDType, PtrDType], var_name:str, var_dtype:DType, idx:str) -> str: if isinstance(buf_dtype, ImageDType): assert var_dtype == dtypes.float.vec(4), f"images must be float4, getting {var_dtype}" return f"write_imagef({buf_name}, {idx}, {var_name});" if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16: return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});" if var_dtype.count > 1: - prefix = self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix + prefix = self.smem_prefix if buf_dtype.local and self.smem_prefix_for_cast else self.buffer_prefix return f"*(({prefix}{self.render_dtype(var_dtype)}*)({buf_name}+{idx})) = {var_name};" return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};" @@ -124,8 +124,9 @@ class CStyleLanguage(Renderer): kk("}") elif uop is UOps.STORE: # mark DEFINE_GLOBAL buf as writable + assert isinstance(src[0].dtype, (ImageDType, PtrDType)) if src[0].op is UOps.DEFINE_GLOBAL: bufs[src[0]] = (bufs[src[0]][0], (bufs[src[0]][1][0], True)) - rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL) + rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]])) kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 and src[3].op is not UOps.IF else rendered_store) else: if uop is UOps.RANGE: @@ -150,7 +151,7 @@ class CStyleLanguage(Renderer): bufs[u] = (args[0], (dtype,False)) r[u] = args[0] elif uop is UOps.LOAD: - val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL) + val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]])) # NOTE: this relies on the load not happening if it's in the unselected branch if len(src) > 3 and src[3].op is UOps.ALU: val = self.code_for_op[TernaryOps.WHERE](r[src[3]], val, r[src[2]], dtype) kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")