mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
small changes from lowerer. [run_process_replay] [no_assert] (#5102)
This commit is contained in:
@@ -329,6 +329,8 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_tiny_add(self):
|
||||
helper_test_op([(3), (3)], lambda x,y: x+y, Tensor.add, forward_only=True)
|
||||
def test_tiny_mul(self):
|
||||
helper_test_op([(64), (64)], lambda x,y: x*y, Tensor.mul, forward_only=True)
|
||||
|
||||
def test_add(self):
|
||||
helper_test_op([(45,68), (45,68)], lambda x,y: x+y, Tensor.add)
|
||||
|
||||
@@ -63,7 +63,7 @@ class TestBEAM(unittest.TestCase):
|
||||
assert GlobalCounters.kernel_count == kernel_count + 1
|
||||
k_beam_0 = capturing[0].captured
|
||||
capturing.clear()
|
||||
assert k_beam_0[-1].prg.p.src != k_beam_1[-1].prg.p.src
|
||||
self.assertNotEqual(k_beam_0[-1].prg.p.src, k_beam_1[-1].prg.p.src)
|
||||
|
||||
def test_get_linearizer_actions(self):
|
||||
from test.test_linearizer import helper_realized_ast
|
||||
|
||||
@@ -56,10 +56,8 @@ class UOp:
|
||||
def __mod__(self, x): return UOp.alu(BinaryOps.MOD, self, ufix(self.dtype, x))
|
||||
def lt(self, x): return UOp.alu(BinaryOps.CMPLT, self, ufix(self.dtype, x))
|
||||
def ge(self, x): return -self.lt(x)
|
||||
@staticmethod
|
||||
def max(x, y): return UOp.alu(BinaryOps.MAX, x, y)
|
||||
@staticmethod
|
||||
def min(x, y): return -UOp.alu(BinaryOps.MAX, -x, -y)
|
||||
def max(self, x): return UOp.alu(BinaryOps.MAX, self, x)
|
||||
def min(self, x): return -UOp.alu(BinaryOps.MAX, -self, -x)
|
||||
@staticmethod
|
||||
def const(dtype:Optional[DType], b:ConstType|Variable):
|
||||
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (), b)
|
||||
@@ -214,6 +212,7 @@ constant_folder = PatternMatcher([
|
||||
(UOp.var('x') * 1, lambda x: x), # x*1 -> x
|
||||
(UOp.var('x') // 1, lambda x: x), # x/1 -> x
|
||||
(UOp.var('x') // -1, lambda x: -x), # x/-1 -> -x
|
||||
(UOp.var('x', dtype=dtypes.bool).max(UOp.const(dtypes.bool, False)), lambda x: x), # max(x, False) -> x
|
||||
# ** zero folding **
|
||||
#x*0 -> 0 or 0*x -> 0
|
||||
#if x is nan it should render the nan value.
|
||||
|
||||
@@ -95,6 +95,6 @@ def graph_uops(uops:List[UOp]):
|
||||
G = nx.DiGraph()
|
||||
for u in uops:
|
||||
if u.op in {UOps.ENDRANGE, UOps.ENDIF}: continue
|
||||
G.add_node(uops.index(u), label=f"{str(u.op)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.op, "#ffffff")) # noqa: E501
|
||||
G.add_node(uops.index(u), label=f"{str(u.op)[5:]}{(' '+str(u.arg).replace(':', '')) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.op, "#ffffff")) # noqa: E501
|
||||
for v in u.src: G.add_edge(uops.index(v), uops.index(u))
|
||||
save_graph(G, f'{GRAPHPATH}.uops', '-Grankdir=LR')
|
||||
|
||||
Reference in New Issue
Block a user