uop global buf number tracking try 2 [pr] (#7912)

* uop buffer init small refactor [pr]

* add early

* this way it doesn't need late

* buffer_num

* itertools.count

* count from 0

* down to 380
This commit is contained in:
qazal
2024-12-02 01:45:17 -05:00
committed by GitHub
parent cbcc1c20eb
commit b797aee720
3 changed files with 6 additions and 5 deletions

View File

@@ -90,11 +90,11 @@ class TestKernelSpeed(unittest.TestCase):
def test_gemm_4096(self): self._test_matmul(4096, nv_tflops=95, amd_tflops=70)
def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=125, amd_tflops=70)
def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=430, amd_gbs=400)
def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=430, amd_gbs=380) # AMD was flaky at 400
def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=430, amd_gbs=380) # AMD was flaky at 400
# TODO: tiny7 is slower than tiny12
def test_conv_3x3_256_32_32_256_256(self): self._test_conv_3x3(256, 32, 32, 256, 256, nv_tflops=27, amd_tflops=18)
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@@ -65,7 +65,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache
assert buf.op is not None, f"base must be base itself {buf}"
dtype = buf.dtype if buf.op in GroupOp.Meta else buf.dtype.base
if buf.is_realized:
ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype, num=len(buffers))
ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype)
buffers[ubuf] = buf.buffer
op = None
elif buf.op is Ops.ASSIGN:
@@ -73,7 +73,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache
ctx.assigns.add(ubuf:=target.buf_uop)
op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg)
else:
ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype, num=len(buffers))
ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype)
buffers[ubuf] = buf.buffer
op = UOp(buf.op, dtype, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs), buf.arg)
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)

View File

@@ -378,8 +378,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# *** uop Buffer stuff ***
buffer_num = itertools.count(0)
@staticmethod
def new_buffer(device:str, size:int, dtype:DType, num=-1): return UOp(Ops.BUFFER, dtype.ptr(), (), (num, (device, size, dtype)))
def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype.ptr(), (), (next(UOp.buffer_num), (device, size, dtype)))
@functools.cached_property
def device(self) -> str:
match self.op: