diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 18e0b89883..364a47d3d4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -# on Windows -- $env:SKIP="devicetests,tests,example" +# on Windows -- $env:SKIP="tests,example" repos: - repo: local hooks: diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 324276dc44..b488ddac9a 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -166,6 +166,17 @@ class TestGraphRewrite(unittest.TestCase): self.assertEqual(nout.src[1].op, UOps.CONST) self.assertEqual(nout.src[1].arg, 3.0) + def test_commutative_work(self): + a = UOp.variable('a', 0, 1) + b = UOp.variable('b', 0, 1) + self.assertIs(a+b, b+a) + + def test_consts_go_last_right_away(self): + a = UOp.variable('a', 0, 1) + tst = 2+a + self.assertIs(tst.src[0], a) + self.assertIs(tst.src[1], a.const_like(2)) + def test_consts_go_last(self): a = UOp.variable('a', 0, 1) b = UOp.variable('b', 0, 1) diff --git a/test/test_uops.py b/test/test_uops.py index cb74c0a15e..332cb9741b 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -443,12 +443,12 @@ class TestIndexingOrdering(unittest.TestCase): class TestUPatHelpers(unittest.TestCase): def test_location(self): - self.assertEqual(sym.patterns[-1][0].location[0].split("/")[-1], "uopgraph.py") - self.assertEqual(to_si.patterns[0][0].location[0].split("/")[-1], "schedule.py") - self.assertEqual(spec.patterns[0][0].location[0].split("/")[-1], "ops.py") + self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "uopgraph.py") + self.assertEqual(to_si.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "schedule.py") + self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "ops.py") with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*? test_upat = UPat(UOps.CONST, dtypes.bool) - self.assertEqual(test_upat.location[0].split("/")[-1], __file__.split("/")[-1]) + self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1]) if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index c05d2ca72f..e816312b12 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, TypeVar, DefaultDict -import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle +import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib from enum import auto, IntEnum, Enum from dataclasses import dataclass, field from collections import defaultdict @@ -198,6 +198,7 @@ class UOp(MathTrait): __slots__ = ["op", "dtype", "src", "arg"] def __init__(self, op:UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None): + if getattr(self, 'op', None) is not None: return # TODO: instant check rules here make debugging easier #assert op in UOps and isinstance(dtype, DType), f"bad UOp creation with {op} {dtype}" #if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool @@ -494,7 +495,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]: def get_location() -> Tuple[str, int]: frm = sys._getframe(1) # find the real frame in the file that has the UPat, TODO: is there a better way to do this? - while frm.f_back is not None and frm.f_back.f_code.co_filename.split("/")[-1] in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py"}: + while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py"}: frm = frm.f_back return frm.f_code.co_filename, frm.f_lineno @functools.lru_cache(None)