mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
uop global buf number tracking try 2 [pr] (#7912)
* uop buffer init small refactor [pr] * add early * this way it doesn't need late * buffer_num * itertools.count * count from 0 * down to 380
This commit is contained in:
4
test/external/speed_v_theoretical.py
vendored
4
test/external/speed_v_theoretical.py
vendored
@@ -90,11 +90,11 @@ class TestKernelSpeed(unittest.TestCase):
|
||||
def test_gemm_4096(self): self._test_matmul(4096, nv_tflops=95, amd_tflops=70)
|
||||
def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=125, amd_tflops=70)
|
||||
|
||||
def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=430, amd_gbs=400)
|
||||
def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=430, amd_gbs=380) # AMD was flaky at 400
|
||||
def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=430, amd_gbs=380) # AMD was flaky at 400
|
||||
|
||||
# TODO: tiny7 is slower than tiny12
|
||||
def test_conv_3x3_256_32_32_256_256(self): self._test_conv_3x3(256, 32, 32, 256, 256, nv_tflops=27, amd_tflops=18)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@@ -65,7 +65,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache
|
||||
assert buf.op is not None, f"base must be base itself {buf}"
|
||||
dtype = buf.dtype if buf.op in GroupOp.Meta else buf.dtype.base
|
||||
if buf.is_realized:
|
||||
ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype, num=len(buffers))
|
||||
ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype)
|
||||
buffers[ubuf] = buf.buffer
|
||||
op = None
|
||||
elif buf.op is Ops.ASSIGN:
|
||||
@@ -73,7 +73,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache
|
||||
ctx.assigns.add(ubuf:=target.buf_uop)
|
||||
op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg)
|
||||
else:
|
||||
ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype, num=len(buffers))
|
||||
ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype)
|
||||
buffers[ubuf] = buf.buffer
|
||||
op = UOp(buf.op, dtype, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs), buf.arg)
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
|
||||
|
||||
@@ -378,8 +378,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
# *** uop Buffer stuff ***
|
||||
|
||||
buffer_num = itertools.count(0)
|
||||
@staticmethod
|
||||
def new_buffer(device:str, size:int, dtype:DType, num=-1): return UOp(Ops.BUFFER, dtype.ptr(), (), (num, (device, size, dtype)))
|
||||
def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype.ptr(), (), (next(UOp.buffer_num), (device, size, dtype)))
|
||||
@functools.cached_property
|
||||
def device(self) -> str:
|
||||
match self.op:
|
||||
|
||||
Reference in New Issue
Block a user