with commutative hack, uops can change. fix that (#7266)

* with commutative hack, uops can change. fix that

* simpler
This commit is contained in:
George Hotz
2024-10-24 17:50:23 +07:00
committed by GitHub
parent d482d927a8
commit 9a3d498d9c
4 changed files with 19 additions and 7 deletions

View File

@@ -1,4 +1,4 @@
# on Windows -- $env:SKIP="devicetests,tests,example" # on Windows -- $env:SKIP="tests,example"
repos: repos:
- repo: local - repo: local
hooks: hooks:

View File

@@ -166,6 +166,17 @@ class TestGraphRewrite(unittest.TestCase):
self.assertEqual(nout.src[1].op, UOps.CONST) self.assertEqual(nout.src[1].op, UOps.CONST)
self.assertEqual(nout.src[1].arg, 3.0) self.assertEqual(nout.src[1].arg, 3.0)
def test_commutative_work(self):
a = UOp.variable('a', 0, 1)
b = UOp.variable('b', 0, 1)
self.assertIs(a+b, b+a)
def test_consts_go_last_right_away(self):
a = UOp.variable('a', 0, 1)
tst = 2+a
self.assertIs(tst.src[0], a)
self.assertIs(tst.src[1], a.const_like(2))
def test_consts_go_last(self): def test_consts_go_last(self):
a = UOp.variable('a', 0, 1) a = UOp.variable('a', 0, 1)
b = UOp.variable('b', 0, 1) b = UOp.variable('b', 0, 1)

View File

@@ -443,12 +443,12 @@ class TestIndexingOrdering(unittest.TestCase):
class TestUPatHelpers(unittest.TestCase): class TestUPatHelpers(unittest.TestCase):
def test_location(self): def test_location(self):
self.assertEqual(sym.patterns[-1][0].location[0].split("/")[-1], "uopgraph.py") self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "uopgraph.py")
self.assertEqual(to_si.patterns[0][0].location[0].split("/")[-1], "schedule.py") self.assertEqual(to_si.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "schedule.py")
self.assertEqual(spec.patterns[0][0].location[0].split("/")[-1], "ops.py") self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "ops.py")
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*? with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
test_upat = UPat(UOps.CONST, dtypes.bool) test_upat = UPat(UOps.CONST, dtypes.bool)
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.split("/")[-1]) self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main(verbosity=2) unittest.main(verbosity=2)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, TypeVar, DefaultDict from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, TypeVar, DefaultDict
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib
from enum import auto, IntEnum, Enum from enum import auto, IntEnum, Enum
from dataclasses import dataclass, field from dataclasses import dataclass, field
from collections import defaultdict from collections import defaultdict
@@ -198,6 +198,7 @@ class UOp(MathTrait):
__slots__ = ["op", "dtype", "src", "arg"] __slots__ = ["op", "dtype", "src", "arg"]
def __init__(self, op:UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None): def __init__(self, op:UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
if getattr(self, 'op', None) is not None: return
# TODO: instant check rules here make debugging easier # TODO: instant check rules here make debugging easier
#assert op in UOps and isinstance(dtype, DType), f"bad UOp creation with {op} {dtype}" #assert op in UOps and isinstance(dtype, DType), f"bad UOp creation with {op} {dtype}"
#if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool #if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool
@@ -494,7 +495,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
def get_location() -> Tuple[str, int]: def get_location() -> Tuple[str, int]:
frm = sys._getframe(1) frm = sys._getframe(1)
# find the real frame in the file that has the UPat, TODO: is there a better way to do this? # find the real frame in the file that has the UPat, TODO: is there a better way to do this?
while frm.f_back is not None and frm.f_back.f_code.co_filename.split("/")[-1] in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py"}: while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py"}:
frm = frm.f_back frm = frm.f_back
return frm.f_code.co_filename, frm.f_lineno return frm.f_code.co_filename, frm.f_lineno
@functools.lru_cache(None) @functools.lru_cache(None)