From 29c81db0f7880848b001c2728aa555a1ef17e7d3 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 12 Oct 2024 22:14:33 +0800 Subject: [PATCH] dont truncate in const fold --- tinygrad/ops.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 114552d4b7..d223e1b920 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -417,10 +417,11 @@ python_alu: Dict[Op, Callable] = { BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf, TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z} -def exec_alu(op:Op, dtype:DType, operands): +def exec_alu(op:Op, dtype:DType, operands, do_truncate=True): if dtype.count > 1: - return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)]) - return truncate.get(dtype, lambda x: x)(python_alu[op](*operands)) + return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands], do_truncate) for i in range(dtype.count)]) + ret = python_alu[op](*operands) + return truncate.get(dtype, lambda x: x)(ret) if do_truncate else ret # ***** uop helpers ***** @@ -907,7 +908,7 @@ symbolic = PatternMatcher([ (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), # ** constant folding ** (UPat(UOps.ALU, name="root", src=UPat((UOps.VCONST, UOps.CONST))), - lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))), + lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], do_truncate=False))), # ALU min==max -> CONST (slow!) (UPat(UOps.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), # max folding