mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix this
This commit is contained in:
15
test/external/external_benchmark_schedule.py
vendored
15
test/external/external_benchmark_schedule.py
vendored
@@ -1,8 +1,9 @@
|
||||
from extra.models.resnet import ResNet50
|
||||
from tinygrad import Tensor, nn, Device
|
||||
from tinygrad.helpers import Profiling, Timing, getenv
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites, rewrites_for_linearizer
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites
|
||||
from tinygrad.codegen.late.control_flow import schedule
|
||||
from tinygrad.uop.spec import type_verify
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -28,10 +29,10 @@ if __name__ == "__main__":
|
||||
asts = list({x.ast.key:x.ast for x in sched if x.ast.op is Ops.SINK}.values())
|
||||
if (restrict_kernel := getenv("RESTRICT_KERNEL", -1)) != -1: asts = asts[restrict_kernel:restrict_kernel+1]
|
||||
|
||||
rewrites = get_rewrites_for_renderer(Device.default.renderer, linearizer=False)
|
||||
rewrites = get_rewrites_for_renderer(Device.default.renderer, linearizer=LINEARIZE)
|
||||
with Profiling(PROFILE, fn="/tmp/rewrite.prof"):
|
||||
with Timing("***** model rewrite in "):
|
||||
rewritten_uops = []
|
||||
rewritten_uops: list[UOp] = []
|
||||
for u in asts:
|
||||
rewritten_uops.append(apply_rewrites(u, rewrites))
|
||||
|
||||
@@ -39,7 +40,7 @@ if __name__ == "__main__":
|
||||
with Timing("***** model linearize in "):
|
||||
uops_line = []
|
||||
for u in rewritten_uops:
|
||||
uops_line.append(apply_rewrites(u, rewrites_for_linearizer))
|
||||
uops_line.append(schedule(list(u.toposort())))
|
||||
with Timing("***** model verify in "):
|
||||
for u in uops_line: type_verify(u.arg.lst)
|
||||
print(sum(len(u.arg.lst) for u in uops_line))
|
||||
for u in uops_line: type_verify(u)
|
||||
print(sum(len(u) for u in uops_line))
|
||||
|
||||
@@ -67,7 +67,7 @@ class CFGContext:
|
||||
def __init__(self, sink:UOp):
|
||||
# there are 3 relationships between ranges:
|
||||
# nested, meaning endrange y is a dependency of endrange x and range x is a dependency of endrange y
|
||||
# dependent, meaning endrange y is a dependency of endrange x and range x is not a dependency of endrange y (i.e. load in range x depends on store in range y)
|
||||
# dependent, meaning endrange y is a dependency of endrange x and range x is not a dependency of endrange y
|
||||
# independent, endrange y is not a dependency of endrange x
|
||||
# ifs are always independent
|
||||
deps: dict[UOp, set[UOp]] = {}
|
||||
|
||||
Reference in New Issue
Block a user