mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-28 08:17:58 -05:00
cleanups from Estimate [pr] (#8329)
This commit is contained in:
@@ -617,7 +617,7 @@ class Kernel:
|
||||
if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals
|
||||
local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape))
|
||||
st = store_st = ShapeTracker.from_shape(local_shape)
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{i + 1}", st.real_size()))
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), (f"temp{i + 1}", st.real_size()))
|
||||
if tc_pattern: store_st = fix_st(store_st, *tc_pattern)
|
||||
local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i])
|
||||
srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store))
|
||||
@@ -646,7 +646,8 @@ class Kernel:
|
||||
for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
|
||||
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
|
||||
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local=True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
|
||||
local_size = st_uop.arg.real_size()
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), (f"temp{self.reduceops.index(op)+1}", local_size))
|
||||
local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, ret)))
|
||||
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
|
||||
if op is self.reduceops[-1]: return grouped_reduce
|
||||
|
||||
@@ -49,7 +49,8 @@ def fold_expanded(ex, buf):
|
||||
rootsrc[0] if isinstance(rootsrc, tuple) else None)
|
||||
else:
|
||||
# for non image, we upcast the index pointer
|
||||
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(local=new_src[0].dtype.local))
|
||||
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(size=new_src[0].dtype.size//fold_length,
|
||||
local=new_src[0].dtype.local))
|
||||
# generate the folded new_srcs
|
||||
if is_load:
|
||||
new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
|
||||
|
||||
@@ -78,7 +78,7 @@ class BufferCopy(Runner):
|
||||
def __init__(self, total_sz, dest_device, src_device):
|
||||
if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
|
||||
else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
|
||||
super().__init__(colored(name, "yellow"), dest_device, Estimates(mem=total_sz))
|
||||
super().__init__(colored(name, "yellow"), dest_device, Estimates(lds=total_sz, mem=total_sz))
|
||||
def copy(self, dest, src):
|
||||
disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.dev, 'io_uring') and \
|
||||
getattr(src.allocator.dev, 'fd', None) is not None
|
||||
|
||||
@@ -36,7 +36,7 @@ class Estimates:
|
||||
|
||||
def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
||||
flops: sint = 0
|
||||
mem: sint = 0
|
||||
lds: sint = 0
|
||||
mults: sint = 1
|
||||
mult_stack: List[sint] = []
|
||||
dont_count: Set[UOp] = set()
|
||||
@@ -53,11 +53,10 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
||||
mults *= (u.src[1] - u.src[0]).ssimplify()
|
||||
elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1)
|
||||
elif u.op is Ops.SPECIAL: mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
|
||||
elif u.op is Ops.LOAD: mem += u.dtype.itemsize * mults
|
||||
elif u.op is Ops.STORE: mem += u.src[1].dtype.itemsize * mults
|
||||
elif u.op in {Ops.LOAD, Ops.STORE}: lds += u.src[0].dtype.itemsize * mults
|
||||
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
|
||||
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
||||
return flops, mem
|
||||
return flops, lds
|
||||
|
||||
@dataclass
|
||||
class ProgramSpec:
|
||||
|
||||
Reference in New Issue
Block a user