mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Ops.TRUNC (#11659)
This commit is contained in:
@@ -393,7 +393,7 @@ class TestOps(unittest.TestCase):
|
||||
def test_trunc(self):
|
||||
helper_test_op([()], lambda x: x.trunc(), forward_only=True)
|
||||
helper_test_op([(45,35)], lambda x: x.trunc(), forward_only=True)
|
||||
helper_test_op(None, lambda x: x.trunc(), vals=[[1.499, 1.5, 1.501, 1.0, 2.1, 0.0, -5.0, -2.499, -2.5, -2.501]], forward_only=True)
|
||||
helper_test_op(None, lambda x: x.trunc(), vals=[[1.499, 1.5, 1.501, 1.0, 2.1, 0.0, -5.0, -2.499, -2.5, -2.501, 1e12, -1e12]], forward_only=True)
|
||||
def test_floor(self):
|
||||
helper_test_op([()], lambda x: x.floor(), forward_only=True)
|
||||
helper_test_op([(45,35)], lambda x: x.floor(), forward_only=True)
|
||||
|
||||
@@ -679,6 +679,10 @@ class TestSymbolic(unittest.TestCase):
|
||||
# TODO: should z3 work?
|
||||
self.helper_test_variable(2*(2*a).reciprocal(), -math.inf, math.inf, "(1/a)", test_z3=False)
|
||||
|
||||
def test_trunc_noop(self):
|
||||
a = Variable("a", 1, 10, dtypes.int)
|
||||
self.helper_test_variable(a.trunc(), 1, 10, "a", test_z3=False)
|
||||
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
MIN, MAX = 0, 10
|
||||
|
||||
@@ -99,6 +99,7 @@ class CStyleLanguage(Renderer):
|
||||
code_for_op: dict = {
|
||||
Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}",
|
||||
Ops.EXP2: lambda x,dtype: f"exp2({x})", Ops.LOG2: lambda x,dtype: f"log2({x})", Ops.SIN: lambda x,dtype: f"sin({x})",
|
||||
Ops.TRUNC: lambda x,dtype: f"trunc({x})",
|
||||
Ops.AND: lambda a,b,dtype: f"({a}&{b})", Ops.XOR: lambda a,b,dtype: f"({a}^{b})", Ops.OR: lambda a,b,dtype: f"({a}|{b})",
|
||||
Ops.ADD: lambda a,b,dtype: f"({a}+{b})", Ops.SUB: lambda a,b,dtype: f"({a}-{b})", Ops.MUL: lambda a,b,dtype: f"({a}*{b})",
|
||||
Ops.MOD: lambda a,b,dtype: f"({a}%{b})", Ops.IDIV: lambda a,b,dtype: f"({a}/{b})", Ops.CMPNE: lambda a,b,dtype: f"({a}!={b})",
|
||||
@@ -200,11 +201,12 @@ class ClangRenderer(CStyleLanguage):
|
||||
# language options
|
||||
buffer_suffix = " restrict"
|
||||
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
|
||||
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2]}),
|
||||
Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"}
|
||||
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2, Ops.TRUNC]}),
|
||||
Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})",
|
||||
Ops.TRUNC: lambda x,dtype: f"__builtin_trunc({x})" if dtype == dtypes.float64 else f"__builtin_truncf({x})"}
|
||||
# LLVM legalizes double => half cast on systems that don't support it natively (like x86 cpus without AVX512-FP16) into a compiler-rt libcall.
|
||||
extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16)),
|
||||
(UPat(Ops.SQRT, name="alu"), no_vectorized_alu),]) + CStyleLanguage.extra_matcher
|
||||
(UPat((Ops.SQRT, Ops.TRUNC), name="alu"), no_vectorized_alu),]) + CStyleLanguage.extra_matcher
|
||||
|
||||
if sys.platform == 'win32':
|
||||
kernel_typedef = "__attribute__((ms_abi)) void"
|
||||
@@ -414,7 +416,7 @@ class AMDRenderer(CStyleLanguage):
|
||||
ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
|
||||
ocml = [(f"__ocml_{name}_f{n}", f"{dt}, {dt}" if "fmax" == name else dt, dt, atr)
|
||||
for dt, n in [(dtype.name, dtype.itemsize * 8) for dtype in [dtypes.float, dtypes.double, dtypes.half]]
|
||||
for name, atr in [("fmax", "const"), ("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", "")]]
|
||||
for name, atr in [("fmax", "const"), ("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", ""), ("trunc", "")]]
|
||||
|
||||
kernel_typedef = "\n".join(f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml)
|
||||
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
|
||||
@@ -423,6 +425,7 @@ class AMDRenderer(CStyleLanguage):
|
||||
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
|
||||
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
|
||||
code_for_op = { **CStyleLanguage.code_for_op,
|
||||
Ops.TRUNC: lambda x,dtype: f"__ocml_trunc_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
Ops.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
Ops.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
Ops.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
|
||||
@@ -92,6 +92,8 @@ base_rewrite = PatternMatcher([
|
||||
# unary/binary/ternary ops
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
||||
(UPat(Ops.TRUNC, name="x"),
|
||||
lambda ctx,x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.trunc.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
|
||||
(UPat(GroupOp.Binary, name="x"), lambda ctx,x:
|
||||
f" {ctx[x]} = {lop[x.src[0].dtype.scalar()][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
|
||||
(UPat(Ops.WHERE, name="x"), lambda ctx,x:
|
||||
|
||||
@@ -19,6 +19,7 @@ asm_for_op: dict[Ops, Callable] = {
|
||||
Ops.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
|
||||
Ops.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", Ops.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
||||
Ops.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", Ops.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
||||
Ops.TRUNC: lambda d,a,dt,name: f"cvt.rzi.{name}.{name} {d}, {a};",
|
||||
Ops.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", Ops.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
|
||||
Ops.ADD: lambda d,a,b,dt,name: f"{'or' if dt == dtypes.bool else 'add'}.{name} {d}, {a}, {b};",
|
||||
Ops.MUL: lambda d,a,b,dt,name: f"{'and' if dt == dtypes.bool else 'mul'}{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
|
||||
@@ -33,7 +34,7 @@ asm_for_op: dict[Ops, Callable] = {
|
||||
f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
|
||||
}
|
||||
|
||||
supports_half = (Ops.EXP2, Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPLT, Ops.WHERE)
|
||||
supports_half = (Ops.EXP2, Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPLT, Ops.WHERE, Ops.TRUNC)
|
||||
doesnt_support_half: tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if op not in supports_half)
|
||||
ptx_matcher = PatternMatcher([
|
||||
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
|
||||
|
||||
@@ -3143,7 +3143,7 @@ class Tensor(MathTrait):
|
||||
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).trunc().numpy())
|
||||
```
|
||||
"""
|
||||
return self.cast(dtypes.int32).cast(self.dtype)
|
||||
return self._apply_uop(UOp.trunc)
|
||||
|
||||
def ceil(self: Tensor) -> Tensor:
|
||||
"""
|
||||
|
||||
@@ -49,7 +49,7 @@ class Ops(FastEnum):
|
||||
UNROLL = auto(); CONTRACT = auto(); GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
|
||||
|
||||
# UnaryOps
|
||||
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
||||
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto(); TRUNC = auto() # noqa: E702
|
||||
|
||||
# load/store before math
|
||||
LOAD = auto(); STORE = auto() # noqa: E702
|
||||
@@ -80,7 +80,7 @@ class Ops(FastEnum):
|
||||
CUSTOM = auto(); CUSTOMI = auto() # noqa: E702
|
||||
|
||||
class GroupOp:
|
||||
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
|
||||
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG, Ops.TRUNC}
|
||||
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ,
|
||||
Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB, Ops.FDIV, Ops.POW}
|
||||
Ternary = {Ops.WHERE, Ops.MULACC}
|
||||
|
||||
@@ -161,6 +161,7 @@ class MathTrait:
|
||||
raise RuntimeError("where needs at least one UOp arg")
|
||||
def threefry(self, seed): return self.alu(Ops.THREEFRY, seed)
|
||||
def reciprocal(self): return self.alu(Ops.RECIP)
|
||||
def trunc(self): return self.alu(Ops.TRUNC)
|
||||
def sqrt(self): return self.alu(Ops.SQRT)
|
||||
def sin(self): return self.alu(Ops.SIN)
|
||||
def log2(self): return self.alu(Ops.LOG2)
|
||||
|
||||
@@ -573,7 +573,7 @@ def safe_pow(x, y):
|
||||
python_alu: dict[Ops, Callable] = {
|
||||
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
|
||||
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
||||
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow,
|
||||
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow, Ops.TRUNC: math.trunc,
|
||||
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
|
||||
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
|
||||
Ops.MOD: cmod, Ops.IDIV: cdiv, Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq}
|
||||
|
||||
@@ -41,6 +41,7 @@ symbolic_simple = PatternMatcher([
|
||||
(UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),
|
||||
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
|
||||
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
|
||||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool,)).trunc(), lambda x: x),
|
||||
# ** zero folding **
|
||||
(UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
|
||||
(UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0
|
||||
|
||||
Reference in New Issue
Block a user