mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
create_schedule([x.lazydata]) -> x.schedule() in tests (#8449)
This commit is contained in:
@@ -3,7 +3,6 @@ import numpy as np
|
||||
import torch
|
||||
import unittest, copy, mmap, random, math, array
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.helpers import getenv, temp, _METADATA, mv_address
|
||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
@@ -725,7 +724,7 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
W = Tensor.rand(3, 3, requires_grad=True)
|
||||
out = x.matmul(W)
|
||||
self.assertEqual(out.lazydata.metadata.name, "matmul")
|
||||
si = create_schedule([out.lazydata])[-1]
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "matmul")
|
||||
|
||||
@@ -733,7 +732,7 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
out = x.relu()
|
||||
self.assertEqual(out.lazydata.metadata.name, "relu")
|
||||
si = create_schedule([out.lazydata])[-1]
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "relu")
|
||||
|
||||
@@ -744,7 +743,7 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
self.assertEqual(out.lazydata.metadata.name, "__mul__")
|
||||
self.assertEqual(out.lazydata.src[0].metadata.name, "relu")
|
||||
self.assertEqual(out.lazydata.src[1].metadata.name, "sigmoid")
|
||||
si = create_schedule([out.lazydata])[-1]
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 3)
|
||||
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
|
||||
|
||||
@@ -758,7 +757,7 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
self.assertTrue(x.grad.lazydata.metadata.backward)
|
||||
self.assertEqual(y.grad.lazydata.metadata.name, "sigmoid")
|
||||
self.assertTrue(y.grad.lazydata.metadata.backward)
|
||||
si = create_schedule([out.lazydata, x.grad.lazydata, y.grad.lazydata])[-1]
|
||||
si = Tensor.schedule(out, x.grad, y.grad)[-1]
|
||||
self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
|
||||
self.assertEqual(set(m.name for m in si.metadata), {"sigmoid", "sigmoid", "relu"})
|
||||
bw = [m for m in si.metadata if m.backward]
|
||||
|
||||
Reference in New Issue
Block a user