put local on the PtrDtype [run_process_replay] (#6656)

* put local on the PtrDtype [run_process_replay]

* those are local too
This commit is contained in:
George Hotz
2024-09-23 10:29:17 +08:00
committed by GitHub
parent 90c1ccc402
commit e945fa9c5c
7 changed files with 23 additions and 17 deletions

View File

@@ -889,6 +889,8 @@ class TestOps(unittest.TestCase):
def test_sum_simple(self):
helper_test_op(None, lambda x: x.sum(), vals=[[1.,1.]])
# NOTE: simple test for locals
# FORWARD_ONLY=1 DEBUG=4 python3 test/test_ops.py TestOps.test_sum_full
def test_sum_full(self):
helper_test_op([(16384)], lambda x: x.sum())
def test_sum_relu(self):

View File

@@ -385,7 +385,7 @@ class TestUOpGraph(unittest.TestCase):
def test_fold_gated_load_local(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int), (), ("temp", 1))
smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int, local=True), (), ("temp", 1))
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
st = UOp(UOps.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
barrier = UOp(UOps.BARRIER, dtypes.void, (st, ))
@@ -650,7 +650,7 @@ def gate_rewrite(sink): return graph_rewrite(sink, constant_folder + expander +
class TestIFUOps(unittest.TestCase):
def test_create_ifs(self):
gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 4))
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float, local=True), (), ("smem", 4))
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
gate = valid&(lidx.ne(2))
@@ -669,7 +669,7 @@ class TestIFUOps(unittest.TestCase):
def test_expand_ifs_one_gate(self):
gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 16))
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float, local=True), (), ("smem", 16))
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1)
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
gate = valid&(lidx.ne(2))

View File

@@ -305,7 +305,7 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
def test_local_basic(self):
uops = []
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ('smem', 16))
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.float32, local=True), (), ('smem', 16))
st = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), uop(uops, UOps.CONST, dtypes.float32, (), 42.0)))
barr = uop(uops, UOps.BARRIER, dtypes.void, (st,))
sres = uop(uops, UOps.LOAD, dtypes.float32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), barr))
@@ -314,7 +314,7 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
def test_local_indirect(self):
uops = []
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.int32), (), ('smem', 16))
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.int32, local=True), (), ('smem', 16))
st1 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), uop(uops, UOps.CONST, dtypes.int32, (), 2)))
st2 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 2), uop(uops, UOps.CONST, dtypes.int32, (), 42)))
barr = uop(uops, UOps.BARRIER, dtypes.void, (st1,st2))

View File

@@ -702,7 +702,7 @@ class Kernel:
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD]
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
st_uop = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uop()
membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in), (), (f"temp{-(-1-i)}", st_uop.arg.real_size()))
membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in, True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size()))
local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn)
srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, st_uop, local_store)))
else:
@@ -732,7 +732,7 @@ class Kernel:
for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(op.dtype), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(op.dtype, True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start)))
grouped_reduce = UOp(UOps.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis))
if op is self.reduceops[-1]: return grouped_reduce

View File

@@ -25,17 +25,20 @@ class DType:
class ImageDType(DType):
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
base: DType
local: bool = False # images are never local
def scalar(self): return self.base
def vec(self, sz:int): return self.base.vec(sz)
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
# @dataclass(frozen=True, init=False, repr=False, eq=False)
class PtrDType(DType):
def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
def __init__(self, dt:DType, local=False):
self.base, self.local = dt, local
super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
def __hash__(self): return super().__hash__()
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
def __ne__(self, dt): return not (self == dt)
def __repr__(self): return f"PtrDType({super().__repr__()})"
def __repr__(self): return f"PtrDType({super().__repr__()}, local=True)" if self.local else f"PtrDType({super().__repr__()})"
class dtypes:
@staticmethod

View File

@@ -572,8 +572,8 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
# this is the matcher for the final rendered UOps
# matcher functions returns True or False (or None to not match)
spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.bool, r) if (r:=fxn(**kw)) is not None else None, y)) for (x,y) in [
(UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType))),
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType)),
(UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST, name="c"),), name="x", allow_any_len=True),
lambda x,c: all(y.op is UOps.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
(UPat(UOps.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),

View File

@@ -56,14 +56,14 @@ class CStyleLanguage(Renderer):
return (self.render_cast(val, dtype) if dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
# returns a str expression of the loaded value with the output type
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
def render_load(self, output_dtype, buf_name, buf_dtype, idx) -> str:
if isinstance(buf_dtype, ImageDType):
assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}"
return f"read_imagef({buf_name}, smp, {idx})"
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})"
if output_dtype.count > 1:
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(output_dtype)}*)({buf_name}+{idx}))" # noqa: E501
return f"*(({self.smem_prefix if buf_dtype.local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(output_dtype)}*)({buf_name}+{idx}))" # noqa: E501
return f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
def get_kernel_modifier(self, uops:List[UOp]) -> str: return ""
@@ -78,14 +78,14 @@ class CStyleLanguage(Renderer):
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
# returns a str statement that does the store
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str:
def render_store(self, buf_name:str, buf_dtype:Union[ImageDType, PtrDType], var_name:str, var_dtype:DType, idx:str) -> str:
if isinstance(buf_dtype, ImageDType):
assert var_dtype == dtypes.float.vec(4), f"images must be float4, getting {var_dtype}"
return f"write_imagef({buf_name}, {idx}, {var_name});"
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16:
return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});"
if var_dtype.count > 1:
prefix = self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix
prefix = self.smem_prefix if buf_dtype.local and self.smem_prefix_for_cast else self.buffer_prefix
return f"*(({prefix}{self.render_dtype(var_dtype)}*)({buf_name}+{idx})) = {var_name};"
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
@@ -124,8 +124,9 @@ class CStyleLanguage(Renderer):
kk("}")
elif uop is UOps.STORE:
# mark DEFINE_GLOBAL buf as writable
assert isinstance(src[0].dtype, (ImageDType, PtrDType))
if src[0].op is UOps.DEFINE_GLOBAL: bufs[src[0]] = (bufs[src[0]][0], (bufs[src[0]][1][0], True))
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]))
kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 and src[3].op is not UOps.IF else rendered_store)
else:
if uop is UOps.RANGE:
@@ -150,7 +151,7 @@ class CStyleLanguage(Renderer):
bufs[u] = (args[0], (dtype,False))
r[u] = args[0]
elif uop is UOps.LOAD:
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]))
# NOTE: this relies on the load not happening if it's in the unselected branch
if len(src) > 3 and src[3].op is UOps.ALU: val = self.code_for_op[TernaryOps.WHERE](r[src[3]], val, r[src[2]], dtype)
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")