mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
cpu threading (#11951)
* start cpu threading
* fix
* fix2
* fix
* hacks?
* threads
* minor
* no dsp
* dsp 2
* n
* more
* test
* xm
* cleaner
* readable
* f
* reorder
* when no threads
* rangeify
* typos
* not needed
* reapply
* remoev this
* linter
* fixed cpu count in ci
* fix
* fixes
* rm
* typo
* sort based on speed
* test if test works in ci
* Revert "test if test works in ci"
This reverts commit 1f05edb531.
* do not pad thread
This commit is contained in:
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
@@ -730,7 +730,7 @@ jobs:
|
||||
opencl: ${{ matrix.backend == 'gpu' && 'true' }}
|
||||
llvm: ${{ matrix.backend == 'llvm' && 'true' }}
|
||||
- name: Set env
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1' || matrix.backend == 'gpu' && 'GPU=1' }}" >> $GITHUB_ENV
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_COUNT=2' || matrix.backend == 'gpu' && 'GPU=1' }}" >> $GITHUB_ENV
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['LLVM','CPU','GPU'], Device.DEFAULT"
|
||||
@@ -954,7 +954,7 @@ jobs:
|
||||
pydeps: "capstone"
|
||||
llvm: ${{ matrix.backend == 'llvm' && 'true' }}
|
||||
- name: Set env
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1' || matrix.backend == 'metal' && 'METAL=1'}}" >> $GITHUB_ENV
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_COUNT=2' || matrix.backend == 'metal' && 'METAL=1'}}" >> $GITHUB_ENV
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
python -c "from tinygrad import Device; assert Device.DEFAULT == '${{ matrix.backend }}'.upper(), Device.DEFAULT"
|
||||
@@ -991,7 +991,7 @@ jobs:
|
||||
pydeps: ${{ matrix.backend == 'webgpu' && 'dawn-python' || '' }}
|
||||
- name: Set env
|
||||
shell: bash
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1' || matrix.backend == 'webgpu' && 'WEBGPU=1'}}" >> $GITHUB_ENV
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_COUNT=2' || matrix.backend == 'webgpu' && 'WEBGPU=1'}}" >> $GITHUB_ENV
|
||||
- name: Run unit tests
|
||||
if: matrix.backend=='llvm'
|
||||
# test_newton_schulz hits RecursionError
|
||||
|
||||
@@ -67,11 +67,12 @@ def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,in
|
||||
forward_args = ",".join(f"{dtype}{'*' if name not in symbolic_vars.values() else ''} {name}" for name,dtype,_ in (outputs+inputs if wasm else inputs+outputs))
|
||||
|
||||
if not wasm:
|
||||
thread_id = 0 # NOTE: export does not support threading, thread_id is always 0
|
||||
for name,cl in bufs_to_save.items():
|
||||
weight = ''.join(["\\x%02X"%x for x in bytes(to_mv(cl._buf.va_addr, cl._buf.size))])
|
||||
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
|
||||
cprog += [f"{dtype_map[dtype]} {name}[{len}];" if name not in bufs_to_save else f"{dtype_map[dtype]} *{name} = ({dtype_map[dtype]} *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in input_names+output_names]
|
||||
cprog += [f"void net({forward_args}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
|
||||
cprog += [f"void net({forward_args}) {{"] + [f"{name}({', '.join(args)}, {thread_id});" for (name, args, _global_size, _local_size) in statements] + ["}"]
|
||||
return '\n'.join(headers + cprog)
|
||||
else:
|
||||
if bufs_to_save:
|
||||
@@ -239,7 +240,9 @@ export default {model_name};
|
||||
|
||||
def export_model(model, target:str, *inputs, model_name: Optional[str] = "model", stream_weights=False):
|
||||
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, f"only {', '.join(EXPORT_SUPPORTED_DEVICE)} are supported"
|
||||
with Context(JIT=2): run,special_names = jit_model(model, *inputs)
|
||||
|
||||
# NOTE: CPU_COUNT=1, since export does not support threading
|
||||
with Context(JIT=2, CPU_COUNT=1): run,special_names = jit_model(model, *inputs)
|
||||
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
state = get_state_dict(model)
|
||||
weight_names = {id(x.uop.base.realized): name for name, x in state.items()}
|
||||
|
||||
@@ -336,5 +336,19 @@ class TestKernelOpts(unittest.TestCase):
|
||||
[Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UNROLL, axis=1, arg=4)], # noqa: E501
|
||||
])
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_threads, "test requires threads")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.global_max is not None and
|
||||
Device[Device.DEFAULT].renderer.global_max[0] > 1, "test requires multicore")
|
||||
def test_thread_opts(self):
|
||||
a = Tensor.rand(4, 4, 4, 4)
|
||||
b = Tensor.rand(4, 4, 4)
|
||||
r = (b.sqrt() + ((a+1).sum(axis=3).exp()))
|
||||
helper_linearizer_opt(r, [
|
||||
[Opt(OptOps.THREAD, 0, 2)],
|
||||
[Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.THREAD, 0, 2)],
|
||||
[Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.THREAD, 0, 2), Opt(OptOps.UNROLL, 0, 2)],
|
||||
] + [[Opt(OptOps.THREAD, 0, 4)] if Device[Device.DEFAULT].renderer.global_max[0] >= 4 else []]
|
||||
+ [[Opt(OptOps.THREAD, 0, 8)] if Device[Device.DEFAULT].renderer.global_max[0] >= 8 else []])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -59,7 +59,7 @@ def add_gpudims(ctx:Renderer, s:UOp):
|
||||
all_ranges = {x.arg[0:-1]:x for x in s_topo if x.op is Ops.RANGE}
|
||||
|
||||
# extract global/local dims
|
||||
global_dims = sorted(dedup([x.arg[0:-1] for x in all_ranges.values() if x.arg[-1] is AxisType.GLOBAL]))
|
||||
global_dims = sorted(dedup([x.arg[0:-1] for x in all_ranges.values() if x.arg[-1] in (AxisType.GLOBAL, AxisType.THREAD)]))
|
||||
local_dims = sorted(dedup([x.arg[0:-1] for x in all_ranges.values() if x.arg[-1] in (AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE)]))
|
||||
if not global_dims and not local_dims: return None
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass
|
||||
from tinygrad.uop.ops import AxisType
|
||||
|
||||
class OptOps(Enum):
|
||||
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
||||
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto(); THREAD = auto() # noqa: E702
|
||||
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
@@ -16,10 +16,10 @@ class Opt:
|
||||
arg: int|tuple|None = None
|
||||
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
|
||||
|
||||
axis_letters = {AxisType.GLOBAL: "g", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
||||
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
|
||||
axis_colors = {AxisType.GLOBAL: "blue", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE", AxisType.UPCAST: "yellow",
|
||||
AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
|
||||
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
|
||||
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
|
||||
|
||||
class KernelOptError(Exception): pass
|
||||
def check(cond:bool, msg:str=""):
|
||||
|
||||
@@ -171,4 +171,16 @@ def hand_coded_optimizations(k:Scheduler) -> list[Opt]:
|
||||
k.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
|
||||
if will_delete_shape: deleted_shape += 1
|
||||
|
||||
# **** threading ****
|
||||
|
||||
if k.opts.has_threads and k.opts.global_max is not None:
|
||||
for threads in [32,16,12,8,6,5,4,3,2]:
|
||||
# Skip is too many threads. Heuristic: use about 128K ops per thread
|
||||
if threads > k.opts.global_max[0] or resolve(prod(k.full_shape) // (128 << 10) < threads): continue
|
||||
for axis in k.axes_of(AxisType.LOOP):
|
||||
if k.full_shape[axis] % threads == 0:
|
||||
k.apply_opt(Opt(OptOps.THREAD, axis, threads))
|
||||
break
|
||||
if k.applied_opts and k.applied_opts[-1].op is OptOps.THREAD: break
|
||||
|
||||
return k.applied_opts
|
||||
|
||||
@@ -12,7 +12,7 @@ from tinygrad.renderer import Renderer
|
||||
from tinygrad.schedule.rangeify import remove_tags
|
||||
|
||||
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
|
||||
axis_to_pos = {AxisType.LOOP: -1, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
||||
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
||||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
||||
|
||||
class Scheduler:
|
||||
@@ -127,7 +127,7 @@ class Scheduler:
|
||||
opt_to_at = {
|
||||
OptOps.LOCAL: AxisType.LOCAL, OptOps.UPCAST: AxisType.UPCAST,
|
||||
OptOps.UNROLL: AxisType.UNROLL, OptOps.GROUP: AxisType.GROUP_REDUCE,
|
||||
OptOps.GROUPTOP: AxisType.GROUP_REDUCE}
|
||||
OptOps.GROUPTOP: AxisType.GROUP_REDUCE, OptOps.THREAD: AxisType.THREAD}
|
||||
|
||||
ret = None
|
||||
if opt.op in opt_to_at:
|
||||
@@ -149,11 +149,16 @@ class Scheduler:
|
||||
if opt.op is OptOps.LOCAL:
|
||||
check(not self.dont_use_locals, "can't use locals")
|
||||
check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOOP}, "local is for globals")
|
||||
if opt.op is OptOps.THREAD:
|
||||
check(self.opts is not None and self.opts.has_threads, "target does not support threads")
|
||||
check(self.opts is not None and self.opts.global_max is not None and amt <= self.opts.global_max[0], "too many threads")
|
||||
check(all(x is not AxisType.THREAD for x in self.axis_types), "already threaded")
|
||||
check(rng in self._globalizable_rngs(), "can't apply range to this dim")
|
||||
if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}:
|
||||
check(all(x.op is not OptOps.TC for x in self.applied_opts), "no grouping with tensor cores") # TODO: why is this wrong?
|
||||
check(not self.dont_use_locals, "can't use locals")
|
||||
check(rng.arg[-1] == AxisType.REDUCE, "group is for reduce")
|
||||
ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op==OptOps.GROUPTOP)
|
||||
ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op in {OptOps.GROUPTOP, OptOps.THREAD})
|
||||
elif opt.op is OptOps.TC:
|
||||
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps
|
||||
check(opt.axis is not None, "tensor core opts must have an axis")
|
||||
@@ -166,6 +171,7 @@ class Scheduler:
|
||||
elif opt.op is OptOps.PADTO:
|
||||
check(rng.src[0].op is Ops.CONST, "only pad const axes")
|
||||
check(rng.arg[-1] not in {AxisType.UPCAST, AxisType.UNROLL}, "cannot pad upcasted") # TODO: why is this wrong?
|
||||
check(rng.arg[-1] is not AxisType.THREAD, "cannot pad thread")
|
||||
# ok to pad SUM if all parent ALU ops have f(0) = 0
|
||||
if (r:=self.reduceop) is not None and rng.arg[-1] in (AxisType.GROUP_REDUCE, AxisType.REDUCE):
|
||||
check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}")
|
||||
|
||||
@@ -22,6 +22,7 @@ actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0, getenv("TC", 1)))]
|
||||
# covers resnet kernels (3 global * 3 reduce)
|
||||
actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2), getenv("TC", 1))) for axis in range(9)]
|
||||
actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)]
|
||||
actions += [Opt(op=OptOps.THREAD, axis=axis, arg=amt) for amt in [2,3,4,5,8,12,16,24,32,64] for axis in range(3)]
|
||||
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
||||
|
||||
def get_test_global_size(global_size, max_global_size, var_vals):
|
||||
|
||||
@@ -47,7 +47,8 @@ def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None)
|
||||
src = renderer.render(uops)
|
||||
|
||||
return ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, renderer.device, ast, uops,
|
||||
global_size=[1,1,1] if renderer.has_local else None, local_size=[1,1,1] if renderer.has_local else None)
|
||||
global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None,
|
||||
local_size=[1,1,1] if renderer.has_local else None)
|
||||
|
||||
# **************** Runners ****************
|
||||
|
||||
|
||||
@@ -142,6 +142,7 @@ CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), Co
|
||||
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, AMD_LLVM = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0), ContextVar("AMD_LLVM", 1)
|
||||
RANGEIFY, FUSE_ATTENTION = ContextVar("RANGEIFY", 0), ContextVar("FUSE_ATTENTION", 0)
|
||||
EMULATE = ContextVar("EMULATE", "")
|
||||
CPU_COUNT = ContextVar("CPU_COUNT", max(1, (os.cpu_count() or 1) // 4))
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
||||
@@ -112,6 +112,7 @@ class Renderer:
|
||||
# TODO: make this generic with a list of supported types
|
||||
supports_float4: bool = True
|
||||
has_local: bool = True
|
||||
has_threads: bool = False
|
||||
has_shared: bool = True
|
||||
# NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
|
||||
global_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
|
||||
|
||||
@@ -3,7 +3,7 @@ import os, math, sys
|
||||
from collections import defaultdict, Counter
|
||||
from tinygrad.codegen.opt import tc
|
||||
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, sint_to_uop
|
||||
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
|
||||
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX, CPU_COUNT
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen.late.devectorizer import no_vectorized_alu
|
||||
@@ -192,9 +192,12 @@ class ClangRenderer(CStyleLanguage):
|
||||
float4_style = ('{', '}')
|
||||
gep_arr_threshold = 0
|
||||
has_local = False
|
||||
global_max = None
|
||||
has_threads = True
|
||||
global_max = (CPU_COUNT.value, 0, 0)
|
||||
infinity = "__builtin_inff()"
|
||||
nan = '__builtin_nanf("")'
|
||||
code_for_workitem = {"g": lambda _: "core_id"}
|
||||
extra_args = ['int core_id']
|
||||
if AMX: tensor_cores = tc.amx
|
||||
|
||||
# language options
|
||||
|
||||
@@ -51,7 +51,7 @@ class CPUWorker(threading.Thread):
|
||||
|
||||
class CPUComputeQueue(HWQueue):
|
||||
def _exec(self, tid, prg, bufs, *args):
|
||||
prg.fxn(*map(ctypes.c_uint64, args[:bufs]), *map(ctypes.c_int64 if platform.machine() == "arm64" else ctypes.c_int32, args[bufs:]))
|
||||
prg.fxn(*map(ctypes.c_uint64, args[:bufs]), *map(ctypes.c_int64 if platform.machine() == "arm64" else ctypes.c_int32, args[bufs:]), tid)
|
||||
def _signal(self, tid, signal_addr, value): to_mv(signal_addr, 4).cast('I')[0] = value
|
||||
def _wait(self, tid, signal_addr, value): wait_cond(lambda: to_mv(signal_addr, 4).cast('I')[0] >= value, timeout_ms=60000)
|
||||
def _timestamp(self, tid, timestamp_addr): to_mv(timestamp_addr, 8).cast('Q')[0] = time.perf_counter_ns()
|
||||
@@ -61,7 +61,7 @@ class CPUComputeQueue(HWQueue):
|
||||
|
||||
def memory_barrier(self): return self
|
||||
def exec(self, prg:CPUProgram, args_state:HCQArgsState, global_size, local_size):
|
||||
return self.cmd(self._exec, prg, len(args_state.bufs), *[x.va_addr for x in args_state.bufs], *args_state.vals)
|
||||
return self.cmd(self._exec, prg, len(args_state.bufs), *[x.va_addr for x in args_state.bufs], *args_state.vals, threads=(global_size or (1,))[0])
|
||||
def wait(self, signal, value=0): return self.cmd(self._wait, signal.value_addr, value)
|
||||
def timestamp(self, signal): return self.cmd(self._timestamp, signal.timestamp_addr)
|
||||
def signal(self, signal, value:sint=0): return self.cmd(self._signal, signal.value_addr, value)
|
||||
|
||||
@@ -36,8 +36,10 @@ dsp_string = PatternMatcher([
|
||||
class DSPRenderer(ClangRenderer):
|
||||
device = "DSP"
|
||||
supports_float4 = True
|
||||
has_threads = False
|
||||
buffer_suffix = " restrict __attribute__((align_value(128)))"
|
||||
kernel_typedef = "__attribute__((noinline)) void"
|
||||
extra_args = []
|
||||
pre_matcher = dsp_pm
|
||||
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher
|
||||
string_rewrite = dsp_string+ClangRenderer.string_rewrite
|
||||
|
||||
@@ -14,6 +14,7 @@ if TYPE_CHECKING:
|
||||
|
||||
class AxisType(Enum):
|
||||
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
|
||||
THREAD = auto()
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
||||
Reference in New Issue
Block a user