minor clean up in kernels (#2832)

clean up some long lines and combined some short lines
This commit is contained in:
chenyu
2023-12-18 19:35:59 -05:00
committed by GitHub
parent 954a2fef75
commit dad9253d52

View File

@@ -66,7 +66,7 @@ class LinearizerOptions(NamedTuple):
class Kernel:
def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None):
self.opts = opts if opts else (cast(Compiled, Device[Device.DEFAULT]).linearizer_opts if isinstance(Device[Device.DEFAULT], Compiled) else LinearizerOptions()) # noqa: E501
self.opts = opts or (device.linearizer_opts if isinstance(device:=Device[Device.DEFAULT], Compiled) else LinearizerOptions())
self.ast = ast
assert ast.op == BufferOps.STORE, f"kernels must have a store as the output, got {ast.op}"
@@ -263,7 +263,7 @@ class Kernel:
strides.append(special_strides)
# merge dimensions if we can, multi get_shape_strides
# TODO: does this always preserve the reduce dimension, NO
# NOTE: this does not always preserve the reduce dimension
# TODO: move this into shapetracker, with tests!
rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))]
for i in range(1, len(shapes[0])):
@@ -375,8 +375,7 @@ class Kernel:
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
if extra_opts is not None:
for opt in extra_opts:
self.apply_opt(opt)
for opt in extra_opts: self.apply_opt(opt)
else:
# hand-coded TC opts
if s1_exists:
@@ -402,7 +401,7 @@ class Kernel:
assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals" # noqa: E501
self.applied_opts.append(opt)
if opt.axis is not None:
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0)) # noqa: E501
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op in [OptOps.GROUP, OptOps.GROUPTOP] else 0)) # noqa: E501
else:
axis = -1
if opt.amt is not None:
@@ -490,12 +489,9 @@ class Kernel:
for global_idx in range(self.global_dims):
if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}") # noqa: E501
if MV_THREADS_PER_ROW > 1:
self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
if MV_BLOCKSIZE > 1:
self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
if MV_ROWS_PER_THREAD > 1:
self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
if MV_THREADS_PER_ROW > 1: self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
if MV_BLOCKSIZE > 1: self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
if MV_ROWS_PER_THREAD > 1: self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
return
if self.opts.has_local and self.opts.has_shared and all(isinstance(s, int) for s in self.sts[0].shape[:self.first_reduce]):
@@ -544,8 +540,7 @@ class Kernel:
prod(self.full_shape[self.shape_len - self.upcasted:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
to_upcast.append(axis)
for axis in to_upcast[::-1]:
self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
# potentially do more upcasts of non reduce axes based on a heuristic
upcasted_axis = set()
@@ -560,8 +555,7 @@ class Kernel:
if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
self.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
upcasted_axis.add(xb_choices[0][2])
else:
break
else: break
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast.
if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): # noqa: E501