From 175c38d1b3621cc1bf55b4d82c0d4e0087275245 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 1 Feb 2023 12:00:33 -0800 Subject: [PATCH] triton: it already was GT0 --- accel/triton/ops_triton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel/triton/ops_triton.py b/accel/triton/ops_triton.py index a682d740df..d2a5733aab 100644 --- a/accel/triton/ops_triton.py +++ b/accel/triton/ops_triton.py @@ -23,7 +23,7 @@ stream = cuda.Stream() class TritonASTKernel(ASTKernel): code_for_op : Dict[Op, str] = { - UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "tl.maximum(A, 0.0)", UnaryOps.SIGN: "tl.where(A>0,1,0)", + UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "tl.maximum(A, 0.0)", UnaryOps.GT0: "tl.where(A>0,1,0)", UnaryOps.EXP: "tl.exp(A)", UnaryOps.LOG: "tl.log(A)", UnaryOps.RECIPROCAL: "(1.0/A)", BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "tl.exp(tl.log(A)*B)", BinaryOps.CMPEQ: "(A==B)",