mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
early gate the graph (#3070)
This commit is contained in:
@@ -7,7 +7,7 @@ from typing import List, Optional
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.device import Device, Compiled
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.helpers import DEBUG, GRAPH
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.graph import print_tree, realized_lazybuffer
|
||||
from tinygrad import nn, dtypes
|
||||
@@ -17,10 +17,11 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
|
||||
if to_prerealize:
|
||||
for pre in to_prerealize:
|
||||
for s in pre.lazydata.schedule(seen.copy()):
|
||||
realized_lazybuffer(s.out, 0)
|
||||
if GRAPH: realized_lazybuffer(s.out, 0)
|
||||
seen.add(s.out)
|
||||
sched = t.lazydata.schedule(seen)
|
||||
for i,s in enumerate(sched): realized_lazybuffer(s.out, i+1)
|
||||
if GRAPH:
|
||||
for i,s in enumerate(sched): realized_lazybuffer(s.out, i+1)
|
||||
if filter_loadops: sched = [s for s in sched if s.ast.op not in LoadOps]
|
||||
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
if len(sched) != allowed or DEBUG >= 3:
|
||||
|
||||
Reference in New Issue
Block a user