mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
assert in lowerer for dtype support
lint
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
# the job of the lowerer is to do indexing
|
||||
import functools, itertools, operator
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
from typing import cast, Optional
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
|
||||
@@ -131,4 +132,17 @@ pm_lowerer = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
|
||||
])
|
||||
|
||||
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
|
||||
upcast_support = PatternMatcher([(UPat(Ops.CAST, name="x"), lambda ctx, x: is_dtype_supported(x.dtype, ctx))])
|
||||
def verify_upcast_support(uop: UOp, device: str, ast: Optional[UOp]=None):
|
||||
ast = uop if ast is None else ast
|
||||
if upcast_support.rewrite(uop, device) is False:
|
||||
print(f"Upcast is not supported on device {device} for dtype {uop.dtype}. ast:")
|
||||
print(ast)
|
||||
raise RuntimeError("Upcast not supported")
|
||||
for u in uop.src: verify_upcast_support(u, device, ast)
|
||||
|
||||
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
|
||||
lowered = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
|
||||
verify_upcast_support(lowered, opts.device)
|
||||
return lowered
|
||||
|
||||
|
||||
@@ -72,7 +72,6 @@ class WGSLRenderer(CStyleLanguage):
|
||||
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if dt.itemsize < 4 else x
|
||||
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if dt.itemsize < 4 else self.type_map[dt.base]
|
||||
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
|
||||
assert all(u.dtype is not dtypes.long for u in uops), "WEBGPU does not support int64"
|
||||
local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])]
|
||||
if not local_size: local_size = [1]
|
||||
bind_it = iter(range(len(bufs)))
|
||||
|
||||
Reference in New Issue
Block a user