mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
remove RawConst and add test (#1939)
This commit is contained in:
@@ -5,13 +5,13 @@
|
||||
import unittest
|
||||
from typing import List, Optional
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.ops import LoadOps, Device, Compiled
|
||||
from tinygrad.helpers import DEBUG, dtypes
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.graph import log_schedule_item
|
||||
from tinygrad import nn
|
||||
|
||||
def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None):
|
||||
def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True):
|
||||
seen = set()
|
||||
if to_prerealize:
|
||||
for pre in to_prerealize:
|
||||
@@ -20,7 +20,7 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
|
||||
seen.add(s[1])
|
||||
sched = t.lazydata.schedule(seen)
|
||||
for s in sched: log_schedule_item(*s)
|
||||
sched = [s for s in sched if s[0].op not in LoadOps]
|
||||
if filter_loadops: sched = [s for s in sched if s[0].op not in LoadOps]
|
||||
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
if len(sched) != allowed or DEBUG >= 3:
|
||||
from extra.utils import print_tree
|
||||
@@ -28,8 +28,9 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
|
||||
print("op", i)
|
||||
print_tree(s[0])
|
||||
assert len(sched) == allowed
|
||||
# test the ops linearize
|
||||
# test the (non loadops) ops linearize
|
||||
for s in sched:
|
||||
if s[0].op in LoadOps: continue
|
||||
l = Linearizer(s[0])
|
||||
l.hand_coded_optimizations()
|
||||
l.linearize()
|
||||
@@ -76,6 +77,11 @@ class TestSchedule(unittest.TestCase):
|
||||
d = (a+b).permute(1,0)+c
|
||||
check_schedule(d, 1)
|
||||
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or Device.DEFAULT == "LLVM", "only test for compiled backends")
|
||||
def test_constants_are_embedded(self):
|
||||
a = Tensor.empty(3,3) * 2
|
||||
check_schedule(a, 2, filter_loadops=False)
|
||||
|
||||
def test_binop_elu_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
b = a.elu()
|
||||
|
||||
Reference in New Issue
Block a user