rangeify: fix test_where_fold (llvm) (#12416)

* rangeify: fix test_where_fold (AMD_LLVM)

* rm comment
This commit is contained in:
b1tg
2025-10-02 14:57:49 +08:00
committed by GitHub
parent 13a25b2e67
commit ec177c80c2
2 changed files with 5 additions and 2 deletions

View File

@@ -722,6 +722,8 @@ jobs:
run: |
VIZ=1 SQTT=1 DEBUG=5 python3 test/test_ops.py TestOps.test_add
extra/sqtt/rgptool.py create "/tmp/profile.pkl.$USER" -o /tmp/gpu0.rgp
- name: Run pytest (amd) with RANGEIFY
run: RANGEIFY=1 python -m pytest test/test_linearizer.py::TestLinearizer::test_where_fold
- name: Run process replay tests
uses: ./.github/actions/process-replay

View File

@@ -92,6 +92,9 @@ base_rewrite = PatternMatcher([
f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
# unary/binary/ternary ops
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
# rewrite cast to bool to CMPNE 0
(UPat(Ops.CAST, name="x", dtype=dtypes.bool),
lambda ctx,x: f" {ctx[x]} = {lop[x.src[0].dtype.scalar()][Ops.CMPNE]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, zeroinitializer"),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
(UPat(Ops.TRUNC, name="x"),
lambda ctx,x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.trunc.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
@@ -129,8 +132,6 @@ class LLVMRenderer(Renderer):
if AMX: tensor_cores = tc.amx
extra_matcher = PatternMatcher([
# rewrite cast to bool to CMPNE 0
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
# rewrite MAX to CMPLT + WHERE
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
# copied from cstyle.py, upcast to float32 all the ops that don't support bfloat16