Merge branch 'master' into nak

This commit is contained in:
Christopher Milan
2025-10-12 08:38:29 -07:00
21 changed files with 251 additions and 55 deletions

View File

@@ -310,9 +310,9 @@ jobs:
- name: Fuzz Test fast idiv
run: python test/external/fuzz_fast_idiv.py
- name: Fuzz Test shapetracker
run: |
python test/external/fuzz_shapetracker.py
python test/external/fuzz_shapetracker_math.py
run: CNT=50 python test/external/fuzz_shapetracker.py
- name: Fuzz Test shapetracker math
run: CNT=200 python test/external/fuzz_shapetracker_math.py
- name: Fuzz Test shape ops
run: python test/external/fuzz_shape_ops.py
@@ -377,7 +377,7 @@ jobs:
llvm: 'true'
- name: Test openpilot model kernel count and gate usage
run: |
ALLOWED_KERNEL_COUNT=190 ALLOWED_READ_IMAGE=2041 ALLOWED_GATED_READ_IMAGE=41 FLOAT16=0 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
ALLOWED_KERNEL_COUNT=190 ALLOWED_READ_IMAGE=2092 ALLOWED_GATED_READ_IMAGE=55 FLOAT16=0 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
- name: Test openpilot alt model correctness (float32)
run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
- name: Test openpilot fastvits model correctness (float32)

View File

@@ -42,13 +42,13 @@ class ProcessReplayWarning(Warning): pass
# *** replay the function and convert return values to string
def replay_kernelize(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str, tuple[Any, ...]]:
def replay_get_rangeify_map(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str, tuple[Any, ...]]:
UOp.unique_num = itertools.count(max([u.arg for u in big_sink.toposort() if u.op is Ops.UNIQUE], default=0)+1)
new_sink = big_sink.substitute(get_rangeify_map(big_sink))
def to_str(ret:UOp) -> str:
asts = [repr(u.arg.ast) for u in ret.toposort() if u.op is Ops.KERNEL]
return "\n".join([f"{len(asts)} kernels", *asts])
return to_str(new_sink), to_str(ret[big_sink]), (big_sink,)
return to_str(new_sink), to_str(big_sink.substitute(ret)), (big_sink,)
def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]:
# NOTE: this always uses the opts_to_apply path
@@ -65,7 +65,7 @@ def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer|None=None, opts
ast_repr = codecs.decode(str(input_ast), "unicode_escape")
return to_str(p2), to_str(p), (ast_repr, renderer)
replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {"get_kernelize_map":replay_kernelize, "get_program":replay_get_program}
replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {"get_rangeify_map":replay_get_rangeify_map, "get_program":replay_get_program}
# *** run replayers on captured rows and print diffs

View File

@@ -8,9 +8,11 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp
from tinygrad.device import Device, Buffer, is_dtype_supported
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner, get_program
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT, getenv
from tinygrad.dtype import DType, dtypes, PtrDType, AddrSpace
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.cstyle import CUDARenderer
MOCKGPU = getenv("MOCKGPU")
class TestLinearizer(unittest.TestCase):
def test_arg_dedup(self):
@@ -314,7 +316,7 @@ class TestLinearizer(unittest.TestCase):
a.realize()
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX indexes differently. might be ok?")
@unittest.skipIf(MOCKGPU and isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, CUDARenderer)), "PTX indexes differently. might be ok?")
def test_where_fold(self):
a = Tensor.ones(4, 4).contiguous().realize()
b = a.shrink(((1, 2), None)).pad(((1, 2), None))

View File

@@ -30,6 +30,7 @@ class BaseTestViz(unittest.TestCase):
# clear the global context
for lst in [tracked_keys, tracked_ctxs, active_rewrites, _name_cnt]: lst.clear()
Buffer.profile_events.clear()
cpu_events.clear()
self.tms = TRACK_MATCH_STATS.value
self.profile = PROFILE.value
TRACK_MATCH_STATS.value = 2
@@ -462,5 +463,21 @@ class TestVizMemoryLayout(BaseTestViz):
self.assertEqual(ret["peak"], 2)
self.assertEqual(len(ret["events"]), 4)
def test_free_last(self):
bufs = []
for _ in range(3):
bufs.append(_alloc(1))
profile_marker("alloc")
device = bufs[0].device
while bufs:
b = bufs.pop()
del b
profile_marker("free")
profile = load_profile(cpu_events+Buffer.profile_events)
ret = profile["layout"][f"{device} Memory"]
self.assertEqual(ret["peak"], 3)
self.assertEqual(len(ret["events"]), 6)
self.assertEqual(len(profile["markers"]), 6)
if __name__ == "__main__":
unittest.main()

View File

@@ -42,7 +42,7 @@ class TestWinograd(unittest.TestCase):
out = Tensor.conv2d(x,w, padding=1)
out.mean().backward()
backward_schedule = Tensor.schedule(x.grad, w.grad)
self.assertEqual(len(backward_schedule), 4)
self.assertEqual(len(backward_schedule), 5)
def test_counters(self):
IC, OC, X, Y = 4,4,9,9

View File

@@ -178,7 +178,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
if k.opts.has_threads and k.opts.global_max is not None:
for threads in [32,16,12,8,6,5,4,3,2]:
# Skip is too many threads. Heuristic: use about 128K ops per thread
# Skip if too many threads. Heuristic: use about 128K ops per thread
if threads > k.opts.global_max[0] or resolve(prod(k.full_shape) // (128 << 10) < threads): continue
for axis in k.axes_of(AxisType.LOOP):
if k.full_shape[axis] % threads == 0:

View File

@@ -165,15 +165,22 @@ def beam_search(lin:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=Tr
if isinstance(e, RuntimeError): continue
raise
timed_lins.append((acted_lins[i], min(tms)))
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {time_to_str(compile_et, w=12)} compile/{time_to_str(timed_lins[-1][1], w=12)} run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {time_to_str(timed_lins[-1][1], w=12)} {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
if BEAM_DEBUG > 1:
print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops",
f"{time_to_str(compile_et, w=12)} compile/{time_to_str(timed_lins[-1][1], w=12)} run",
f" {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}")
elif DEBUG >= 2:
print(f"\r{time.perf_counter() - st:7.2f}s: {time_to_str(timed_lins[-1][1], w=12)}",
f" {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="")
# done
opts = sorted(timed_lins, key=lambda x: x[1])
exiting = len(opts) == 0 or (opts[0][1] < min_progress) or (len(beam) > 0 and ((beam[0][1]-opts[0][1]) < min_progress))
if not exiting: beam = opts[:amt]
elif len(opts) > 0 and opts[0][1] < beam[0][1]: beam = opts[:1]
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(time_to_str(beam[0][1], w=12), "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
if DEBUG >= 2:
print(f"\r{time.perf_counter() - st:7.2f}s:", colored(time_to_str(beam[0][1], w=12), "green" if exiting else None),
f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape())
except KeyboardInterrupt as e:
if beam_pool is not None: beam_pool.terminate()
raise e

View File

@@ -23,7 +23,7 @@ class _Device:
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
@functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none
def __get_canonicalized_item(self, ix:str) -> Compiled:
assert ALLOW_DEVICE_USAGE or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"usage of device {ix} disallowed"
assert ALLOW_DEVICE_USAGE or ix.split(":")[0] in ["DISK", "TINYFS", "NPY", "PYTHON"], f"usage of device {ix} disallowed"
base = (__package__ or __name__).split('.')[0] # tinygrad
x = ix.split(":")[0].lower()
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{base}.runtime.ops_{x}')) \
@@ -39,7 +39,7 @@ class _Device:
@functools.cached_property
def DEFAULT(self) -> str:
dev = [dev] if (dev:=getenv("DEV", "").upper()) else []
from_env = dedup(dev + [d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1])
from_env = dedup(dev + [d for d in self._devices if d not in ["DISK", "TINYFS", "NPY"] and getenv(d) == 1])
assert len(from_env) < 2, f"multiple devices set in env: {from_env}"
if len(from_env) == 1: return from_env[0]
try:

View File

@@ -121,7 +121,7 @@ class BufferCopy(Runner):
getattr(src.allocator.dev, 'fd', None) is not None and dest.allocator.supports_copy_from_disk
if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096:
dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
elif src.device.startswith("DISK") and hasattr(dest.allocator, '_as_buffer'):
elif (src.device.startswith("DISK") or src.device.startswith("TINYFS")) and hasattr(dest.allocator, '_as_buffer'):
# fast(ish) path, uses readinto in diskbuffers
src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf)
else:

View File

@@ -23,10 +23,13 @@ def argfix(*x):
if len(x) != 1: raise ValueError(f"bad arg {x}")
return tuple(x[0])
return x
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
# https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__))
def all_same(items:tuple[T, ...]|list[T]): return all(x == items[0] for x in items)
def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstance(s, int) for s in t)
def colored(st, color:str|None, background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
def colored(st, color:str|None, background=False): # replace the termcolor library
colors = ['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white']
return f"\u001b[{10*background+60*(color.upper() == color)+30+colors.index(color.lower())}m{st}\u001b[0m" if color is not None else st
def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
def time_to_str(t:float, w=8) -> str: return next((f"{t * d:{w}.2f}{pr}" for d,pr in [(1, "s "),(1e3, "ms")] if t > 10/d), f"{t * 1e6:{w}.2f}us")
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
@@ -150,7 +153,7 @@ CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), Co
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0)
FUSE_ATTENTION = ContextVar("FUSE_ATTENTION", 0)
EMULATE = ContextVar("EMULATE", "")
CPU_COUNT = ContextVar("CPU_COUNT", max(1, (os.cpu_count() or 1) // (4 if ARCH_X86 else 2))) # take 1/2 of the cores, accounting HT
CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(aff(0)) if (aff:=getattr(os, "sched_getaffinity", None)) else (os.cpu_count() or 1)))
CPU_LLVM, CPU_LVP, AMD_LLVM = ContextVar("CPU_LLVM", 0), ContextVar("CPU_LVP", 0), ContextVar("AMD_LLVM", 1)
VIZ = PROFILE = ContextVar("VIZ", 0)
SPEC = ContextVar("SPEC", 0)
@@ -218,11 +221,12 @@ class TracingKey:
class ProfileEvent: pass
@dataclass
class ProfileRangeEvent(ProfileEvent): device:str; name:str|TracingKey; st:decimal.Decimal; en:decimal.Decimal|None=None; is_copy:bool=False # noqa: E702
class ProfileRangeEvent(ProfileEvent):
device:str; name:str|TracingKey; st:decimal.Decimal; en:decimal.Decimal|None=None; is_copy:bool=False # noqa: E702
@dataclass(frozen=True)
class ProfilePointEvent(ProfileEvent): device:str; name:str; key:Any; arg:dict=field(default_factory=dict); \
ts:decimal.Decimal=field(default_factory=perf_counter_us) # noqa: E702
class ProfilePointEvent(ProfileEvent):
device:str; name:str; key:Any; arg:dict=field(default_factory=dict); ts:decimal.Decimal=field(default_factory=perf_counter_us) # noqa: E702
cpu_events:list[ProfileEvent] = []
@contextlib.contextmanager
@@ -281,7 +285,8 @@ def diskcache_put(table:str, key:dict|str|int, val:Any, prepickled=False):
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
_db_tables.add(table)
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key))}, ?)", tuple(key.values()) + (val if prepickled else pickle.dumps(val), )) # noqa: E501
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key))}, ?)",
tuple(key.values()) + (val if prepickled else pickle.dumps(val),))
conn.commit()
cur.close()
return val

View File

@@ -108,7 +108,9 @@ class CStyleLanguage(Renderer):
extra_matcher = extra_pm
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
tmp = ""
if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs):
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]
@@ -229,10 +231,12 @@ class ClangRenderer(CStyleLanguage):
# 'static' in C roughly means that function symbol isn't exported. LLVM puts those symbols at the end of object file which allows Clang JIT
# to just jump at the start of a shellcode without having to deal with symbols or trampolines at all. This is better than having to inline
# wmma function every time it is called or wasting complexity on a symbol parsing and a memory page on trampoline.
prefix += [f"""static {(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
out, dt1, dt2 = self.render_dtype(dtype_in.vec(N*N)), self.render_dtype(dtype_in.vec(N)), self.render_dtype(dtype_in.vec(M))
prefix += [f"""static {out} __{name}({dt1} data1, {dt2} data2, {out} data0){{
AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
AMX(0, (int *)(&data2), 0ull<<62); AMX(1, (int *)(&data1), 0ull<<62); AMX(12, 0, 0ull);
for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}\n AMX_SET(1);\n return data0;\n}}"""] # noqa: E501
for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
AMX_SET(1);\n return data0;\n}}"""]
return prefix
def _render_body(self, function_name, kernel, bufs, uops, pref=None) -> str: return super().render_kernel(function_name, kernel, bufs, uops, pref)
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str: return ""

View File

@@ -23,10 +23,12 @@ class CLCompiler(Compiler):
build_status: int = cl.clBuildProgram(program, 1, self.dev.device_id, None, cl.clBuildProgram.argtypes[4](), None)
if build_status != 0:
cl.clGetProgramBuildInfo(program, self.dev.device_id, cl.CL_PROGRAM_BUILD_LOG, 0, None, log_size := ctypes.c_size_t())
cl.clGetProgramBuildInfo(program, self.dev.device_id, cl.CL_PROGRAM_BUILD_LOG, log_size.value, mstr := ctypes.create_string_buffer(log_size.value), None) # noqa: E501
cl.clGetProgramBuildInfo(program, self.dev.device_id, cl.CL_PROGRAM_BUILD_LOG,
log_size.value, mstr := ctypes.create_string_buffer(log_size.value), None)
raise CompileError(f"OpenCL Compile Error\n\n{mstr.value.decode()}")
check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARY_SIZES, ctypes.sizeof(ctypes.c_size_t), binary_sizes := (ctypes.c_size_t * 1)(), None))
check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARIES, ctypes.sizeof(ctypes.c_void_p), (ctypes.c_void_p * 1)(ctypes.addressof(binary := ctypes.create_string_buffer(binary_sizes[0]))), None)) # noqa: E501
check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARIES, ctypes.sizeof(ctypes.c_void_p),
(ctypes.c_void_p * 1)(ctypes.addressof(binary := ctypes.create_string_buffer(binary_sizes[0]))), None))
check(cl.clReleaseProgram(program))
return bytes(binary)
@@ -97,16 +99,22 @@ class CLDevice(Compiled):
err = cl.clGetDeviceIDs(platform_ids[0], device_type, 0, None, num_devices := ctypes.c_uint32())
if err == 0 and num_devices.value != 0: break
if DEBUG >= 1: print(f"CLDevice: got {num_platforms.value} platforms and {num_devices.value} devices")
CLDevice.device_ids = init_c_var((cl.cl_device_id * num_devices.value)(), lambda x: check(cl.clGetDeviceIDs(platform_ids[0], device_type, num_devices, x, None))) # noqa: E501
CLDevice.device_ids = init_c_var((cl.cl_device_id * num_devices.value)(),
lambda x: check(cl.clGetDeviceIDs(platform_ids[0], device_type, num_devices, x, None)))
self.device_id = CLDevice.device_ids[0 if ":" not in device else int(device.split(":")[1])]
self.device_name = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_NAME, 256, buf := ctypes.create_string_buffer(256), None), buf.value.decode())[1] # noqa: E501
self.driver_version = (cl.clGetDeviceInfo(self.device_id, cl.CL_DRIVER_VERSION, 256, buf := ctypes.create_string_buffer(256), None), buf.value.decode())[1] # noqa: E501
self.device_name = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_NAME, 256,
buf:=ctypes.create_string_buffer(256), None), buf.value.decode())[1]
self.driver_version = (cl.clGetDeviceInfo(self.device_id, cl.CL_DRIVER_VERSION, 256,
buf:=ctypes.create_string_buffer(256), None), buf.value.decode())[1]
if DEBUG >= 1: print(f"CLDevice: opening {self.device_name} with version {self.driver_version}")
self.context = checked(cl.clCreateContext(None, 1, self.device_id, cl.clCreateContext.argtypes[3](), None, status := ctypes.c_int32()), status)
self.queue = checked(cl.clCreateCommandQueue(self.context, self.device_id, cl.CL_QUEUE_PROFILING_ENABLE, status), status)
self.pending_copyin: list[memoryview] = []
self.device_exts = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_EXTENSIONS, 4096, ctypes.byref(buf := ctypes.create_string_buffer(4096)), ctypes.byref(total := ctypes.c_size_t())), ctypes.string_at(buf, size=total.value).decode())[1] # noqa: E501
self.device_exts = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_EXTENSIONS, 4096,
ctypes.byref(buf := ctypes.create_string_buffer(4096)),
ctypes.byref(total := ctypes.c_size_t())),
ctypes.string_at(buf, size=total.value).decode())[1]
compilers = [(IntelRenderer if "cl_intel_subgroup_matrix_multiply_accumulate" in self.device_exts else OpenCLRenderer,
functools.partial(CLCompiler, self, f"compile_cl_{hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()}"))]

View File

@@ -10,7 +10,9 @@ if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401 # pylint:
if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.cuda import cuda # type: ignore # pylint: disable=reimported
def check(status):
if status != 0: raise RuntimeError(f"CUDA Error {status}, {ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()}") # noqa: E501
if status != 0:
error = ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()
raise RuntimeError(f"CUDA Error {status}, {error}")
def encode_args(args, vals) -> tuple[ctypes.Structure, ctypes.Array]:
c_args = init_c_struct_t(tuple([(f'f{i}', cuda.CUdeviceptr_v2) for i in range(len(args))] +

View File

@@ -0,0 +1,137 @@
import socket, uuid, json, asyncio, threading
from contextlib import asynccontextmanager
from tinygrad.device import Compiled, Allocator
from tinygrad.helpers import DEBUG, getenv
TINYFS_ENDPOINT = getenv("TINYFS_ENDPOINT", "localhost:6767")
CHUNK_SIZE = 2**20
class TinyFSDevice(Compiled):
def __init__(self, device:str):
self.op = device[len("tinyfs:"):].upper()
super().__init__(device, TinyFSAllocator(self), None, None, None)
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.connect((TINYFS_ENDPOINT.rsplit(":", 1)[0], int(TINYFS_ENDPOINT.rsplit(":", 1)[1])))
self.sfile = self.sock.makefile("rwb")
# fetch node info
self.sfile.write(b"INFO\r\n")
self.sfile.flush()
info = self.sfile.readline()
self.node_info = json.loads(info)
if DEBUG >= 2: print(f"nodes: {self.node_info}")
# spawn thread for async copyout
self.start_event = threading.Event()
self.t = threading.Thread(target=self._start_thread, daemon=True)
self.t.start()
self.start_event.wait()
# connection pools
self.conn_pools: dict[str, asyncio.Queue] = {}
self.conn_pools_lock = asyncio.Lock()
def finalize(self):
self.sfile.close()
for pool in self.conn_pools.values():
while not pool.empty():
_, w = pool.get_nowait()
w.close()
asyncio.run_coroutine_threadsafe(w.wait_closed(), self.loop).result()
if hasattr(self, "loop"):
self.loop.call_soon_threadsafe(self.loop.stop)
self.t.join()
def _start_thread(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.start_event.set()
self.loop.run_forever()
self.loop.close()
@asynccontextmanager
async def connection(self, loc):
if loc not in self.conn_pools:
await self.conn_pools_lock.acquire()
if loc not in self.conn_pools:
self.conn_pools[loc] = asyncio.Queue(nw:=getenv("ASYNC_COPY_WORKERS", 4))
conn_tasks = [asyncio.open_connection(*self.node_info[loc][-1].rsplit(":", 1)) for _ in range(nw)]
connections = await asyncio.gather(*conn_tasks)
for reader, writer in connections: self.conn_pools[loc].put_nowait((reader, writer))
self.conn_pools_lock.release()
reader, writer = await self.conn_pools[loc].get()
try:
yield reader, writer
finally:
await self.conn_pools[loc].put((reader, writer))
class TinyFSBuffer:
def __init__(self, device:TinyFSDevice, size:int, offset=0, request_id=None, copyout_queue=None):
self.device, self.size, self.offset = device, size, offset
self.request_id: uuid.UUID|None = request_id
self.copyout_queue = copyout_queue or []
def __repr__(self): return f"<TinyFSBuffer size={self.size} offset={self.offset}>"
class TinyFSAllocator(Allocator[TinyFSDevice]):
def _alloc(self, size, options):
return TinyFSBuffer(self.dev, size)
def _copyin(self, dest:TinyFSBuffer, src:memoryview):
if DEBUG >= 2: print(f"Copying in {dest.size} bytes to TINYFS:{dest.device.op}")
self.dev.sfile.write(f"{dest.device.op}_IN {dest.size}\r\n".encode())
if dest.device.op == "STORE":
self.dev.sfile.flush()
dest.request_id = uuid.UUID(bytes=self.dev.sfile.read(16))
if DEBUG >= 2: print(f"Request ID: {dest.request_id}")
self.dev.sfile.write(src)
self.dev.sfile.flush()
if dest.device.op == "LOAD":
locs = self.dev.sfile.readline()
locs = json.loads(locs)
dest.copyout_queue = []
for i, loc in enumerate(locs):
dest.copyout_queue.append((i, loc, src[i*16:(i+1)*16]))
def _copyout(self, dest:memoryview, src:TinyFSBuffer):
if DEBUG >= 2: print(f"Copying out {src.size} bytes from TINYFS:{src.device.op}")
if src.device.op == "LOAD":
asyncio.run_coroutine_threadsafe(self._copyout_async(dest, src), src.device.loop).result()
else:
self.dev.sfile.write(f"{src.device.op}_OUT {src.size} {src.request_id}\r\n".encode())
self.dev.sfile.flush()
src.request_id = uuid.UUID(bytes=self.dev.sfile.read(16))
if DEBUG >= 2: print(f"Request ID: {src.request_id}")
self.dev.sfile.readinto(dest)
async def _copyout_async(self, dest:memoryview, src:TinyFSBuffer):
async def _worker(item):
i, loc, h = item
async with self.dev.connection(loc) as (reader, writer):
ptr = i * CHUNK_SIZE
size = min(len(dest[ptr:ptr+CHUNK_SIZE]), CHUNK_SIZE)
writer.write(f"CHUNK_OUT {size}\r\n".encode())
writer.write(h)
await writer.drain()
chunk = await reader.readexactly(size)
view = dest[ptr:ptr+len(chunk)]
view[:] = chunk
del view
workers = [asyncio.create_task(_worker(item)) for item in src.copyout_queue]
await asyncio.gather(*workers)
src.copyout_queue.clear()
def _offset(self, buf:TinyFSBuffer, size:int, offset:int):
return TinyFSBuffer(buf.device, size, offset, buf.request_id, buf.copyout_queue)

View File

@@ -60,7 +60,11 @@ def compile_hip(prg:str, arch="gfx1100", asm=False) -> bytes:
check(comgr.amd_comgr_set_data_name(data_src, b"<null>"))
check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
# -include hiprtc_runtime.h was removed
check(set_options(action_info, f"-O3 -mcumode --hip-version=6.0.32830 -DHIP_VERSION_MAJOR=6 -DHIP_VERSION_MINOR=0 -DHIP_VERSION_PATCH=32830 -D__HIPCC_RTC__ -std=c++14 -nogpuinc -Wno-gnu-line-marker -Wno-missing-prototypes --offload-arch={arch} -I/opt/rocm/include -Xclang -disable-llvm-passes -Xclang -aux-triple -Xclang x86_64-unknown-linux-gnu".encode())) # noqa: E501
options = [
"-O3", "-mcumode", "--hip-version=6.0.32830", "-DHIP_VERSION_MAJOR=6", "-DHIP_VERSION_MINOR=0", "-DHIP_VERSION_PATCH=32830",
"-D__HIPCC_RTC__", "-std=c++14", "-nogpuinc", "-Wno-gnu-line-marker", "-Wno-missing-prototypes", f"--offload-arch={arch}",
"-I/opt/rocm/include", "-Xclang -disable-llvm-passes", "-Xclang -aux-triple", "-Xclang x86_64-unknown-linux-gnu"]
check(set_options(action_info, ' '.join(options).encode()))
status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, action_info, data_set_src, data_set_bc)
if status != 0:
print(_get_comgr_data(data_set_bc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())

View File

@@ -22,10 +22,12 @@ def jitlink_check(status, ctx=None):
def pretty_ptx(s):
# all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers # noqa: E501
s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])',
lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers
s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types
s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions
s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers # noqa: E501
s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])',
lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers
s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
return s

View File

@@ -3,7 +3,7 @@ import functools, operator, itertools
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType
from tinygrad.uop.symbolic import sym, symbolic
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid
from tinygrad.helpers import argsort, all_same, cpu_profile, TracingKey
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
@@ -112,8 +112,8 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO
case Ops.EXPAND: rngs = tuple(a if in_sh == out_sh else a.const_like(0) for a,in_sh,out_sh in zip(rngs, in_shape, arg))
case Ops.PAD:
# TODO: why is multiple graph_rewrites faster than one here?
rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh+s))).where(r-s, UOp.invalid()), sym, name="pad")
for r,sh,(s,e) in zip(rngs, in_shape, arg))
rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh+s))).where(r-s, UOp.invalid()),
symbolic+pm_simplify_valid, name="pad") for r,sh,(s,e) in zip(rngs, in_shape, arg))
case Ops.RESHAPE:
acc = 1
axes_in:list[UOp] = []
@@ -126,7 +126,7 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO
axes_out.append(combined_axes % s)
combined_axes //= s
# this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code
rngs = graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic, name="reshape").src
rngs = graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid, name="reshape").src
case _: raise RuntimeError(f"{op} is not a MovementOp")
return rngs

View File

@@ -207,7 +207,7 @@ pm_cleanups = pm_mops+PatternMatcher([
])
def late_buffer_view(t:UOp, b:UOp):
if isinstance(b.device, str) and b.device.startswith("DISK"):
if isinstance(b.device, str) and (b.device.startswith("DISK") or b.device.startswith("TINYFS")):
rngs = b.src[1:]
size = prod(shape := [int(r.vmax+1) for r in rngs])

View File

@@ -2484,17 +2484,20 @@ class Tensor(MathTrait):
if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, dtype)
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
padding_ = self._resolve_pool_pads(padding, len(HW))
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
assert groups*cin == cin_ and len(self.shape) == len(weight.shape),\
f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
# conv2d is a pooling op (with padding)
x = self.pad(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not WINO:
# normal conv
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW)\
.permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))])
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, dtype=dtype).reshape(bs, cout, *oyx) # noqa: E501
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW))\
.sum([-1-i for i in range(1+len(oyx))], keepdim=True, dtype=dtype).reshape(bs, cout, *oyx)
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
@@ -2505,7 +2508,8 @@ class Tensor(MathTrait):
# TODO: stride == dilation
# use padding to round up to 4x4 output tiles
# (bs, cin_, tyx, HWI)
d = self.pad(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
pads = [[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])]
d = self.pad(sum(pads, []))._pool(HWI, HWO)
# move HW to the front: # (HWI, bs, cin_, tyx)
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW)))
tyx = d.shape[-len(HWI):] # dim of tiling

View File

@@ -177,8 +177,8 @@ def fold_binary_numerator(d: UOp, x: UOp, y: UOp) -> UOp|None:
x,const = x.pop_const()
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x.split_uop(Ops.ADD)])
if len(terms)==1 and (v:=terms[0]).vmax-v.vmin == 1:
y1 = cmod(factors[0]*v.vmin+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmin+const, c) # type: ignore
y2 = cmod(factors[0]*v.vmax+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmax+const, c) # type: ignore
y1 = cmod(factors[0]*v.vmin+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmin+const, c)
y2 = cmod(factors[0]*v.vmax+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmax+const, c)
return (y2-y1)*(v-v.vmin) + y1
return None
@@ -437,7 +437,7 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
# try all the valids together (but only the whole expressions)
if (s_uop:=uop.substitute(sub_dict:=dict(all_candidates))) is not uop:
uop = s_uop.simplify(tracked=True).substitute({newX:X for X,newX in sub_dict.items()}).simplify(full_symbolic=False)
uop = s_uop.simplify().substitute({newX:X for X,newX in sub_dict.items()}).simplify(full_symbolic=False)
# put the loads back in
uop = uop.substitute({v:k for k,v in load_subs.items()})
return uop
@@ -470,13 +470,16 @@ def reduce_mul_chain(r:UOp):
if len(outside) == 0: return None
return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside)
# this is symbolic 2.0
REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT, Ops.NOOP}
REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP}
sym = symbolic_flat+PatternMatcher([
pm_simplify_valid = PatternMatcher([
# simplify valid
(UPat(Ops.AND, name="valid"), simplify_valid),
(UPat.var("c").where(UPat.var("x", dtype=dtypes.index), invalid_pat), lambda c,x,i: c.where(uop_given_valid(c, x, try_simplex=False), i)),
])
# this is symbolic 2.0
REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT, Ops.NOOP}
REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP}
sym = symbolic_flat+pm_simplify_valid+PatternMatcher([
# LOAD/STORE -> NOOP
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),

View File

@@ -257,8 +257,8 @@ async function renderProfiler() {
x += 1; y += nbytes; valueMap.set(ts, y);
} else {
const free = buf_shapes.get(key);
timestamps.push(ts);
x += 1; y -= free.nbytes; valueMap.set(ts, y);
timestamps.push(ts); valueMap.set(ts, y);
x += 1; y -= free.nbytes;
free.x.push(x);
free.y.push(free.y.at(-1));
temp.delete(key);
@@ -401,6 +401,7 @@ async function renderProfiler() {
}
}
// draw markers
ctx.textBaseline = "top";
for (const m of markers) {
const x = xscale(m.ts);
drawLine(ctx, [x, x], [0, canvas.clientHeight], { color:m.color });