* WIP

* Cleanup

* Cleanup

* Fix variable, refactor to use set

* right shift should be signed/unsigned

* Test for bitshifts

* Allow a neg
This commit is contained in:
Szymon Ożóg
2024-06-03 21:16:01 +02:00
committed by GitHub
parent e78a9bf3f2
commit bb7b031c5c
4 changed files with 21 additions and 6 deletions

View File

@@ -2,7 +2,7 @@ from typing import Optional, Tuple, Any, List
import unittest, math
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import CI
from tinygrad.helpers import CI, getenv
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.device import Buffer, Device
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
@@ -81,12 +81,12 @@ class TestUOps(unittest.TestCase):
a = dtypes.as_const(a, dts[0])
self._equal(f([a], op, dts), fxn(a))
def _test_bop_fxn(self, op, fxn, dts=(dtypes.float32, )*2, no_b_zero=False):
def _test_bop_fxn(self, op, fxn, dts=(dtypes.float32, )*2, no_b_zero=False, no_b_neg=False):
for f in [_test_single_value, _test_single_value_const]:
for a in [-2.0, 0.0, 1.0]:
for b in [-3.0, 1.0] + ([] if no_b_zero else [0.0]):
a = dtypes.as_const(a, dts[0])
b = dtypes.as_const(b, dts[1])
b = dtypes.as_const(abs(b) if no_b_neg else b, dts[1])
self._equal(f([a,b], op, dts), fxn(a,b))
def _test_top_fxn(self, op, fxn, dts=(dtypes.float32, )*3):
@@ -122,6 +122,10 @@ class TestNonFloatUOps(TestUOps):
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), (dtypes.int32, dtypes.int32))
def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), (dtypes.int32, dtypes.int32))
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), (dtypes.int32, dtypes.int32))
@unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts")
def test_shr_int32(self): self._test_bop_fxn(BinaryOps.SHR, lambda a,b: int(a)>>int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
@unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts")
def test_shl_int32(self): self._test_bop_fxn(BinaryOps.SHL, lambda a,b: int(a)<<int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
def test_div_int32(self):
self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), (dtypes.int32, dtypes.int32), no_b_zero=True)
def test_mod_int32(self):

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast
import functools, itertools, heapq
from collections import defaultdict
from enum import Enum, auto
@@ -62,6 +62,7 @@ def uop_alu_resolve(u:UOp) -> sint:
if u.uop is UOps.DEFINE_VAR: return u.arg
if u.uop is UOps.SPECIAL: return u.arg[2]-1
if u.uop is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.vin[0]) * uop_alu_resolve(u.vin[1])
if u.uop is UOps.ALU and u.arg is BinaryOps.SHL: return uop_alu_resolve(u.vin[0]) * (2**cast(int, uop_alu_resolve(u.vin[1])))
if u.uop is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.vin[0]) + uop_alu_resolve(u.vin[1])
raise RuntimeError(f"ALU resolve fail @ {u.uop}")
@@ -94,7 +95,8 @@ class UPat:
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
if pat.name in store and store[pat.name] != uop: return False
if pat.name is not None: store[pat.name] = uop
if pat.arg is not None and uop.arg != pat.arg: return False
if isinstance(pat.arg, set) and uop.arg not in pat.arg: return False
elif pat.arg is not None and uop.arg != pat.arg: return False
if isinstance(pat.dtype, set) and uop.dtype not in pat.dtype: return False
if isinstance(pat.dtype, DType) and uop.dtype != pat.dtype: return False
if isinstance(pat.uop, set) and uop.uop not in pat.uop: return False

View File

@@ -18,6 +18,7 @@ class UnaryOps(Enum):
class BinaryOps(Enum):
"""A + A -> A (elementwise)"""
ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPEQ = auto(); XOR = auto() # noqa: E702
SHR = auto(); SHL = auto() # noqa: E702
class TernaryOps(Enum):
"""A + A + A -> A (elementwise)"""
WHERE = auto(); MULACC = auto() # noqa: E702
@@ -119,6 +120,7 @@ python_alu = {
UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))),
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin,
UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor,
BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt,
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0],

View File

@@ -1,5 +1,5 @@
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
import struct
import struct, math
from collections import defaultdict
from tinygrad.helpers import DEBUG
from tinygrad.codegen.linearizer import UOps, UOp
@@ -40,6 +40,7 @@ class PTXRenderer(Renderer):
UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"neg.{name} {d}, {a};",
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
BinaryOps.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", BinaryOps.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
BinaryOps.SUB: lambda d,a,b,dt,name: f"sub.{name} {d}, {a}, {b};",
BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
@@ -236,6 +237,12 @@ class PTXRenderer(Renderer):
return self.render_kernel(kernel, name, bufs, c.items())
ptx_matcher = PatternMatcher([
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.MUL, "dtype": set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
"vin": [{"__name__": "const", "uop": UOps.CONST, "arg": set([2**i for i in range(64)])}, {"__name__": "mul"}]},
lambda root, mul, const: UOp(UOps.ALU, root.dtype, (mul, UOp.const(root.dtype, int(math.log2(const.arg)))), BinaryOps.SHL)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.DIV, "dtype": set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
"vin": [{"__name__": "const", "uop": UOps.CONST, "arg": set([2**i for i in range(64)])}, {"__name__": "div"}]},
lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(root.dtype, int(math.log2(const.arg)))), BinaryOps.SHR)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},