mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
update variable names around jit [pr] (#14049)
lbs, st_vars_dtype_device and rawbuffers no more
This commit is contained in:
@@ -7,7 +7,7 @@ if __name__ == "__main__":
|
||||
with open(fetch(sys.argv[1]), "rb") as f:
|
||||
run_onnx_jit = pickle.load(f)
|
||||
input_name = run_onnx_jit.captured.expected_names[0]
|
||||
device = run_onnx_jit.captured.expected_st_vars_dtype_device[0][-1]
|
||||
device = run_onnx_jit.captured.expected_input_info[0][-1]
|
||||
print(f"input goes into {input_name=} on {device=}")
|
||||
hit = 0
|
||||
for i,(img,y) in enumerate(imagenet_dataloader(cnt=getenv("CNT", 100))):
|
||||
|
||||
@@ -24,7 +24,7 @@ def _check_no_non_tensor_return(ret):
|
||||
|
||||
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
|
||||
|
||||
def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int], max_batch_size=0) -> list[ExecItem]:
|
||||
def apply_graph_to_jit(jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int], max_batch_size=0) -> list[ExecItem]:
|
||||
# Split JIT cache into batches for faster graph execution.
|
||||
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
||||
graphed_jit_cache: list[ExecItem] = []
|
||||
@@ -36,10 +36,10 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer]
|
||||
try:
|
||||
if len(current_batch_devs) == 0: raise GraphException("no device for graph")
|
||||
if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): raise GraphException("only one kernel doesn't graph")
|
||||
graph_runner = current_batch_devs[0].graph(current_batch, input_rawbuffers, var_vals)
|
||||
graph_runner = current_batch_devs[0].graph(current_batch, input_buffers, var_vals)
|
||||
# clear jit inputs to allow their memory to be freed/reused
|
||||
for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
|
||||
graphed_jit_cache.append(ExecItem(UOp(Ops.NOOP), cast(list[Buffer|None], input_rawbuffers), prg=graph_runner))
|
||||
graphed_jit_cache.append(ExecItem(UOp(Ops.NOOP), cast(list[Buffer|None], input_buffers), prg=graph_runner))
|
||||
max_batch_size *= 2
|
||||
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_batch_devs[0]}")
|
||||
except GraphException as e:
|
||||
@@ -72,18 +72,18 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer]
|
||||
if len(current_batch) > 0: flush_batch()
|
||||
return graphed_jit_cache
|
||||
|
||||
def get_input_replace(jit_cache: list[ExecItem], input_rawbuffers:list[Buffer]) -> dict[tuple[int, int], int]:
|
||||
def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer]) -> dict[tuple[int, int], int]:
|
||||
input_replace: dict[tuple[int, int], int] = {}
|
||||
for j,ji in enumerate(jit_cache):
|
||||
for i,a in enumerate(ji.bufs):
|
||||
if a in input_rawbuffers:
|
||||
input_replace[(j,i)] = input_rawbuffers.index(a)
|
||||
if a in input_buffers:
|
||||
input_replace[(j,i)] = input_buffers.index(a)
|
||||
return input_replace
|
||||
|
||||
class GraphRunner(Runner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]):
|
||||
self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
|
||||
self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_buffers)
|
||||
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
|
||||
self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {}
|
||||
self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {}
|
||||
@@ -125,19 +125,19 @@ class GraphRunner(Runner):
|
||||
for j, (gl, lc) in self.launch_dims_replace.items():
|
||||
yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1])
|
||||
|
||||
def _access_resources(self, rawbufs:list[Buffer], write:list[int], new_dependency:Any):
|
||||
def _access_resources(self, bufs:list[Buffer], write:list[int], new_dependency:Any):
|
||||
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
|
||||
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
|
||||
wait_nodes = []
|
||||
|
||||
for i,rawbuf in enumerate(rawbufs):
|
||||
if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)])
|
||||
for i,buf in enumerate(bufs):
|
||||
if id(buf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(buf.base._buf)])
|
||||
if i in write:
|
||||
if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf)))
|
||||
if id(buf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(buf.base._buf)))
|
||||
|
||||
for i,rawbuf in enumerate(rawbufs):
|
||||
if i in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
|
||||
else: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency)
|
||||
for i,buf in enumerate(bufs):
|
||||
if i in write: self.w_dependency_map[id(buf.base._buf)] = new_dependency
|
||||
else: self.r_dependency_map[id(buf.base._buf)].append(new_dependency)
|
||||
|
||||
return list({id(x):x for x in wait_nodes}.values())
|
||||
|
||||
@@ -168,12 +168,11 @@ class CapturedJit(Generic[ReturnType]):
|
||||
input_replace: dict[tuple[int, int], int]
|
||||
extra_view_inputs: list[tuple[int, int, str, int, DType]]
|
||||
expected_names: list[int|str]
|
||||
expected_st_vars_dtype_device: list[tuple[UOp, tuple[Variable, ...], DType, str]]
|
||||
expected_input_info: list[tuple[UOp, tuple[Variable, ...], DType, str]] # (view, variables, dtype, device) per input
|
||||
|
||||
def __reduce__(self):
|
||||
# TODO: free_intermediates here? replan_buffers_memory_layout here?
|
||||
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
|
||||
self.expected_names, self.expected_st_vars_dtype_device)
|
||||
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_input_info)
|
||||
|
||||
def __post_init__(self):
|
||||
self._jit_cache: list[ExecItem] = self.jit_cache
|
||||
@@ -233,17 +232,17 @@ def _prepare_jit_inputs(args, kwargs):
|
||||
it = x if isinstance(x, (tuple,list)) else x.values() if isinstance(x, dict) else []
|
||||
tensors += [t for t in it if t.__class__ is Tensor and not any(t is y for y in tensors)]
|
||||
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
|
||||
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
|
||||
if any(lb.base.op is Ops.CONST for lb in lbs):
|
||||
input_uops: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
|
||||
if any(u.base.op is Ops.CONST for u in input_uops):
|
||||
raise JitError("JIT inputs cannot be const, create a buffer with .contiguous()")
|
||||
input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb]
|
||||
for lb in lbs if lb.base.realized is not None])
|
||||
input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b:=u.base.realized, MultiBuffer) else [b]
|
||||
for u in input_uops if u.base.realized is not None])
|
||||
if len(set(input_buffers)) != len(input_buffers): raise JitError("duplicate inputs to JIT")
|
||||
st_varval_dtype_device = [(*(lb.substitute({lb.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), lb.dtype, lb.device) for lb in lbs]
|
||||
_var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
|
||||
inputs = [(*(u.substitute({u.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), u.dtype, u.device) for u in input_uops]
|
||||
_var_vals = merge_dicts([x[1] for x in inputs] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
|
||||
var_vals = {k.expr:v for k,v in _var_vals.items()}
|
||||
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device]
|
||||
return input_buffers, var_vals, names, st_vars_dtype_device
|
||||
expected_input_info = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in inputs]
|
||||
return input_buffers, var_vals, names, expected_input_info
|
||||
|
||||
class TinyJit(Generic[ReturnType]):
|
||||
def __init__(self, fxn:Callable[..., ReturnType]|None, captured:CapturedJit|None=None, prune=False, optimize=False):
|
||||
@@ -284,7 +283,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
|
||||
|
||||
def __call__(self, *args, **kwargs) -> ReturnType:
|
||||
input_buffers, var_vals, names, st_vars_dtype_device = _prepare_jit_inputs(args, kwargs)
|
||||
input_buffers, var_vals, names, expected_input_info = _prepare_jit_inputs(args, kwargs)
|
||||
if not JIT or self.cnt == 0:
|
||||
# jit ignore
|
||||
assert self.fxn is not None
|
||||
@@ -342,14 +341,14 @@ class TinyJit(Generic[ReturnType]):
|
||||
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
|
||||
|
||||
# set this for next run
|
||||
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
|
||||
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, expected_input_info)
|
||||
if self.optimize: self.captured.replan_buffers_memory_layout()
|
||||
elif self.cnt >= 2:
|
||||
# jit exec
|
||||
assert self.captured is not None
|
||||
if self.captured.expected_names != names: raise JitError(f"args mismatch in JIT: {self.captured.expected_names=} != {names}")
|
||||
if self.captured.expected_st_vars_dtype_device != st_vars_dtype_device:
|
||||
raise JitError(f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}")
|
||||
if self.captured.expected_input_info != expected_input_info:
|
||||
raise JitError(f"args mismatch in JIT: {self.captured.expected_input_info=} != {expected_input_info=}")
|
||||
ret = self.captured(input_buffers, var_vals)
|
||||
|
||||
self.cnt += 1
|
||||
|
||||
@@ -8,13 +8,13 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
||||
from tinygrad.engine.jit import MultiGraphRunner, GraphException
|
||||
|
||||
class CUDAGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, input_buffers, var_vals)
|
||||
|
||||
# Check all jit items are compatible.
|
||||
if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException
|
||||
|
||||
self.jc_idx_with_updatable_rawbufs = dedup([x[0] for x in self.input_replace.keys()])
|
||||
self.jc_idx_with_updatable_bufs = dedup([x[0] for x in self.input_replace.keys()])
|
||||
self.updatable_nodes: dict[int, tuple[Any, Any, Any, bool]] = {} # dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
|
||||
|
||||
self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
|
||||
@@ -31,29 +31,29 @@ class CUDAGraph(MultiGraphRunner):
|
||||
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs)
|
||||
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
|
||||
|
||||
if j in self.launch_dims_replace or j in self.var_vals_replace or j in self.jc_idx_with_updatable_rawbufs:
|
||||
if j in self.launch_dims_replace or j in self.var_vals_replace or j in self.jc_idx_with_updatable_bufs:
|
||||
self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
src_dev = cast(CUDADevice, Device[src.device])
|
||||
node_from = cuda.CUgraphNode()
|
||||
deps = self._access_resources(rawbufs=[dest.base, src.base], write=[0], new_dependency=node_from)
|
||||
deps = self._access_resources(bufs=[dest.base, src.base], write=[0], new_dependency=node_from)
|
||||
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
||||
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
|
||||
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
|
||||
WidthInBytes=dest.nbytes, Height=1, Depth=1)
|
||||
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
|
||||
if j in self.jc_idx_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
|
||||
if j in self.jc_idx_with_updatable_bufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
|
||||
|
||||
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
|
||||
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||
# Update rawbuffers in the c_args struct.
|
||||
def __call__(self, input_buffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||
# Update buffers in the c_args struct.
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
|
||||
if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_buffers[input_idx]._buf)
|
||||
else:
|
||||
if i == 0: self.updatable_nodes[j][1].destDevice = input_rawbuffers[input_idx]._buf
|
||||
elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf
|
||||
if i == 0: self.updatable_nodes[j][1].destDevice = input_buffers[input_idx]._buf
|
||||
elif i == 1: self.updatable_nodes[j][1].srcDevice = input_buffers[input_idx]._buf
|
||||
|
||||
# Update var_vals in the c_args struct.
|
||||
for j, i, v in self.updated_vars(var_vals): setattr(self.updatable_nodes[j][2], f'v{i}', v)
|
||||
|
||||
@@ -9,8 +9,8 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, Buffer
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
|
||||
class HCQGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, input_buffers, var_vals)
|
||||
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
|
||||
|
||||
# CPU Device is always last
|
||||
@@ -189,10 +189,10 @@ class HCQGraph(MultiGraphRunner):
|
||||
|
||||
def _dev_copy_queues(self, dev): return [q for (d, _), q in self.copy_queues.items() if d == dev]
|
||||
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||
# Map input rawbuffers
|
||||
def __call__(self, input_buffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||
# Map input buffers
|
||||
for dev in self.devices:
|
||||
for idx_to_map in self.input_replace_map[dev]: cast(HCQAllocator, dev.allocator).map(input_rawbuffers[idx_to_map]._buf)
|
||||
for idx_to_map in self.input_replace_map[dev]: cast(HCQAllocator, dev.allocator).map(input_buffers[idx_to_map]._buf)
|
||||
|
||||
# Wait and restore signals
|
||||
self.kickoff_value += 1
|
||||
@@ -203,9 +203,9 @@ class HCQGraph(MultiGraphRunner):
|
||||
**{var.expr: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
|
||||
**{sig.base_buf.va_addr.expr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}}
|
||||
|
||||
# Update rawbuffers
|
||||
# Update buffers
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
hcq_var_vals[self.input_replace_to_var[(j,i)].expr] = input_rawbuffers[input_idx]._buf.va_addr
|
||||
hcq_var_vals[self.input_replace_to_var[(j,i)].expr] = input_buffers[input_idx]._buf.va_addr
|
||||
|
||||
for dev in self.devices:
|
||||
self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {}))
|
||||
|
||||
@@ -10,8 +10,8 @@ from tinygrad.runtime.autogen import metal
|
||||
from tinygrad.runtime.support import objc
|
||||
|
||||
class MetalGraph(GraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, input_buffers, var_vals)
|
||||
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
||||
|
||||
# create metal batch exec
|
||||
@@ -39,7 +39,7 @@ class MetalGraph(GraphRunner):
|
||||
all_pipelines.append(prg._prg.pipeline_state)
|
||||
icb_command.setComputePipelineState(prg._prg.pipeline_state)
|
||||
for i,b in enumerate(ji.bufs):
|
||||
if b is not None and b not in input_rawbuffers:
|
||||
if b is not None and b not in input_buffers:
|
||||
icb_command.setKernelBuffer_offset_atIndex(b._buf.buf, b._buf.offset, i)
|
||||
all_resources.append(b._buf.buf)
|
||||
for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i)
|
||||
@@ -55,15 +55,15 @@ class MetalGraph(GraphRunner):
|
||||
for var in self.fixedvars: self.int_buf_view[self.varlist.index(var)] = self.fixedvars[var]
|
||||
self.range = metal.NSRange(0, len(jit_cache))
|
||||
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||
def __call__(self, input_buffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
||||
# NOTE: old command buffer may not be inflight anymore
|
||||
if self.command_buffer is not None and PROFILE: self.collect_timestamps()
|
||||
|
||||
all_resources = dedup(self.all_resources + [input_rawbuffers[input_idx]._buf.buf for input_idx in self.input_replace.values()])
|
||||
all_resources = dedup(self.all_resources + [input_buffers[input_idx]._buf.buf for input_idx in self.input_replace.values()])
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
computeCommand = self.icb.indirectComputeCommandAtIndex(j)
|
||||
computeCommand.setKernelBuffer_offset_atIndex(input_rawbuffers[input_idx]._buf.buf, input_rawbuffers[input_idx]._buf.offset, i)
|
||||
computeCommand.setKernelBuffer_offset_atIndex(input_buffers[input_idx]._buf.buf, input_buffers[input_idx]._buf.offset, i)
|
||||
|
||||
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
||||
computeCommand = self.icb.indirectComputeCommandAtIndex(j)
|
||||
|
||||
@@ -28,7 +28,7 @@ class NullAllocator(Allocator['NullDevice']):
|
||||
def _offset(self, buf, offset:int, size:int): pass
|
||||
|
||||
class NullGraph(MultiGraphRunner):
|
||||
def __call__(self, input_rawbuffers, var_vals, wait=False) -> float|None: return 1e-1
|
||||
def __call__(self, input_buffers, var_vals, wait=False) -> float|None: return 1e-1
|
||||
|
||||
class NullDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
|
||||
Reference in New Issue
Block a user