mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-17 18:11:49 -05:00
Make vectorization of CONST explicit (#5322)
* remove test_const_vectorize_fold * remove const folding UPat for VECTORIZE * refactor cstyle render_const * remove calls to dtype.scalar() in render_const * add assert * add vectorized const to UOp.const * add UPat GEP-VECTORIZE-CONST -> CONST * render_vectorize for DEFINE_ACC in cstyle * add back missing render_cast in render_const * generate vectorized consts as UOps for DEFINE_ACC * update asserts for DEFINE_ACC with VECTORIZE src * add UPats for PHI with VECTORIZE src * use prev rendered vectorize in DEFINE_ACC render * update DEFINE_ACC in python runtime * update vectorized DEFINE_ACC in PTXRenderer * rebase DEFINE_ACC changes on lowerer * verbose rewrite of bad UPats * simplify UOps.CONST implementation in ops_python * update sum_collapse UPats for DEFINE_ACC-VECTORIZE * revert linearizer to TOT * fix DEFINE_ACC implementation in ops_python * simplify DEFINE_ACC in cstyle * Fix linter error * support VECTORIZE in fold gated load/store UPat * support VECTORIZE in other fold gated load UPats * rewrite VECTORIZE in UPat for no input DEFINE_ACC * simplify DEFINE_ACC render in cstyle * make VECTORIZE rules more concise * add more vectorize fold tests * inline VECTORIZE-CONSTs in cstyle render * revert VECTORIZE/GEP rule refactor * revert cstyle render_const refactor * inline VECTORIZE-CONSTs in cstyle render * implicitly vectorized const rendering -> explicit * WMMA VECTORIZE CONST process replay hacks * VECTORIZE CONST NAN process_replay hacks * more VECTORIZE CONST NAN hacks * cleanup process_replay hacks * isnan() -> not isfinite() cstyle VECTORIZE CONST * tweak isnan and isfinite checks VECTORIZE CONST * tweak for positive vs negative infinity VECTORIZE CONST * add assert to PTX CONST render * process_replay VECTORIZE CONST render parity for PTX STORE * vmin/vmax for VECTORIZE'd CONST * update WMMA folding rules * add tests for WMMA VECTORIZE fold * hack for cstyle half4 CONST zero process_replay parity * revert PTX backend changes * add back minimal DEFINE_ACC PTX change * remove cstyle process_replay hacks * remove dead code in PTX CONST render * cleanup vmin/vmax logic for VECTORIZE'd CONSTs * update vectorize fold tests to use DEFINE_VAR * fix long line formatting in test * remove unwanted merge artifact * more vmin/vmax cleanup * remove unnecessary asserts * yet more vmin/vmax cleanup * get rid of explicit VECTORIZE CONST logic in _min_max * reuse CONST instead of creating a new one * remove unneeded cast * handle DType correctly in sconst * improve readability of tests * save a line * save another line * tuplize pats in src * remove GEP-VECTORIZE pats * add vec +0 fold * HACK: fold only vec8 +0 * remove vectorized ALU fold hack --------- Co-authored-by: qazal <qazal.software@gmail.com> Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
@@ -133,15 +133,6 @@ class TestUOpGraph(TestUOps):
|
||||
self.assertEqual(out.op, UOps.CONST)
|
||||
self.assertEqual(out.arg, 0)
|
||||
|
||||
def test_const_vectorize_fold(self):
|
||||
c0 = UOp(UOps.CONST, dtypes.half, arg=0.0)
|
||||
out = UOp(UOps.VECTORIZE, dtypes.half.vec(2), (c0, c0))
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.op, UOps.CONST)
|
||||
self.assertEqual(out.arg, 0.0)
|
||||
|
||||
def test_noop_vectorize_fold(self):
|
||||
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
@@ -192,6 +183,73 @@ class TestUOpGraph(TestUOps):
|
||||
xy2 = tuple(UOp(UOps.GEP, dtypes.float, (val2, ), i) for i in range(2))
|
||||
self.assertIs(_test_vec(xy1+xy2).op, UOps.VECTORIZE)
|
||||
|
||||
def test_gep_vec_const_fold(self):
|
||||
for vec_size in [2, 4, 8]:
|
||||
consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)]
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
|
||||
geps = [UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)]
|
||||
g = UOpGraph(geps)
|
||||
for uop, const in zip(g.uops, consts):
|
||||
self.assert_equiv_uops(uop, const)
|
||||
|
||||
def test_wmma_vectorize_fold(self):
|
||||
for i in [2, 4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
g = UOpGraph([wmma])
|
||||
self.assert_equiv_uops(g.uops[0], acc)
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
g = UOpGraph([wmma])
|
||||
self.assert_equiv_uops(g.uops[0], acc)
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
|
||||
def test_wmma_vectorize_no_fold(self):
|
||||
for i in [4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
|
||||
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
g = UOpGraph([wmma])
|
||||
self.assert_equiv_uops(g.uops[-1], wmma)
|
||||
|
||||
for i in [4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
|
||||
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
g = UOpGraph([wmma])
|
||||
self.assert_equiv_uops(g.uops[-1], wmma)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
g = UOpGraph([wmma])
|
||||
self.assert_equiv_uops(g.uops[-1], wmma)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
g = UOpGraph([wmma])
|
||||
self.assert_equiv_uops(g.uops[-1], wmma)
|
||||
|
||||
def test_cast_alu_fold(self):
|
||||
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=0)
|
||||
d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1)
|
||||
|
||||
Reference in New Issue
Block a user