From 595a6e3069a980b3cec2d3746b6d617942740a43 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 5 May 2024 11:13:43 -0700 Subject: [PATCH] test_fold_conv_relu_backward test --- test/test_schedule.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 98cc2ca5a9..3349d9238d 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -6,9 +6,9 @@ import unittest from typing import List, Optional, Union from tinygrad.tensor import Tensor from tinygrad.ops import BinaryOps, LoadOps, ReduceOps -from tinygrad.helpers import DEBUG, GRAPH, flatten +from tinygrad.helpers import DEBUG, flatten from tinygrad.codegen.linearizer import Linearizer -from tinygrad.features.graph import print_tree, realized_lazybuffer +from tinygrad.features.graph import print_tree from tinygrad.engine.schedule import create_schedule from tinygrad import nn, dtypes from test.helpers import is_dtype_supported @@ -20,12 +20,8 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt for pre in to_prerealize: for s in pre.schedule(seen=seen.copy()): for i,out in enumerate(s.outputs): - if GRAPH: realized_lazybuffer(out, 0) seen.add(out) sched = create_schedule(flatten([r.lazydata.lbs for r in t]), seen) - if GRAPH: - for i,s in enumerate(sched): - for out in s.outputs: realized_lazybuffer(out, i+1) if filter_loadops: sched = [s for s in sched if s.ast[0].op not in LoadOps] if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}") if len(sched) != allowed or DEBUG >= 3: @@ -210,6 +206,17 @@ class TestSchedule(unittest.TestCase): # this is too high check_schedule(opt.schedule_step(), 17) + def test_fold_conv_relu_backward(self): + c1 = nn.Conv2d(3,16,3, bias=False) + c1.weight.requires_grad = True + + # run + img = Tensor.rand(2,3,64,64, requires_grad=True) + c1(img).relu().mean().backward() + # TODO: this should be 4, not 5 + # img.grad is requiring two reduces + check_schedule([img.grad, c1.weight.grad], 5) + def test_fold_conv_relu(self): c1 = nn.Conv2d(3,16,3)