mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
@@ -3,8 +3,8 @@ import unittest, pickle, functools, math
|
||||
import z3
|
||||
|
||||
from tinygrad.dtype import dtypes, ConstType, DType, Invalid
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.helpers import Context
|
||||
from test.helpers import get_uops
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer
|
||||
from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid
|
||||
from tinygrad.uop.validate import uops_to_z3
|
||||
@@ -747,7 +747,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
|
||||
# TODO: copied from render, render does not support cast
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
|
||||
uops = full_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink())
|
||||
uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink())
|
||||
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1]
|
||||
|
||||
self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half)))
|
||||
|
||||
Reference in New Issue
Block a user