mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
with commutative hack, uops can change. fix that (#7266)
* with commutative hack, uops can change. fix that * simpler
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
# on Windows -- $env:SKIP="devicetests,tests,example"
|
# on Windows -- $env:SKIP="tests,example"
|
||||||
repos:
|
repos:
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
|
|||||||
@@ -166,6 +166,17 @@ class TestGraphRewrite(unittest.TestCase):
|
|||||||
self.assertEqual(nout.src[1].op, UOps.CONST)
|
self.assertEqual(nout.src[1].op, UOps.CONST)
|
||||||
self.assertEqual(nout.src[1].arg, 3.0)
|
self.assertEqual(nout.src[1].arg, 3.0)
|
||||||
|
|
||||||
|
def test_commutative_work(self):
|
||||||
|
a = UOp.variable('a', 0, 1)
|
||||||
|
b = UOp.variable('b', 0, 1)
|
||||||
|
self.assertIs(a+b, b+a)
|
||||||
|
|
||||||
|
def test_consts_go_last_right_away(self):
|
||||||
|
a = UOp.variable('a', 0, 1)
|
||||||
|
tst = 2+a
|
||||||
|
self.assertIs(tst.src[0], a)
|
||||||
|
self.assertIs(tst.src[1], a.const_like(2))
|
||||||
|
|
||||||
def test_consts_go_last(self):
|
def test_consts_go_last(self):
|
||||||
a = UOp.variable('a', 0, 1)
|
a = UOp.variable('a', 0, 1)
|
||||||
b = UOp.variable('b', 0, 1)
|
b = UOp.variable('b', 0, 1)
|
||||||
|
|||||||
@@ -443,12 +443,12 @@ class TestIndexingOrdering(unittest.TestCase):
|
|||||||
|
|
||||||
class TestUPatHelpers(unittest.TestCase):
|
class TestUPatHelpers(unittest.TestCase):
|
||||||
def test_location(self):
|
def test_location(self):
|
||||||
self.assertEqual(sym.patterns[-1][0].location[0].split("/")[-1], "uopgraph.py")
|
self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "uopgraph.py")
|
||||||
self.assertEqual(to_si.patterns[0][0].location[0].split("/")[-1], "schedule.py")
|
self.assertEqual(to_si.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "schedule.py")
|
||||||
self.assertEqual(spec.patterns[0][0].location[0].split("/")[-1], "ops.py")
|
self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "ops.py")
|
||||||
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
|
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
|
||||||
test_upat = UPat(UOps.CONST, dtypes.bool)
|
test_upat = UPat(UOps.CONST, dtypes.bool)
|
||||||
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.split("/")[-1])
|
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main(verbosity=2)
|
unittest.main(verbosity=2)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, TypeVar, DefaultDict
|
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, TypeVar, DefaultDict
|
||||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle
|
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib
|
||||||
from enum import auto, IntEnum, Enum
|
from enum import auto, IntEnum, Enum
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@@ -198,6 +198,7 @@ class UOp(MathTrait):
|
|||||||
|
|
||||||
__slots__ = ["op", "dtype", "src", "arg"]
|
__slots__ = ["op", "dtype", "src", "arg"]
|
||||||
def __init__(self, op:UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
|
def __init__(self, op:UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||||
|
if getattr(self, 'op', None) is not None: return
|
||||||
# TODO: instant check rules here make debugging easier
|
# TODO: instant check rules here make debugging easier
|
||||||
#assert op in UOps and isinstance(dtype, DType), f"bad UOp creation with {op} {dtype}"
|
#assert op in UOps and isinstance(dtype, DType), f"bad UOp creation with {op} {dtype}"
|
||||||
#if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool
|
#if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool
|
||||||
@@ -494,7 +495,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
|||||||
def get_location() -> Tuple[str, int]:
|
def get_location() -> Tuple[str, int]:
|
||||||
frm = sys._getframe(1)
|
frm = sys._getframe(1)
|
||||||
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
|
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
|
||||||
while frm.f_back is not None and frm.f_back.f_code.co_filename.split("/")[-1] in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py"}:
|
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py"}:
|
||||||
frm = frm.f_back
|
frm = frm.f_back
|
||||||
return frm.f_code.co_filename, frm.f_lineno
|
return frm.f_code.co_filename, frm.f_lineno
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
|
|||||||
Reference in New Issue
Block a user