assigns are no longer used [pr] (#11333)

This commit is contained in:
George Hotz
2025-07-22 15:35:07 -07:00
committed by GitHub
parent 09431d4ad1
commit fcbd0e4de3
6 changed files with 6 additions and 12 deletions

View File

@@ -293,7 +293,7 @@ def no_vectorized_acc(acc:UOp, c:UOp):
devectorize = PatternMatcher([
# no ALU on vectorized dtypes
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN), name="alu"), no_vectorized_alu),
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu),
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
(UPat(Ops.DEFINE_REG, name="acc").index(UPat.cvar("c")), no_vectorized_acc),
])
@@ -314,7 +314,7 @@ pm_render = PatternMatcher([
lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src+(UOp(Ops.IF, src=(idx.src[2],)),))),
])
# *** Ops.REDUCE -> Ops.DEFINE_ACC+Ops.ASSIGN ***
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
@dataclass
class ReduceContext:

View File

@@ -83,7 +83,7 @@ expander = PatternMatcher([
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
# do expansion
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX,
Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
(UPat(Ops.CONTRACT, name="con"), do_contract),
# vectorize DEFINE_ACC

View File

@@ -10,7 +10,6 @@ from tinygrad.codegen.devectorizer import no_vectorized_alu
base_rewrite = PatternMatcher([
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}] = {{{ctx[x.src[0]]}}};"),
(UPat(Ops.ASSIGN, name="x"), lambda ctx,x: f"{ctx[x.src[0]]} = {ctx[x.src[1]]};"),
(UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
(UPat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"),
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
@@ -67,7 +66,7 @@ extra_pm = PatternMatcher([
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
# devectorize any bools
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
# CAST (from bool) can't be vectorized
(UPat(Ops.CAST, src=(UPat(dtype=dtypes.bool),), name="alu"), no_vectorized_alu),
# WHERE can't be vectorized
@@ -167,7 +166,7 @@ class CStyleLanguage(Renderer):
(u.op in {Ops.VECTORIZE, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
r[u] = l
else:
if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG} or u.dtype == dtypes.void:
if u.op in {Ops.RANGE, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG} or u.dtype == dtypes.void:
if u.op is Ops.STORE: r[u] = r[u.src[0]]
else:
l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")

View File

@@ -215,6 +215,5 @@ class PTXRenderer(Renderer):
raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
kernel.extend([l] if isinstance(l, str) else l)
if u.op is Ops.ASSIGN: r[u] = r[u.src[0]]
elif u.op is Ops.SPECIAL: kernel = [f".reg .u32 %{u.arg[0]};"] + kernel
if u.op is Ops.SPECIAL: kernel = [f".reg .u32 %{u.arg[0]};"] + kernel
return self.render_kernel(kernel, name, bufs, c.items())

View File

@@ -180,7 +180,6 @@ spec = PatternMatcher([
(UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)),
(UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_REG, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
# WMMA has a <a, b, acc>

View File

@@ -411,8 +411,6 @@ sym = symbolic_flat+PatternMatcher([
# LOAD/STORE -> NOOP
(UPat.var('x').store(UPat.var('x').load()), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),
# self ASSIGN is just self
(UPat(Ops.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
# VECTORIZE/CONST, VECTORIZE/GEP
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat.var("x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
@@ -440,7 +438,6 @@ sym = symbolic_flat+PatternMatcher([
lambda x,y: y.where(x.cast(dtypes.uint32), UOp.const(dtypes.uint32, 0))),
# ** self folding **
(UPat(Ops.DEFINE_REG, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
(UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
# x!=0 -> (bool)x
(UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
# ** where **