fix define global (#4383)

* fix define global

* remove name from DEFINE_GLOBAL

* fix fuzzing

* fix ptx

* fix python
This commit is contained in:
George Hotz
2024-05-01 19:32:56 -07:00
committed by GitHub
parent ad116dc5c6
commit f635c4d273
10 changed files with 48 additions and 42 deletions

View File

@@ -2,7 +2,7 @@ import itertools
import numpy as np
from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar
from tinygrad.buffer import Buffer
from tinygrad.engine.realize import CustomOp, ExecItem, capturing, lower_schedule_item
from tinygrad.engine.realize import CustomOp, capturing, lower_schedule_item
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.schedule import _graph_schedule, _LBScheduleItem
@@ -67,7 +67,7 @@ def fuzz_schedule(outs: List[LazyBuffer]):
raise e
def _exec_si(si: ScheduleItem, seed:int):
ei = ExecItem(lower_schedule_item(si), list(si.bufs))
ei = lower_schedule_item(si)
if len(capturing): capturing[0].add(ei)
if isinstance(ei.prg, CustomOp): Tensor._seed = seed
ei.run()

View File

@@ -14,7 +14,7 @@ from test.helpers import is_dtype_supported
def _uops_to_prg(uops):
src = Device[Device.DEFAULT].compiler.render("test", uops)
has_local = Device[Device.DEFAULT].compiler.compiler_opts.has_local
return CompiledRunner("test", src, Device.DEFAULT, [1] if has_local else None, [1] if has_local else None)
return CompiledRunner("test", src, Device.DEFAULT, [1] if has_local else None, [1] if has_local else None, uops=uops)
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(vin), arg))
@@ -23,8 +23,8 @@ def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], ar
def _test_single_value(vals, op, dts):
uops = []
output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0]
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0',True))
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (i+1, f'data{i+1}',False)) for i,dtype in enumerate(dts)]
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, True))
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (i+1, False)) for i,dtype in enumerate(dts)]
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts))
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
@@ -39,7 +39,7 @@ def _test_single_value(vals, op, dts):
def _test_single_value_const(vals, op, dts):
uops = []
output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0]
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0',True))
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, True))
loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
@@ -52,7 +52,7 @@ def _test_single_value_const(vals, op, dts):
def _test_uops_result(output_dtype, uops, res):
# uops = []
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0',True))
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, True))
# res = output_fn(uops)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res))
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
@@ -238,7 +238,7 @@ class TestAssembly(unittest.TestCase):
def test_pointer_arithmetics_caching(self):
from tinygrad.renderer.assembly import ptr_ar
uops = UOpGraph()
u1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple(), (0, 'data0', True))
u1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple(), (0, True))
u2 = uops.add(UOps.SPECIAL, dtypes.int, tuple(), (0, 'gidx0', 9))
u3 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=42)
u4 = uops.add(UOps.ALU, dtypes.int, (u2, u3), BinaryOps.MUL)