test_fold_conv_relu_backward test

This commit is contained in:
George Hotz
2024-05-05 11:13:43 -07:00
parent cc16f644d0
commit 595a6e3069

View File

@@ -6,9 +6,9 @@ import unittest
from typing import List, Optional, Union
from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps
from tinygrad.helpers import DEBUG, GRAPH, flatten
from tinygrad.helpers import DEBUG, flatten
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.graph import print_tree, realized_lazybuffer
from tinygrad.features.graph import print_tree
from tinygrad.engine.schedule import create_schedule
from tinygrad import nn, dtypes
from test.helpers import is_dtype_supported
@@ -20,12 +20,8 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
for pre in to_prerealize:
for s in pre.schedule(seen=seen.copy()):
for i,out in enumerate(s.outputs):
if GRAPH: realized_lazybuffer(out, 0)
seen.add(out)
sched = create_schedule(flatten([r.lazydata.lbs for r in t]), seen)
if GRAPH:
for i,s in enumerate(sched):
for out in s.outputs: realized_lazybuffer(out, i+1)
if filter_loadops: sched = [s for s in sched if s.ast[0].op not in LoadOps]
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if len(sched) != allowed or DEBUG >= 3:
@@ -210,6 +206,17 @@ class TestSchedule(unittest.TestCase):
# this is too high
check_schedule(opt.schedule_step(), 17)
def test_fold_conv_relu_backward(self):
c1 = nn.Conv2d(3,16,3, bias=False)
c1.weight.requires_grad = True
# run
img = Tensor.rand(2,3,64,64, requires_grad=True)
c1(img).relu().mean().backward()
# TODO: this should be 4, not 5
# img.grad is requiring two reduces
check_schedule([img.grad, c1.weight.grad], 5)
def test_fold_conv_relu(self):
c1 = nn.Conv2d(3,16,3)