mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix define global (#4383)
* fix define global * remove name from DEFINE_GLOBAL * fix fuzzing * fix ptx * fix python
This commit is contained in:
4
test/external/fuzz_schedule.py
vendored
4
test/external/fuzz_schedule.py
vendored
@@ -2,7 +2,7 @@ import itertools
|
||||
import numpy as np
|
||||
from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar
|
||||
from tinygrad.buffer import Buffer
|
||||
from tinygrad.engine.realize import CustomOp, ExecItem, capturing, lower_schedule_item
|
||||
from tinygrad.engine.realize import CustomOp, capturing, lower_schedule_item
|
||||
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.engine.schedule import _graph_schedule, _LBScheduleItem
|
||||
@@ -67,7 +67,7 @@ def fuzz_schedule(outs: List[LazyBuffer]):
|
||||
raise e
|
||||
|
||||
def _exec_si(si: ScheduleItem, seed:int):
|
||||
ei = ExecItem(lower_schedule_item(si), list(si.bufs))
|
||||
ei = lower_schedule_item(si)
|
||||
if len(capturing): capturing[0].add(ei)
|
||||
if isinstance(ei.prg, CustomOp): Tensor._seed = seed
|
||||
ei.run()
|
||||
|
||||
@@ -14,7 +14,7 @@ from test.helpers import is_dtype_supported
|
||||
def _uops_to_prg(uops):
|
||||
src = Device[Device.DEFAULT].compiler.render("test", uops)
|
||||
has_local = Device[Device.DEFAULT].compiler.compiler_opts.has_local
|
||||
return CompiledRunner("test", src, Device.DEFAULT, [1] if has_local else None, [1] if has_local else None)
|
||||
return CompiledRunner("test", src, Device.DEFAULT, [1] if has_local else None, [1] if has_local else None, uops=uops)
|
||||
|
||||
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
uops.append(UOp(uop, dtype, tuple(vin), arg))
|
||||
@@ -23,8 +23,8 @@ def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], ar
|
||||
def _test_single_value(vals, op, dts):
|
||||
uops = []
|
||||
output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0]
|
||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0',True))
|
||||
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (i+1, f'data{i+1}',False)) for i,dtype in enumerate(dts)]
|
||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, True))
|
||||
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (i+1, False)) for i,dtype in enumerate(dts)]
|
||||
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts))
|
||||
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
|
||||
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
||||
@@ -39,7 +39,7 @@ def _test_single_value(vals, op, dts):
|
||||
def _test_single_value_const(vals, op, dts):
|
||||
uops = []
|
||||
output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0]
|
||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0',True))
|
||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, True))
|
||||
loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
|
||||
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
|
||||
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
||||
@@ -52,7 +52,7 @@ def _test_single_value_const(vals, op, dts):
|
||||
|
||||
def _test_uops_result(output_dtype, uops, res):
|
||||
# uops = []
|
||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0',True))
|
||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, True))
|
||||
# res = output_fn(uops)
|
||||
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res))
|
||||
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
|
||||
@@ -238,7 +238,7 @@ class TestAssembly(unittest.TestCase):
|
||||
def test_pointer_arithmetics_caching(self):
|
||||
from tinygrad.renderer.assembly import ptr_ar
|
||||
uops = UOpGraph()
|
||||
u1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple(), (0, 'data0', True))
|
||||
u1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple(), (0, True))
|
||||
u2 = uops.add(UOps.SPECIAL, dtypes.int, tuple(), (0, 'gidx0', 9))
|
||||
u3 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=42)
|
||||
u4 = uops.add(UOps.ALU, dtypes.int, (u2, u3), BinaryOps.MUL)
|
||||
|
||||
Reference in New Issue
Block a user