assert in lowerer for dtype support

lint
This commit is contained in:
Mesozoic Egg
2025-01-07 07:49:20 +08:00
parent 49cd6bbfc1
commit 8e9b1b79bf
2 changed files with 16 additions and 3 deletions

View File

@@ -1,11 +1,12 @@
# the job of the lowerer is to do indexing
import functools, itertools, operator
from dataclasses import dataclass
from typing import cast
from typing import cast, Optional
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
from tinygrad.renderer import Renderer
from tinygrad.helpers import all_int, prod, partition, flatten
from tinygrad.device import is_dtype_supported
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
@@ -131,4 +132,17 @@ pm_lowerer = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
])
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
upcast_support = PatternMatcher([(UPat(Ops.CAST, name="x"), lambda ctx, x: is_dtype_supported(x.dtype, ctx))])
def verify_upcast_support(uop: UOp, device: str, ast: Optional[UOp]=None):
ast = uop if ast is None else ast
if upcast_support.rewrite(uop, device) is False:
print(f"Upcast is not supported on device {device} for dtype {uop.dtype}. ast:")
print(ast)
raise RuntimeError("Upcast not supported")
for u in uop.src: verify_upcast_support(u, device, ast)
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
lowered = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
verify_upcast_support(lowered, opts.device)
return lowered

View File

@@ -72,7 +72,6 @@ class WGSLRenderer(CStyleLanguage):
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if dt.itemsize < 4 else x
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if dt.itemsize < 4 else self.type_map[dt.base]
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
assert all(u.dtype is not dtypes.long for u in uops), "WEBGPU does not support int64"
local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])]
if not local_size: local_size = [1]
bind_it = iter(range(len(bufs)))