fix no locals behavior (#5593)

This commit is contained in:
George Hotz
2024-07-19 14:35:09 -07:00
committed by GitHub
parent 2e617ca59e
commit a02998472b
3 changed files with 15 additions and 6 deletions

View File

@@ -725,7 +725,7 @@ class Kernel:
local_load = LazyOp(BufferOps.LOAD, (local_store,), local_buffer)
return LazyOp(op.op, (local_load,), tuple(range(self.first_reduce, self.first_reduce+self.group_for_reduces)))
elif op.op is MetaOps.KERNEL:
arg = KernelInfo(self.local_dims, self.upcasted)
arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals)
else:
arg = op.arg
return LazyOp(op.op, tuple(fixup_ast(x, apply_to_st) for x in op.src), arg)
@@ -749,8 +749,12 @@ class Kernel:
self.local_size: Optional[List[int]] = [1,1,1]
for u in uop_sink.parents:
if u.op is UOps.SPECIAL:
if u.arg[1][0] == 'l': self.local_size[u.arg[0]] = u.arg[2]
else: self.global_size[u.arg[0]] = u.arg[2]
if u.arg[1][0] == 'i': self.local_size = None
if u.arg[1][0] == 'l':
assert self.local_size is not None
self.local_size[u.arg[0]] = u.arg[2]
else:
self.global_size[u.arg[0]] = u.arg[2]
else:
self.global_size, self.local_size = None, None

View File

@@ -104,9 +104,13 @@ class IndependentLowerer:
global_dims = first_reduce-ki.local_dims
if opts.has_local:
# define indexes for GPU-like execution
self.idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max, reverse=True) + \
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
if ki.dont_use_locals:
assert ki.local_dims == 0, "can't use locals if there's no local dims"
self.idxs = get_grouped_dims("idx", full_shape[:global_dims], opts.global_max, reverse=True)
else:
# define indexes for GPU-like execution
self.idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max, reverse=True) + \
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
else:
# all loops are RANGES
self.idxs = [UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(g)), (i, False))

View File

@@ -49,6 +49,7 @@ class ConstBuffer:
class KernelInfo:
local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to EXPAND)
dont_use_locals: bool = False # don't use local indexing
@dataclass(frozen=True, eq=False)
class LazyOp: