cpu threading (#11951)

* start cpu threading

* fix

* fix2

* fix

* hacks?

* threads

* minor

* no dsp

* dsp 2

* n

* more

* test

* xm

* cleaner

* readable

* f

* reorder

* when no threads

* rangeify

* typos

* not needed

* reapply

* remoev this

* linter

* fixed cpu count in ci

* fix

* fixes

* rm

* typo

* sort based on speed

* test if test works in ci

* Revert "test if test works in ci"

This reverts commit 1f05edb531.

* do not pad thread
This commit is contained in:
nimlgen
2025-09-06 16:13:43 +03:00
committed by GitHub
parent 2b1844da27
commit 10ac427aaa
15 changed files with 63 additions and 18 deletions

View File

@@ -730,7 +730,7 @@ jobs:
opencl: ${{ matrix.backend == 'gpu' && 'true' }}
llvm: ${{ matrix.backend == 'llvm' && 'true' }}
- name: Set env
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1' || matrix.backend == 'gpu' && 'GPU=1' }}" >> $GITHUB_ENV
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_COUNT=2' || matrix.backend == 'gpu' && 'GPU=1' }}" >> $GITHUB_ENV
- name: Check Device.DEFAULT and print some source
run: |
python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['LLVM','CPU','GPU'], Device.DEFAULT"
@@ -954,7 +954,7 @@ jobs:
pydeps: "capstone"
llvm: ${{ matrix.backend == 'llvm' && 'true' }}
- name: Set env
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1' || matrix.backend == 'metal' && 'METAL=1'}}" >> $GITHUB_ENV
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_COUNT=2' || matrix.backend == 'metal' && 'METAL=1'}}" >> $GITHUB_ENV
- name: Check Device.DEFAULT and print some source
run: |
python -c "from tinygrad import Device; assert Device.DEFAULT == '${{ matrix.backend }}'.upper(), Device.DEFAULT"
@@ -991,7 +991,7 @@ jobs:
pydeps: ${{ matrix.backend == 'webgpu' && 'dawn-python' || '' }}
- name: Set env
shell: bash
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1' || matrix.backend == 'webgpu' && 'WEBGPU=1'}}" >> $GITHUB_ENV
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_COUNT=2' || matrix.backend == 'webgpu' && 'WEBGPU=1'}}" >> $GITHUB_ENV
- name: Run unit tests
if: matrix.backend=='llvm'
# test_newton_schulz hits RecursionError

View File

@@ -67,11 +67,12 @@ def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,in
forward_args = ",".join(f"{dtype}{'*' if name not in symbolic_vars.values() else ''} {name}" for name,dtype,_ in (outputs+inputs if wasm else inputs+outputs))
if not wasm:
thread_id = 0 # NOTE: export does not support threading, thread_id is always 0
for name,cl in bufs_to_save.items():
weight = ''.join(["\\x%02X"%x for x in bytes(to_mv(cl._buf.va_addr, cl._buf.size))])
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
cprog += [f"{dtype_map[dtype]} {name}[{len}];" if name not in bufs_to_save else f"{dtype_map[dtype]} *{name} = ({dtype_map[dtype]} *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in input_names+output_names]
cprog += [f"void net({forward_args}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
cprog += [f"void net({forward_args}) {{"] + [f"{name}({', '.join(args)}, {thread_id});" for (name, args, _global_size, _local_size) in statements] + ["}"]
return '\n'.join(headers + cprog)
else:
if bufs_to_save:
@@ -239,7 +240,9 @@ export default {model_name};
def export_model(model, target:str, *inputs, model_name: Optional[str] = "model", stream_weights=False):
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, f"only {', '.join(EXPORT_SUPPORTED_DEVICE)} are supported"
with Context(JIT=2): run,special_names = jit_model(model, *inputs)
# NOTE: CPU_COUNT=1, since export does not support threading
with Context(JIT=2, CPU_COUNT=1): run,special_names = jit_model(model, *inputs)
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
state = get_state_dict(model)
weight_names = {id(x.uop.base.realized): name for name, x in state.items()}

View File

@@ -336,5 +336,19 @@ class TestKernelOpts(unittest.TestCase):
[Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UNROLL, axis=1, arg=4)], # noqa: E501
])
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_threads, "test requires threads")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.global_max is not None and
Device[Device.DEFAULT].renderer.global_max[0] > 1, "test requires multicore")
def test_thread_opts(self):
a = Tensor.rand(4, 4, 4, 4)
b = Tensor.rand(4, 4, 4)
r = (b.sqrt() + ((a+1).sum(axis=3).exp()))
helper_linearizer_opt(r, [
[Opt(OptOps.THREAD, 0, 2)],
[Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.THREAD, 0, 2)],
[Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.THREAD, 0, 2), Opt(OptOps.UNROLL, 0, 2)],
] + [[Opt(OptOps.THREAD, 0, 4)] if Device[Device.DEFAULT].renderer.global_max[0] >= 4 else []]
+ [[Opt(OptOps.THREAD, 0, 8)] if Device[Device.DEFAULT].renderer.global_max[0] >= 8 else []])
if __name__ == '__main__':
unittest.main()

View File

@@ -59,7 +59,7 @@ def add_gpudims(ctx:Renderer, s:UOp):
all_ranges = {x.arg[0:-1]:x for x in s_topo if x.op is Ops.RANGE}
# extract global/local dims
global_dims = sorted(dedup([x.arg[0:-1] for x in all_ranges.values() if x.arg[-1] is AxisType.GLOBAL]))
global_dims = sorted(dedup([x.arg[0:-1] for x in all_ranges.values() if x.arg[-1] in (AxisType.GLOBAL, AxisType.THREAD)]))
local_dims = sorted(dedup([x.arg[0:-1] for x in all_ranges.values() if x.arg[-1] in (AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE)]))
if not global_dims and not local_dims: return None

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass
from tinygrad.uop.ops import AxisType
class OptOps(Enum):
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto(); THREAD = auto() # noqa: E702
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
def __lt__(self, x:OptOps): return self.value < x.value
@@ -16,10 +16,10 @@ class Opt:
arg: int|tuple|None = None
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
axis_letters = {AxisType.GLOBAL: "g", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
axis_colors = {AxisType.GLOBAL: "blue", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE", AxisType.UPCAST: "yellow",
AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
class KernelOptError(Exception): pass
def check(cond:bool, msg:str=""):

View File

@@ -171,4 +171,16 @@ def hand_coded_optimizations(k:Scheduler) -> list[Opt]:
k.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
if will_delete_shape: deleted_shape += 1
# **** threading ****
if k.opts.has_threads and k.opts.global_max is not None:
for threads in [32,16,12,8,6,5,4,3,2]:
# Skip is too many threads. Heuristic: use about 128K ops per thread
if threads > k.opts.global_max[0] or resolve(prod(k.full_shape) // (128 << 10) < threads): continue
for axis in k.axes_of(AxisType.LOOP):
if k.full_shape[axis] % threads == 0:
k.apply_opt(Opt(OptOps.THREAD, axis, threads))
break
if k.applied_opts and k.applied_opts[-1].op is OptOps.THREAD: break
return k.applied_opts

View File

@@ -12,7 +12,7 @@ from tinygrad.renderer import Renderer
from tinygrad.schedule.rangeify import remove_tags
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
axis_to_pos = {AxisType.LOOP: -1, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
class Scheduler:
@@ -127,7 +127,7 @@ class Scheduler:
opt_to_at = {
OptOps.LOCAL: AxisType.LOCAL, OptOps.UPCAST: AxisType.UPCAST,
OptOps.UNROLL: AxisType.UNROLL, OptOps.GROUP: AxisType.GROUP_REDUCE,
OptOps.GROUPTOP: AxisType.GROUP_REDUCE}
OptOps.GROUPTOP: AxisType.GROUP_REDUCE, OptOps.THREAD: AxisType.THREAD}
ret = None
if opt.op in opt_to_at:
@@ -149,11 +149,16 @@ class Scheduler:
if opt.op is OptOps.LOCAL:
check(not self.dont_use_locals, "can't use locals")
check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOOP}, "local is for globals")
if opt.op is OptOps.THREAD:
check(self.opts is not None and self.opts.has_threads, "target does not support threads")
check(self.opts is not None and self.opts.global_max is not None and amt <= self.opts.global_max[0], "too many threads")
check(all(x is not AxisType.THREAD for x in self.axis_types), "already threaded")
check(rng in self._globalizable_rngs(), "can't apply range to this dim")
if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}:
check(all(x.op is not OptOps.TC for x in self.applied_opts), "no grouping with tensor cores") # TODO: why is this wrong?
check(not self.dont_use_locals, "can't use locals")
check(rng.arg[-1] == AxisType.REDUCE, "group is for reduce")
ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op==OptOps.GROUPTOP)
ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op in {OptOps.GROUPTOP, OptOps.THREAD})
elif opt.op is OptOps.TC:
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps
check(opt.axis is not None, "tensor core opts must have an axis")
@@ -166,6 +171,7 @@ class Scheduler:
elif opt.op is OptOps.PADTO:
check(rng.src[0].op is Ops.CONST, "only pad const axes")
check(rng.arg[-1] not in {AxisType.UPCAST, AxisType.UNROLL}, "cannot pad upcasted") # TODO: why is this wrong?
check(rng.arg[-1] is not AxisType.THREAD, "cannot pad thread")
# ok to pad SUM if all parent ALU ops have f(0) = 0
if (r:=self.reduceop) is not None and rng.arg[-1] in (AxisType.GROUP_REDUCE, AxisType.REDUCE):
check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}")

View File

@@ -22,6 +22,7 @@ actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0, getenv("TC", 1)))]
# covers resnet kernels (3 global * 3 reduce)
actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2), getenv("TC", 1))) for axis in range(9)]
actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)]
actions += [Opt(op=OptOps.THREAD, axis=axis, arg=amt) for amt in [2,3,4,5,8,12,16,24,32,64] for axis in range(3)]
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
def get_test_global_size(global_size, max_global_size, var_vals):

View File

@@ -47,7 +47,8 @@ def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None)
src = renderer.render(uops)
return ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, renderer.device, ast, uops,
global_size=[1,1,1] if renderer.has_local else None, local_size=[1,1,1] if renderer.has_local else None)
global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None,
local_size=[1,1,1] if renderer.has_local else None)
# **************** Runners ****************

View File

@@ -142,6 +142,7 @@ CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), Co
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, AMD_LLVM = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0), ContextVar("AMD_LLVM", 1)
RANGEIFY, FUSE_ATTENTION = ContextVar("RANGEIFY", 0), ContextVar("FUSE_ATTENTION", 0)
EMULATE = ContextVar("EMULATE", "")
CPU_COUNT = ContextVar("CPU_COUNT", max(1, (os.cpu_count() or 1) // 4))
@dataclass(frozen=True)
class Metadata:

View File

@@ -112,6 +112,7 @@ class Renderer:
# TODO: make this generic with a list of supported types
supports_float4: bool = True
has_local: bool = True
has_threads: bool = False
has_shared: bool = True
# NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
global_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now

View File

@@ -3,7 +3,7 @@ import os, math, sys
from collections import defaultdict, Counter
from tinygrad.codegen.opt import tc
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, sint_to_uop
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX, CPU_COUNT
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate
from tinygrad.renderer import Renderer
from tinygrad.codegen.late.devectorizer import no_vectorized_alu
@@ -192,9 +192,12 @@ class ClangRenderer(CStyleLanguage):
float4_style = ('{', '}')
gep_arr_threshold = 0
has_local = False
global_max = None
has_threads = True
global_max = (CPU_COUNT.value, 0, 0)
infinity = "__builtin_inff()"
nan = '__builtin_nanf("")'
code_for_workitem = {"g": lambda _: "core_id"}
extra_args = ['int core_id']
if AMX: tensor_cores = tc.amx
# language options

View File

@@ -51,7 +51,7 @@ class CPUWorker(threading.Thread):
class CPUComputeQueue(HWQueue):
def _exec(self, tid, prg, bufs, *args):
prg.fxn(*map(ctypes.c_uint64, args[:bufs]), *map(ctypes.c_int64 if platform.machine() == "arm64" else ctypes.c_int32, args[bufs:]))
prg.fxn(*map(ctypes.c_uint64, args[:bufs]), *map(ctypes.c_int64 if platform.machine() == "arm64" else ctypes.c_int32, args[bufs:]), tid)
def _signal(self, tid, signal_addr, value): to_mv(signal_addr, 4).cast('I')[0] = value
def _wait(self, tid, signal_addr, value): wait_cond(lambda: to_mv(signal_addr, 4).cast('I')[0] >= value, timeout_ms=60000)
def _timestamp(self, tid, timestamp_addr): to_mv(timestamp_addr, 8).cast('Q')[0] = time.perf_counter_ns()
@@ -61,7 +61,7 @@ class CPUComputeQueue(HWQueue):
def memory_barrier(self): return self
def exec(self, prg:CPUProgram, args_state:HCQArgsState, global_size, local_size):
return self.cmd(self._exec, prg, len(args_state.bufs), *[x.va_addr for x in args_state.bufs], *args_state.vals)
return self.cmd(self._exec, prg, len(args_state.bufs), *[x.va_addr for x in args_state.bufs], *args_state.vals, threads=(global_size or (1,))[0])
def wait(self, signal, value=0): return self.cmd(self._wait, signal.value_addr, value)
def timestamp(self, signal): return self.cmd(self._timestamp, signal.timestamp_addr)
def signal(self, signal, value:sint=0): return self.cmd(self._signal, signal.value_addr, value)

View File

@@ -36,8 +36,10 @@ dsp_string = PatternMatcher([
class DSPRenderer(ClangRenderer):
device = "DSP"
supports_float4 = True
has_threads = False
buffer_suffix = " restrict __attribute__((align_value(128)))"
kernel_typedef = "__attribute__((noinline)) void"
extra_args = []
pre_matcher = dsp_pm
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher
string_rewrite = dsp_string+ClangRenderer.string_rewrite

View File

@@ -14,6 +14,7 @@ if TYPE_CHECKING:
class AxisType(Enum):
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
THREAD = auto()
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)