This commit is contained in:
ttomsa
2025-10-07 21:21:53 +01:00
parent 381df7f45f
commit 4687d000a7
2 changed files with 9 additions and 8 deletions

View File

@@ -1,8 +1,9 @@
from extra.models.resnet import ResNet50
from tinygrad import Tensor, nn, Device
from tinygrad.helpers import Profiling, Timing, getenv
from tinygrad.uop.ops import Ops
from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites, rewrites_for_linearizer
from tinygrad.uop.ops import Ops, UOp
from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites
from tinygrad.codegen.late.control_flow import schedule
from tinygrad.uop.spec import type_verify
if __name__ == "__main__":
@@ -28,10 +29,10 @@ if __name__ == "__main__":
asts = list({x.ast.key:x.ast for x in sched if x.ast.op is Ops.SINK}.values())
if (restrict_kernel := getenv("RESTRICT_KERNEL", -1)) != -1: asts = asts[restrict_kernel:restrict_kernel+1]
rewrites = get_rewrites_for_renderer(Device.default.renderer, linearizer=False)
rewrites = get_rewrites_for_renderer(Device.default.renderer, linearizer=LINEARIZE)
with Profiling(PROFILE, fn="/tmp/rewrite.prof"):
with Timing("***** model rewrite in "):
rewritten_uops = []
rewritten_uops: list[UOp] = []
for u in asts:
rewritten_uops.append(apply_rewrites(u, rewrites))
@@ -39,7 +40,7 @@ if __name__ == "__main__":
with Timing("***** model linearize in "):
uops_line = []
for u in rewritten_uops:
uops_line.append(apply_rewrites(u, rewrites_for_linearizer))
uops_line.append(schedule(list(u.toposort())))
with Timing("***** model verify in "):
for u in uops_line: type_verify(u.arg.lst)
print(sum(len(u.arg.lst) for u in uops_line))
for u in uops_line: type_verify(u)
print(sum(len(u) for u in uops_line))

View File

@@ -67,7 +67,7 @@ class CFGContext:
def __init__(self, sink:UOp):
# there are 3 relationships between ranges:
# nested, meaning endrange y is a dependency of endrange x and range x is a dependency of endrange y
# dependent, meaning endrange y is a dependency of endrange x and range x is not a dependency of endrange y (i.e. load in range x depends on store in range y)
# dependent, meaning endrange y is a dependency of endrange x and range x is not a dependency of endrange y
# independent, endrange y is not a dependency of endrange x
# ifs are always independent
deps: dict[UOp, set[UOp]] = {}