mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix neg logical_not inconsistencies (#3222)
* try * test: add logical_not tests * gah im retarded, but this doesn't match types for const() * fix: can't we jsut do this? * big change: I don't actually know what I'm doing * WOOO IM JUST CHANGING EVERYTHING WOW probably gon revert later * BYE BYE noqa: E501 * fix: less lines and add test * fix: rm 2 redundant tests * fix: eq with False so we don't unintentionally implicit upcast, but it's bool anyways so w/e
This commit is contained in:
@@ -63,7 +63,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
|
||||
UnaryOps.LOG2: lambda x,dtype,: f"tl.math.log2({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"tl.sin({x})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"tl.sqrt({x})",
|
||||
UnaryOps.NEG: lambda x,dtype: f"-{x}" if dtype != dtypes.bool else f"tl.where({x}, 0, 1)",
|
||||
UnaryOps.NEG: lambda x,dtype: f"-{x}",
|
||||
BinaryOps.ADD: lambda x,y,dtype: f"({x}+{y})", BinaryOps.SUB: lambda x,y,: f"({x}-{y})",
|
||||
BinaryOps.MUL: lambda x,y,dtype: f"({x}*{y})", BinaryOps.DIV: lambda x,y,: f"({x}/{y})" if y != '0.0' else f"{x}*tl.where({x}==0.0, float('nan'), float('inf'))",
|
||||
BinaryOps.MAX: lambda x,y,dtype: f"tl.maximum({x},{y})",
|
||||
|
||||
Reference in New Issue
Block a user