small changes from lowerer. [run_process_replay] [no_assert] (#5102)

This commit is contained in:
George Hotz
2024-06-22 11:09:35 -07:00
committed by GitHub
parent e468601226
commit 9f875123b6
4 changed files with 7 additions and 6 deletions

View File

@@ -329,6 +329,8 @@ class TestOps(unittest.TestCase):
def test_tiny_add(self):
helper_test_op([(3), (3)], lambda x,y: x+y, Tensor.add, forward_only=True)
def test_tiny_mul(self):
helper_test_op([(64), (64)], lambda x,y: x*y, Tensor.mul, forward_only=True)
def test_add(self):
helper_test_op([(45,68), (45,68)], lambda x,y: x+y, Tensor.add)

View File

@@ -63,7 +63,7 @@ class TestBEAM(unittest.TestCase):
assert GlobalCounters.kernel_count == kernel_count + 1
k_beam_0 = capturing[0].captured
capturing.clear()
assert k_beam_0[-1].prg.p.src != k_beam_1[-1].prg.p.src
self.assertNotEqual(k_beam_0[-1].prg.p.src, k_beam_1[-1].prg.p.src)
def test_get_linearizer_actions(self):
from test.test_linearizer import helper_realized_ast

View File

@@ -56,10 +56,8 @@ class UOp:
def __mod__(self, x): return UOp.alu(BinaryOps.MOD, self, ufix(self.dtype, x))
def lt(self, x): return UOp.alu(BinaryOps.CMPLT, self, ufix(self.dtype, x))
def ge(self, x): return -self.lt(x)
@staticmethod
def max(x, y): return UOp.alu(BinaryOps.MAX, x, y)
@staticmethod
def min(x, y): return -UOp.alu(BinaryOps.MAX, -x, -y)
def max(self, x): return UOp.alu(BinaryOps.MAX, self, x)
def min(self, x): return -UOp.alu(BinaryOps.MAX, -self, -x)
@staticmethod
def const(dtype:Optional[DType], b:ConstType|Variable):
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (), b)
@@ -214,6 +212,7 @@ constant_folder = PatternMatcher([
(UOp.var('x') * 1, lambda x: x), # x*1 -> x
(UOp.var('x') // 1, lambda x: x), # x/1 -> x
(UOp.var('x') // -1, lambda x: -x), # x/-1 -> -x
(UOp.var('x', dtype=dtypes.bool).max(UOp.const(dtypes.bool, False)), lambda x: x), # max(x, False) -> x
# ** zero folding **
#x*0 -> 0 or 0*x -> 0
#if x is nan it should render the nan value.

View File

@@ -95,6 +95,6 @@ def graph_uops(uops:List[UOp]):
G = nx.DiGraph()
for u in uops:
if u.op in {UOps.ENDRANGE, UOps.ENDIF}: continue
G.add_node(uops.index(u), label=f"{str(u.op)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.op, "#ffffff")) # noqa: E501
G.add_node(uops.index(u), label=f"{str(u.op)[5:]}{(' '+str(u.arg).replace(':', '')) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.op, "#ffffff")) # noqa: E501
for v in u.src: G.add_edge(uops.index(v), uops.index(u))
save_graph(G, f'{GRAPHPATH}.uops', '-Grankdir=LR')