This commit is contained in:
chenyu
2025-08-13 15:40:48 -07:00
committed by GitHub
parent eb10a9c76a
commit 4fe19eec72
10 changed files with 22 additions and 10 deletions

View File

@@ -393,7 +393,7 @@ class TestOps(unittest.TestCase):
def test_trunc(self):
helper_test_op([()], lambda x: x.trunc(), forward_only=True)
helper_test_op([(45,35)], lambda x: x.trunc(), forward_only=True)
helper_test_op(None, lambda x: x.trunc(), vals=[[1.499, 1.5, 1.501, 1.0, 2.1, 0.0, -5.0, -2.499, -2.5, -2.501]], forward_only=True)
helper_test_op(None, lambda x: x.trunc(), vals=[[1.499, 1.5, 1.501, 1.0, 2.1, 0.0, -5.0, -2.499, -2.5, -2.501, 1e12, -1e12]], forward_only=True)
def test_floor(self):
helper_test_op([()], lambda x: x.floor(), forward_only=True)
helper_test_op([(45,35)], lambda x: x.floor(), forward_only=True)

View File

@@ -679,6 +679,10 @@ class TestSymbolic(unittest.TestCase):
# TODO: should z3 work?
self.helper_test_variable(2*(2*a).reciprocal(), -math.inf, math.inf, "(1/a)", test_z3=False)
def test_trunc_noop(self):
a = Variable("a", 1, 10, dtypes.int)
self.helper_test_variable(a.trunc(), 1, 10, "a", test_z3=False)
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
MIN, MAX = 0, 10

View File

@@ -99,6 +99,7 @@ class CStyleLanguage(Renderer):
code_for_op: dict = {
Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}",
Ops.EXP2: lambda x,dtype: f"exp2({x})", Ops.LOG2: lambda x,dtype: f"log2({x})", Ops.SIN: lambda x,dtype: f"sin({x})",
Ops.TRUNC: lambda x,dtype: f"trunc({x})",
Ops.AND: lambda a,b,dtype: f"({a}&{b})", Ops.XOR: lambda a,b,dtype: f"({a}^{b})", Ops.OR: lambda a,b,dtype: f"({a}|{b})",
Ops.ADD: lambda a,b,dtype: f"({a}+{b})", Ops.SUB: lambda a,b,dtype: f"({a}-{b})", Ops.MUL: lambda a,b,dtype: f"({a}*{b})",
Ops.MOD: lambda a,b,dtype: f"({a}%{b})", Ops.IDIV: lambda a,b,dtype: f"({a}/{b})", Ops.CMPNE: lambda a,b,dtype: f"({a}!={b})",
@@ -200,11 +201,12 @@ class ClangRenderer(CStyleLanguage):
# language options
buffer_suffix = " restrict"
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2]}),
Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"}
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2, Ops.TRUNC]}),
Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})",
Ops.TRUNC: lambda x,dtype: f"__builtin_trunc({x})" if dtype == dtypes.float64 else f"__builtin_truncf({x})"}
# LLVM legalizes double => half cast on systems that don't support it natively (like x86 cpus without AVX512-FP16) into a compiler-rt libcall.
extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16)),
(UPat(Ops.SQRT, name="alu"), no_vectorized_alu),]) + CStyleLanguage.extra_matcher
(UPat((Ops.SQRT, Ops.TRUNC), name="alu"), no_vectorized_alu),]) + CStyleLanguage.extra_matcher
if sys.platform == 'win32':
kernel_typedef = "__attribute__((ms_abi)) void"
@@ -414,7 +416,7 @@ class AMDRenderer(CStyleLanguage):
ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
ocml = [(f"__ocml_{name}_f{n}", f"{dt}, {dt}" if "fmax" == name else dt, dt, atr)
for dt, n in [(dtype.name, dtype.itemsize * 8) for dtype in [dtypes.float, dtypes.double, dtypes.half]]
for name, atr in [("fmax", "const"), ("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", "")]]
for name, atr in [("fmax", "const"), ("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", ""), ("trunc", "")]]
kernel_typedef = "\n".join(f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml)
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
@@ -423,6 +425,7 @@ class AMDRenderer(CStyleLanguage):
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
code_for_op = { **CStyleLanguage.code_for_op,
Ops.TRUNC: lambda x,dtype: f"__ocml_trunc_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
Ops.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
Ops.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
Ops.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",

View File

@@ -92,6 +92,8 @@ base_rewrite = PatternMatcher([
# unary/binary/ternary ops
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
(UPat(Ops.TRUNC, name="x"),
lambda ctx,x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.trunc.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
(UPat(GroupOp.Binary, name="x"), lambda ctx,x:
f" {ctx[x]} = {lop[x.src[0].dtype.scalar()][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
(UPat(Ops.WHERE, name="x"), lambda ctx,x:

View File

@@ -19,6 +19,7 @@ asm_for_op: dict[Ops, Callable] = {
Ops.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
Ops.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", Ops.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
Ops.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", Ops.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
Ops.TRUNC: lambda d,a,dt,name: f"cvt.rzi.{name}.{name} {d}, {a};",
Ops.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", Ops.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
Ops.ADD: lambda d,a,b,dt,name: f"{'or' if dt == dtypes.bool else 'add'}.{name} {d}, {a}, {b};",
Ops.MUL: lambda d,a,b,dt,name: f"{'and' if dt == dtypes.bool else 'mul'}{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
@@ -33,7 +34,7 @@ asm_for_op: dict[Ops, Callable] = {
f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
}
supports_half = (Ops.EXP2, Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPLT, Ops.WHERE)
supports_half = (Ops.EXP2, Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPLT, Ops.WHERE, Ops.TRUNC)
doesnt_support_half: tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if op not in supports_half)
ptx_matcher = PatternMatcher([
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)

View File

@@ -3143,7 +3143,7 @@ class Tensor(MathTrait):
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).trunc().numpy())
```
"""
return self.cast(dtypes.int32).cast(self.dtype)
return self._apply_uop(UOp.trunc)
def ceil(self: Tensor) -> Tensor:
"""

View File

@@ -49,7 +49,7 @@ class Ops(FastEnum):
UNROLL = auto(); CONTRACT = auto(); GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
# UnaryOps
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto(); TRUNC = auto() # noqa: E702
# load/store before math
LOAD = auto(); STORE = auto() # noqa: E702
@@ -80,7 +80,7 @@ class Ops(FastEnum):
CUSTOM = auto(); CUSTOMI = auto() # noqa: E702
class GroupOp:
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG, Ops.TRUNC}
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ,
Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB, Ops.FDIV, Ops.POW}
Ternary = {Ops.WHERE, Ops.MULACC}

View File

@@ -161,6 +161,7 @@ class MathTrait:
raise RuntimeError("where needs at least one UOp arg")
def threefry(self, seed): return self.alu(Ops.THREEFRY, seed)
def reciprocal(self): return self.alu(Ops.RECIP)
def trunc(self): return self.alu(Ops.TRUNC)
def sqrt(self): return self.alu(Ops.SQRT)
def sin(self): return self.alu(Ops.SIN)
def log2(self): return self.alu(Ops.LOG2)

View File

@@ -573,7 +573,7 @@ def safe_pow(x, y):
python_alu: dict[Ops, Callable] = {
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow,
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow, Ops.TRUNC: math.trunc,
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
Ops.MOD: cmod, Ops.IDIV: cdiv, Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq}

View File

@@ -41,6 +41,7 @@ symbolic_simple = PatternMatcher([
(UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool,)).trunc(), lambda x: x),
# ** zero folding **
(UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
(UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0