mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
create_schedule([x.lazydata]) -> x.schedule() in tests (#8449)
This commit is contained in:
@@ -3,7 +3,6 @@ import unittest
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.helpers import prod
|
||||
from test.unit.test_shapetracker import shapetracker_getitem
|
||||
@@ -11,11 +10,10 @@ from test.unit.test_shapetracker import shapetracker_getitem
|
||||
class TestConvShapetracker(unittest.TestCase):
|
||||
def test_conv_3x3_one_view(self):
|
||||
conv = Conv2d(16, 32, (3, 3))
|
||||
|
||||
# first run to init the weights, they are scheduled.
|
||||
conv(Tensor.empty(1, 16, 10, 10)).schedule()
|
||||
# run it again to get the kernels
|
||||
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) if si.ast.op is Ops.SINK]
|
||||
sched = [si for si in conv(Tensor.empty(1, 16, 10, 10)).schedule() if si.ast.op is Ops.SINK]
|
||||
assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
|
||||
for st in [x.st_arg for x in sched[0].ast.toposort if x.op is Ops.LOAD]:
|
||||
assert len(st.views) == 1
|
||||
|
||||
Reference in New Issue
Block a user