mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix test
This commit is contained in:
@@ -445,10 +445,10 @@ class TestUOpGraph(unittest.TestCase):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), src=(), arg=0)
|
||||
v = Variable("v", 0, 20)
|
||||
st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v), UOp.const(dtypes.int, 0), UOp(Ops.IF, src=(v<16,))))
|
||||
st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v, v<16), UOp.const(dtypes.int, 0)))
|
||||
to_uops_list([st0])
|
||||
|
||||
st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v), v, v<20))
|
||||
st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v, v<20), v))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([st1])
|
||||
|
||||
def test_in_bounds_access_gated_local(self):
|
||||
@@ -463,7 +463,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
|
||||
gate = (gidx<400) & (lidx<8)
|
||||
|
||||
local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx), UOp.const(dtypes.uint, 1), UOp(Ops.IF, src=(lidx<8,))))
|
||||
local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx, lidx<8), UOp.const(dtypes.uint, 1)))
|
||||
|
||||
barrier = UOp(Ops.BARRIER, dtypes.void, (local_store,))
|
||||
if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier))
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.helpers import NOOPT, BEAM, USE_TC, getenv
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.uop.spec import type_verify
|
||||
|
||||
def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp:
|
||||
def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp|None:
|
||||
"""
|
||||
Optimize an AST based on heuristics or BEAM search.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user