mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
tests from new lazy branch (#2774)
* tests from new lazy branch * fix lin 11 * that was needed * doesn't fail * mark * meant that * llvm passes
This commit is contained in:
@@ -25,7 +25,7 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
|
||||
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
if len(sched) != allowed or DEBUG >= 3:
|
||||
for i, s in enumerate(sched):
|
||||
print("op", i)
|
||||
print("kernel", i+1)
|
||||
print_tree(s.ast)
|
||||
assert len(sched) == allowed
|
||||
# test the (non loadops) ops linearize
|
||||
@@ -260,6 +260,22 @@ class TestSchedule(unittest.TestCase):
|
||||
check_schedule(c, 1)
|
||||
check_schedule(e, 1)
|
||||
|
||||
def test_shrink_fuse(self):
|
||||
a = Tensor.empty(8192, 16)
|
||||
b = Tensor.empty(8192, 16)
|
||||
c = a * b
|
||||
d = Tensor.empty(1, 16)
|
||||
e = c[0] * d
|
||||
check_schedule(e, 1)
|
||||
|
||||
def test_expand_nofuse(self):
|
||||
a = Tensor.empty(1, 16)
|
||||
b = Tensor.empty(1, 16)
|
||||
c = a * b
|
||||
d = Tensor.empty(8192, 16)
|
||||
e = c * d
|
||||
check_schedule(e, 2)
|
||||
|
||||
# this is the failing case in openpilot...it's very simple like this
|
||||
@unittest.skip("failing in old lazy")
|
||||
def test_image_conv_fusion(self):
|
||||
@@ -304,12 +320,18 @@ class TestSchedule(unittest.TestCase):
|
||||
check_schedule(x, 3)
|
||||
|
||||
def test_resnet_block(self):
|
||||
from extra.models.resnet import BasicBlock
|
||||
Tensor.training = False
|
||||
bb = BasicBlock(64,64)
|
||||
|
||||
in_planes, planes = 64, 64
|
||||
conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
bn1 = nn.BatchNorm2d(planes)
|
||||
conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
|
||||
bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
x = Tensor.empty(1, 64, 32, 32)
|
||||
out = bb(x)
|
||||
out = bn1(conv1(x)).relu()
|
||||
out = bn2(conv2(out))
|
||||
out = (out + x).relu()
|
||||
check_schedule(out, 4)
|
||||
|
||||
def test_contiguous_while_contiguous(self):
|
||||
|
||||
Reference in New Issue
Block a user