mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
rangeify: fix test_where_fold (llvm) (#12416)
* rangeify: fix test_where_fold (AMD_LLVM) * rm comment
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -722,6 +722,8 @@ jobs:
|
||||
run: |
|
||||
VIZ=1 SQTT=1 DEBUG=5 python3 test/test_ops.py TestOps.test_add
|
||||
extra/sqtt/rgptool.py create "/tmp/profile.pkl.$USER" -o /tmp/gpu0.rgp
|
||||
- name: Run pytest (amd) with RANGEIFY
|
||||
run: RANGEIFY=1 python -m pytest test/test_linearizer.py::TestLinearizer::test_where_fold
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
|
||||
@@ -92,6 +92,9 @@ base_rewrite = PatternMatcher([
|
||||
f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
|
||||
# unary/binary/ternary ops
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
||||
# rewrite cast to bool to CMPNE 0
|
||||
(UPat(Ops.CAST, name="x", dtype=dtypes.bool),
|
||||
lambda ctx,x: f" {ctx[x]} = {lop[x.src[0].dtype.scalar()][Ops.CMPNE]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, zeroinitializer"),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
||||
(UPat(Ops.TRUNC, name="x"),
|
||||
lambda ctx,x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.trunc.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
|
||||
@@ -129,8 +132,6 @@ class LLVMRenderer(Renderer):
|
||||
if AMX: tensor_cores = tc.amx
|
||||
|
||||
extra_matcher = PatternMatcher([
|
||||
# rewrite cast to bool to CMPNE 0
|
||||
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
|
||||
# rewrite MAX to CMPLT + WHERE
|
||||
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||
# copied from cstyle.py, upcast to float32 all the ops that don't support bfloat16
|
||||
|
||||
Reference in New Issue
Block a user