mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
This commit is contained in:
committed by
GitHub
parent
f866b2a513
commit
2e72625652
@@ -18,7 +18,7 @@ settings.register_profile("my_profile", max_examples=200, deadline=None, derando
|
||||
settings.load_profile("my_profile")
|
||||
|
||||
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
|
||||
if not is_dtype_supported(dtype) and dtype not in (dtypes.long, dtypes.ulong): return []
|
||||
if not is_dtype_supported(dtype): return []
|
||||
# dont cast internal dtypes
|
||||
return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")]
|
||||
|
||||
@@ -333,14 +333,8 @@ class TestUint16DType(TestDType):
|
||||
class TestInt32DType(TestDType): DTYPE = dtypes.int32
|
||||
class TestUint32DType(TestDType): DTYPE = dtypes.uint32
|
||||
|
||||
class TestInt64DType(TestDType):
|
||||
DTYPE = dtypes.int64
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
|
||||
class TestInt64DType(TestDType): DTYPE = dtypes.int64
|
||||
class TestUint64DType(TestDType):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
DTYPE = dtypes.uint64
|
||||
def test_uint64_load(self):
|
||||
assert Tensor(2**64 - 1, dtype=dtypes.uint64).numpy() == 2**64 - 1
|
||||
|
||||
@@ -165,6 +165,7 @@ class TestDTypeALU(unittest.TestCase):
|
||||
@given(ht.uint32, ht.uint32, strat.sampled_from(integer_binary_operations))
|
||||
def test_uint32(self, a, b, op): universal_test(a, b, dtypes.uint32, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint64), f"no uint64 on {Device.DEFAULT}")
|
||||
@given(ht.uint64, ht.uint64, strat.sampled_from(integer_binary_operations))
|
||||
def test_uint64(self, a, b, op): universal_test(a, b, dtypes.uint64, op)
|
||||
|
||||
@@ -177,6 +178,7 @@ class TestDTypeALU(unittest.TestCase):
|
||||
@given(ht.int32, ht.int32, strat.sampled_from(integer_binary_operations))
|
||||
def test_int32(self, a, b, op): universal_test(a, b, dtypes.int32, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.int64), f"no int64 on {Device.DEFAULT}")
|
||||
@given(ht.int64, ht.int64, strat.sampled_from(integer_binary_operations))
|
||||
def test_int64(self, a, b, op): universal_test(a, b, dtypes.int64, op)
|
||||
|
||||
@@ -191,6 +193,7 @@ class TestDTypeALU(unittest.TestCase):
|
||||
@given(ht.uint32, strat.sampled_from(integer_unary_operations))
|
||||
def test_uint32_unary(self, a, op): universal_test_unary(a, dtypes.uint32, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint64), f"no uint64 on {Device.DEFAULT}")
|
||||
@given(ht.uint64, strat.sampled_from(integer_unary_operations))
|
||||
def test_uint64_unary(self, a, op): universal_test_unary(a, dtypes.uint64, op)
|
||||
|
||||
@@ -203,6 +206,7 @@ class TestDTypeALU(unittest.TestCase):
|
||||
@given(ht.int32, strat.sampled_from(integer_unary_operations))
|
||||
def test_int32_unary(self, a, op): universal_test_unary(a, dtypes.int32, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.int64), f"no int64 on {Device.DEFAULT}")
|
||||
@given(ht.int64, strat.sampled_from(integer_unary_operations))
|
||||
def test_int64_unary(self, a, op): universal_test_unary(a, dtypes.int64, op)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from tinygrad import Tensor, dtypes, nn
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.device import Device, is_dtype_supported
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.renderer.nir import NIRRenderer
|
||||
|
||||
@@ -207,7 +207,8 @@ class TestUOpValidationIssue(unittest.TestCase):
|
||||
# these fail with UOp verification error.
|
||||
# we want more of these with diverse errors!
|
||||
|
||||
@unittest.skipIf(MOCKGPU or isinstance(Device[Device.DEFAULT].renderer, NIRRenderer), "hangs gpuocelot, NIR cannot render")
|
||||
@unittest.skipIf((not is_dtype_supported(dtypes.long)) or MOCKGPU or isinstance(Device[Device.DEFAULT].renderer, NIRRenderer),
|
||||
"hangs gpuocelot, NIR cannot render")
|
||||
def test_tensor_index_overflow(self):
|
||||
val = Tensor([1])
|
||||
big = val.expand(2**31 + 3)
|
||||
|
||||
@@ -339,7 +339,7 @@ class TestIndexing(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper(output, input_list)
|
||||
'''
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't support long indexing: #13624")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"long dtype not supported on {Device.DEFAULT}")
|
||||
def test_index_ind_dtype(self):
|
||||
x = Tensor.randn(4, 4)
|
||||
# ind_long = torch.randint(4, (4,), dtype=torch.long)
|
||||
|
||||
@@ -95,7 +95,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
||||
|
||||
# decompositions
|
||||
supported_ops = tuple(ren.code_for_op.keys())
|
||||
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, ren.device, TRANSCENDENTAL>=2)
|
||||
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)
|
||||
sink = graph_rewrite(sink, pm_decomp, ctx=ren.device, name="decompositions")
|
||||
|
||||
# final rules for the renderer (without sym)
|
||||
|
||||
@@ -2,8 +2,7 @@ from typing import Callable
|
||||
import math, functools
|
||||
from tinygrad.dtype import dtypes, DType, promo_lattice
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import flatten, polyN, DISABLE_FAST_IDIV
|
||||
from tinygrad.uop import GroupOp
|
||||
from tinygrad.helpers import polyN, DISABLE_FAST_IDIV
|
||||
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher
|
||||
|
||||
TRANSCENDENTAL_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64)
|
||||
@@ -315,70 +314,11 @@ def threefry2x32(x: UOp, key: UOp):
|
||||
|
||||
return xr[1].cast(dtypes.uint64) * 2**32 | xr[0].cast(dtypes.uint64)
|
||||
|
||||
# ***** long as 2 ints *****
|
||||
|
||||
l2i_dt = {dtypes.long: dtypes.int, dtypes.ulong: dtypes.uint}
|
||||
def unpack32(v): return v.bitcast(dtypes.uint) & 0xFFFF, v.bitcast(dtypes.uint) >> 16
|
||||
def l2i_idx(idx,off): return idx.replace(src=(idx.src[0], idx.src[1]*2+off))
|
||||
|
||||
# 4.3.1 is the relevant section in TAOCP
|
||||
def l2i(op: Ops, dt: DType, *uops:UOp):
|
||||
zero = UOp.const(dt, 0)
|
||||
if len(uops) == 2: a0, a1 = uops
|
||||
elif len(uops) == 4: a0, a1, b0, b1 = uops
|
||||
match op:
|
||||
case Ops.NEG: return l2i(Ops.SUB, dt, zero, zero, *uops)
|
||||
case Ops.CAST if dt in (dtypes.long, dtypes.ulong) and uops[0].dtype not in dtypes.floats:
|
||||
return uops[0].cast(l2i_dt[dt]), (uops[0] < 0).where(UOp.const(l2i_dt[dt], -1), UOp.const(l2i_dt[dt], 0))
|
||||
case Ops.CAST if dt in (dtypes.long, dtypes.ulong):
|
||||
return (lo:=uops[0].cast(l2i_dt[dt])), (uops[0] / 2**32).cast(l2i_dt[dt]) - ((uops[0] < 0) & lo.ne(0)).cast(l2i_dt[dt])
|
||||
case Ops.CAST if dt in dtypes.floats:
|
||||
small = (a1.eq(0) & (a0 >= 0)) | (a1.eq(-1) & (a0 < 0))
|
||||
return small.where(a0.cast(dt), ((a1.cast(dtypes.float32) * (2**32)) + a0.bitcast(dtypes.uint).cast(dtypes.float32)).cast(dt))
|
||||
case Ops.CAST: return a0.bitcast(dtypes.uint).cast(dt)
|
||||
case Ops.BITCAST: return a0.bitcast(dt), a1.bitcast(dt)
|
||||
case Ops.SHL:
|
||||
lo, hi = a0 << (b0_mod:=b0 & 31), (a1 << b0_mod) | ((a0 >> 1) >> (31 - b0_mod))
|
||||
return (b0 >= 32).where(zero, lo), (b0 >= 32).where(lo, hi)
|
||||
case Ops.SHR:
|
||||
lo, hi = (a0 >> (b0_mod:=b0 & 31)) | ((a1 << 1) << (31 - b0_mod)), a1 >> b0_mod
|
||||
return (b0 >= 32).where(hi, lo), (b0 >= 32).where(zero, hi)
|
||||
case Ops.ADD: return (low:=a0+b0), (a1 + b1).replace(dtype=dt) + (low.bitcast(dtypes.uint) < a0.bitcast(dtypes.uint)).cast(dt)
|
||||
case Ops.SUB: return a0 - b0, a1 - b1 - (a0.bitcast(dtypes.uint) < b0.bitcast(dtypes.uint)).cast(dt)
|
||||
case Ops.MUL:
|
||||
(a00, a01), (b00, b01) = unpack32(a0), unpack32(b0)
|
||||
mid = l2i(Ops.ADD, dt, ((a00*b01)<<16).bitcast(dt), ((a00*b01)>>16).bitcast(dt), ((a01*b00)<<16).bitcast(dt), ((a01*b00)>>16).bitcast(dt))
|
||||
return l2i(Ops.ADD, dt, *mid, (a00*b00).bitcast(dt), (a01*b01).bitcast(dt) + a0*b1 + a1*b0)
|
||||
case Ops.IDIV | Ops.MOD:
|
||||
# TAOCP Algorithm 4.3.1D could be faster here, but must be parameterized over the width of b
|
||||
if dt == dtypes.int:
|
||||
a0, a1 = (a_neg:=a1 < zero).where((n:=l2i(Ops.NEG, dt, a0, a1))[0], a0).bitcast(dtypes.uint), a_neg.where(n[1], a1).bitcast(dtypes.uint)
|
||||
b0, b1 = (b_neg:=b1 < zero).where((n:=l2i(Ops.NEG, dt, b0, b1))[0], b0).bitcast(dtypes.uint), b_neg.where(n[1], b1).bitcast(dtypes.uint)
|
||||
q, r = (z:=UOp.const(dtypes.uint, 0), z), (z, z)
|
||||
for i in range(63, -1, -1):
|
||||
r = l2i(Ops.SHL, dtypes.uint, *r, UOp.const(dtypes.uint, 1), z)
|
||||
r = (r[0] | l2i(Ops.SHR, dtypes.uint, a0, a1, UOp.const(dtypes.uint, i), z)[0] & 1), r[1]
|
||||
cond = l2i(Ops.CMPLT, dtypes.uint, *r, b0, b1).logical_not()
|
||||
diff = l2i(Ops.SUB, dtypes.uint, *r, b0, b1)
|
||||
q = ((q[0] | cond.cast(dtypes.uint) << (i % 32), q[1]) if i < 32 else (q[0], q[1] | cond.cast(dtypes.uint) << (i % 32)))
|
||||
r = l2i(Ops.WHERE, dtypes.uint, cond, *diff, *r)
|
||||
if dt == dtypes.int:
|
||||
nq, nr = l2i(Ops.NEG, dt, q0:=q[0].bitcast(dt), q1:=q[1].bitcast(dt)), l2i(Ops.NEG, dt, r0:=r[0].bitcast(dt), r1:=r[1].bitcast(dt))
|
||||
return (a_neg.where(nr[0], r0), a_neg.where(nr[1], r1)) if op == Ops.MOD else ((a_neg^b_neg).where(nq[0], q0), (a_neg^b_neg).where(nq[1], q1))
|
||||
return (r[0].bitcast(dt), r[1].bitcast(dt)) if op == Ops.MOD else (q[0].bitcast(dt), q[1].bitcast(dt))
|
||||
case Ops.CMPLT: return (a1 < b1) | ((a1.eq(b1)) & (a0.bitcast(dtypes.uint) < b0.bitcast(dtypes.uint)))
|
||||
case Ops.CMPEQ: return a0.eq(b0) & a1.eq(b1)
|
||||
case Ops.CMPNE: return a0.ne(b0) | a1.ne(b1)
|
||||
case Ops.XOR | Ops.OR | Ops.AND: return UOp(op, dt, src=(a0, b0)), UOp(op, dt, src=(a1, b1))
|
||||
case Ops.WHERE: return uops[0].where(uops[1], uops[3]), uops[0].where(uops[2], uops[4])
|
||||
case Ops.MAX: return l2i(Ops.WHERE, dt, l2i(Ops.CMPLT, dt, *uops), b0, b1, a0, a1)
|
||||
case _: raise NotImplementedError(f"long decomposition of {op} unsupported")
|
||||
|
||||
# ***** decomposition patterns *****
|
||||
|
||||
powers_of_two = {2**i:i for i in range(64)}
|
||||
@functools.cache
|
||||
def get_late_rewrite_patterns(ops:tuple[Ops, ...], device, force_transcendental):
|
||||
def get_late_rewrite_patterns(ops:tuple[Ops, ...], force_transcendental):
|
||||
pat: list[tuple[UPat, Callable]] = []
|
||||
for op,f in ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)):
|
||||
if op not in ops or force_transcendental:
|
||||
@@ -406,8 +346,8 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], device, force_transcendental)
|
||||
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d", vec=False), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))]
|
||||
pat += [(UPat.var("x", dtypes.ints)%UPat.var("d"), lambda x, d: x-d*(x//d))]
|
||||
if Ops.NEG in ops:
|
||||
pat += [(UPat.var('x')*-1, lambda ctx,x: x.alu(Ops.NEG))]
|
||||
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda ctx,x,y: x.alu(Ops.SUB, y))]
|
||||
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
|
||||
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
|
||||
if Ops.CMPLT in ops:
|
||||
# These are late rewrites because simplex expects equalities to be a certain format
|
||||
pat += [
|
||||
@@ -424,22 +364,4 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], device, force_transcendental)
|
||||
if Ops.FDIV in ops:
|
||||
pat += [(UPat.var("x").reciprocal(), lambda x: x.const_like(1).alu(Ops.FDIV, x))]
|
||||
pat += [(UPat.var("a", dtypes.floats) * UPat.const(dtypes.floats, 1).alu(Ops.FDIV, UPat.var("b")), lambda a,b: a.alu(Ops.FDIV, b))]
|
||||
if not is_dtype_supported(dtypes.long, device):
|
||||
pat += [(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
|
||||
x.replace(dtype=l2i_dt[x.dtype.base].ptr(x.dtype.size * 2)) if x.dtype.base in l2i_dt else None)]
|
||||
pat += [(UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), name='st'), lambda st,idx,val:
|
||||
st.replace(src=(l2i_idx(idx, 0), val.rtag(0))).group(st.replace(src=(l2i_idx(idx, 1), val.rtag(1)))) if val.tag is None else None)]
|
||||
pat += [(UPat(GroupOp.Comparison, src=(UPat.var('a', tuple(l2i_dt.keys())), UPat.var('b', tuple(l2i_dt.keys()))), name="x"), lambda a,b,x:
|
||||
l2i(x.op, dt:=l2i_dt[a.dtype], a.rtag(0).cast(dt), a.rtag(1).cast(dt), b.rtag(0).cast(dt), b.rtag(1).cast(dt)))]
|
||||
pat += [(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a'),), name="x"), lambda a,x:
|
||||
l2i(x.op, x.dtype, a)[x.tag] if x.tag is not None else None)]
|
||||
pat += [(UPat(Ops.CAST, src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x:
|
||||
l2i(x.op, x.dtype, a.rtag(0).cast(dt:=l2i_dt[a.dtype]), a.rtag(1).cast(dt)))]
|
||||
pat += [(UPat((*(GroupOp.ALU - GroupOp.Comparison), Ops.BITCAST), tuple(l2i_dt.keys()), name="x"), lambda x:
|
||||
None if x.tag is None else l2i(x.op, l2i_dt[x.dtype], *flatten((a.rtag(0).cast(dt:=l2i_dt[x.src[-1].dtype]), a.rtag(1).cast(dt))
|
||||
if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag])]
|
||||
pat += [(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx:
|
||||
None if x.tag is None else x.replace(dtype=l2i_dt[x.dtype], src=(l2i_idx(idx, x.tag),)))]
|
||||
pat += [(UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x:
|
||||
None if x.tag is None else UOp.const(l2i_dt[x.dtype], (x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF)))]
|
||||
return PatternMatcher(pat)
|
||||
|
||||
Reference in New Issue
Block a user