mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into nak
This commit is contained in:
8
.github/workflows/test.yml
vendored
8
.github/workflows/test.yml
vendored
@@ -310,9 +310,9 @@ jobs:
|
||||
- name: Fuzz Test fast idiv
|
||||
run: python test/external/fuzz_fast_idiv.py
|
||||
- name: Fuzz Test shapetracker
|
||||
run: |
|
||||
python test/external/fuzz_shapetracker.py
|
||||
python test/external/fuzz_shapetracker_math.py
|
||||
run: CNT=50 python test/external/fuzz_shapetracker.py
|
||||
- name: Fuzz Test shapetracker math
|
||||
run: CNT=200 python test/external/fuzz_shapetracker_math.py
|
||||
- name: Fuzz Test shape ops
|
||||
run: python test/external/fuzz_shape_ops.py
|
||||
|
||||
@@ -377,7 +377,7 @@ jobs:
|
||||
llvm: 'true'
|
||||
- name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
ALLOWED_KERNEL_COUNT=190 ALLOWED_READ_IMAGE=2041 ALLOWED_GATED_READ_IMAGE=41 FLOAT16=0 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
ALLOWED_KERNEL_COUNT=190 ALLOWED_READ_IMAGE=2092 ALLOWED_GATED_READ_IMAGE=55 FLOAT16=0 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
- name: Test openpilot alt model correctness (float32)
|
||||
run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
|
||||
- name: Test openpilot fastvits model correctness (float32)
|
||||
|
||||
@@ -42,13 +42,13 @@ class ProcessReplayWarning(Warning): pass
|
||||
|
||||
# *** replay the function and convert return values to string
|
||||
|
||||
def replay_kernelize(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str, tuple[Any, ...]]:
|
||||
def replay_get_rangeify_map(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str, tuple[Any, ...]]:
|
||||
UOp.unique_num = itertools.count(max([u.arg for u in big_sink.toposort() if u.op is Ops.UNIQUE], default=0)+1)
|
||||
new_sink = big_sink.substitute(get_rangeify_map(big_sink))
|
||||
def to_str(ret:UOp) -> str:
|
||||
asts = [repr(u.arg.ast) for u in ret.toposort() if u.op is Ops.KERNEL]
|
||||
return "\n".join([f"{len(asts)} kernels", *asts])
|
||||
return to_str(new_sink), to_str(ret[big_sink]), (big_sink,)
|
||||
return to_str(new_sink), to_str(big_sink.substitute(ret)), (big_sink,)
|
||||
|
||||
def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]:
|
||||
# NOTE: this always uses the opts_to_apply path
|
||||
@@ -65,7 +65,7 @@ def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer|None=None, opts
|
||||
ast_repr = codecs.decode(str(input_ast), "unicode_escape")
|
||||
return to_str(p2), to_str(p), (ast_repr, renderer)
|
||||
|
||||
replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {"get_kernelize_map":replay_kernelize, "get_program":replay_get_program}
|
||||
replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {"get_rangeify_map":replay_get_rangeify_map, "get_program":replay_get_program}
|
||||
|
||||
# *** run replayers on captured rows and print diffs
|
||||
|
||||
|
||||
@@ -8,9 +8,11 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp
|
||||
from tinygrad.device import Device, Buffer, is_dtype_supported
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner, get_program
|
||||
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT
|
||||
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT, getenv
|
||||
from tinygrad.dtype import DType, dtypes, PtrDType, AddrSpace
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
MOCKGPU = getenv("MOCKGPU")
|
||||
|
||||
class TestLinearizer(unittest.TestCase):
|
||||
def test_arg_dedup(self):
|
||||
@@ -314,7 +316,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a.realize()
|
||||
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
|
||||
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX indexes differently. might be ok?")
|
||||
@unittest.skipIf(MOCKGPU and isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, CUDARenderer)), "PTX indexes differently. might be ok?")
|
||||
def test_where_fold(self):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
b = a.shrink(((1, 2), None)).pad(((1, 2), None))
|
||||
|
||||
@@ -30,6 +30,7 @@ class BaseTestViz(unittest.TestCase):
|
||||
# clear the global context
|
||||
for lst in [tracked_keys, tracked_ctxs, active_rewrites, _name_cnt]: lst.clear()
|
||||
Buffer.profile_events.clear()
|
||||
cpu_events.clear()
|
||||
self.tms = TRACK_MATCH_STATS.value
|
||||
self.profile = PROFILE.value
|
||||
TRACK_MATCH_STATS.value = 2
|
||||
@@ -462,5 +463,21 @@ class TestVizMemoryLayout(BaseTestViz):
|
||||
self.assertEqual(ret["peak"], 2)
|
||||
self.assertEqual(len(ret["events"]), 4)
|
||||
|
||||
def test_free_last(self):
|
||||
bufs = []
|
||||
for _ in range(3):
|
||||
bufs.append(_alloc(1))
|
||||
profile_marker("alloc")
|
||||
device = bufs[0].device
|
||||
while bufs:
|
||||
b = bufs.pop()
|
||||
del b
|
||||
profile_marker("free")
|
||||
profile = load_profile(cpu_events+Buffer.profile_events)
|
||||
ret = profile["layout"][f"{device} Memory"]
|
||||
self.assertEqual(ret["peak"], 3)
|
||||
self.assertEqual(len(ret["events"]), 6)
|
||||
self.assertEqual(len(profile["markers"]), 6)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -42,7 +42,7 @@ class TestWinograd(unittest.TestCase):
|
||||
out = Tensor.conv2d(x,w, padding=1)
|
||||
out.mean().backward()
|
||||
backward_schedule = Tensor.schedule(x.grad, w.grad)
|
||||
self.assertEqual(len(backward_schedule), 4)
|
||||
self.assertEqual(len(backward_schedule), 5)
|
||||
|
||||
def test_counters(self):
|
||||
IC, OC, X, Y = 4,4,9,9
|
||||
|
||||
@@ -178,7 +178,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
|
||||
|
||||
if k.opts.has_threads and k.opts.global_max is not None:
|
||||
for threads in [32,16,12,8,6,5,4,3,2]:
|
||||
# Skip is too many threads. Heuristic: use about 128K ops per thread
|
||||
# Skip if too many threads. Heuristic: use about 128K ops per thread
|
||||
if threads > k.opts.global_max[0] or resolve(prod(k.full_shape) // (128 << 10) < threads): continue
|
||||
for axis in k.axes_of(AxisType.LOOP):
|
||||
if k.full_shape[axis] % threads == 0:
|
||||
|
||||
@@ -165,15 +165,22 @@ def beam_search(lin:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=Tr
|
||||
if isinstance(e, RuntimeError): continue
|
||||
raise
|
||||
timed_lins.append((acted_lins[i], min(tms)))
|
||||
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {time_to_str(compile_et, w=12)} compile/{time_to_str(timed_lins[-1][1], w=12)} run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
|
||||
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {time_to_str(timed_lins[-1][1], w=12)} {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
||||
if BEAM_DEBUG > 1:
|
||||
print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops",
|
||||
f"{time_to_str(compile_et, w=12)} compile/{time_to_str(timed_lins[-1][1], w=12)} run",
|
||||
f" {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}")
|
||||
elif DEBUG >= 2:
|
||||
print(f"\r{time.perf_counter() - st:7.2f}s: {time_to_str(timed_lins[-1][1], w=12)}",
|
||||
f" {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="")
|
||||
|
||||
# done
|
||||
opts = sorted(timed_lins, key=lambda x: x[1])
|
||||
exiting = len(opts) == 0 or (opts[0][1] < min_progress) or (len(beam) > 0 and ((beam[0][1]-opts[0][1]) < min_progress))
|
||||
if not exiting: beam = opts[:amt]
|
||||
elif len(opts) > 0 and opts[0][1] < beam[0][1]: beam = opts[:1]
|
||||
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(time_to_str(beam[0][1], w=12), "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
|
||||
if DEBUG >= 2:
|
||||
print(f"\r{time.perf_counter() - st:7.2f}s:", colored(time_to_str(beam[0][1], w=12), "green" if exiting else None),
|
||||
f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape())
|
||||
except KeyboardInterrupt as e:
|
||||
if beam_pool is not None: beam_pool.terminate()
|
||||
raise e
|
||||
|
||||
@@ -23,7 +23,7 @@ class _Device:
|
||||
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
|
||||
@functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def __get_canonicalized_item(self, ix:str) -> Compiled:
|
||||
assert ALLOW_DEVICE_USAGE or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"usage of device {ix} disallowed"
|
||||
assert ALLOW_DEVICE_USAGE or ix.split(":")[0] in ["DISK", "TINYFS", "NPY", "PYTHON"], f"usage of device {ix} disallowed"
|
||||
base = (__package__ or __name__).split('.')[0] # tinygrad
|
||||
x = ix.split(":")[0].lower()
|
||||
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{base}.runtime.ops_{x}')) \
|
||||
@@ -39,7 +39,7 @@ class _Device:
|
||||
@functools.cached_property
|
||||
def DEFAULT(self) -> str:
|
||||
dev = [dev] if (dev:=getenv("DEV", "").upper()) else []
|
||||
from_env = dedup(dev + [d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1])
|
||||
from_env = dedup(dev + [d for d in self._devices if d not in ["DISK", "TINYFS", "NPY"] and getenv(d) == 1])
|
||||
assert len(from_env) < 2, f"multiple devices set in env: {from_env}"
|
||||
if len(from_env) == 1: return from_env[0]
|
||||
try:
|
||||
|
||||
@@ -121,7 +121,7 @@ class BufferCopy(Runner):
|
||||
getattr(src.allocator.dev, 'fd', None) is not None and dest.allocator.supports_copy_from_disk
|
||||
if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096:
|
||||
dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
|
||||
elif src.device.startswith("DISK") and hasattr(dest.allocator, '_as_buffer'):
|
||||
elif (src.device.startswith("DISK") or src.device.startswith("TINYFS")) and hasattr(dest.allocator, '_as_buffer'):
|
||||
# fast(ish) path, uses readinto in diskbuffers
|
||||
src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf)
|
||||
else:
|
||||
|
||||
@@ -23,10 +23,13 @@ def argfix(*x):
|
||||
if len(x) != 1: raise ValueError(f"bad arg {x}")
|
||||
return tuple(x[0])
|
||||
return x
|
||||
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
||||
# https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
||||
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__))
|
||||
def all_same(items:tuple[T, ...]|list[T]): return all(x == items[0] for x in items)
|
||||
def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstance(s, int) for s in t)
|
||||
def colored(st, color:str|None, background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
|
||||
def colored(st, color:str|None, background=False): # replace the termcolor library
|
||||
colors = ['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white']
|
||||
return f"\u001b[{10*background+60*(color.upper() == color)+30+colors.index(color.lower())}m{st}\u001b[0m" if color is not None else st
|
||||
def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
|
||||
def time_to_str(t:float, w=8) -> str: return next((f"{t * d:{w}.2f}{pr}" for d,pr in [(1, "s "),(1e3, "ms")] if t > 10/d), f"{t * 1e6:{w}.2f}us")
|
||||
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
|
||||
@@ -150,7 +153,7 @@ CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), Co
|
||||
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0)
|
||||
FUSE_ATTENTION = ContextVar("FUSE_ATTENTION", 0)
|
||||
EMULATE = ContextVar("EMULATE", "")
|
||||
CPU_COUNT = ContextVar("CPU_COUNT", max(1, (os.cpu_count() or 1) // (4 if ARCH_X86 else 2))) # take 1/2 of the cores, accounting HT
|
||||
CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(aff(0)) if (aff:=getattr(os, "sched_getaffinity", None)) else (os.cpu_count() or 1)))
|
||||
CPU_LLVM, CPU_LVP, AMD_LLVM = ContextVar("CPU_LLVM", 0), ContextVar("CPU_LVP", 0), ContextVar("AMD_LLVM", 1)
|
||||
VIZ = PROFILE = ContextVar("VIZ", 0)
|
||||
SPEC = ContextVar("SPEC", 0)
|
||||
@@ -218,11 +221,12 @@ class TracingKey:
|
||||
class ProfileEvent: pass
|
||||
|
||||
@dataclass
|
||||
class ProfileRangeEvent(ProfileEvent): device:str; name:str|TracingKey; st:decimal.Decimal; en:decimal.Decimal|None=None; is_copy:bool=False # noqa: E702
|
||||
class ProfileRangeEvent(ProfileEvent):
|
||||
device:str; name:str|TracingKey; st:decimal.Decimal; en:decimal.Decimal|None=None; is_copy:bool=False # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProfilePointEvent(ProfileEvent): device:str; name:str; key:Any; arg:dict=field(default_factory=dict); \
|
||||
ts:decimal.Decimal=field(default_factory=perf_counter_us) # noqa: E702
|
||||
class ProfilePointEvent(ProfileEvent):
|
||||
device:str; name:str; key:Any; arg:dict=field(default_factory=dict); ts:decimal.Decimal=field(default_factory=perf_counter_us) # noqa: E702
|
||||
|
||||
cpu_events:list[ProfileEvent] = []
|
||||
@contextlib.contextmanager
|
||||
@@ -281,7 +285,8 @@ def diskcache_put(table:str, key:dict|str|int, val:Any, prepickled=False):
|
||||
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
|
||||
cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
|
||||
_db_tables.add(table)
|
||||
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key))}, ?)", tuple(key.values()) + (val if prepickled else pickle.dumps(val), )) # noqa: E501
|
||||
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key))}, ?)",
|
||||
tuple(key.values()) + (val if prepickled else pickle.dumps(val),))
|
||||
conn.commit()
|
||||
cur.close()
|
||||
return val
|
||||
|
||||
@@ -108,7 +108,9 @@ class CStyleLanguage(Renderer):
|
||||
extra_matcher = extra_pm
|
||||
|
||||
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
|
||||
tmp = ""
|
||||
if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs):
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
|
||||
buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else
|
||||
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
|
||||
local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]
|
||||
@@ -229,10 +231,12 @@ class ClangRenderer(CStyleLanguage):
|
||||
# 'static' in C roughly means that function symbol isn't exported. LLVM puts those symbols at the end of object file which allows Clang JIT
|
||||
# to just jump at the start of a shellcode without having to deal with symbols or trampolines at all. This is better than having to inline
|
||||
# wmma function every time it is called or wasting complexity on a symbol parsing and a memory page on trampoline.
|
||||
prefix += [f"""static {(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
|
||||
out, dt1, dt2 = self.render_dtype(dtype_in.vec(N*N)), self.render_dtype(dtype_in.vec(N)), self.render_dtype(dtype_in.vec(M))
|
||||
prefix += [f"""static {out} __{name}({dt1} data1, {dt2} data2, {out} data0){{
|
||||
AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
|
||||
AMX(0, (int *)(&data2), 0ull<<62); AMX(1, (int *)(&data1), 0ull<<62); AMX(12, 0, 0ull);
|
||||
for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}\n AMX_SET(1);\n return data0;\n}}"""] # noqa: E501
|
||||
for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
|
||||
AMX_SET(1);\n return data0;\n}}"""]
|
||||
return prefix
|
||||
def _render_body(self, function_name, kernel, bufs, uops, pref=None) -> str: return super().render_kernel(function_name, kernel, bufs, uops, pref)
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str: return ""
|
||||
|
||||
@@ -23,10 +23,12 @@ class CLCompiler(Compiler):
|
||||
build_status: int = cl.clBuildProgram(program, 1, self.dev.device_id, None, cl.clBuildProgram.argtypes[4](), None)
|
||||
if build_status != 0:
|
||||
cl.clGetProgramBuildInfo(program, self.dev.device_id, cl.CL_PROGRAM_BUILD_LOG, 0, None, log_size := ctypes.c_size_t())
|
||||
cl.clGetProgramBuildInfo(program, self.dev.device_id, cl.CL_PROGRAM_BUILD_LOG, log_size.value, mstr := ctypes.create_string_buffer(log_size.value), None) # noqa: E501
|
||||
cl.clGetProgramBuildInfo(program, self.dev.device_id, cl.CL_PROGRAM_BUILD_LOG,
|
||||
log_size.value, mstr := ctypes.create_string_buffer(log_size.value), None)
|
||||
raise CompileError(f"OpenCL Compile Error\n\n{mstr.value.decode()}")
|
||||
check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARY_SIZES, ctypes.sizeof(ctypes.c_size_t), binary_sizes := (ctypes.c_size_t * 1)(), None))
|
||||
check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARIES, ctypes.sizeof(ctypes.c_void_p), (ctypes.c_void_p * 1)(ctypes.addressof(binary := ctypes.create_string_buffer(binary_sizes[0]))), None)) # noqa: E501
|
||||
check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARIES, ctypes.sizeof(ctypes.c_void_p),
|
||||
(ctypes.c_void_p * 1)(ctypes.addressof(binary := ctypes.create_string_buffer(binary_sizes[0]))), None))
|
||||
check(cl.clReleaseProgram(program))
|
||||
return bytes(binary)
|
||||
|
||||
@@ -97,16 +99,22 @@ class CLDevice(Compiled):
|
||||
err = cl.clGetDeviceIDs(platform_ids[0], device_type, 0, None, num_devices := ctypes.c_uint32())
|
||||
if err == 0 and num_devices.value != 0: break
|
||||
if DEBUG >= 1: print(f"CLDevice: got {num_platforms.value} platforms and {num_devices.value} devices")
|
||||
CLDevice.device_ids = init_c_var((cl.cl_device_id * num_devices.value)(), lambda x: check(cl.clGetDeviceIDs(platform_ids[0], device_type, num_devices, x, None))) # noqa: E501
|
||||
CLDevice.device_ids = init_c_var((cl.cl_device_id * num_devices.value)(),
|
||||
lambda x: check(cl.clGetDeviceIDs(platform_ids[0], device_type, num_devices, x, None)))
|
||||
|
||||
self.device_id = CLDevice.device_ids[0 if ":" not in device else int(device.split(":")[1])]
|
||||
self.device_name = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_NAME, 256, buf := ctypes.create_string_buffer(256), None), buf.value.decode())[1] # noqa: E501
|
||||
self.driver_version = (cl.clGetDeviceInfo(self.device_id, cl.CL_DRIVER_VERSION, 256, buf := ctypes.create_string_buffer(256), None), buf.value.decode())[1] # noqa: E501
|
||||
self.device_name = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_NAME, 256,
|
||||
buf:=ctypes.create_string_buffer(256), None), buf.value.decode())[1]
|
||||
self.driver_version = (cl.clGetDeviceInfo(self.device_id, cl.CL_DRIVER_VERSION, 256,
|
||||
buf:=ctypes.create_string_buffer(256), None), buf.value.decode())[1]
|
||||
if DEBUG >= 1: print(f"CLDevice: opening {self.device_name} with version {self.driver_version}")
|
||||
self.context = checked(cl.clCreateContext(None, 1, self.device_id, cl.clCreateContext.argtypes[3](), None, status := ctypes.c_int32()), status)
|
||||
self.queue = checked(cl.clCreateCommandQueue(self.context, self.device_id, cl.CL_QUEUE_PROFILING_ENABLE, status), status)
|
||||
self.pending_copyin: list[memoryview] = []
|
||||
self.device_exts = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_EXTENSIONS, 4096, ctypes.byref(buf := ctypes.create_string_buffer(4096)), ctypes.byref(total := ctypes.c_size_t())), ctypes.string_at(buf, size=total.value).decode())[1] # noqa: E501
|
||||
self.device_exts = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_EXTENSIONS, 4096,
|
||||
ctypes.byref(buf := ctypes.create_string_buffer(4096)),
|
||||
ctypes.byref(total := ctypes.c_size_t())),
|
||||
ctypes.string_at(buf, size=total.value).decode())[1]
|
||||
|
||||
compilers = [(IntelRenderer if "cl_intel_subgroup_matrix_multiply_accumulate" in self.device_exts else OpenCLRenderer,
|
||||
functools.partial(CLCompiler, self, f"compile_cl_{hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()}"))]
|
||||
|
||||
@@ -10,7 +10,9 @@ if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401 # pylint:
|
||||
if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.cuda import cuda # type: ignore # pylint: disable=reimported
|
||||
|
||||
def check(status):
|
||||
if status != 0: raise RuntimeError(f"CUDA Error {status}, {ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()}") # noqa: E501
|
||||
if status != 0:
|
||||
error = ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()
|
||||
raise RuntimeError(f"CUDA Error {status}, {error}")
|
||||
|
||||
def encode_args(args, vals) -> tuple[ctypes.Structure, ctypes.Array]:
|
||||
c_args = init_c_struct_t(tuple([(f'f{i}', cuda.CUdeviceptr_v2) for i in range(len(args))] +
|
||||
|
||||
137
tinygrad/runtime/ops_tinyfs.py
Normal file
137
tinygrad/runtime/ops_tinyfs.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import socket, uuid, json, asyncio, threading
|
||||
from contextlib import asynccontextmanager
|
||||
from tinygrad.device import Compiled, Allocator
|
||||
from tinygrad.helpers import DEBUG, getenv
|
||||
|
||||
TINYFS_ENDPOINT = getenv("TINYFS_ENDPOINT", "localhost:6767")
|
||||
CHUNK_SIZE = 2**20
|
||||
|
||||
class TinyFSDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
self.op = device[len("tinyfs:"):].upper()
|
||||
super().__init__(device, TinyFSAllocator(self), None, None, None)
|
||||
|
||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.sock.connect((TINYFS_ENDPOINT.rsplit(":", 1)[0], int(TINYFS_ENDPOINT.rsplit(":", 1)[1])))
|
||||
self.sfile = self.sock.makefile("rwb")
|
||||
|
||||
# fetch node info
|
||||
self.sfile.write(b"INFO\r\n")
|
||||
self.sfile.flush()
|
||||
info = self.sfile.readline()
|
||||
self.node_info = json.loads(info)
|
||||
if DEBUG >= 2: print(f"nodes: {self.node_info}")
|
||||
|
||||
# spawn thread for async copyout
|
||||
self.start_event = threading.Event()
|
||||
self.t = threading.Thread(target=self._start_thread, daemon=True)
|
||||
self.t.start()
|
||||
self.start_event.wait()
|
||||
|
||||
# connection pools
|
||||
self.conn_pools: dict[str, asyncio.Queue] = {}
|
||||
self.conn_pools_lock = asyncio.Lock()
|
||||
|
||||
def finalize(self):
|
||||
self.sfile.close()
|
||||
|
||||
for pool in self.conn_pools.values():
|
||||
while not pool.empty():
|
||||
_, w = pool.get_nowait()
|
||||
w.close()
|
||||
asyncio.run_coroutine_threadsafe(w.wait_closed(), self.loop).result()
|
||||
|
||||
if hasattr(self, "loop"):
|
||||
self.loop.call_soon_threadsafe(self.loop.stop)
|
||||
self.t.join()
|
||||
|
||||
def _start_thread(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
self.start_event.set()
|
||||
self.loop.run_forever()
|
||||
self.loop.close()
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(self, loc):
|
||||
if loc not in self.conn_pools:
|
||||
await self.conn_pools_lock.acquire()
|
||||
if loc not in self.conn_pools:
|
||||
self.conn_pools[loc] = asyncio.Queue(nw:=getenv("ASYNC_COPY_WORKERS", 4))
|
||||
conn_tasks = [asyncio.open_connection(*self.node_info[loc][-1].rsplit(":", 1)) for _ in range(nw)]
|
||||
connections = await asyncio.gather(*conn_tasks)
|
||||
for reader, writer in connections: self.conn_pools[loc].put_nowait((reader, writer))
|
||||
self.conn_pools_lock.release()
|
||||
|
||||
reader, writer = await self.conn_pools[loc].get()
|
||||
try:
|
||||
yield reader, writer
|
||||
finally:
|
||||
await self.conn_pools[loc].put((reader, writer))
|
||||
|
||||
class TinyFSBuffer:
|
||||
def __init__(self, device:TinyFSDevice, size:int, offset=0, request_id=None, copyout_queue=None):
|
||||
self.device, self.size, self.offset = device, size, offset
|
||||
self.request_id: uuid.UUID|None = request_id
|
||||
self.copyout_queue = copyout_queue or []
|
||||
def __repr__(self): return f"<TinyFSBuffer size={self.size} offset={self.offset}>"
|
||||
|
||||
class TinyFSAllocator(Allocator[TinyFSDevice]):
|
||||
def _alloc(self, size, options):
|
||||
return TinyFSBuffer(self.dev, size)
|
||||
|
||||
def _copyin(self, dest:TinyFSBuffer, src:memoryview):
|
||||
if DEBUG >= 2: print(f"Copying in {dest.size} bytes to TINYFS:{dest.device.op}")
|
||||
self.dev.sfile.write(f"{dest.device.op}_IN {dest.size}\r\n".encode())
|
||||
|
||||
if dest.device.op == "STORE":
|
||||
self.dev.sfile.flush()
|
||||
dest.request_id = uuid.UUID(bytes=self.dev.sfile.read(16))
|
||||
if DEBUG >= 2: print(f"Request ID: {dest.request_id}")
|
||||
|
||||
self.dev.sfile.write(src)
|
||||
self.dev.sfile.flush()
|
||||
|
||||
if dest.device.op == "LOAD":
|
||||
locs = self.dev.sfile.readline()
|
||||
locs = json.loads(locs)
|
||||
|
||||
dest.copyout_queue = []
|
||||
for i, loc in enumerate(locs):
|
||||
dest.copyout_queue.append((i, loc, src[i*16:(i+1)*16]))
|
||||
|
||||
def _copyout(self, dest:memoryview, src:TinyFSBuffer):
|
||||
if DEBUG >= 2: print(f"Copying out {src.size} bytes from TINYFS:{src.device.op}")
|
||||
if src.device.op == "LOAD":
|
||||
asyncio.run_coroutine_threadsafe(self._copyout_async(dest, src), src.device.loop).result()
|
||||
else:
|
||||
self.dev.sfile.write(f"{src.device.op}_OUT {src.size} {src.request_id}\r\n".encode())
|
||||
self.dev.sfile.flush()
|
||||
src.request_id = uuid.UUID(bytes=self.dev.sfile.read(16))
|
||||
if DEBUG >= 2: print(f"Request ID: {src.request_id}")
|
||||
self.dev.sfile.readinto(dest)
|
||||
|
||||
async def _copyout_async(self, dest:memoryview, src:TinyFSBuffer):
|
||||
async def _worker(item):
|
||||
i, loc, h = item
|
||||
async with self.dev.connection(loc) as (reader, writer):
|
||||
ptr = i * CHUNK_SIZE
|
||||
size = min(len(dest[ptr:ptr+CHUNK_SIZE]), CHUNK_SIZE)
|
||||
|
||||
writer.write(f"CHUNK_OUT {size}\r\n".encode())
|
||||
writer.write(h)
|
||||
await writer.drain()
|
||||
|
||||
chunk = await reader.readexactly(size)
|
||||
|
||||
view = dest[ptr:ptr+len(chunk)]
|
||||
view[:] = chunk
|
||||
del view
|
||||
|
||||
workers = [asyncio.create_task(_worker(item)) for item in src.copyout_queue]
|
||||
await asyncio.gather(*workers)
|
||||
src.copyout_queue.clear()
|
||||
|
||||
def _offset(self, buf:TinyFSBuffer, size:int, offset:int):
|
||||
return TinyFSBuffer(buf.device, size, offset, buf.request_id, buf.copyout_queue)
|
||||
@@ -60,7 +60,11 @@ def compile_hip(prg:str, arch="gfx1100", asm=False) -> bytes:
|
||||
check(comgr.amd_comgr_set_data_name(data_src, b"<null>"))
|
||||
check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
|
||||
# -include hiprtc_runtime.h was removed
|
||||
check(set_options(action_info, f"-O3 -mcumode --hip-version=6.0.32830 -DHIP_VERSION_MAJOR=6 -DHIP_VERSION_MINOR=0 -DHIP_VERSION_PATCH=32830 -D__HIPCC_RTC__ -std=c++14 -nogpuinc -Wno-gnu-line-marker -Wno-missing-prototypes --offload-arch={arch} -I/opt/rocm/include -Xclang -disable-llvm-passes -Xclang -aux-triple -Xclang x86_64-unknown-linux-gnu".encode())) # noqa: E501
|
||||
options = [
|
||||
"-O3", "-mcumode", "--hip-version=6.0.32830", "-DHIP_VERSION_MAJOR=6", "-DHIP_VERSION_MINOR=0", "-DHIP_VERSION_PATCH=32830",
|
||||
"-D__HIPCC_RTC__", "-std=c++14", "-nogpuinc", "-Wno-gnu-line-marker", "-Wno-missing-prototypes", f"--offload-arch={arch}",
|
||||
"-I/opt/rocm/include", "-Xclang -disable-llvm-passes", "-Xclang -aux-triple", "-Xclang x86_64-unknown-linux-gnu"]
|
||||
check(set_options(action_info, ' '.join(options).encode()))
|
||||
status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, action_info, data_set_src, data_set_bc)
|
||||
if status != 0:
|
||||
print(_get_comgr_data(data_set_bc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())
|
||||
|
||||
@@ -22,10 +22,12 @@ def jitlink_check(status, ctx=None):
|
||||
|
||||
def pretty_ptx(s):
|
||||
# all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
|
||||
s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers # noqa: E501
|
||||
s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])',
|
||||
lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers
|
||||
s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types
|
||||
s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions
|
||||
s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers # noqa: E501
|
||||
s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])',
|
||||
lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers
|
||||
s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
|
||||
s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
|
||||
return s
|
||||
|
||||
@@ -3,7 +3,7 @@ import functools, operator, itertools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType
|
||||
from tinygrad.uop.symbolic import sym, symbolic
|
||||
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid
|
||||
from tinygrad.helpers import argsort, all_same, cpu_profile, TracingKey
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
@@ -112,8 +112,8 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO
|
||||
case Ops.EXPAND: rngs = tuple(a if in_sh == out_sh else a.const_like(0) for a,in_sh,out_sh in zip(rngs, in_shape, arg))
|
||||
case Ops.PAD:
|
||||
# TODO: why is multiple graph_rewrites faster than one here?
|
||||
rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh+s))).where(r-s, UOp.invalid()), sym, name="pad")
|
||||
for r,sh,(s,e) in zip(rngs, in_shape, arg))
|
||||
rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh+s))).where(r-s, UOp.invalid()),
|
||||
symbolic+pm_simplify_valid, name="pad") for r,sh,(s,e) in zip(rngs, in_shape, arg))
|
||||
case Ops.RESHAPE:
|
||||
acc = 1
|
||||
axes_in:list[UOp] = []
|
||||
@@ -126,7 +126,7 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO
|
||||
axes_out.append(combined_axes % s)
|
||||
combined_axes //= s
|
||||
# this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code
|
||||
rngs = graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic, name="reshape").src
|
||||
rngs = graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid, name="reshape").src
|
||||
case _: raise RuntimeError(f"{op} is not a MovementOp")
|
||||
return rngs
|
||||
|
||||
|
||||
@@ -207,7 +207,7 @@ pm_cleanups = pm_mops+PatternMatcher([
|
||||
])
|
||||
|
||||
def late_buffer_view(t:UOp, b:UOp):
|
||||
if isinstance(b.device, str) and b.device.startswith("DISK"):
|
||||
if isinstance(b.device, str) and (b.device.startswith("DISK") or b.device.startswith("TINYFS")):
|
||||
rngs = b.src[1:]
|
||||
size = prod(shape := [int(r.vmax+1) for r in rngs])
|
||||
|
||||
|
||||
@@ -2484,17 +2484,20 @@ class Tensor(MathTrait):
|
||||
if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, dtype)
|
||||
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
||||
padding_ = self._resolve_pool_pads(padding, len(HW))
|
||||
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
|
||||
assert groups*cin == cin_ and len(self.shape) == len(weight.shape),\
|
||||
f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
|
||||
|
||||
# conv2d is a pooling op (with padding)
|
||||
x = self.pad(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
|
||||
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
|
||||
if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not WINO:
|
||||
# normal conv
|
||||
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501
|
||||
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW)\
|
||||
.permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))])
|
||||
|
||||
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
|
||||
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, dtype=dtype).reshape(bs, cout, *oyx) # noqa: E501
|
||||
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW))\
|
||||
.sum([-1-i for i in range(1+len(oyx))], keepdim=True, dtype=dtype).reshape(bs, cout, *oyx)
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
|
||||
|
||||
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
|
||||
@@ -2505,7 +2508,8 @@ class Tensor(MathTrait):
|
||||
# TODO: stride == dilation
|
||||
# use padding to round up to 4x4 output tiles
|
||||
# (bs, cin_, tyx, HWI)
|
||||
d = self.pad(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
|
||||
pads = [[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])]
|
||||
d = self.pad(sum(pads, []))._pool(HWI, HWO)
|
||||
# move HW to the front: # (HWI, bs, cin_, tyx)
|
||||
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW)))
|
||||
tyx = d.shape[-len(HWI):] # dim of tiling
|
||||
|
||||
@@ -177,8 +177,8 @@ def fold_binary_numerator(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
||||
x,const = x.pop_const()
|
||||
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x.split_uop(Ops.ADD)])
|
||||
if len(terms)==1 and (v:=terms[0]).vmax-v.vmin == 1:
|
||||
y1 = cmod(factors[0]*v.vmin+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmin+const, c) # type: ignore
|
||||
y2 = cmod(factors[0]*v.vmax+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmax+const, c) # type: ignore
|
||||
y1 = cmod(factors[0]*v.vmin+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmin+const, c)
|
||||
y2 = cmod(factors[0]*v.vmax+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmax+const, c)
|
||||
return (y2-y1)*(v-v.vmin) + y1
|
||||
return None
|
||||
|
||||
@@ -437,7 +437,7 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
|
||||
|
||||
# try all the valids together (but only the whole expressions)
|
||||
if (s_uop:=uop.substitute(sub_dict:=dict(all_candidates))) is not uop:
|
||||
uop = s_uop.simplify(tracked=True).substitute({newX:X for X,newX in sub_dict.items()}).simplify(full_symbolic=False)
|
||||
uop = s_uop.simplify().substitute({newX:X for X,newX in sub_dict.items()}).simplify(full_symbolic=False)
|
||||
# put the loads back in
|
||||
uop = uop.substitute({v:k for k,v in load_subs.items()})
|
||||
return uop
|
||||
@@ -470,13 +470,16 @@ def reduce_mul_chain(r:UOp):
|
||||
if len(outside) == 0: return None
|
||||
return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside)
|
||||
|
||||
# this is symbolic 2.0
|
||||
REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT, Ops.NOOP}
|
||||
REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP}
|
||||
sym = symbolic_flat+PatternMatcher([
|
||||
pm_simplify_valid = PatternMatcher([
|
||||
# simplify valid
|
||||
(UPat(Ops.AND, name="valid"), simplify_valid),
|
||||
(UPat.var("c").where(UPat.var("x", dtype=dtypes.index), invalid_pat), lambda c,x,i: c.where(uop_given_valid(c, x, try_simplex=False), i)),
|
||||
])
|
||||
|
||||
# this is symbolic 2.0
|
||||
REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT, Ops.NOOP}
|
||||
REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP}
|
||||
sym = symbolic_flat+pm_simplify_valid+PatternMatcher([
|
||||
# LOAD/STORE -> NOOP
|
||||
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
|
||||
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),
|
||||
|
||||
@@ -257,8 +257,8 @@ async function renderProfiler() {
|
||||
x += 1; y += nbytes; valueMap.set(ts, y);
|
||||
} else {
|
||||
const free = buf_shapes.get(key);
|
||||
timestamps.push(ts);
|
||||
x += 1; y -= free.nbytes; valueMap.set(ts, y);
|
||||
timestamps.push(ts); valueMap.set(ts, y);
|
||||
x += 1; y -= free.nbytes;
|
||||
free.x.push(x);
|
||||
free.y.push(free.y.at(-1));
|
||||
temp.delete(key);
|
||||
@@ -401,6 +401,7 @@ async function renderProfiler() {
|
||||
}
|
||||
}
|
||||
// draw markers
|
||||
ctx.textBaseline = "top";
|
||||
for (const m of markers) {
|
||||
const x = xscale(m.ts);
|
||||
drawLine(ctx, [x, x], [0, canvas.clientHeight], { color:m.color });
|
||||
|
||||
Reference in New Issue
Block a user