Revert "decompose dtypes.long to ints where unsupported (#14261)" (#14362)

This commit is contained in:
Christopher Milan
2026-01-26 23:04:59 -08:00
committed by GitHub
parent f866b2a513
commit 2e72625652
6 changed files with 15 additions and 94 deletions

View File

@@ -18,7 +18,7 @@ settings.register_profile("my_profile", max_examples=200, deadline=None, derando
settings.load_profile("my_profile")
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
if not is_dtype_supported(dtype) and dtype not in (dtypes.long, dtypes.ulong): return []
if not is_dtype_supported(dtype): return []
# dont cast internal dtypes
return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")]
@@ -333,14 +333,8 @@ class TestUint16DType(TestDType):
class TestInt32DType(TestDType): DTYPE = dtypes.int32
class TestUint32DType(TestDType): DTYPE = dtypes.uint32
class TestInt64DType(TestDType):
DTYPE = dtypes.int64
@classmethod
def setUpClass(cls): cls.DATA = rand_for_dtype(cls.DTYPE, 10)
class TestInt64DType(TestDType): DTYPE = dtypes.int64
class TestUint64DType(TestDType):
@classmethod
def setUpClass(cls): cls.DATA = rand_for_dtype(cls.DTYPE, 10)
DTYPE = dtypes.uint64
def test_uint64_load(self):
assert Tensor(2**64 - 1, dtype=dtypes.uint64).numpy() == 2**64 - 1

View File

@@ -165,6 +165,7 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.uint32, ht.uint32, strat.sampled_from(integer_binary_operations))
def test_uint32(self, a, b, op): universal_test(a, b, dtypes.uint32, op)
@unittest.skipUnless(is_dtype_supported(dtypes.uint64), f"no uint64 on {Device.DEFAULT}")
@given(ht.uint64, ht.uint64, strat.sampled_from(integer_binary_operations))
def test_uint64(self, a, b, op): universal_test(a, b, dtypes.uint64, op)
@@ -177,6 +178,7 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.int32, ht.int32, strat.sampled_from(integer_binary_operations))
def test_int32(self, a, b, op): universal_test(a, b, dtypes.int32, op)
@unittest.skipUnless(is_dtype_supported(dtypes.int64), f"no int64 on {Device.DEFAULT}")
@given(ht.int64, ht.int64, strat.sampled_from(integer_binary_operations))
def test_int64(self, a, b, op): universal_test(a, b, dtypes.int64, op)
@@ -191,6 +193,7 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.uint32, strat.sampled_from(integer_unary_operations))
def test_uint32_unary(self, a, op): universal_test_unary(a, dtypes.uint32, op)
@unittest.skipUnless(is_dtype_supported(dtypes.uint64), f"no uint64 on {Device.DEFAULT}")
@given(ht.uint64, strat.sampled_from(integer_unary_operations))
def test_uint64_unary(self, a, op): universal_test_unary(a, dtypes.uint64, op)
@@ -203,6 +206,7 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.int32, strat.sampled_from(integer_unary_operations))
def test_int32_unary(self, a, op): universal_test_unary(a, dtypes.int32, op)
@unittest.skipUnless(is_dtype_supported(dtypes.int64), f"no int64 on {Device.DEFAULT}")
@given(ht.int64, strat.sampled_from(integer_unary_operations))
def test_int64_unary(self, a, op): universal_test_unary(a, dtypes.int64, op)

View File

@@ -26,7 +26,7 @@ import unittest
import numpy as np
import torch
from tinygrad import Tensor, dtypes, nn
from tinygrad.device import Device
from tinygrad.device import Device, is_dtype_supported
from tinygrad.helpers import getenv
from tinygrad.renderer.nir import NIRRenderer
@@ -207,7 +207,8 @@ class TestUOpValidationIssue(unittest.TestCase):
# these fail with UOp verification error.
# we want more of these with diverse errors!
@unittest.skipIf(MOCKGPU or isinstance(Device[Device.DEFAULT].renderer, NIRRenderer), "hangs gpuocelot, NIR cannot render")
@unittest.skipIf((not is_dtype_supported(dtypes.long)) or MOCKGPU or isinstance(Device[Device.DEFAULT].renderer, NIRRenderer),
"hangs gpuocelot, NIR cannot render")
def test_tensor_index_overflow(self):
val = Tensor([1])
big = val.expand(2**31 + 3)

View File

@@ -339,7 +339,7 @@ class TestIndexing(unittest.TestCase):
numpy_testing_assert_equal_helper(output, input_list)
'''
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't support long indexing: #13624")
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"long dtype not supported on {Device.DEFAULT}")
def test_index_ind_dtype(self):
x = Tensor.randn(4, 4)
# ind_long = torch.randint(4, (4,), dtype=torch.long)

View File

@@ -95,7 +95,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
# decompositions
supported_ops = tuple(ren.code_for_op.keys())
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, ren.device, TRANSCENDENTAL>=2)
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)
sink = graph_rewrite(sink, pm_decomp, ctx=ren.device, name="decompositions")
# final rules for the renderer (without sym)

View File

@@ -2,8 +2,7 @@ from typing import Callable
import math, functools
from tinygrad.dtype import dtypes, DType, promo_lattice
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import flatten, polyN, DISABLE_FAST_IDIV
from tinygrad.uop import GroupOp
from tinygrad.helpers import polyN, DISABLE_FAST_IDIV
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher
TRANSCENDENTAL_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64)
@@ -315,70 +314,11 @@ def threefry2x32(x: UOp, key: UOp):
return xr[1].cast(dtypes.uint64) * 2**32 | xr[0].cast(dtypes.uint64)
# ***** long as 2 ints *****
l2i_dt = {dtypes.long: dtypes.int, dtypes.ulong: dtypes.uint}
def unpack32(v): return v.bitcast(dtypes.uint) & 0xFFFF, v.bitcast(dtypes.uint) >> 16
def l2i_idx(idx,off): return idx.replace(src=(idx.src[0], idx.src[1]*2+off))
# 4.3.1 is the relevant section in TAOCP
def l2i(op: Ops, dt: DType, *uops:UOp):
zero = UOp.const(dt, 0)
if len(uops) == 2: a0, a1 = uops
elif len(uops) == 4: a0, a1, b0, b1 = uops
match op:
case Ops.NEG: return l2i(Ops.SUB, dt, zero, zero, *uops)
case Ops.CAST if dt in (dtypes.long, dtypes.ulong) and uops[0].dtype not in dtypes.floats:
return uops[0].cast(l2i_dt[dt]), (uops[0] < 0).where(UOp.const(l2i_dt[dt], -1), UOp.const(l2i_dt[dt], 0))
case Ops.CAST if dt in (dtypes.long, dtypes.ulong):
return (lo:=uops[0].cast(l2i_dt[dt])), (uops[0] / 2**32).cast(l2i_dt[dt]) - ((uops[0] < 0) & lo.ne(0)).cast(l2i_dt[dt])
case Ops.CAST if dt in dtypes.floats:
small = (a1.eq(0) & (a0 >= 0)) | (a1.eq(-1) & (a0 < 0))
return small.where(a0.cast(dt), ((a1.cast(dtypes.float32) * (2**32)) + a0.bitcast(dtypes.uint).cast(dtypes.float32)).cast(dt))
case Ops.CAST: return a0.bitcast(dtypes.uint).cast(dt)
case Ops.BITCAST: return a0.bitcast(dt), a1.bitcast(dt)
case Ops.SHL:
lo, hi = a0 << (b0_mod:=b0 & 31), (a1 << b0_mod) | ((a0 >> 1) >> (31 - b0_mod))
return (b0 >= 32).where(zero, lo), (b0 >= 32).where(lo, hi)
case Ops.SHR:
lo, hi = (a0 >> (b0_mod:=b0 & 31)) | ((a1 << 1) << (31 - b0_mod)), a1 >> b0_mod
return (b0 >= 32).where(hi, lo), (b0 >= 32).where(zero, hi)
case Ops.ADD: return (low:=a0+b0), (a1 + b1).replace(dtype=dt) + (low.bitcast(dtypes.uint) < a0.bitcast(dtypes.uint)).cast(dt)
case Ops.SUB: return a0 - b0, a1 - b1 - (a0.bitcast(dtypes.uint) < b0.bitcast(dtypes.uint)).cast(dt)
case Ops.MUL:
(a00, a01), (b00, b01) = unpack32(a0), unpack32(b0)
mid = l2i(Ops.ADD, dt, ((a00*b01)<<16).bitcast(dt), ((a00*b01)>>16).bitcast(dt), ((a01*b00)<<16).bitcast(dt), ((a01*b00)>>16).bitcast(dt))
return l2i(Ops.ADD, dt, *mid, (a00*b00).bitcast(dt), (a01*b01).bitcast(dt) + a0*b1 + a1*b0)
case Ops.IDIV | Ops.MOD:
# TAOCP Algorithm 4.3.1D could be faster here, but must be parameterized over the width of b
if dt == dtypes.int:
a0, a1 = (a_neg:=a1 < zero).where((n:=l2i(Ops.NEG, dt, a0, a1))[0], a0).bitcast(dtypes.uint), a_neg.where(n[1], a1).bitcast(dtypes.uint)
b0, b1 = (b_neg:=b1 < zero).where((n:=l2i(Ops.NEG, dt, b0, b1))[0], b0).bitcast(dtypes.uint), b_neg.where(n[1], b1).bitcast(dtypes.uint)
q, r = (z:=UOp.const(dtypes.uint, 0), z), (z, z)
for i in range(63, -1, -1):
r = l2i(Ops.SHL, dtypes.uint, *r, UOp.const(dtypes.uint, 1), z)
r = (r[0] | l2i(Ops.SHR, dtypes.uint, a0, a1, UOp.const(dtypes.uint, i), z)[0] & 1), r[1]
cond = l2i(Ops.CMPLT, dtypes.uint, *r, b0, b1).logical_not()
diff = l2i(Ops.SUB, dtypes.uint, *r, b0, b1)
q = ((q[0] | cond.cast(dtypes.uint) << (i % 32), q[1]) if i < 32 else (q[0], q[1] | cond.cast(dtypes.uint) << (i % 32)))
r = l2i(Ops.WHERE, dtypes.uint, cond, *diff, *r)
if dt == dtypes.int:
nq, nr = l2i(Ops.NEG, dt, q0:=q[0].bitcast(dt), q1:=q[1].bitcast(dt)), l2i(Ops.NEG, dt, r0:=r[0].bitcast(dt), r1:=r[1].bitcast(dt))
return (a_neg.where(nr[0], r0), a_neg.where(nr[1], r1)) if op == Ops.MOD else ((a_neg^b_neg).where(nq[0], q0), (a_neg^b_neg).where(nq[1], q1))
return (r[0].bitcast(dt), r[1].bitcast(dt)) if op == Ops.MOD else (q[0].bitcast(dt), q[1].bitcast(dt))
case Ops.CMPLT: return (a1 < b1) | ((a1.eq(b1)) & (a0.bitcast(dtypes.uint) < b0.bitcast(dtypes.uint)))
case Ops.CMPEQ: return a0.eq(b0) & a1.eq(b1)
case Ops.CMPNE: return a0.ne(b0) | a1.ne(b1)
case Ops.XOR | Ops.OR | Ops.AND: return UOp(op, dt, src=(a0, b0)), UOp(op, dt, src=(a1, b1))
case Ops.WHERE: return uops[0].where(uops[1], uops[3]), uops[0].where(uops[2], uops[4])
case Ops.MAX: return l2i(Ops.WHERE, dt, l2i(Ops.CMPLT, dt, *uops), b0, b1, a0, a1)
case _: raise NotImplementedError(f"long decomposition of {op} unsupported")
# ***** decomposition patterns *****
powers_of_two = {2**i:i for i in range(64)}
@functools.cache
def get_late_rewrite_patterns(ops:tuple[Ops, ...], device, force_transcendental):
def get_late_rewrite_patterns(ops:tuple[Ops, ...], force_transcendental):
pat: list[tuple[UPat, Callable]] = []
for op,f in ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)):
if op not in ops or force_transcendental:
@@ -406,8 +346,8 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], device, force_transcendental)
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d", vec=False), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))]
pat += [(UPat.var("x", dtypes.ints)%UPat.var("d"), lambda x, d: x-d*(x//d))]
if Ops.NEG in ops:
pat += [(UPat.var('x')*-1, lambda ctx,x: x.alu(Ops.NEG))]
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda ctx,x,y: x.alu(Ops.SUB, y))]
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
if Ops.CMPLT in ops:
# These are late rewrites because simplex expects equalities to be a certain format
pat += [
@@ -424,22 +364,4 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], device, force_transcendental)
if Ops.FDIV in ops:
pat += [(UPat.var("x").reciprocal(), lambda x: x.const_like(1).alu(Ops.FDIV, x))]
pat += [(UPat.var("a", dtypes.floats) * UPat.const(dtypes.floats, 1).alu(Ops.FDIV, UPat.var("b")), lambda a,b: a.alu(Ops.FDIV, b))]
if not is_dtype_supported(dtypes.long, device):
pat += [(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
x.replace(dtype=l2i_dt[x.dtype.base].ptr(x.dtype.size * 2)) if x.dtype.base in l2i_dt else None)]
pat += [(UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), name='st'), lambda st,idx,val:
st.replace(src=(l2i_idx(idx, 0), val.rtag(0))).group(st.replace(src=(l2i_idx(idx, 1), val.rtag(1)))) if val.tag is None else None)]
pat += [(UPat(GroupOp.Comparison, src=(UPat.var('a', tuple(l2i_dt.keys())), UPat.var('b', tuple(l2i_dt.keys()))), name="x"), lambda a,b,x:
l2i(x.op, dt:=l2i_dt[a.dtype], a.rtag(0).cast(dt), a.rtag(1).cast(dt), b.rtag(0).cast(dt), b.rtag(1).cast(dt)))]
pat += [(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a'),), name="x"), lambda a,x:
l2i(x.op, x.dtype, a)[x.tag] if x.tag is not None else None)]
pat += [(UPat(Ops.CAST, src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x:
l2i(x.op, x.dtype, a.rtag(0).cast(dt:=l2i_dt[a.dtype]), a.rtag(1).cast(dt)))]
pat += [(UPat((*(GroupOp.ALU - GroupOp.Comparison), Ops.BITCAST), tuple(l2i_dt.keys()), name="x"), lambda x:
None if x.tag is None else l2i(x.op, l2i_dt[x.dtype], *flatten((a.rtag(0).cast(dt:=l2i_dt[x.src[-1].dtype]), a.rtag(1).cast(dt))
if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag])]
pat += [(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx:
None if x.tag is None else x.replace(dtype=l2i_dt[x.dtype], src=(l2i_idx(idx, x.tag),)))]
pat += [(UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x:
None if x.tag is None else UOp.const(l2i_dt[x.dtype], (x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF)))]
return PatternMatcher(pat)