mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
remove vectorized alu in expander [run_process_replay] (#5561)
This commit is contained in:
@@ -121,12 +121,6 @@ def float4_contract_store(buf, ex, var, store_allow_any_len, idx=UOp.const(dtype
|
||||
new_var = UOp(UOps.CONTRACT, var.dtype.vec(len(ex.src)), (var,), (ex.arg[0][0],))
|
||||
return UOp(UOps.STORE, None, (buf, idx, new_var) + store_allow_any_len.src[3:])
|
||||
|
||||
def no_float4_alu(alu):
|
||||
if alu.dtype.count == 1: return None
|
||||
alus = tuple(UOp(UOps.ALU, alu.dtype.scalar(),
|
||||
tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), i) for s in alu.src), alu.arg) for i in range(alu.dtype.count))
|
||||
return UOp(UOps.VECTORIZE, alu.dtype, alus)
|
||||
|
||||
float4_folding = PatternMatcher([
|
||||
# reorder index to bring const closer to store
|
||||
(UOp(UOps.STORE, src=(UOp.var("buf"), UOp.var("idx")+
|
||||
@@ -148,8 +142,6 @@ float4_folding = PatternMatcher([
|
||||
UOp(UOps.EXPAND).name("ex")+UOp.var("idx"), UOp.var("var"))).name("store_allow_any_len"), float4_contract_store),
|
||||
(UOp(UOps.STORE, src=(UOp.var("buf"),
|
||||
UOp(UOps.EXPAND).name("ex"), UOp.var("var"))).name("store_allow_any_len"), float4_contract_store),
|
||||
# no ALU on float4 (float4 constructor doesn't work in METAL/GPU)
|
||||
(UOp(UOps.ALU).name("alu"), no_float4_alu),
|
||||
])
|
||||
|
||||
# ***** transcendental *****
|
||||
@@ -450,6 +442,12 @@ def do_contract(con:UOp):
|
||||
srcs += [UOp(UOps.VECTORIZE, con.dtype, tuple(src)) for src in zip(*to_join[i:i+con.dtype.count])]
|
||||
return UOp(UOps.EXPAND, con.dtype, tuple(srcs), tuple(x for x in ex.arg if x[0] != con.arg[0]))
|
||||
|
||||
def no_vectorized_alu(alu):
|
||||
if alu.dtype.count == 1: return None
|
||||
alus = tuple(UOp(UOps.ALU, alu.dtype.scalar(),
|
||||
tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), i) for s in alu.src), alu.arg) for i in range(alu.dtype.count))
|
||||
return UOp(UOps.VECTORIZE, alu.dtype, alus)
|
||||
|
||||
expander = PatternMatcher([
|
||||
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE,
|
||||
UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF}, name="root"), do_expand),
|
||||
@@ -465,6 +463,8 @@ expander = PatternMatcher([
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="ls"), fix_image_idx),
|
||||
# empty EXPAND is NOOP
|
||||
(UOp(UOps.EXPAND, src=(UOp.var('x'),), arg=()), lambda x: x),
|
||||
# no ALU on vectorized dtypes
|
||||
(UOp(UOps.ALU).name("alu"), no_vectorized_alu),
|
||||
])
|
||||
|
||||
# *** uop graph ***
|
||||
|
||||
Reference in New Issue
Block a user