mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
tc3 shape expand [pr] (#11043)
* tc3 shape expand [pr] * remove unused stuff in lowerer
This commit is contained in:
@@ -44,12 +44,11 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
||||
# NOTE: always using ridxs is fine here
|
||||
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
|
||||
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}"
|
||||
alu_op: Ops = x.arg[0]
|
||||
ret = x.src[0]
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
||||
# REDUCE supports both "horizontal" reduction and range reduction. the horizontal elements are taken in the nearest group
|
||||
return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple(reduce_range), alu_op)
|
||||
return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple(reduce_range), x.arg[0])
|
||||
|
||||
def lower_load(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if buf.op is Ops.DEFINE_LOCAL else ctx.idxs)
|
||||
@@ -58,12 +57,8 @@ def lower_load(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
|
||||
def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.idxs)
|
||||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||
if cast(PtrDType, buf.dtype).local and x.src[1].op is Ops.REDUCE:
|
||||
reduce_input = x.src[1].src[0]
|
||||
store_back = reduce_input.op is Ops.LOAD and cast(PtrDType, reduce_input.src[0].dtype).local
|
||||
else: store_back = False
|
||||
if (not cast(PtrDType, buf.dtype).local) or store_back:
|
||||
if not cast(PtrDType, buf.dtype).local:
|
||||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
|
||||
if oidx is not ridx: valid = valid * oidx.eq(0)
|
||||
return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid), x.src[1]))
|
||||
|
||||
@@ -486,7 +486,7 @@ class Kernel:
|
||||
if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals
|
||||
local_shape = tuple(1 if st == 0 or i < wd or (i >= self.first_reduce and i < tcd) else src_st.shape[i] \
|
||||
for i,st in enumerate(src_st.real_strides()))
|
||||
st = store_st = ShapeTracker.from_shape(local_shape)
|
||||
st = store_st = ShapeTracker.from_shape(local_shape).expand(self.full_shape[:wd]+local_shape[wd:])
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), f"temp{i}")
|
||||
if swizzle: store_st = get_tc_swizzle_st(store_st.shape, *swizzle)
|
||||
local_store = UOp.store(local_buffer.view(store_st), srcs[i])
|
||||
|
||||
Reference in New Issue
Block a user