fix neg logical_not inconsistencies (#3222)

* try

* test: add logical_not tests

* gah im retarded, but this doesn't match types for const()

* fix: can't we jsut do this?

* big change: I don't actually know what I'm doing

* WOOO IM JUST CHANGING EVERYTHING WOW probably gon revert later

* BYE BYE noqa: E501

* fix: less lines and add test

* fix: rm 2 redundant tests

* fix: eq with False so we don't unintentionally implicit upcast, but it's bool anyways so w/e
This commit is contained in:
geohotstan
2024-01-25 00:48:40 +08:00
committed by GitHub
parent e2e4632aea
commit 842053873d
9 changed files with 14 additions and 12 deletions

View File

@@ -63,7 +63,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
UnaryOps.LOG2: lambda x,dtype,: f"tl.math.log2({x})",
UnaryOps.SIN: lambda x,dtype: f"tl.sin({x})",
UnaryOps.SQRT: lambda x,dtype: f"tl.sqrt({x})",
UnaryOps.NEG: lambda x,dtype: f"-{x}" if dtype != dtypes.bool else f"tl.where({x}, 0, 1)",
UnaryOps.NEG: lambda x,dtype: f"-{x}",
BinaryOps.ADD: lambda x,y,dtype: f"({x}+{y})", BinaryOps.SUB: lambda x,y,: f"({x}-{y})",
BinaryOps.MUL: lambda x,y,dtype: f"({x}*{y})", BinaryOps.DIV: lambda x,y,: f"({x}/{y})" if y != '0.0' else f"{x}*tl.where({x}==0.0, float('nan'), float('inf'))",
BinaryOps.MAX: lambda x,y,dtype: f"tl.maximum({x},{y})",