mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
work
This commit is contained in:
@@ -5,10 +5,10 @@ from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.helpers import CI, DEBUG, getenv, Timing
|
||||
from tinygrad.dtype import dtypes, DType, AddrSpace
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu # noqa F401
|
||||
from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu, AxisType
|
||||
from tinygrad.uop.spec import shared_spec
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program, get_runner
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program, get_runner, ExecItem
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.uop.symbolic import sym
|
||||
from tinygrad.device import is_dtype_supported
|
||||
@@ -567,24 +567,23 @@ class TestZeroRange(unittest.TestCase):
|
||||
|
||||
class TestUOpPrograms(unittest.TestCase):
|
||||
def test_matmul(self):
|
||||
A, B, C = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(100), arg=i) for i in range(3)]
|
||||
i = UOp.range(10, 0)
|
||||
j = UOp.range(10, 1)
|
||||
C = C.after(C[i*10+j].store(UOp.const(dtypes.float, 0.0)))
|
||||
k = UOp.range(10, 2)
|
||||
store = C[i*10+j].store(C.after(k)[i*10+j] + (A[i*10+k] * B[k*10+j]))
|
||||
prog = store.end(i,j,k).sink()
|
||||
|
||||
runner = get_runner(Device.DEFAULT, prog)
|
||||
print(runner.p.src)
|
||||
|
||||
a = Tensor.rand(10,10)
|
||||
b = Tensor.rand(10,10)
|
||||
c = Tensor.empty(10,10)
|
||||
ref = a@b
|
||||
Tensor.realize(a, b, c, ref)
|
||||
|
||||
runner([a.uop.buffer, b.uop.buffer, c.uop.buffer.ensure_allocated()], wait=True)
|
||||
A, B, C = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(100), arg=i) for i in range(3)]
|
||||
i = UOp.range(10, 0)
|
||||
j = UOp.range(10, 1)
|
||||
C = C.after(C[i*10+j].store(UOp.const(dtypes.float, 0.0)))
|
||||
k = UOp.range(10, 2, AxisType.REDUCE) # <-- this tells the GPU it can't be a global dim
|
||||
store = C[i*10+j].store(C.after(k)[i*10+j] + (A[i*10+k] * B[k*10+j]))
|
||||
prog = store.end(i,j,k).sink(arg=KernelInfo(opts_to_apply=()))
|
||||
|
||||
runner = get_runner(Device.DEFAULT, prog)
|
||||
print(runner.p.src)
|
||||
ExecItem(runner, [a.uop.buffer, b.uop.buffer, c.uop.buffer]).run(wait=True)
|
||||
self.assertLessEqual((c-ref).square().mean().item(), 1e-6)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -64,7 +64,7 @@ class Scheduler:
|
||||
return self.ast.replace(arg=KernelInfo(name=name, applied_opts=tuple(self.applied_opts), dont_use_locals=self.dont_use_locals), tag=1)
|
||||
|
||||
def _output_rngs(self) -> list[UOp]:
|
||||
return flatten([list(UOp.sink(*s.src[1:]).ranges) for s in self.ast.src if s.op is Ops.END])
|
||||
return flatten([[r for r in UOp.sink(*s.src[1:]).ranges if r.arg[-1] != AxisType.REDUCE] for s in self.ast.src if s.op is Ops.END])
|
||||
def _globalizable_rngs(self) -> list[UOp]:
|
||||
ret = self._output_rngs()
|
||||
# exclude any output ranges from global that don't appear in all BUFFERIZE
|
||||
|
||||
Reference in New Issue
Block a user