tc3 shape expand [pr] (#11043)

* tc3 shape expand [pr]

* remove unused stuff in lowerer
This commit is contained in:
George Hotz
2025-06-30 13:38:14 -07:00
committed by GitHub
parent 539b17fcbf
commit 752c76ceb7
2 changed files with 4 additions and 9 deletions

View File

@@ -44,12 +44,11 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
# NOTE: always using ridxs is fine here
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}"
alu_op: Ops = x.arg[0]
ret = x.src[0]
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
# REDUCE supports both "horizontal" reduction and range reduction. the horizontal elements are taken in the nearest group
return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple(reduce_range), alu_op)
return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple(reduce_range), x.arg[0])
def lower_load(ctx: IndexContext, x: UOp, buf: UOp):
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if buf.op is Ops.DEFINE_LOCAL else ctx.idxs)
@@ -58,12 +57,8 @@ def lower_load(ctx: IndexContext, x: UOp, buf: UOp):
def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
idx, valid = x.st_arg.to_indexed_uops(ctx.idxs)
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
if cast(PtrDType, buf.dtype).local and x.src[1].op is Ops.REDUCE:
reduce_input = x.src[1].src[0]
store_back = reduce_input.op is Ops.LOAD and cast(PtrDType, reduce_input.src[0].dtype).local
else: store_back = False
if (not cast(PtrDType, buf.dtype).local) or store_back:
if not cast(PtrDType, buf.dtype).local:
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
if oidx is not ridx: valid = valid * oidx.eq(0)
return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid), x.src[1]))

View File

@@ -486,7 +486,7 @@ class Kernel:
if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals
local_shape = tuple(1 if st == 0 or i < wd or (i >= self.first_reduce and i < tcd) else src_st.shape[i] \
for i,st in enumerate(src_st.real_strides()))
st = store_st = ShapeTracker.from_shape(local_shape)
st = store_st = ShapeTracker.from_shape(local_shape).expand(self.full_shape[:wd]+local_shape[wd:])
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), f"temp{i}")
if swizzle: store_st = get_tc_swizzle_st(store_st.shape, *swizzle)
local_store = UOp.store(local_buffer.view(store_st), srcs[i])