remove RawConst and add test (#1939)

This commit is contained in:
George Hotz
2023-09-29 01:21:51 -07:00
committed by GitHub
parent 22b8576887
commit d52df788d3
3 changed files with 21 additions and 17 deletions

View File

@@ -5,13 +5,13 @@
import unittest
from typing import List, Optional
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps
from tinygrad.ops import LoadOps, Device, Compiled
from tinygrad.helpers import DEBUG, dtypes
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.graph import log_schedule_item
from tinygrad import nn
def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None):
def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True):
seen = set()
if to_prerealize:
for pre in to_prerealize:
@@ -20,7 +20,7 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
seen.add(s[1])
sched = t.lazydata.schedule(seen)
for s in sched: log_schedule_item(*s)
sched = [s for s in sched if s[0].op not in LoadOps]
if filter_loadops: sched = [s for s in sched if s[0].op not in LoadOps]
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if len(sched) != allowed or DEBUG >= 3:
from extra.utils import print_tree
@@ -28,8 +28,9 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
print("op", i)
print_tree(s[0])
assert len(sched) == allowed
# test the ops linearize
# test the (non loadops) ops linearize
for s in sched:
if s[0].op in LoadOps: continue
l = Linearizer(s[0])
l.hand_coded_optimizations()
l.linearize()
@@ -76,6 +77,11 @@ class TestSchedule(unittest.TestCase):
d = (a+b).permute(1,0)+c
check_schedule(d, 1)
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or Device.DEFAULT == "LLVM", "only test for compiled backends")
def test_constants_are_embedded(self):
a = Tensor.empty(3,3) * 2
check_schedule(a, 2, filter_loadops=False)
def test_binop_elu_fusion(self):
a = Tensor.empty(10)
b = a.elu()