From 0fb4ff30c81576d925e2fcde0cf90459e5fa61c9 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 1 Dec 2023 11:10:36 -0500 Subject: [PATCH] share duplicate renders with cstyle (#2538) --- tinygrad/renderer/wgsl.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index feb5215d0b..ff754c1152 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -1,7 +1,7 @@ from tinygrad.helpers import dtypes, DType from tinygrad.renderer.cstyle import CStyleLanguage from typing import List, Union -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps +from tinygrad.ops import BinaryOps, TernaryOps import math from typing import Tuple @@ -13,15 +13,7 @@ class WGSLLanguage(CStyleLanguage): barrier="workgroupBarrier();" generic_var_prefix = "var " external_local_bufs = True - code_for_op = { - UnaryOps.NEG: lambda x: f"(-{x})", - UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})", - UnaryOps.SIN: lambda x: f"sin({x})", UnaryOps.SQRT: lambda x: f"sqrt({x})", - BinaryOps.ADD: lambda x,y: f"({x}+{y})", BinaryOps.SUB: lambda x,y: f"({x}-{y})", BinaryOps.MUL: lambda x,y: f"({x}*{y})", - BinaryOps.DIV: lambda x,y: f"({x}/{y})", BinaryOps.MOD: lambda x,y: f"({x}%{y})", - BinaryOps.MAX: lambda x,y: f"max({x},{y})", BinaryOps.CMPLT: lambda x,y: f"f32({x}<{y})", - TernaryOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c: f"select({c},{b},{a}!=0.)" - } + code_for_op = { **CStyleLanguage().code_for_op, BinaryOps.CMPLT: lambda x,y: f"f32({x}<{y})", TernaryOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c: f"select({c},{b},{a}!=0.)" } def render_local(self, name: str, size: int): return f"var {name}: array;" @@ -59,4 +51,4 @@ class WGSLLanguage(CStyleLanguage): def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str: if buf_dtype != var_dtype: var_name = f"{type_map[buf_dtype]}({var_name})" - return f"{buf_name}[{idx}] = {var_name};" \ No newline at end of file + return f"{buf_name}[{idx}] = {var_name};"