cleanup lbs (#4163)

This commit is contained in:
George Hotz
2024-04-12 22:32:16 -07:00
committed by GitHub
parent a7c6864260
commit ba7314c26b
2 changed files with 2 additions and 3 deletions

View File

@@ -1,14 +1,13 @@
import unittest, math import unittest, math
from tinygrad import Tensor, Device, dtypes from tinygrad import Tensor, Device, dtypes
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.features.multi import MultiLazyBuffer
from tinygrad.helpers import CI from tinygrad.helpers import CI
from tinygrad.ops import BufferOps from tinygrad.ops import BufferOps
import numpy as np import numpy as np
def _check_ast_count(desired_count:int, t:Tensor): def _check_ast_count(desired_count:int, t:Tensor):
# NOTE: this has side effect because everything can be scheduled only once # NOTE: this has side effect because everything can be scheduled only once
schedule = create_schedule(t.lazydata.lbs if isinstance(t.lazydata, MultiLazyBuffer) else [t.lazydata]) schedule = create_schedule(t.lazydata.lbs)
asts = [s for s in schedule if s.ast[0].op is BufferOps.STORE] asts = [s for s in schedule if s.ast[0].op is BufferOps.STORE]
assert len(asts) == desired_count assert len(asts) == desired_count

View File

@@ -138,7 +138,7 @@ class Tensor:
@staticmethod @staticmethod
def corealize(lst:Iterable[Tensor]): def corealize(lst:Iterable[Tensor]):
run_schedule(*create_schedule_with_vars(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst]))) run_schedule(*create_schedule_with_vars(flatten([x.lazydata.lbs for x in lst])))
def realize(self) -> Tensor: def realize(self) -> Tensor:
Tensor.corealize([self]) Tensor.corealize([self])