mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
Remote beam (#10357)
* Use renderer properties instead of `.device` * Remote beam
This commit is contained in:
@@ -18,7 +18,7 @@ base_rewrite = PatternMatcher([
|
||||
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = 0; {ctx[x]} < {ctx[x.src[0]]}; {ctx[x]}++) {{"),
|
||||
(UPat(Ops.VECTORIZE, name="x"),
|
||||
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
|
||||
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device in {'CPU', 'DSP'} else f"({','.join([ctx[y] for y in x.src])})")),
|
||||
f"{ctx.float4_style[0]}{','.join([ctx[y] for y in x.src])}{ctx.float4_style[1]}"),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x:
|
||||
f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
|
||||
@@ -53,8 +53,7 @@ base_rewrite = PatternMatcher([
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
|
||||
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR} else ctx[v] for v in x.src]), x.dtype)),
|
||||
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
|
||||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device in {'CPU', 'DSP'} else \
|
||||
f".{'xyzwabcd'[x.arg[0]]}")),
|
||||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > ctx.gep_arr_threshold else f".{'xyzwabcd'[x.arg[0]]}")),
|
||||
# custom passes through with format
|
||||
(UPat((Ops.CUSTOM, Ops.CUSTOMI), name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
|
||||
])
|
||||
@@ -87,6 +86,8 @@ class CStyleLanguage(Renderer):
|
||||
code_for_workitem: dict[Literal["g", "l", "i"], Callable] = {}
|
||||
extra_args: list[str] = []
|
||||
float4: str|None = None
|
||||
float4_style: tuple[str, str] = ('(', ')')
|
||||
gep_arr_threshold: int = 4
|
||||
type_map: dict[DType, str] = {}
|
||||
infinity: str = "INFINITY"
|
||||
nan: str = "NAN"
|
||||
@@ -179,6 +180,8 @@ class CStyleLanguage(Renderer):
|
||||
class ClangRenderer(CStyleLanguage):
|
||||
device = "CPU"
|
||||
float4 = "(float4)"
|
||||
float4_style = ('{', '}')
|
||||
gep_arr_threshold = 0
|
||||
has_local = False
|
||||
global_max = None
|
||||
infinity = "__builtin_inff()"
|
||||
@@ -347,6 +350,7 @@ class CUDARenderer(CStyleLanguage):
|
||||
smem_prefix_for_cast = False
|
||||
barrier = "__syncthreads();"
|
||||
float4 = "make_float4"
|
||||
gep_arr_threshold = 8
|
||||
code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+int(x))}", "l": lambda x: f"threadIdx.{chr(120+int(x))}",
|
||||
"i": lambda x: f"(blockIdx.{chr(120+int(x))}*blockDim.{chr(120+int(x))}+threadIdx.{chr(120+int(x))})"}
|
||||
code_for_op = { **CStyleLanguage.code_for_op,
|
||||
|
||||
@@ -313,9 +313,11 @@ class RemoteDevice(Compiled):
|
||||
if not renderer[0].startswith("tinygrad.renderer.") or not renderer[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {renderer}")
|
||||
renderer_class = fromimport(renderer[0], renderer[1]) # TODO: is this secure?
|
||||
if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {renderer}")
|
||||
renderer_instance = renderer_class(*renderer[2])
|
||||
renderer_instance.device = device
|
||||
graph_supported, graph_multi = self.properties.graph_supported, self.properties.graph_supports_multi
|
||||
graph = fromimport('tinygrad.runtime.graph.remote', f"Remote{'Multi' if graph_multi else ''}Graph") if graph_supported else None
|
||||
super().__init__(device, RemoteAllocator(self), renderer_class(*renderer[2]), Compiler(), functools.partial(RemoteProgram, self), graph)
|
||||
super().__init__(device, RemoteAllocator(self), renderer_instance, Compiler(), functools.partial(RemoteProgram, self), graph)
|
||||
|
||||
def finalize(self):
|
||||
with contextlib.suppress(ConnectionError, http.client.HTTPException): self.q(SessionFree(), wait=True)
|
||||
|
||||
Reference in New Issue
Block a user