mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-28 00:08:16 -05:00
KernelInfo + cleanups [run_process_replay] (#5372)
This commit is contained in:
@@ -21,6 +21,16 @@ class OptOps(Enum):
|
||||
|
||||
class KernelOptError(Exception): pass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KernelInfo:
|
||||
full_shape: Tuple[sint, ...] # full shape (findable in AST)
|
||||
global_dims: int # number of global dimensions (this is remapping RANGE to SPECIAL)
|
||||
first_reduce: int # axis of first_reduce (findable in AST)
|
||||
group_for_reduces: int # count that are grouped (findable in AST)
|
||||
upcasted: int # count that are upcasted (this is remapping RANGE to EXPAND)
|
||||
@property
|
||||
def shape_len(self): return len(self.full_shape)
|
||||
|
||||
def check(cond:bool, msg:str=""):
|
||||
if not cond: raise KernelOptError(msg)
|
||||
|
||||
@@ -630,7 +640,7 @@ class Kernel:
|
||||
suffix = f"{'n'+str(Kernel.kernel_cnt[function_name]-1)}" if Kernel.kernel_cnt[function_name] > 1 else ""
|
||||
return name+colored(suffix, 'BLACK')
|
||||
|
||||
def get_optimized_ast(self) -> Tuple[LazyOp, ...]:
|
||||
def get_optimized_ast(self) -> Tuple[Tuple[LazyOp, ...], KernelInfo]:
|
||||
# set the shapetrackers to the optimized ones, fixup reduceop
|
||||
# transformed to the final LazyOp
|
||||
@functools.lru_cache(None)
|
||||
@@ -695,4 +705,5 @@ class Kernel:
|
||||
else:
|
||||
arg = op.arg
|
||||
return LazyOp(op.op, tuple(fixup_ast(x) for x in op.src), arg)
|
||||
return tuple(fixup_ast(x) for x in self.ast)
|
||||
return tuple(fixup_ast(x) for x in self.ast), \
|
||||
KernelInfo(self.full_shape, self.global_dims, self.first_reduce, self.group_for_reduces, self.upcasted)
|
||||
|
||||
@@ -61,11 +61,13 @@ class Lowerer(Kernel):
|
||||
if x.arg.idx == -1:
|
||||
buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype), (), ("temp", x.arg.st.size))
|
||||
else:
|
||||
# NOTE: outbufs is quickly findable in AST
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, x.arg.dtype if isinstance(x.arg.dtype, ImageDType) else PtrDType(x.arg.dtype), (),
|
||||
(x.arg.idx, any(x.arg.idx == y.idx for y in self.outbufs)))
|
||||
if x.op is BufferOps.LOAD:
|
||||
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else ()
|
||||
return UOp(UOps.LOAD, x.arg.dtype.scalar(), (buf, idx) + ((valid, UOp.const(x.arg.dtype.scalar(), 0)) if has_valid else ()) + barrier)
|
||||
# TODO: what is this?
|
||||
if self.group_for_reduces > 0 and x.arg.idx != -1: valid, has_valid = valid * self.idxs[self.first_reduce].eq(0), True
|
||||
return UOp(UOps.STORE, None, (buf, idx, self.to_uop(x.src[0])) + ((valid,) if has_valid else ()))
|
||||
|
||||
@@ -73,7 +75,6 @@ class Lowerer(Kernel):
|
||||
if x.op is UnaryOps.CAST: return UOp(UOps.CAST, x.arg.scalar(), in_uops)
|
||||
if x.op is UnaryOps.BITCAST: return UOp(UOps.BITCAST, x.arg.scalar(), in_uops)
|
||||
if x.op in ReduceOps:
|
||||
# NOTE: always using ridxs is fine here
|
||||
dtype = x.dtype.base if isinstance(x.dtype, ImageDType) else x.dtype
|
||||
if x.op is ReduceOps.WMMA:
|
||||
wmma_sz, upcast_axis = x.arg[4], x.arg[6]
|
||||
@@ -82,20 +83,20 @@ class Lowerer(Kernel):
|
||||
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=(upcast_axis[1],)),
|
||||
UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
|
||||
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axis[2])
|
||||
# NOTE: always using ridxs is fine here
|
||||
return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op)
|
||||
return UOp.alu(x.op, *in_uops)
|
||||
|
||||
def linearize(self) -> Lowerer:
|
||||
modified_ast = self.get_optimized_ast()
|
||||
modified_ast, ki = self.get_optimized_ast()
|
||||
if DEBUG >= 4:
|
||||
from tinygrad.engine.graph import print_tree
|
||||
for mast in modified_ast: print_tree(mast)
|
||||
|
||||
self.idxs = []
|
||||
if self.opts.has_local:
|
||||
# define indexes
|
||||
global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0)
|
||||
local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+self.group_for_reduces], 3 if self.opts.has_local else 0) # noqa: E501
|
||||
global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, ki.full_shape[:ki.global_dims], 3)
|
||||
local_idxs, loop_local_idxs = get_grouped_dims("lidx", ki.global_dims, ki.full_shape[ki.global_dims:ki.first_reduce+ki.group_for_reduces], 3)
|
||||
self.idxs = global_idxs + local_idxs
|
||||
|
||||
# define sizes
|
||||
@@ -104,26 +105,24 @@ class Lowerer(Kernel):
|
||||
self.global_size += [1]*(3-len(self.global_size))
|
||||
self.local_size += [1]*(3-len(self.local_size))
|
||||
else:
|
||||
# all loops
|
||||
self.idxs = []
|
||||
for i,g in enumerate(self.full_shape[:self.first_reduce]):
|
||||
self.idxs.append(UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(g)), (i, False)))
|
||||
# all loops are RANGES
|
||||
self.idxs = [UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(g)), (i, False))
|
||||
for i,g in enumerate(ki.full_shape[:ki.first_reduce])]
|
||||
self.global_size, self.local_size = None, None
|
||||
|
||||
# reduce loops
|
||||
for i,g in enumerate(self.full_shape[self.first_reduce+self.group_for_reduces:], start=self.first_reduce+self.group_for_reduces):
|
||||
unrolled, is_reduce = i >= (self.shape_len-self.upcasted), self.full_shape[i] != self.output_shape[i]
|
||||
if unrolled:
|
||||
assert isinstance(g, int), "needs to be int to unroll"
|
||||
uop = UOp(UOps.EXPAND, dtypes.int32, tuple(UOp.const(dtypes.int32, j) for j in range(0, g)), i)
|
||||
else:
|
||||
uop = UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(g)), (i, is_reduce))
|
||||
self.idxs.append(uop)
|
||||
self.idxs += [UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(g)), (i, True))
|
||||
for i,g in enumerate(ki.full_shape[ki.first_reduce+ki.group_for_reduces:ki.shape_len-ki.upcasted], start=ki.first_reduce+ki.group_for_reduces)]
|
||||
|
||||
# late indexes
|
||||
# upcast loops
|
||||
for i,g in enumerate(ki.full_shape[ki.shape_len-ki.upcasted:], start=ki.shape_len-ki.upcasted):
|
||||
assert isinstance(g, int), "needs to be int to upcast/unroll"
|
||||
self.idxs.append(UOp(UOps.EXPAND, dtypes.int32, tuple(UOp.const(dtypes.int32, j) for j in range(0, g)), i))
|
||||
|
||||
# late indexes (group for reduce)
|
||||
self.ridxs = self.idxs[:]
|
||||
for a in range(self.first_reduce, self.first_reduce+self.group_for_reduces):
|
||||
self.ridxs[a] = UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(self.full_shape[a])), (1000+a, True))
|
||||
for a in range(ki.first_reduce, ki.first_reduce+ki.group_for_reduces):
|
||||
self.ridxs[a] = UOp(UOps.RANGE, dtypes.int32, (UOp.const(dtypes.int32, 0), variable_to_uop(ki.full_shape[a])), (1000+a, True))
|
||||
|
||||
self.uop_cache: Dict[LazyOp, UOp] = {}
|
||||
self.uops:UOpGraph = UOpGraph([self.to_uop(x) for x in modified_ast], self.opts)
|
||||
|
||||
Reference in New Issue
Block a user