mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
cpu: support several threads in runtime (#12055)
This commit is contained in:
@@ -30,24 +30,33 @@ class ClangJITCompiler(Compiler):
|
||||
class CPUWorker(threading.Thread):
|
||||
def __init__(self, dev, tasks, thread_id):
|
||||
super().__init__()
|
||||
self.dev, self.tasks, self.thread_id, self.daemon = dev, tasks, thread_id, True
|
||||
self.dev, self.tasks, self.thread_id, self.pool, self.daemon = dev, tasks, thread_id, [], True
|
||||
|
||||
def push_task(self, tid, cmd, args):
|
||||
if len(self.pool) <= tid:
|
||||
self.pool.append(queue.Queue())
|
||||
CPUWorker(self, self.pool[tid], thread_id=tid+1).start()
|
||||
self.pool[tid].put([cmd, 1, len(args)] + args)
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
cmd_iter = iter(self.tasks.get())
|
||||
for cmd in cmd_iter:
|
||||
args_cnt = next(cmd_iter)
|
||||
cmd(*[next(cmd_iter) for _ in range(args_cnt)])
|
||||
threads, args_cnt = next(cmd_iter), next(cmd_iter)
|
||||
args = [next(cmd_iter) for _ in range(args_cnt)]
|
||||
for th in range(threads - 1): self.push_task(th, cmd, args)
|
||||
cmd(self.thread_id, *args)
|
||||
for th in range(threads - 1): self.pool[th].join()
|
||||
self.tasks.task_done()
|
||||
|
||||
class CPUComputeQueue(HWQueue):
|
||||
def _exec(self, prg, bufs, *args):
|
||||
def _exec(self, tid, prg, bufs, *args):
|
||||
prg.fxn(*map(ctypes.c_uint64, args[:bufs]), *map(ctypes.c_int64 if platform.machine() == "arm64" else ctypes.c_int32, args[bufs:]))
|
||||
def _signal(self, signal_addr, value): to_mv(signal_addr, 4).cast('I')[0] = value
|
||||
def _wait(self, signal_addr, value): wait_cond(lambda: to_mv(signal_addr, 4).cast('I')[0] >= value, timeout_ms=60000)
|
||||
def _timestamp(self, timestamp_addr): to_mv(timestamp_addr, 8).cast('Q')[0] = time.perf_counter_ns()
|
||||
def cmd(self, cmd, *args):
|
||||
self.q(cmd, len(args), *args)
|
||||
def _signal(self, tid, signal_addr, value): to_mv(signal_addr, 4).cast('I')[0] = value
|
||||
def _wait(self, tid, signal_addr, value): wait_cond(lambda: to_mv(signal_addr, 4).cast('I')[0] >= value, timeout_ms=60000)
|
||||
def _timestamp(self, tid, timestamp_addr): to_mv(timestamp_addr, 8).cast('Q')[0] = time.perf_counter_ns()
|
||||
def cmd(self, cmd, *args, threads=1):
|
||||
self.q(cmd, threads, len(args), *args)
|
||||
return self
|
||||
|
||||
def memory_barrier(self): return self
|
||||
|
||||
Reference in New Issue
Block a user