This commit is contained in:
George Hotz
2025-10-30 11:01:53 +08:00
parent 79f98a6624
commit 80f9347e53
2 changed files with 14 additions and 15 deletions

View File

@@ -5,10 +5,10 @@ from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.helpers import CI, DEBUG, getenv, Timing
from tinygrad.dtype import dtypes, DType, AddrSpace
from tinygrad.device import Buffer, Device
from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu # noqa F401
from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu, AxisType
from tinygrad.uop.spec import shared_spec
from tinygrad.renderer import ProgramSpec
from tinygrad.engine.realize import CompiledRunner, get_program, get_runner
from tinygrad.engine.realize import CompiledRunner, get_program, get_runner, ExecItem
from tinygrad.codegen import full_rewrite
from tinygrad.uop.symbolic import sym
from tinygrad.device import is_dtype_supported
@@ -567,24 +567,23 @@ class TestZeroRange(unittest.TestCase):
class TestUOpPrograms(unittest.TestCase):
def test_matmul(self):
A, B, C = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(100), arg=i) for i in range(3)]
i = UOp.range(10, 0)
j = UOp.range(10, 1)
C = C.after(C[i*10+j].store(UOp.const(dtypes.float, 0.0)))
k = UOp.range(10, 2)
store = C[i*10+j].store(C.after(k)[i*10+j] + (A[i*10+k] * B[k*10+j]))
prog = store.end(i,j,k).sink()
runner = get_runner(Device.DEFAULT, prog)
print(runner.p.src)
a = Tensor.rand(10,10)
b = Tensor.rand(10,10)
c = Tensor.empty(10,10)
ref = a@b
Tensor.realize(a, b, c, ref)
runner([a.uop.buffer, b.uop.buffer, c.uop.buffer.ensure_allocated()], wait=True)
A, B, C = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(100), arg=i) for i in range(3)]
i = UOp.range(10, 0)
j = UOp.range(10, 1)
C = C.after(C[i*10+j].store(UOp.const(dtypes.float, 0.0)))
k = UOp.range(10, 2, AxisType.REDUCE) # <-- this tells the GPU it can't be a global dim
store = C[i*10+j].store(C.after(k)[i*10+j] + (A[i*10+k] * B[k*10+j]))
prog = store.end(i,j,k).sink(arg=KernelInfo(opts_to_apply=()))
runner = get_runner(Device.DEFAULT, prog)
print(runner.p.src)
ExecItem(runner, [a.uop.buffer, b.uop.buffer, c.uop.buffer]).run(wait=True)
self.assertLessEqual((c-ref).square().mean().item(), 1e-6)
if __name__ == '__main__':

View File

@@ -64,7 +64,7 @@ class Scheduler:
return self.ast.replace(arg=KernelInfo(name=name, applied_opts=tuple(self.applied_opts), dont_use_locals=self.dont_use_locals), tag=1)
def _output_rngs(self) -> list[UOp]:
return flatten([list(UOp.sink(*s.src[1:]).ranges) for s in self.ast.src if s.op is Ops.END])
return flatten([[r for r in UOp.sink(*s.src[1:]).ranges if r.arg[-1] != AxisType.REDUCE] for s in self.ast.src if s.op is Ops.END])
def _globalizable_rngs(self) -> list[UOp]:
ret = self._output_rngs()
# exclude any output ranges from global that don't appear in all BUFFERIZE