mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Bitshift (#4728)
* WIP * Cleanup * Cleanup * Fix variable, refactor to use set * right shift should be signed/unsigned * Test for bitshifts * Allow a neg
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import Optional, Tuple, Any, List
|
||||
import unittest, math
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.helpers import CI, getenv
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
|
||||
@@ -81,12 +81,12 @@ class TestUOps(unittest.TestCase):
|
||||
a = dtypes.as_const(a, dts[0])
|
||||
self._equal(f([a], op, dts), fxn(a))
|
||||
|
||||
def _test_bop_fxn(self, op, fxn, dts=(dtypes.float32, )*2, no_b_zero=False):
|
||||
def _test_bop_fxn(self, op, fxn, dts=(dtypes.float32, )*2, no_b_zero=False, no_b_neg=False):
|
||||
for f in [_test_single_value, _test_single_value_const]:
|
||||
for a in [-2.0, 0.0, 1.0]:
|
||||
for b in [-3.0, 1.0] + ([] if no_b_zero else [0.0]):
|
||||
a = dtypes.as_const(a, dts[0])
|
||||
b = dtypes.as_const(b, dts[1])
|
||||
b = dtypes.as_const(abs(b) if no_b_neg else b, dts[1])
|
||||
self._equal(f([a,b], op, dts), fxn(a,b))
|
||||
|
||||
def _test_top_fxn(self, op, fxn, dts=(dtypes.float32, )*3):
|
||||
@@ -122,6 +122,10 @@ class TestNonFloatUOps(TestUOps):
|
||||
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), (dtypes.int32, dtypes.int32))
|
||||
@unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts")
|
||||
def test_shr_int32(self): self._test_bop_fxn(BinaryOps.SHR, lambda a,b: int(a)>>int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
|
||||
@unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts")
|
||||
def test_shl_int32(self): self._test_bop_fxn(BinaryOps.SHL, lambda a,b: int(a)<<int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
|
||||
def test_div_int32(self):
|
||||
self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), (dtypes.int32, dtypes.int32), no_b_zero=True)
|
||||
def test_mod_int32(self):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union
|
||||
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast
|
||||
import functools, itertools, heapq
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
@@ -62,6 +62,7 @@ def uop_alu_resolve(u:UOp) -> sint:
|
||||
if u.uop is UOps.DEFINE_VAR: return u.arg
|
||||
if u.uop is UOps.SPECIAL: return u.arg[2]-1
|
||||
if u.uop is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.vin[0]) * uop_alu_resolve(u.vin[1])
|
||||
if u.uop is UOps.ALU and u.arg is BinaryOps.SHL: return uop_alu_resolve(u.vin[0]) * (2**cast(int, uop_alu_resolve(u.vin[1])))
|
||||
if u.uop is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.vin[0]) + uop_alu_resolve(u.vin[1])
|
||||
raise RuntimeError(f"ALU resolve fail @ {u.uop}")
|
||||
|
||||
@@ -94,7 +95,8 @@ class UPat:
|
||||
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
|
||||
if pat.name in store and store[pat.name] != uop: return False
|
||||
if pat.name is not None: store[pat.name] = uop
|
||||
if pat.arg is not None and uop.arg != pat.arg: return False
|
||||
if isinstance(pat.arg, set) and uop.arg not in pat.arg: return False
|
||||
elif pat.arg is not None and uop.arg != pat.arg: return False
|
||||
if isinstance(pat.dtype, set) and uop.dtype not in pat.dtype: return False
|
||||
if isinstance(pat.dtype, DType) and uop.dtype != pat.dtype: return False
|
||||
if isinstance(pat.uop, set) and uop.uop not in pat.uop: return False
|
||||
|
||||
@@ -18,6 +18,7 @@ class UnaryOps(Enum):
|
||||
class BinaryOps(Enum):
|
||||
"""A + A -> A (elementwise)"""
|
||||
ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPEQ = auto(); XOR = auto() # noqa: E702
|
||||
SHR = auto(); SHL = auto() # noqa: E702
|
||||
class TernaryOps(Enum):
|
||||
"""A + A + A -> A (elementwise)"""
|
||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||
@@ -119,6 +120,7 @@ python_alu = {
|
||||
UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))),
|
||||
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin,
|
||||
UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
|
||||
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
|
||||
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor,
|
||||
BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt,
|
||||
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0],
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
|
||||
import struct
|
||||
import struct, math
|
||||
from collections import defaultdict
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
@@ -40,6 +40,7 @@ class PTXRenderer(Renderer):
|
||||
UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"neg.{name} {d}, {a};",
|
||||
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
||||
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
||||
BinaryOps.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", BinaryOps.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
|
||||
BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
|
||||
BinaryOps.SUB: lambda d,a,b,dt,name: f"sub.{name} {d}, {a}, {b};",
|
||||
BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
|
||||
@@ -236,6 +237,12 @@ class PTXRenderer(Renderer):
|
||||
return self.render_kernel(kernel, name, bufs, c.items())
|
||||
|
||||
ptx_matcher = PatternMatcher([
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.MUL, "dtype": set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
||||
"vin": [{"__name__": "const", "uop": UOps.CONST, "arg": set([2**i for i in range(64)])}, {"__name__": "mul"}]},
|
||||
lambda root, mul, const: UOp(UOps.ALU, root.dtype, (mul, UOp.const(root.dtype, int(math.log2(const.arg)))), BinaryOps.SHL)),
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.DIV, "dtype": set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
||||
"vin": [{"__name__": "const", "uop": UOps.CONST, "arg": set([2**i for i in range(64)])}, {"__name__": "div"}]},
|
||||
lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(root.dtype, int(math.log2(const.arg)))), BinaryOps.SHR)),
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
|
||||
lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
|
||||
|
||||
Reference in New Issue
Block a user