Remote beam (#10357)

* Use renderer properties instead of `.device`

* Remote beam
This commit is contained in:
uuuvn
2025-05-17 06:59:22 +05:00
committed by GitHub
parent 7cc35a031b
commit 64409a8bda
2 changed files with 10 additions and 4 deletions

View File

@@ -18,7 +18,7 @@ base_rewrite = PatternMatcher([
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = 0; {ctx[x]} < {ctx[x.src[0]]}; {ctx[x]}++) {{"),
(UPat(Ops.VECTORIZE, name="x"),
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device in {'CPU', 'DSP'} else f"({','.join([ctx[y] for y in x.src])})")),
f"{ctx.float4_style[0]}{','.join([ctx[y] for y in x.src])}{ctx.float4_style[1]}"),
(UPat(Ops.CAST, name="x"), lambda ctx,x:
f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
@@ -53,8 +53,7 @@ base_rewrite = PatternMatcher([
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR} else ctx[v] for v in x.src]), x.dtype)),
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device in {'CPU', 'DSP'} else \
f".{'xyzwabcd'[x.arg[0]]}")),
(f"[{x.arg[0]}]" if x.src[0].dtype.count > ctx.gep_arr_threshold else f".{'xyzwabcd'[x.arg[0]]}")),
# custom passes through with format
(UPat((Ops.CUSTOM, Ops.CUSTOMI), name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
])
@@ -87,6 +86,8 @@ class CStyleLanguage(Renderer):
code_for_workitem: dict[Literal["g", "l", "i"], Callable] = {}
extra_args: list[str] = []
float4: str|None = None
float4_style: tuple[str, str] = ('(', ')')
gep_arr_threshold: int = 4
type_map: dict[DType, str] = {}
infinity: str = "INFINITY"
nan: str = "NAN"
@@ -179,6 +180,8 @@ class CStyleLanguage(Renderer):
class ClangRenderer(CStyleLanguage):
device = "CPU"
float4 = "(float4)"
float4_style = ('{', '}')
gep_arr_threshold = 0
has_local = False
global_max = None
infinity = "__builtin_inff()"
@@ -347,6 +350,7 @@ class CUDARenderer(CStyleLanguage):
smem_prefix_for_cast = False
barrier = "__syncthreads();"
float4 = "make_float4"
gep_arr_threshold = 8
code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+int(x))}", "l": lambda x: f"threadIdx.{chr(120+int(x))}",
"i": lambda x: f"(blockIdx.{chr(120+int(x))}*blockDim.{chr(120+int(x))}+threadIdx.{chr(120+int(x))})"}
code_for_op = { **CStyleLanguage.code_for_op,

View File

@@ -313,9 +313,11 @@ class RemoteDevice(Compiled):
if not renderer[0].startswith("tinygrad.renderer.") or not renderer[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {renderer}")
renderer_class = fromimport(renderer[0], renderer[1]) # TODO: is this secure?
if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {renderer}")
renderer_instance = renderer_class(*renderer[2])
renderer_instance.device = device
graph_supported, graph_multi = self.properties.graph_supported, self.properties.graph_supports_multi
graph = fromimport('tinygrad.runtime.graph.remote', f"Remote{'Multi' if graph_multi else ''}Graph") if graph_supported else None
super().__init__(device, RemoteAllocator(self), renderer_class(*renderer[2]), Compiler(), functools.partial(RemoteProgram, self), graph)
super().__init__(device, RemoteAllocator(self), renderer_instance, Compiler(), functools.partial(RemoteProgram, self), graph)
def finalize(self):
with contextlib.suppress(ConnectionError, http.client.HTTPException): self.q(SessionFree(), wait=True)