mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
test_fold_conv_relu_backward test
This commit is contained in:
@@ -6,9 +6,9 @@ import unittest
|
||||
from typing import List, Optional, Union
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps
|
||||
from tinygrad.helpers import DEBUG, GRAPH, flatten
|
||||
from tinygrad.helpers import DEBUG, flatten
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.features.graph import print_tree, realized_lazybuffer
|
||||
from tinygrad.features.graph import print_tree
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad import nn, dtypes
|
||||
from test.helpers import is_dtype_supported
|
||||
@@ -20,12 +20,8 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
|
||||
for pre in to_prerealize:
|
||||
for s in pre.schedule(seen=seen.copy()):
|
||||
for i,out in enumerate(s.outputs):
|
||||
if GRAPH: realized_lazybuffer(out, 0)
|
||||
seen.add(out)
|
||||
sched = create_schedule(flatten([r.lazydata.lbs for r in t]), seen)
|
||||
if GRAPH:
|
||||
for i,s in enumerate(sched):
|
||||
for out in s.outputs: realized_lazybuffer(out, i+1)
|
||||
if filter_loadops: sched = [s for s in sched if s.ast[0].op not in LoadOps]
|
||||
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
if len(sched) != allowed or DEBUG >= 3:
|
||||
@@ -210,6 +206,17 @@ class TestSchedule(unittest.TestCase):
|
||||
# this is too high
|
||||
check_schedule(opt.schedule_step(), 17)
|
||||
|
||||
def test_fold_conv_relu_backward(self):
|
||||
c1 = nn.Conv2d(3,16,3, bias=False)
|
||||
c1.weight.requires_grad = True
|
||||
|
||||
# run
|
||||
img = Tensor.rand(2,3,64,64, requires_grad=True)
|
||||
c1(img).relu().mean().backward()
|
||||
# TODO: this should be 4, not 5
|
||||
# img.grad is requiring two reduces
|
||||
check_schedule([img.grad, c1.weight.grad], 5)
|
||||
|
||||
def test_fold_conv_relu(self):
|
||||
c1 = nn.Conv2d(3,16,3)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user