mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
cleanup lbs (#4163)
This commit is contained in:
@@ -1,14 +1,13 @@
|
|||||||
import unittest, math
|
import unittest, math
|
||||||
from tinygrad import Tensor, Device, dtypes
|
from tinygrad import Tensor, Device, dtypes
|
||||||
from tinygrad.engine.schedule import create_schedule
|
from tinygrad.engine.schedule import create_schedule
|
||||||
from tinygrad.features.multi import MultiLazyBuffer
|
|
||||||
from tinygrad.helpers import CI
|
from tinygrad.helpers import CI
|
||||||
from tinygrad.ops import BufferOps
|
from tinygrad.ops import BufferOps
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
def _check_ast_count(desired_count:int, t:Tensor):
|
def _check_ast_count(desired_count:int, t:Tensor):
|
||||||
# NOTE: this has side effect because everything can be scheduled only once
|
# NOTE: this has side effect because everything can be scheduled only once
|
||||||
schedule = create_schedule(t.lazydata.lbs if isinstance(t.lazydata, MultiLazyBuffer) else [t.lazydata])
|
schedule = create_schedule(t.lazydata.lbs)
|
||||||
asts = [s for s in schedule if s.ast[0].op is BufferOps.STORE]
|
asts = [s for s in schedule if s.ast[0].op is BufferOps.STORE]
|
||||||
assert len(asts) == desired_count
|
assert len(asts) == desired_count
|
||||||
|
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ class Tensor:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def corealize(lst:Iterable[Tensor]):
|
def corealize(lst:Iterable[Tensor]):
|
||||||
run_schedule(*create_schedule_with_vars(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst])))
|
run_schedule(*create_schedule_with_vars(flatten([x.lazydata.lbs for x in lst])))
|
||||||
|
|
||||||
def realize(self) -> Tensor:
|
def realize(self) -> Tensor:
|
||||||
Tensor.corealize([self])
|
Tensor.corealize([self])
|
||||||
|
|||||||
Reference in New Issue
Block a user