dtype decomps don't require bitshifts (#14542)

* dtype decomps don't require bitshifts

* simplify shr/shl

* ruff
This commit is contained in:
Christopher Milan
2026-02-05 11:42:30 -08:00
committed by GitHub
parent b47397ab17
commit aa9dc50577
3 changed files with 7 additions and 16 deletions

View File

@@ -8,7 +8,6 @@ from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float,
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
from tinygrad import Context, Device, Tensor, dtypes
from tinygrad.uop import Ops
from hypothesis import given, settings, strategies as strat
from test.helpers import rand_for_dtype
from test.unit.test_dtype_spec import _assert_eq, core_dtypes, dtype_ints, dtype_floats, FP8E4M3_MAX, FP8E5M2_MAX
@@ -247,7 +246,6 @@ class TestBFloat16DTypeCast(unittest.TestCase):
class TestHalfDType(TestDType): DTYPE = dtypes.half
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "half decomp requires bitshift")
class TestEmulatedHalf(TestHalfDType):
@classmethod
def setUpClass(cls):
@@ -351,7 +349,6 @@ class TestUint32DType(TestDType): DTYPE = dtypes.uint32
class TestInt64DType(TestDType): DTYPE = dtypes.int64
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
class TestEmulatedInt64DType(TestInt64DType):
@classmethod
@@ -368,7 +365,6 @@ class TestUint64DType(TestDType):
def test_uint64_load(self):
assert Tensor(2**64 - 1, dtype=dtypes.uint64).numpy() == 2**64 - 1
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
class TestEmulatedUInt64DType(TestUint64DType):
@classmethod

View File

@@ -7,7 +7,6 @@ from tinygrad.device import is_dtype_supported
from tinygrad.runtime.ops_python import from_storage_scalar
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
from tinygrad.uop import Ops
import numpy as np
import pytest
from hypothesis import assume, given, strategies as strat, settings
@@ -181,7 +180,6 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.uint64, ht.uint64, strat.sampled_from(integer_binary_operations))
def test_uint64(self, a, b, op): universal_test(a, b, dtypes.uint64, op)
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
@given(ht.uint64, ht.uint64, strat.sampled_from(integer_binary_operations))
@Context(EMULATED_DTYPES="long")
@@ -200,7 +198,6 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.int64, ht.int64, strat.sampled_from(integer_binary_operations))
def test_int64(self, a, b, op): universal_test(a, b, dtypes.int64, op)
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
@given(ht.int64, ht.int64, strat.sampled_from(integer_binary_operations))
@Context(EMULATED_DTYPES="long")
@@ -221,7 +218,6 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.uint64, strat.sampled_from(integer_unary_operations))
def test_uint64_unary(self, a, op): universal_test_unary(a, dtypes.uint64, op)
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
@given(ht.uint64, strat.sampled_from(integer_unary_operations))
@Context(EMULATED_DTYPES="long")
@@ -240,7 +236,6 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.int64, strat.sampled_from(integer_unary_operations))
def test_int64_unary(self, a, op): universal_test_unary(a, dtypes.int64, op)
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
@given(ht.int64, strat.sampled_from(integer_unary_operations))
@Context(EMULATED_DTYPES="long")

View File

@@ -18,8 +18,8 @@ def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32:
def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d.scalar()]
# **** utils ****
def shr(x:UOp|int, y:int) -> UOp: return x // (2**y)
def shl(x:UOp|int, y:int) -> UOp: return x * (2**y)
def shr(x:UOp|int, y:UOp|int) -> UOp: return x // (2**(y.simplify().arg) if isinstance(y, UOp) else 2**y)
def shl(x:UOp|int, y:UOp|int) -> UOp: return x * (2**(y.simplify().arg) if isinstance(y, UOp) else 2**y)
def rintk(d:UOp) -> UOp:
"""round d:float to int away from 0"""
@@ -318,7 +318,7 @@ def threefry2x32(x: UOp, key: UOp):
# ***** long as 2 ints *****
l2i_dt = {dtypes.long: dtypes.int, dtypes.ulong: dtypes.uint}
def unpack32(v:UOp) -> tuple[UOp, UOp]: return v.bitcast(dtypes.uint) & 0xFFFF, v.bitcast(dtypes.uint) >> 16
def unpack32(v:UOp) -> tuple[UOp, UOp]: return v.bitcast(dtypes.uint) & 0xFFFF, shr(v.bitcast(dtypes.uint), 16)
def reindex(idx:UOp, off:int, mul=2) -> UOp: return idx.replace(src=(idx.src[0], idx.src[1]*mul+off))
# 4.3.1 is the relevant section in TAOCP
@@ -339,16 +339,16 @@ def l2i(op: Ops, dt: DType, *uops:UOp):
case Ops.CAST: return a0.bitcast(dtypes.uint).cast(dt)
case Ops.BITCAST: return a0.bitcast(dt), a1.bitcast(dt)
case Ops.SHL:
lo, hi = a0 << (b0_mod:=b0 & 31), (a1 << b0_mod) | ((a0 >> 1) >> (31 - b0_mod))
lo, hi = shl(a0, b0_mod:=b0 & 31), shl(a1, b0_mod) | shr(shr(a0, 1), 31 - b0_mod)
return (b0 >= 32).where(zero, lo), (b0 >= 32).where(lo, hi)
case Ops.SHR:
lo, hi = (a0 >> (b0_mod:=b0 & 31)) | ((a1 << 1) << (31 - b0_mod)), a1 >> b0_mod
lo, hi = shr(a0, b0_mod:=b0 & 31) | shl(shl(a1, 1), 31 - b0_mod), shr(a1, b0_mod)
return (b0 >= 32).where(hi, lo), (b0 >= 32).where(zero, hi)
case Ops.ADD: return (low:=a0+b0), (a1 + b1).replace(dtype=dt) + (low.bitcast(dtypes.uint) < a0.bitcast(dtypes.uint)).cast(dt)
case Ops.SUB: return a0 - b0, a1 - b1 - (a0.bitcast(dtypes.uint) < b0.bitcast(dtypes.uint)).cast(dt)
case Ops.MUL:
(a00, a01), (b00, b01) = unpack32(a0), unpack32(b0)
mid = l2i(Ops.ADD, dt, ((a00*b01)<<16).bitcast(dt), ((a00*b01)>>16).bitcast(dt), ((a01*b00)<<16).bitcast(dt), ((a01*b00)>>16).bitcast(dt))
mid = l2i(Ops.ADD, dt, shl(a00*b01, 16).bitcast(dt), shr(a00*b01, 16).bitcast(dt), shl(a01*b00, 16).bitcast(dt), shr(a01*b00, 16).bitcast(dt))
return l2i(Ops.ADD, dt, *mid, (a00*b00).bitcast(dt), (a01*b01).bitcast(dt) + a0*b1 + a1*b0)
case Ops.IDIV | Ops.MOD:
# TAOCP Algorithm 4.3.1D could be faster here, but must be parameterized over the width of b
@@ -362,7 +362,7 @@ def l2i(op: Ops, dt: DType, *uops:UOp):
r = (r[0] | l2i(Ops.SHR, dtypes.uint, a0, a1, UOp.const(dtypes.uint, i), z)[0] & 1), r[1]
cond = l2i(Ops.CMPLT, dtypes.uint, *r, b0, b1).logical_not()
diff = l2i(Ops.SUB, dtypes.uint, *r, b0, b1)
q = ((q[0] | cond.cast(dtypes.uint) << (i % 32), q[1]) if i < 32 else (q[0], q[1] | cond.cast(dtypes.uint) << (i % 32)))
q = ((q[0] | shl(cond.cast(dtypes.uint), i % 32), q[1]) if i < 32 else (q[0], q[1] | shl(cond.cast(dtypes.uint), i % 32)))
r = l2i(Ops.WHERE, dtypes.uint, cond, *diff, *r)
if dt == dtypes.int:
(nq0, nq1), (nr0, nr1) = l2i(Ops.BITCAST, dt, *l2i(Ops.NEG, dtypes.uint, *q)), l2i(Ops.BITCAST, dt, *l2i(Ops.NEG, dtypes.uint, *r))