diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 616fababbe..c1f859aa3f 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -6,7 +6,7 @@ from tinygrad.renderer.cstyle import AMDHIPRenderer, create_non_native_float_pat from tinygrad.uop.decompositions import xexp2, xlog2 from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, range_str from tinygrad.dtype import dtypes, float_to_fp8, DType, PtrDType, truncate -from tinygrad.helpers import prod, AMX +from tinygrad.helpers import prod, AMX, CPU_COUNT, getenv def ldt(dt:DType): if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>" @@ -199,7 +199,8 @@ class LLVMRenderer(Renderer): class CPULLVMRenderer(LLVMRenderer): device = "CPU" has_local = False - global_max: tuple[int, ...] | None = None + has_threads = bool(getenv("THREADS", 1)) + global_max = (CPU_COUNT.value, 0, 0) abi = 'win64cc' if sys.platform == 'win32' else None string_rewrite = base_rewrite + PatternMatcher([(UPat(Ops.WMMA, name="wmma"), render_wmma_amx)]) def render(self, uops: list[UOp]) -> str: return "\n".join((k:=self._render_kernel(uops))[0] + (k[1], self._render_footer(uops)))