mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
with commutative hack, uops can change. fix that (#7266)
* with commutative hack, uops can change. fix that * simpler
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# on Windows -- $env:SKIP="devicetests,tests,example"
|
||||
# on Windows -- $env:SKIP="tests,example"
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
|
||||
@@ -166,6 +166,17 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
self.assertEqual(nout.src[1].op, UOps.CONST)
|
||||
self.assertEqual(nout.src[1].arg, 3.0)
|
||||
|
||||
def test_commutative_work(self):
|
||||
a = UOp.variable('a', 0, 1)
|
||||
b = UOp.variable('b', 0, 1)
|
||||
self.assertIs(a+b, b+a)
|
||||
|
||||
def test_consts_go_last_right_away(self):
|
||||
a = UOp.variable('a', 0, 1)
|
||||
tst = 2+a
|
||||
self.assertIs(tst.src[0], a)
|
||||
self.assertIs(tst.src[1], a.const_like(2))
|
||||
|
||||
def test_consts_go_last(self):
|
||||
a = UOp.variable('a', 0, 1)
|
||||
b = UOp.variable('b', 0, 1)
|
||||
|
||||
@@ -443,12 +443,12 @@ class TestIndexingOrdering(unittest.TestCase):
|
||||
|
||||
class TestUPatHelpers(unittest.TestCase):
|
||||
def test_location(self):
|
||||
self.assertEqual(sym.patterns[-1][0].location[0].split("/")[-1], "uopgraph.py")
|
||||
self.assertEqual(to_si.patterns[0][0].location[0].split("/")[-1], "schedule.py")
|
||||
self.assertEqual(spec.patterns[0][0].location[0].split("/")[-1], "ops.py")
|
||||
self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "uopgraph.py")
|
||||
self.assertEqual(to_si.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "schedule.py")
|
||||
self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "ops.py")
|
||||
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
|
||||
test_upat = UPat(UOps.CONST, dtypes.bool)
|
||||
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.split("/")[-1])
|
||||
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, TypeVar, DefaultDict
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib
|
||||
from enum import auto, IntEnum, Enum
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
@@ -198,6 +198,7 @@ class UOp(MathTrait):
|
||||
|
||||
__slots__ = ["op", "dtype", "src", "arg"]
|
||||
def __init__(self, op:UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
if getattr(self, 'op', None) is not None: return
|
||||
# TODO: instant check rules here make debugging easier
|
||||
#assert op in UOps and isinstance(dtype, DType), f"bad UOp creation with {op} {dtype}"
|
||||
#if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool
|
||||
@@ -494,7 +495,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
||||
def get_location() -> Tuple[str, int]:
|
||||
frm = sys._getframe(1)
|
||||
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
|
||||
while frm.f_back is not None and frm.f_back.f_code.co_filename.split("/")[-1] in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py"}:
|
||||
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py"}:
|
||||
frm = frm.f_back
|
||||
return frm.f_code.co_filename, frm.f_lineno
|
||||
@functools.lru_cache(None)
|
||||
|
||||
Reference in New Issue
Block a user