From aa9dc5057736b426b05fe9136f69a9ccb4967f61 Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Thu, 5 Feb 2026 11:42:30 -0800 Subject: [PATCH] dtype decomps don't require bitshifts (#14542) * dtype decomps don't require bitshifts * simplify shr/shl * ruff --- test/test_dtype.py | 4 ---- test/test_dtype_alu.py | 5 ----- tinygrad/uop/decompositions.py | 14 +++++++------- 3 files changed, 7 insertions(+), 16 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index 803b65cdb6..6f5022a7a3 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -8,7 +8,6 @@ from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.nir import NIRRenderer from tinygrad import Context, Device, Tensor, dtypes -from tinygrad.uop import Ops from hypothesis import given, settings, strategies as strat from test.helpers import rand_for_dtype from test.unit.test_dtype_spec import _assert_eq, core_dtypes, dtype_ints, dtype_floats, FP8E4M3_MAX, FP8E5M2_MAX @@ -247,7 +246,6 @@ class TestBFloat16DTypeCast(unittest.TestCase): class TestHalfDType(TestDType): DTYPE = dtypes.half -@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "half decomp requires bitshift") class TestEmulatedHalf(TestHalfDType): @classmethod def setUpClass(cls): @@ -351,7 +349,6 @@ class TestUint32DType(TestDType): DTYPE = dtypes.uint32 class TestInt64DType(TestDType): DTYPE = dtypes.int64 -@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift") @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs") class TestEmulatedInt64DType(TestInt64DType): @classmethod @@ -368,7 +365,6 @@ class TestUint64DType(TestDType): def test_uint64_load(self): assert Tensor(2**64 - 1, dtype=dtypes.uint64).numpy() == 2**64 - 1 -@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift") @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs") class TestEmulatedUInt64DType(TestUint64DType): @classmethod diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index b9d0b987fc..813948ccdd 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -7,7 +7,6 @@ from tinygrad.device import is_dtype_supported from tinygrad.runtime.ops_python import from_storage_scalar from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.nir import NIRRenderer -from tinygrad.uop import Ops import numpy as np import pytest from hypothesis import assume, given, strategies as strat, settings @@ -181,7 +180,6 @@ class TestDTypeALU(unittest.TestCase): @given(ht.uint64, ht.uint64, strat.sampled_from(integer_binary_operations)) def test_uint64(self, a, b, op): universal_test(a, b, dtypes.uint64, op) - @unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift") @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs") @given(ht.uint64, ht.uint64, strat.sampled_from(integer_binary_operations)) @Context(EMULATED_DTYPES="long") @@ -200,7 +198,6 @@ class TestDTypeALU(unittest.TestCase): @given(ht.int64, ht.int64, strat.sampled_from(integer_binary_operations)) def test_int64(self, a, b, op): universal_test(a, b, dtypes.int64, op) - @unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift") @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs") @given(ht.int64, ht.int64, strat.sampled_from(integer_binary_operations)) @Context(EMULATED_DTYPES="long") @@ -221,7 +218,6 @@ class TestDTypeALU(unittest.TestCase): @given(ht.uint64, strat.sampled_from(integer_unary_operations)) def test_uint64_unary(self, a, op): universal_test_unary(a, dtypes.uint64, op) - @unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift") @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs") @given(ht.uint64, strat.sampled_from(integer_unary_operations)) @Context(EMULATED_DTYPES="long") @@ -240,7 +236,6 @@ class TestDTypeALU(unittest.TestCase): @given(ht.int64, strat.sampled_from(integer_unary_operations)) def test_int64_unary(self, a, op): universal_test_unary(a, dtypes.int64, op) - @unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift") @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs") @given(ht.int64, strat.sampled_from(integer_unary_operations)) @Context(EMULATED_DTYPES="long") diff --git a/tinygrad/uop/decompositions.py b/tinygrad/uop/decompositions.py index aa0677eaac..9629c9eae4 100644 --- a/tinygrad/uop/decompositions.py +++ b/tinygrad/uop/decompositions.py @@ -18,8 +18,8 @@ def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d.scalar()] # **** utils **** -def shr(x:UOp|int, y:int) -> UOp: return x // (2**y) -def shl(x:UOp|int, y:int) -> UOp: return x * (2**y) +def shr(x:UOp|int, y:UOp|int) -> UOp: return x // (2**(y.simplify().arg) if isinstance(y, UOp) else 2**y) +def shl(x:UOp|int, y:UOp|int) -> UOp: return x * (2**(y.simplify().arg) if isinstance(y, UOp) else 2**y) def rintk(d:UOp) -> UOp: """round d:float to int away from 0""" @@ -318,7 +318,7 @@ def threefry2x32(x: UOp, key: UOp): # ***** long as 2 ints ***** l2i_dt = {dtypes.long: dtypes.int, dtypes.ulong: dtypes.uint} -def unpack32(v:UOp) -> tuple[UOp, UOp]: return v.bitcast(dtypes.uint) & 0xFFFF, v.bitcast(dtypes.uint) >> 16 +def unpack32(v:UOp) -> tuple[UOp, UOp]: return v.bitcast(dtypes.uint) & 0xFFFF, shr(v.bitcast(dtypes.uint), 16) def reindex(idx:UOp, off:int, mul=2) -> UOp: return idx.replace(src=(idx.src[0], idx.src[1]*mul+off)) # 4.3.1 is the relevant section in TAOCP @@ -339,16 +339,16 @@ def l2i(op: Ops, dt: DType, *uops:UOp): case Ops.CAST: return a0.bitcast(dtypes.uint).cast(dt) case Ops.BITCAST: return a0.bitcast(dt), a1.bitcast(dt) case Ops.SHL: - lo, hi = a0 << (b0_mod:=b0 & 31), (a1 << b0_mod) | ((a0 >> 1) >> (31 - b0_mod)) + lo, hi = shl(a0, b0_mod:=b0 & 31), shl(a1, b0_mod) | shr(shr(a0, 1), 31 - b0_mod) return (b0 >= 32).where(zero, lo), (b0 >= 32).where(lo, hi) case Ops.SHR: - lo, hi = (a0 >> (b0_mod:=b0 & 31)) | ((a1 << 1) << (31 - b0_mod)), a1 >> b0_mod + lo, hi = shr(a0, b0_mod:=b0 & 31) | shl(shl(a1, 1), 31 - b0_mod), shr(a1, b0_mod) return (b0 >= 32).where(hi, lo), (b0 >= 32).where(zero, hi) case Ops.ADD: return (low:=a0+b0), (a1 + b1).replace(dtype=dt) + (low.bitcast(dtypes.uint) < a0.bitcast(dtypes.uint)).cast(dt) case Ops.SUB: return a0 - b0, a1 - b1 - (a0.bitcast(dtypes.uint) < b0.bitcast(dtypes.uint)).cast(dt) case Ops.MUL: (a00, a01), (b00, b01) = unpack32(a0), unpack32(b0) - mid = l2i(Ops.ADD, dt, ((a00*b01)<<16).bitcast(dt), ((a00*b01)>>16).bitcast(dt), ((a01*b00)<<16).bitcast(dt), ((a01*b00)>>16).bitcast(dt)) + mid = l2i(Ops.ADD, dt, shl(a00*b01, 16).bitcast(dt), shr(a00*b01, 16).bitcast(dt), shl(a01*b00, 16).bitcast(dt), shr(a01*b00, 16).bitcast(dt)) return l2i(Ops.ADD, dt, *mid, (a00*b00).bitcast(dt), (a01*b01).bitcast(dt) + a0*b1 + a1*b0) case Ops.IDIV | Ops.MOD: # TAOCP Algorithm 4.3.1D could be faster here, but must be parameterized over the width of b @@ -362,7 +362,7 @@ def l2i(op: Ops, dt: DType, *uops:UOp): r = (r[0] | l2i(Ops.SHR, dtypes.uint, a0, a1, UOp.const(dtypes.uint, i), z)[0] & 1), r[1] cond = l2i(Ops.CMPLT, dtypes.uint, *r, b0, b1).logical_not() diff = l2i(Ops.SUB, dtypes.uint, *r, b0, b1) - q = ((q[0] | cond.cast(dtypes.uint) << (i % 32), q[1]) if i < 32 else (q[0], q[1] | cond.cast(dtypes.uint) << (i % 32))) + q = ((q[0] | shl(cond.cast(dtypes.uint), i % 32), q[1]) if i < 32 else (q[0], q[1] | shl(cond.cast(dtypes.uint), i % 32))) r = l2i(Ops.WHERE, dtypes.uint, cond, *diff, *r) if dt == dtypes.int: (nq0, nq1), (nr0, nr1) = l2i(Ops.BITCAST, dt, *l2i(Ops.NEG, dtypes.uint, *q)), l2i(Ops.BITCAST, dt, *l2i(Ops.NEG, dtypes.uint, *r))