with commutative hack, uops can change. fix that (#7266)

* with commutative hack, uops can change. fix that

* simpler
This commit is contained in:
George Hotz
2024-10-24 17:50:23 +07:00
committed by GitHub
parent d482d927a8
commit 9a3d498d9c
4 changed files with 19 additions and 7 deletions

View File

@@ -1,4 +1,4 @@
# on Windows -- $env:SKIP="devicetests,tests,example"
# on Windows -- $env:SKIP="tests,example"
repos:
- repo: local
hooks:

View File

@@ -166,6 +166,17 @@ class TestGraphRewrite(unittest.TestCase):
self.assertEqual(nout.src[1].op, UOps.CONST)
self.assertEqual(nout.src[1].arg, 3.0)
def test_commutative_work(self):
a = UOp.variable('a', 0, 1)
b = UOp.variable('b', 0, 1)
self.assertIs(a+b, b+a)
def test_consts_go_last_right_away(self):
a = UOp.variable('a', 0, 1)
tst = 2+a
self.assertIs(tst.src[0], a)
self.assertIs(tst.src[1], a.const_like(2))
def test_consts_go_last(self):
a = UOp.variable('a', 0, 1)
b = UOp.variable('b', 0, 1)

View File

@@ -443,12 +443,12 @@ class TestIndexingOrdering(unittest.TestCase):
class TestUPatHelpers(unittest.TestCase):
def test_location(self):
self.assertEqual(sym.patterns[-1][0].location[0].split("/")[-1], "uopgraph.py")
self.assertEqual(to_si.patterns[0][0].location[0].split("/")[-1], "schedule.py")
self.assertEqual(spec.patterns[0][0].location[0].split("/")[-1], "ops.py")
self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "uopgraph.py")
self.assertEqual(to_si.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "schedule.py")
self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "ops.py")
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
test_upat = UPat(UOps.CONST, dtypes.bool)
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.split("/")[-1])
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, TypeVar, DefaultDict
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib
from enum import auto, IntEnum, Enum
from dataclasses import dataclass, field
from collections import defaultdict
@@ -198,6 +198,7 @@ class UOp(MathTrait):
__slots__ = ["op", "dtype", "src", "arg"]
def __init__(self, op:UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
if getattr(self, 'op', None) is not None: return
# TODO: instant check rules here make debugging easier
#assert op in UOps and isinstance(dtype, DType), f"bad UOp creation with {op} {dtype}"
#if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool
@@ -494,7 +495,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
def get_location() -> Tuple[str, int]:
frm = sys._getframe(1)
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
while frm.f_back is not None and frm.f_back.f_code.co_filename.split("/")[-1] in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py"}:
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py"}:
frm = frm.f_back
return frm.f_code.co_filename, frm.f_lineno
@functools.lru_cache(None)