KernelInfo + cleanups [run_process_replay] (#5372)

This commit is contained in:
George Hotz
2024-07-10 21:00:31 -07:00
committed by GitHub
parent 2396ab9b33
commit 3e9f200905
2 changed files with 32 additions and 22 deletions

View File

@@ -21,6 +21,16 @@ class OptOps(Enum):
class KernelOptError(Exception): pass
@dataclass(frozen=True)
class KernelInfo:
full_shape: Tuple[sint, ...] # full shape (findable in AST)
global_dims: int # number of global dimensions (this is remapping RANGE to SPECIAL)
first_reduce: int # axis of first_reduce (findable in AST)
group_for_reduces: int # count that are grouped (findable in AST)
upcasted: int # count that are upcasted (this is remapping RANGE to EXPAND)
@property
def shape_len(self): return len(self.full_shape)
def check(cond:bool, msg:str=""):
if not cond: raise KernelOptError(msg)
@@ -630,7 +640,7 @@ class Kernel:
suffix = f"{'n'+str(Kernel.kernel_cnt[function_name]-1)}" if Kernel.kernel_cnt[function_name] > 1 else ""
return name+colored(suffix, 'BLACK')
def get_optimized_ast(self) -> Tuple[LazyOp, ...]:
def get_optimized_ast(self) -> Tuple[Tuple[LazyOp, ...], KernelInfo]:
# set the shapetrackers to the optimized ones, fixup reduceop
# transformed to the final LazyOp
@functools.lru_cache(None)
@@ -695,4 +705,5 @@ class Kernel:
else:
arg = op.arg
return LazyOp(op.op, tuple(fixup_ast(x) for x in op.src), arg)
return tuple(fixup_ast(x) for x in self.ast)
return tuple(fixup_ast(x) for x in self.ast), \
KernelInfo(self.full_shape, self.global_dims, self.first_reduce, self.group_for_reduces, self.upcasted)

View File

@@ -61,11 +61,13 @@ class Lowerer(Kernel):
if x.arg.idx == -1:
buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype), (), ("temp", x.arg.st.size))
else:
# NOTE: outbufs is quickly findable in AST
buf = UOp(UOps.DEFINE_GLOBAL, x.arg.dtype if isinstance(x.arg.dtype, ImageDType) else PtrDType(x.arg.dtype), (),
(x.arg.idx, any(x.arg.idx == y.idx for y in self.outbufs)))
if x.op is BufferOps.LOAD:
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else ()
return UOp(UOps.LOAD, x.arg.dtype.scalar(), (buf, idx) + ((valid, UOp.const(x.arg.dtype.scalar(), 0)) if has_valid else ()) + barrier)
# TODO: what is this?
if self.group_for_reduces > 0 and x.arg.idx != -1: valid, has_valid = valid * self.idxs[self.first_reduce].eq(0), True
return UOp(UOps.STORE, None, (buf, idx, self.to_uop(x.src[0])) + ((valid,) if has_valid else ()))
@@ -73,7 +75,6 @@ class Lowerer(Kernel):
if x.op is UnaryOps.CAST: return UOp(UOps.CAST, x.arg.scalar(), in_uops)
if x.op is UnaryOps.BITCAST: return UOp(UOps.BITCAST, x.arg.scalar(), in_uops)
if x.op in ReduceOps:
# NOTE: always using ridxs is fine here
dtype = x.dtype.base if isinstance(x.dtype, ImageDType) else x.dtype
if x.op is ReduceOps.WMMA:
wmma_sz, upcast_axis = x.arg[4], x.arg[6]
@@ -82,20 +83,20 @@ class Lowerer(Kernel):
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=(upcast_axis[1],)),
UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axis[2])
# NOTE: always using ridxs is fine here
return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op)
return UOp.alu(x.op, *in_uops)
def linearize(self) -> Lowerer:
modified_ast = self.get_optimized_ast()
modified_ast, ki = self.get_optimized_ast()
if DEBUG >= 4:
from tinygrad.engine.graph import print_tree
for mast in modified_ast: print_tree(mast)
self.idxs = []
if self.opts.has_local:
# define indexes
global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0)
local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+self.group_for_reduces], 3 if self.opts.has_local else 0) # noqa: E501
global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, ki.full_shape[:ki.global_dims], 3)
local_idxs, loop_local_idxs = get_grouped_dims("lidx", ki.global_dims, ki.full_shape[ki.global_dims:ki.first_reduce+ki.group_for_reduces], 3)
self.idxs = global_idxs + local_idxs
# define sizes
@@ -104,26 +105,24 @@ class Lowerer(Kernel):
self.global_size += [1]*(3-len(self.global_size))
self.local_size += [1]*(3-len(self.local_size))
else:
# all loops
self.idxs = []
for i,g in enumerate(self.full_shape[:self.first_reduce]):
self.idxs.append(UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(g)), (i, False)))
# all loops are RANGES
self.idxs = [UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(g)), (i, False))
for i,g in enumerate(ki.full_shape[:ki.first_reduce])]
self.global_size, self.local_size = None, None
# reduce loops
for i,g in enumerate(self.full_shape[self.first_reduce+self.group_for_reduces:], start=self.first_reduce+self.group_for_reduces):
unrolled, is_reduce = i >= (self.shape_len-self.upcasted), self.full_shape[i] != self.output_shape[i]
if unrolled:
assert isinstance(g, int), "needs to be int to unroll"
uop = UOp(UOps.EXPAND, dtypes.int32, tuple(UOp.const(dtypes.int32, j) for j in range(0, g)), i)
else:
uop = UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(g)), (i, is_reduce))
self.idxs.append(uop)
self.idxs += [UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(g)), (i, True))
for i,g in enumerate(ki.full_shape[ki.first_reduce+ki.group_for_reduces:ki.shape_len-ki.upcasted], start=ki.first_reduce+ki.group_for_reduces)]
# late indexes
# upcast loops
for i,g in enumerate(ki.full_shape[ki.shape_len-ki.upcasted:], start=ki.shape_len-ki.upcasted):
assert isinstance(g, int), "needs to be int to upcast/unroll"
self.idxs.append(UOp(UOps.EXPAND, dtypes.int32, tuple(UOp.const(dtypes.int32, j) for j in range(0, g)), i))
# late indexes (group for reduce)
self.ridxs = self.idxs[:]
for a in range(self.first_reduce, self.first_reduce+self.group_for_reduces):
self.ridxs[a] = UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(self.full_shape[a])), (1000+a, True))
for a in range(ki.first_reduce, ki.first_reduce+ki.group_for_reduces):
self.ridxs[a] = UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(ki.full_shape[a])), (1000+a, True))
self.uop_cache: Dict[LazyOp, UOp] = {}
self.uops:UOpGraph = UOpGraph([self.to_uop(x) for x in modified_ast], self.opts)