mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
put local on the PtrDtype [run_process_replay] (#6656)
* put local on the PtrDtype [run_process_replay] * those are local too
This commit is contained in:
@@ -889,6 +889,8 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_sum_simple(self):
|
||||
helper_test_op(None, lambda x: x.sum(), vals=[[1.,1.]])
|
||||
# NOTE: simple test for locals
|
||||
# FORWARD_ONLY=1 DEBUG=4 python3 test/test_ops.py TestOps.test_sum_full
|
||||
def test_sum_full(self):
|
||||
helper_test_op([(16384)], lambda x: x.sum())
|
||||
def test_sum_relu(self):
|
||||
|
||||
@@ -385,7 +385,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
|
||||
def test_fold_gated_load_local(self):
|
||||
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
||||
smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int), (), ("temp", 1))
|
||||
smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int, local=True), (), ("temp", 1))
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||
st = UOp(UOps.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
|
||||
barrier = UOp(UOps.BARRIER, dtypes.void, (st, ))
|
||||
@@ -650,7 +650,7 @@ def gate_rewrite(sink): return graph_rewrite(sink, constant_folder + expander +
|
||||
class TestIFUOps(unittest.TestCase):
|
||||
def test_create_ifs(self):
|
||||
gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
||||
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 4))
|
||||
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float, local=True), (), ("smem", 4))
|
||||
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
|
||||
gate = valid&(lidx.ne(2))
|
||||
@@ -669,7 +669,7 @@ class TestIFUOps(unittest.TestCase):
|
||||
|
||||
def test_expand_ifs_one_gate(self):
|
||||
gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
||||
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 16))
|
||||
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float, local=True), (), ("smem", 16))
|
||||
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1)
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||
gate = valid&(lidx.ne(2))
|
||||
|
||||
@@ -305,7 +305,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
|
||||
def test_local_basic(self):
|
||||
uops = []
|
||||
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ('smem', 16))
|
||||
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.float32, local=True), (), ('smem', 16))
|
||||
st = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), uop(uops, UOps.CONST, dtypes.float32, (), 42.0)))
|
||||
barr = uop(uops, UOps.BARRIER, dtypes.void, (st,))
|
||||
sres = uop(uops, UOps.LOAD, dtypes.float32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), barr))
|
||||
@@ -314,7 +314,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
|
||||
def test_local_indirect(self):
|
||||
uops = []
|
||||
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.int32), (), ('smem', 16))
|
||||
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.int32, local=True), (), ('smem', 16))
|
||||
st1 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), uop(uops, UOps.CONST, dtypes.int32, (), 2)))
|
||||
st2 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 2), uop(uops, UOps.CONST, dtypes.int32, (), 42)))
|
||||
barr = uop(uops, UOps.BARRIER, dtypes.void, (st1,st2))
|
||||
|
||||
@@ -702,7 +702,7 @@ class Kernel:
|
||||
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD]
|
||||
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
|
||||
st_uop = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uop()
|
||||
membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in), (), (f"temp{-(-1-i)}", st_uop.arg.real_size()))
|
||||
membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in, True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size()))
|
||||
local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn)
|
||||
srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, st_uop, local_store)))
|
||||
else:
|
||||
@@ -732,7 +732,7 @@ class Kernel:
|
||||
for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
|
||||
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
|
||||
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
|
||||
local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(op.dtype), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
|
||||
local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(op.dtype, True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
|
||||
local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start)))
|
||||
grouped_reduce = UOp(UOps.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis))
|
||||
if op is self.reduceops[-1]: return grouped_reduce
|
||||
|
||||
@@ -25,17 +25,20 @@ class DType:
|
||||
class ImageDType(DType):
|
||||
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
|
||||
base: DType
|
||||
local: bool = False # images are never local
|
||||
def scalar(self): return self.base
|
||||
def vec(self, sz:int): return self.base.vec(sz)
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
||||
|
||||
# @dataclass(frozen=True, init=False, repr=False, eq=False)
|
||||
class PtrDType(DType):
|
||||
def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
|
||||
def __init__(self, dt:DType, local=False):
|
||||
self.base, self.local = dt, local
|
||||
super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
|
||||
def __hash__(self): return super().__hash__()
|
||||
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
|
||||
def __ne__(self, dt): return not (self == dt)
|
||||
def __repr__(self): return f"PtrDType({super().__repr__()})"
|
||||
def __repr__(self): return f"PtrDType({super().__repr__()}, local=True)" if self.local else f"PtrDType({super().__repr__()})"
|
||||
|
||||
class dtypes:
|
||||
@staticmethod
|
||||
|
||||
@@ -572,8 +572,8 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
|
||||
# this is the matcher for the final rendered UOps
|
||||
# matcher functions returns True or False (or None to not match)
|
||||
spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.bool, r) if (r:=fxn(**kw)) is not None else None, y)) for (x,y) in [
|
||||
(UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType))),
|
||||
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType)),
|
||||
(UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
|
||||
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
|
||||
(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST, name="c"),), name="x", allow_any_len=True),
|
||||
lambda x,c: all(y.op is UOps.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
|
||||
(UPat(UOps.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
||||
|
||||
@@ -56,14 +56,14 @@ class CStyleLanguage(Renderer):
|
||||
return (self.render_cast(val, dtype) if dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
|
||||
|
||||
# returns a str expression of the loaded value with the output type
|
||||
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
|
||||
def render_load(self, output_dtype, buf_name, buf_dtype, idx) -> str:
|
||||
if isinstance(buf_dtype, ImageDType):
|
||||
assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}"
|
||||
return f"read_imagef({buf_name}, smp, {idx})"
|
||||
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
|
||||
return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})"
|
||||
if output_dtype.count > 1:
|
||||
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(output_dtype)}*)({buf_name}+{idx}))" # noqa: E501
|
||||
return f"*(({self.smem_prefix if buf_dtype.local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(output_dtype)}*)({buf_name}+{idx}))" # noqa: E501
|
||||
return f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
|
||||
|
||||
def get_kernel_modifier(self, uops:List[UOp]) -> str: return ""
|
||||
@@ -78,14 +78,14 @@ class CStyleLanguage(Renderer):
|
||||
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
|
||||
|
||||
# returns a str statement that does the store
|
||||
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str:
|
||||
def render_store(self, buf_name:str, buf_dtype:Union[ImageDType, PtrDType], var_name:str, var_dtype:DType, idx:str) -> str:
|
||||
if isinstance(buf_dtype, ImageDType):
|
||||
assert var_dtype == dtypes.float.vec(4), f"images must be float4, getting {var_dtype}"
|
||||
return f"write_imagef({buf_name}, {idx}, {var_name});"
|
||||
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16:
|
||||
return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});"
|
||||
if var_dtype.count > 1:
|
||||
prefix = self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix
|
||||
prefix = self.smem_prefix if buf_dtype.local and self.smem_prefix_for_cast else self.buffer_prefix
|
||||
return f"*(({prefix}{self.render_dtype(var_dtype)}*)({buf_name}+{idx})) = {var_name};"
|
||||
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
|
||||
|
||||
@@ -124,8 +124,9 @@ class CStyleLanguage(Renderer):
|
||||
kk("}")
|
||||
elif uop is UOps.STORE:
|
||||
# mark DEFINE_GLOBAL buf as writable
|
||||
assert isinstance(src[0].dtype, (ImageDType, PtrDType))
|
||||
if src[0].op is UOps.DEFINE_GLOBAL: bufs[src[0]] = (bufs[src[0]][0], (bufs[src[0]][1][0], True))
|
||||
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
|
||||
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]))
|
||||
kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 and src[3].op is not UOps.IF else rendered_store)
|
||||
else:
|
||||
if uop is UOps.RANGE:
|
||||
@@ -150,7 +151,7 @@ class CStyleLanguage(Renderer):
|
||||
bufs[u] = (args[0], (dtype,False))
|
||||
r[u] = args[0]
|
||||
elif uop is UOps.LOAD:
|
||||
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
|
||||
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]))
|
||||
# NOTE: this relies on the load not happening if it's in the unselected branch
|
||||
if len(src) > 3 and src[3].op is UOps.ALU: val = self.code_for_op[TernaryOps.WHERE](r[src[3]], val, r[src[2]], dtype)
|
||||
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
|
||||
|
||||
Reference in New Issue
Block a user