mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
dtype decomps don't require bitshifts (#14542)
* dtype decomps don't require bitshifts * simplify shr/shl * ruff
This commit is contained in:
committed by
GitHub
parent
b47397ab17
commit
aa9dc50577
@@ -8,7 +8,6 @@ from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float,
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.nir import NIRRenderer
|
||||
from tinygrad import Context, Device, Tensor, dtypes
|
||||
from tinygrad.uop import Ops
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from test.helpers import rand_for_dtype
|
||||
from test.unit.test_dtype_spec import _assert_eq, core_dtypes, dtype_ints, dtype_floats, FP8E4M3_MAX, FP8E5M2_MAX
|
||||
@@ -247,7 +246,6 @@ class TestBFloat16DTypeCast(unittest.TestCase):
|
||||
|
||||
class TestHalfDType(TestDType): DTYPE = dtypes.half
|
||||
|
||||
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "half decomp requires bitshift")
|
||||
class TestEmulatedHalf(TestHalfDType):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -351,7 +349,6 @@ class TestUint32DType(TestDType): DTYPE = dtypes.uint32
|
||||
|
||||
class TestInt64DType(TestDType): DTYPE = dtypes.int64
|
||||
|
||||
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
|
||||
class TestEmulatedInt64DType(TestInt64DType):
|
||||
@classmethod
|
||||
@@ -368,7 +365,6 @@ class TestUint64DType(TestDType):
|
||||
def test_uint64_load(self):
|
||||
assert Tensor(2**64 - 1, dtype=dtypes.uint64).numpy() == 2**64 - 1
|
||||
|
||||
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
|
||||
class TestEmulatedUInt64DType(TestUint64DType):
|
||||
@classmethod
|
||||
|
||||
@@ -7,7 +7,6 @@ from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.runtime.ops_python import from_storage_scalar
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.nir import NIRRenderer
|
||||
from tinygrad.uop import Ops
|
||||
import numpy as np
|
||||
import pytest
|
||||
from hypothesis import assume, given, strategies as strat, settings
|
||||
@@ -181,7 +180,6 @@ class TestDTypeALU(unittest.TestCase):
|
||||
@given(ht.uint64, ht.uint64, strat.sampled_from(integer_binary_operations))
|
||||
def test_uint64(self, a, b, op): universal_test(a, b, dtypes.uint64, op)
|
||||
|
||||
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
|
||||
@given(ht.uint64, ht.uint64, strat.sampled_from(integer_binary_operations))
|
||||
@Context(EMULATED_DTYPES="long")
|
||||
@@ -200,7 +198,6 @@ class TestDTypeALU(unittest.TestCase):
|
||||
@given(ht.int64, ht.int64, strat.sampled_from(integer_binary_operations))
|
||||
def test_int64(self, a, b, op): universal_test(a, b, dtypes.int64, op)
|
||||
|
||||
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
|
||||
@given(ht.int64, ht.int64, strat.sampled_from(integer_binary_operations))
|
||||
@Context(EMULATED_DTYPES="long")
|
||||
@@ -221,7 +218,6 @@ class TestDTypeALU(unittest.TestCase):
|
||||
@given(ht.uint64, strat.sampled_from(integer_unary_operations))
|
||||
def test_uint64_unary(self, a, op): universal_test_unary(a, dtypes.uint64, op)
|
||||
|
||||
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
|
||||
@given(ht.uint64, strat.sampled_from(integer_unary_operations))
|
||||
@Context(EMULATED_DTYPES="long")
|
||||
@@ -240,7 +236,6 @@ class TestDTypeALU(unittest.TestCase):
|
||||
@given(ht.int64, strat.sampled_from(integer_unary_operations))
|
||||
def test_int64_unary(self, a, op): universal_test_unary(a, dtypes.int64, op)
|
||||
|
||||
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
|
||||
@given(ht.int64, strat.sampled_from(integer_unary_operations))
|
||||
@Context(EMULATED_DTYPES="long")
|
||||
|
||||
@@ -18,8 +18,8 @@ def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32:
|
||||
def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d.scalar()]
|
||||
|
||||
# **** utils ****
|
||||
def shr(x:UOp|int, y:int) -> UOp: return x // (2**y)
|
||||
def shl(x:UOp|int, y:int) -> UOp: return x * (2**y)
|
||||
def shr(x:UOp|int, y:UOp|int) -> UOp: return x // (2**(y.simplify().arg) if isinstance(y, UOp) else 2**y)
|
||||
def shl(x:UOp|int, y:UOp|int) -> UOp: return x * (2**(y.simplify().arg) if isinstance(y, UOp) else 2**y)
|
||||
|
||||
def rintk(d:UOp) -> UOp:
|
||||
"""round d:float to int away from 0"""
|
||||
@@ -318,7 +318,7 @@ def threefry2x32(x: UOp, key: UOp):
|
||||
# ***** long as 2 ints *****
|
||||
|
||||
l2i_dt = {dtypes.long: dtypes.int, dtypes.ulong: dtypes.uint}
|
||||
def unpack32(v:UOp) -> tuple[UOp, UOp]: return v.bitcast(dtypes.uint) & 0xFFFF, v.bitcast(dtypes.uint) >> 16
|
||||
def unpack32(v:UOp) -> tuple[UOp, UOp]: return v.bitcast(dtypes.uint) & 0xFFFF, shr(v.bitcast(dtypes.uint), 16)
|
||||
def reindex(idx:UOp, off:int, mul=2) -> UOp: return idx.replace(src=(idx.src[0], idx.src[1]*mul+off))
|
||||
|
||||
# 4.3.1 is the relevant section in TAOCP
|
||||
@@ -339,16 +339,16 @@ def l2i(op: Ops, dt: DType, *uops:UOp):
|
||||
case Ops.CAST: return a0.bitcast(dtypes.uint).cast(dt)
|
||||
case Ops.BITCAST: return a0.bitcast(dt), a1.bitcast(dt)
|
||||
case Ops.SHL:
|
||||
lo, hi = a0 << (b0_mod:=b0 & 31), (a1 << b0_mod) | ((a0 >> 1) >> (31 - b0_mod))
|
||||
lo, hi = shl(a0, b0_mod:=b0 & 31), shl(a1, b0_mod) | shr(shr(a0, 1), 31 - b0_mod)
|
||||
return (b0 >= 32).where(zero, lo), (b0 >= 32).where(lo, hi)
|
||||
case Ops.SHR:
|
||||
lo, hi = (a0 >> (b0_mod:=b0 & 31)) | ((a1 << 1) << (31 - b0_mod)), a1 >> b0_mod
|
||||
lo, hi = shr(a0, b0_mod:=b0 & 31) | shl(shl(a1, 1), 31 - b0_mod), shr(a1, b0_mod)
|
||||
return (b0 >= 32).where(hi, lo), (b0 >= 32).where(zero, hi)
|
||||
case Ops.ADD: return (low:=a0+b0), (a1 + b1).replace(dtype=dt) + (low.bitcast(dtypes.uint) < a0.bitcast(dtypes.uint)).cast(dt)
|
||||
case Ops.SUB: return a0 - b0, a1 - b1 - (a0.bitcast(dtypes.uint) < b0.bitcast(dtypes.uint)).cast(dt)
|
||||
case Ops.MUL:
|
||||
(a00, a01), (b00, b01) = unpack32(a0), unpack32(b0)
|
||||
mid = l2i(Ops.ADD, dt, ((a00*b01)<<16).bitcast(dt), ((a00*b01)>>16).bitcast(dt), ((a01*b00)<<16).bitcast(dt), ((a01*b00)>>16).bitcast(dt))
|
||||
mid = l2i(Ops.ADD, dt, shl(a00*b01, 16).bitcast(dt), shr(a00*b01, 16).bitcast(dt), shl(a01*b00, 16).bitcast(dt), shr(a01*b00, 16).bitcast(dt))
|
||||
return l2i(Ops.ADD, dt, *mid, (a00*b00).bitcast(dt), (a01*b01).bitcast(dt) + a0*b1 + a1*b0)
|
||||
case Ops.IDIV | Ops.MOD:
|
||||
# TAOCP Algorithm 4.3.1D could be faster here, but must be parameterized over the width of b
|
||||
@@ -362,7 +362,7 @@ def l2i(op: Ops, dt: DType, *uops:UOp):
|
||||
r = (r[0] | l2i(Ops.SHR, dtypes.uint, a0, a1, UOp.const(dtypes.uint, i), z)[0] & 1), r[1]
|
||||
cond = l2i(Ops.CMPLT, dtypes.uint, *r, b0, b1).logical_not()
|
||||
diff = l2i(Ops.SUB, dtypes.uint, *r, b0, b1)
|
||||
q = ((q[0] | cond.cast(dtypes.uint) << (i % 32), q[1]) if i < 32 else (q[0], q[1] | cond.cast(dtypes.uint) << (i % 32)))
|
||||
q = ((q[0] | shl(cond.cast(dtypes.uint), i % 32), q[1]) if i < 32 else (q[0], q[1] | shl(cond.cast(dtypes.uint), i % 32)))
|
||||
r = l2i(Ops.WHERE, dtypes.uint, cond, *diff, *r)
|
||||
if dt == dtypes.int:
|
||||
(nq0, nq1), (nr0, nr1) = l2i(Ops.BITCAST, dt, *l2i(Ops.NEG, dtypes.uint, *q)), l2i(Ops.BITCAST, dt, *l2i(Ops.NEG, dtypes.uint, *r))
|
||||
|
||||
Reference in New Issue
Block a user