mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
do not use sint_to_uop in renderer [pr] (#12601)
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import Literal, Callable, cast
|
||||
import os, math, sys
|
||||
from collections import defaultdict, Counter
|
||||
from tinygrad.codegen.opt import tc
|
||||
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, sint_to_uop, range_str
|
||||
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str
|
||||
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX, CPU_COUNT
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -112,7 +112,7 @@ class CStyleLanguage(Renderer):
|
||||
buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else
|
||||
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
|
||||
local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]
|
||||
launch_bounds = sint_to_uop(prod(local_dims)).vmax
|
||||
launch_bounds = prod([d.vmax for d in local_dims])
|
||||
prg = ''.join([f"{self.kernel_typedef.format(launch_bounds=launch_bounds)} {function_name}(",] +
|
||||
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
|
||||
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.codegen.opt import tc
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import AMDRenderer
|
||||
from tinygrad.uop.decompositions import xexp2, xlog2
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, sint_to_uop, range_str
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, range_str
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, truncate
|
||||
from tinygrad.helpers import prod, AMX
|
||||
|
||||
@@ -226,7 +226,7 @@ class AMDLLVMRenderer(LLVMRenderer):
|
||||
def _render_footer(self, uops: list[UOp]) -> str:
|
||||
# TODO: this is copied from cstyle
|
||||
local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]
|
||||
requiredMaxThreadsPerBlock = sint_to_uop(prod(local_dims)).vmax
|
||||
requiredMaxThreadsPerBlock = prod([d.vmax for d in local_dims])
|
||||
attributes = ["alwaysinline", "nounwind", '"no-builtins"',
|
||||
f'"amdgpu-flat-work-group-size"="1,{requiredMaxThreadsPerBlock}"', '"no-trapping-math"="true"']
|
||||
return 'attributes #0 = { ' + ' '.join(attributes) + ' }'
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import cast, Callable
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from tinygrad.codegen.opt import tc
|
||||
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, sint_to_uop
|
||||
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, AddrSpace
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
@@ -157,7 +157,7 @@ class PTXRenderer(Renderer):
|
||||
def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
|
||||
kernel = '\n'.join(map(fmt, [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]))
|
||||
local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]
|
||||
launch_bounds = sint_to_uop(prod(local_dims)).vmax
|
||||
launch_bounds = prod([d.vmax for d in local_dims])
|
||||
params = ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs])
|
||||
return f"{self.kernel_prefix.format(launch_bounds=launch_bounds)} {function_name} (\n\t{params}\n)\n.maxntid {launch_bounds}\n{{\n{kernel}\n}}"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user