mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
fix no locals behavior (#5593)
This commit is contained in:
@@ -725,7 +725,7 @@ class Kernel:
|
||||
local_load = LazyOp(BufferOps.LOAD, (local_store,), local_buffer)
|
||||
return LazyOp(op.op, (local_load,), tuple(range(self.first_reduce, self.first_reduce+self.group_for_reduces)))
|
||||
elif op.op is MetaOps.KERNEL:
|
||||
arg = KernelInfo(self.local_dims, self.upcasted)
|
||||
arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals)
|
||||
else:
|
||||
arg = op.arg
|
||||
return LazyOp(op.op, tuple(fixup_ast(x, apply_to_st) for x in op.src), arg)
|
||||
@@ -749,8 +749,12 @@ class Kernel:
|
||||
self.local_size: Optional[List[int]] = [1,1,1]
|
||||
for u in uop_sink.parents:
|
||||
if u.op is UOps.SPECIAL:
|
||||
if u.arg[1][0] == 'l': self.local_size[u.arg[0]] = u.arg[2]
|
||||
else: self.global_size[u.arg[0]] = u.arg[2]
|
||||
if u.arg[1][0] == 'i': self.local_size = None
|
||||
if u.arg[1][0] == 'l':
|
||||
assert self.local_size is not None
|
||||
self.local_size[u.arg[0]] = u.arg[2]
|
||||
else:
|
||||
self.global_size[u.arg[0]] = u.arg[2]
|
||||
else:
|
||||
self.global_size, self.local_size = None, None
|
||||
|
||||
|
||||
@@ -104,9 +104,13 @@ class IndependentLowerer:
|
||||
global_dims = first_reduce-ki.local_dims
|
||||
|
||||
if opts.has_local:
|
||||
# define indexes for GPU-like execution
|
||||
self.idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max, reverse=True) + \
|
||||
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
|
||||
if ki.dont_use_locals:
|
||||
assert ki.local_dims == 0, "can't use locals if there's no local dims"
|
||||
self.idxs = get_grouped_dims("idx", full_shape[:global_dims], opts.global_max, reverse=True)
|
||||
else:
|
||||
# define indexes for GPU-like execution
|
||||
self.idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max, reverse=True) + \
|
||||
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
|
||||
else:
|
||||
# all loops are RANGES
|
||||
self.idxs = [UOp(UOps.RANGE, dtypes.bigint, (UOp.const(dtypes.bigint, 0), variable_to_uop(g)), (i, False))
|
||||
|
||||
@@ -49,6 +49,7 @@ class ConstBuffer:
|
||||
class KernelInfo:
|
||||
local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
|
||||
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to EXPAND)
|
||||
dont_use_locals: bool = False # don't use local indexing
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class LazyOp:
|
||||
|
||||
Reference in New Issue
Block a user