mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
no const/view in schedule sink after sym [pr]
This commit is contained in:
@@ -106,6 +106,9 @@ sym = symbolic_simple+PatternMatcher([
|
||||
# put CAST after expanding BUFFER
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="v"), lambda x,v: x.view(x.st+v.st).cast(v.dtype) if getenv("CAST_AFTER_EXPAND")
|
||||
and x.base.op is Ops.BUFFER and resolve(prod(v.shape) > prod(x.shape)) else None),
|
||||
# remove CONST/BIND/VIEW from SINK
|
||||
(UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src)
|
||||
if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST,Ops.BIND}))) != x.src else None),
|
||||
])
|
||||
|
||||
# **** UOp realization
|
||||
@@ -259,9 +262,8 @@ create_kernels = merge_views+PatternMatcher([
|
||||
lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None),
|
||||
# walk back the local graph until we reach a buffer/assign parent
|
||||
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
|
||||
# remove CONST/BIND from SINK
|
||||
(UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src)
|
||||
if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST,Ops.BIND}))) != x.src else None),
|
||||
# remove downstream reshapes from SINK
|
||||
(UPat(Ops.SINK, name="x"), lambda x:x.replace(src=tuple(s.base for s in x.src)) if any(s.op is Ops.VIEW for s in x.src) else None),
|
||||
])
|
||||
|
||||
# **** fix kernel AST
|
||||
|
||||
Reference in New Issue
Block a user