mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 07:18:40 -05:00
* Move ops_triton to runtime and remove errors from deprecated code * Remove deprecated AST Kernel * Remove deprecated buffer * Add TritonProgram * Triton Buffer * Use RawCUDABuffer * triton_compile * Added new parameter * pass _buf to program * remove deprecated include * Added triton tests * Deprecated includes removed * remove double print * Disable float4 support * Disable float4 support * variable load fix * Track local size * Add pycuda to triton dependencies * Merge test.yml * install cuda packages for testing * merge double package install * remove emulated from triton tests * upscale local index to power of 2 and add masking * cuda envs * Add TernaryOps * ConstOp loading * proper function name * remove deprecated variables * get global program from name * const ops match local shape * Enable test_nn * remove deprecated import * fix linter error * Add wait logic * Add local size override * accumulate local shapes instead of using max shape * Merge triton tests into global tests * fix envs in testing * Old testing routine * split file into renderer and program * remove print and starting whitespace * pretty ptx print on debug 5 * linter errors * ignore triton saturation tests * ignore test example * remove pytorch cpu extra index * Add triton to existing testing routine * use triton tests * disable cuda backend in triton tests * use cudacpu in tests * print used device * Print device default * Remove print * ensure we are running triton backend * update variable signatures * update dtypes for load * infinity render fixed * limit global size * negative infinity now properly rendered * split chain with parentheses for and node * Add option to disable shared memory, disable for triton * missing import * Properly index and mask conditional load * use mask only if not loading a block pointer * nan support * fix symbolic tests to include chain split * proper masking for stores * Implemented bool dtype * Add mod * fix loads for variables with valid range * merge triton with cuda runtime * merge from master * run triton tests with cuda * Correct target when running from triton * conftest with triton compiler config * use triton nightly * verbose tests for triton * capture stdout * fix function depth when exiting multiple loops * add render valid function for readabilty * fix mask for local loops * add _arg_int32 datatype * fix dims for conditional loads * enable non float stores * correct variable dtypes * fix type for arg_int32 * remove junk * Added get max function for range based var.max * remove deprecated code * Fix triton ptxas path * Fix testing for CI * clamp local size by max local size instead of always running max * Disable matmul test in triton cpu * rerun tests * Disable broken test in triton cpu * whitespace removed * rerun tests again * Disable TestSymbolicOps for triton * update to new uops * linter fix * ignore test/extra * linting fix * Update tinygrad/renderer/triton.py Co-authored-by: Gijs Koning <gijs-koning@live.nl> * remove deprecated line * quotes type fix * linter * Remove unnecesary lines * UnaryOps.NEG * dont define constants * Linting fix * Disable tests that are broken in ocelot * remove trailing whitespace * reduce line count * linting fix * update to new uast * New looping style * Update to new uast * make AST runner work with triton * linting fix * set renderer var for testing * disable local for ocelot * reenable all tests for ocelot * Pass shared to cuda * Don't group if the backend doesn't support shared mem * use working gpuocelot branch * enable all tests * enable local for ocelot * cleanup * Update test.yml * update cache key * reenable test symbolic and extra * Update test.yml * Revert "Update test.yml" (rerun tests) This reverts commit98c0630ee5. * Revert "fix symbolic tests to include chain split" This reverts commit22a9a4c9cd. * Revert "split chain with parentheses for and node" This reverts commit7499a7004e. * use global size from linearizer * rename newvar to dtype to match other renderers * join program start lines * simplify code that adds axis to local dims * assign r[u] in ssa * We no longer need to replace target in src * we no longer need to cast indices to int by hand * Update triton.py(rerun tests) * Update triton.py(rerun tests) * Update triton.py(rerun tests) --------- Co-authored-by: Gijs Koning <gijs-koning@live.nl> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
98 lines
5.3 KiB
Python
98 lines
5.3 KiB
Python
from typing import Optional, Tuple, Any, List
|
|
import unittest, math
|
|
import numpy as np
|
|
from tinygrad.helpers import dtypes, getenv, DType, PtrDType
|
|
from tinygrad.tensor import Device
|
|
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ASTRunner, Compiled
|
|
from tinygrad.codegen.linearizer import UOps, UOp
|
|
|
|
def _uops_to_prg(uops):
|
|
src = Device[Device.DEFAULT].renderer("test", uops)
|
|
return ASTRunner("test", src[0] if getenv("TRITON") else src, [1], [1], runtime_args={"binary": getenv("TRITON")}).build(Device[Device.DEFAULT].runtime)
|
|
|
|
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
|
|
uops.append(UOp(uop, dtype, tuple(vin), arg, len(uops)))
|
|
return uops[-1]
|
|
|
|
def _test_single_value(vals, op, dtype):
|
|
uops = []
|
|
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), ('data0', dtype))
|
|
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (f'data{i+1}', dtype)) for i in range(len(vals))]
|
|
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i in range(len(vals)))
|
|
alu = uop(uops, UOps.ALU, dtype, loads, op)
|
|
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
|
buf = Device[Device.DEFAULT].buffer(1, dtype)
|
|
buf2 = [Device[Device.DEFAULT].buffer.fromCPU(np.array([a], dtype=dtype.np)) for a in vals]
|
|
prg = _uops_to_prg(uops)
|
|
prg([buf]+buf2)
|
|
return buf.toCPU()[0]
|
|
|
|
def _test_single_value_const(vals, op, dtype):
|
|
uops = []
|
|
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), ('data0', dtype))
|
|
loads = (uop(uops, UOps.CONST, dtype, [], a) for a in vals)
|
|
alu = uop(uops, UOps.ALU, dtype, loads, op)
|
|
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
|
buf = Device[Device.DEFAULT].buffer(1, dtype)
|
|
prg = _uops_to_prg(uops)
|
|
prg([buf])
|
|
return buf.toCPU()[0]
|
|
|
|
class TestUOps(unittest.TestCase):
|
|
def _equal(self, v1, v2):
|
|
if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5)
|
|
|
|
def _test_uop_fxn(self, bop, fxn, dt=dtypes.float32):
|
|
for f in [_test_single_value, _test_single_value_const]:
|
|
for a in [-2.0, 0.0, 1.0, 2.0]:
|
|
self._equal(f([a], bop, dt), fxn(a))
|
|
|
|
def _test_bop_fxn(self, bop, fxn, dt=dtypes.float32, no_b_zero=False):
|
|
for f in [_test_single_value, _test_single_value_const]:
|
|
for a in [-2.0, 0.0, 1.0, 2.0]:
|
|
for b in [-3.0, 1.0, 3.0] + ([] if no_b_zero else [0.0]):
|
|
self._equal(f([a,b], bop, dt), fxn(a,b))
|
|
|
|
def _test_top_fxn(self, bop, fxn, dt=dtypes.float32):
|
|
for f in [_test_single_value, _test_single_value_const]:
|
|
for a in [-2.0, 0, 1, 2.0]:
|
|
for b in [-3.0, 3.0]:
|
|
for c in [-4.0, 4.0]:
|
|
self._equal(f([a,b,c], bop, dt), fxn(a,b,c))
|
|
|
|
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends")
|
|
class TestFloatUOps(TestUOps):
|
|
def test_neg(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a)
|
|
def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a))
|
|
def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan'))
|
|
def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a))
|
|
def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
|
|
# this is not on most backends
|
|
#def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1.0/a if a != 0 else float('inf'))
|
|
|
|
def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b)
|
|
def test_sub(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: a-b)
|
|
def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
|
|
def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b if b != 0 else a*float('inf'))
|
|
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
|
|
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b))
|
|
# MOD isn't tested on floats
|
|
|
|
def test_mulacc(self): self._test_top_fxn(TernaryOps.MULACC, lambda a,b,c: (a*b)+c)
|
|
def test_where(self): self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c)
|
|
|
|
# TODO: fix this on all the backends
|
|
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or getenv('ARM64', False), "only test for compiled backends, broken on some")
|
|
class TestNonFloatUOps(TestUOps):
|
|
def test_neg_int32(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a, dtypes.int32)
|
|
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), dtypes.int32)
|
|
def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), dtypes.int32)
|
|
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), dtypes.int32)
|
|
def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), dtypes.int32, no_b_zero=True)
|
|
def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], dtypes.int32, no_b_zero=True)
|
|
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), dtypes.int32)
|
|
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), dtypes.bool)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main(verbosity=2)
|