remove nop, use upat [run_process_replay] (#6489)

* remove nop, use upat [run_process_replay]

* mypy passes

* no wonder nothing worked

* fixes
This commit is contained in:
George Hotz
2024-09-12 12:16:19 +08:00
committed by GitHub
parent f12f0857d8
commit 76487a3533
8 changed files with 165 additions and 171 deletions

View File

@@ -5,16 +5,16 @@ from tinygrad import dtypes, Device
from tinygrad.dtype import PtrDType
from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
from tinygrad.ops import NOp, PatternMatcher
from tinygrad.ops import UPat, PatternMatcher
from tinygrad.codegen.lowerer import ast_to_uop
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, graph_rewrite, expander, reducer, constant_folder, float4_folding
from tinygrad.shape.shapetracker import ShapeTracker, View
simple_pm = PatternMatcher([
(NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
(NOp.cvar('x') + NOp.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
(NOp.cvar('x') * NOp.cvar('y') * NOp.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)),
(UPat.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
(UPat.cvar('x') + UPat.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
(UPat.cvar('x') * UPat.cvar('y') * UPat.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
((UPat.var('x') + UPat.cvar('c1')) + UPat.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)),
])
def to_uops_list(u:List[UOp]) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u)))

View File

@@ -7,7 +7,6 @@ from tinygrad.helpers import CI, DEBUG, getenv, Context
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.device import Buffer, Device
from tinygrad.ops import UOps, UOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu # noqa F401
from tinygrad.ops import NOp
from tinygrad.renderer import Program
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
@@ -389,10 +388,6 @@ class TestUOpStr(unittest.TestCase):
sink = UOp(UOps.SINK, dtypes.void, (get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops[-1],))
assert_equiv_uops(sink, eval(str(sink)))
def test_nop_str(self):
a = NOp(UOps.CONST, dtypes.float, (), 2.0, name="c0") + NOp(UOps.CONST, dtypes.float, (), 3.0, name="c1")
assert str(eval(str(a))) == str(a)
def test_variable_const(self):
# TODO: this is not possible after VALID.
uop = UOp(UOps.CONST, dtypes.int, (), arg=Variable("a",1,10))

View File

@@ -30,9 +30,9 @@ class TestPatternMatcher(unittest.TestCase):
def test_arg(self):
matcher = PatternMatcher([
(UPat(UOps.CONST, 0, name="x"), lambda x: x),
(UPat(UOps.CONST, False, name="x"), lambda x: x),
(UPat(UOps.ALU, BinaryOps.MAX, name="x"), lambda x: x),
(UPat(UOps.CONST, arg=0, name="x"), lambda x: x),
(UPat(UOps.CONST, arg=False, name="x"), lambda x: x),
(UPat(UOps.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x),
])
c1 = UOp(UOps.CONST, dtypes.float, arg=0.0)
c2 = UOp(UOps.CONST, dtypes.bool, arg=False)
@@ -47,7 +47,7 @@ class TestPatternMatcher(unittest.TestCase):
def test_filter_arg(self):
matcher = PatternMatcher([
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c"), UPat(UOps.CONST, 2)], name="x"),
(UPat(UOps.ALU, arg=BinaryOps.MUL, src=[UPat(UOps.CONST, name="c"), UPat(UOps.CONST, arg=2)], name="x"),
lambda x,c: x if c.arg in {1, -1} else None)
])
y1 = UOp(UOps.CONST, dtypes.int, arg=1)