cleanups from Estimate [pr] (#8329)

This commit is contained in:
George Hotz
2024-12-18 23:01:14 -08:00
committed by GitHub
parent 3a9ca62b9e
commit 2aa39d03cd
4 changed files with 9 additions and 8 deletions

View File

@@ -617,7 +617,7 @@ class Kernel:
if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals
local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape))
st = store_st = ShapeTracker.from_shape(local_shape)
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{i + 1}", st.real_size()))
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), (f"temp{i + 1}", st.real_size()))
if tc_pattern: store_st = fix_st(store_st, *tc_pattern)
local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i])
srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store))
@@ -646,7 +646,8 @@ class Kernel:
for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local=True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
local_size = st_uop.arg.real_size()
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), (f"temp{self.reduceops.index(op)+1}", local_size))
local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, ret)))
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
if op is self.reduceops[-1]: return grouped_reduce

View File

@@ -49,7 +49,8 @@ def fold_expanded(ex, buf):
rootsrc[0] if isinstance(rootsrc, tuple) else None)
else:
# for non image, we upcast the index pointer
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(local=new_src[0].dtype.local))
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(size=new_src[0].dtype.size//fold_length,
local=new_src[0].dtype.local))
# generate the folded new_srcs
if is_load:
new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))

View File

@@ -78,7 +78,7 @@ class BufferCopy(Runner):
def __init__(self, total_sz, dest_device, src_device):
if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
super().__init__(colored(name, "yellow"), dest_device, Estimates(mem=total_sz))
super().__init__(colored(name, "yellow"), dest_device, Estimates(lds=total_sz, mem=total_sz))
def copy(self, dest, src):
disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.dev, 'io_uring') and \
getattr(src.allocator.dev, 'fd', None) is not None

View File

@@ -36,7 +36,7 @@ class Estimates:
def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
flops: sint = 0
mem: sint = 0
lds: sint = 0
mults: sint = 1
mult_stack: List[sint] = []
dont_count: Set[UOp] = set()
@@ -53,11 +53,10 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
mults *= (u.src[1] - u.src[0]).ssimplify()
elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1)
elif u.op is Ops.SPECIAL: mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
elif u.op is Ops.LOAD: mem += u.dtype.itemsize * mults
elif u.op is Ops.STORE: mem += u.src[1].dtype.itemsize * mults
elif u.op in {Ops.LOAD, Ops.STORE}: lds += u.src[0].dtype.itemsize * mults
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return flops, mem
return flops, lds
@dataclass
class ProgramSpec: