Merge branch 'master' into rtoposort

This commit is contained in:
George Hotz
2025-10-08 15:14:43 +08:00
committed by GitHub
12 changed files with 1191 additions and 44 deletions

View File

@@ -50,7 +50,7 @@ def ioctls_from_header():
hdr = (pathlib.Path(__file__).parent / "kfd_ioctl.h").read_text().replace("\\\n", "")
pattern = r'#define\s+(AMDKFD_IOC_[A-Z0-9_]+)\s+AMDKFD_IOW?R?\((0x[0-9a-fA-F]+),\s+struct\s([A-Za-z0-9_]+)\)'
matches = re.findall(pattern, hdr, re.MULTILINE)
return {int(nr, 0x10):(name, getattr(kfd_ioctl, "struct_"+sname)) for name, nr, sname in matches}
return {int(nr, 0x10):(name, getattr(kfd_ioctl, "struct_"+sname, None)) for name, nr, sname in matches}
nrs = ioctls_from_header()
@ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int, ctypes.c_ulong, ctypes.c_void_p)

File diff suppressed because it is too large Load Diff

View File

@@ -17,7 +17,7 @@ def ioctls_from_header():
pattern = r'#define\s+(AMDKFD_IOC_[A-Z0-9_]+)\s+AMDKFD_(IOW?R?)\((0x[0-9a-fA-F]+),\s+struct\s([A-Za-z0-9_]+)\)'
matches = re.findall(pattern, hdr, re.MULTILINE)
return type("KFD_IOCTLS", (object, ), {name: int(nr, 0x10) for name, _, nr, _ in matches}), \
{int(nr, 0x10): getattr(kfd, "struct_"+sname) for name, idir, nr, sname in matches}
{int(nr, 0x10): getattr(kfd, "struct_"+sname, None) for name, idir, nr, sname in matches}
kfd_ioctls, kfd_headers = ioctls_from_header()
class KFDFileDesc(VirtFileDesc):
@@ -115,6 +115,10 @@ class AMDDriver(VirtDriver):
struct = kfd_headers[nr].from_address(argp)
if nr == kfd_ioctls.AMDKFD_IOC_ACQUIRE_VM: pass
elif nr == kfd_ioctls.AMDKFD_IOC_RUNTIME_ENABLE: pass
elif nr == kfd_ioctls.AMDKFD_IOC_GET_VERSION:
struct.major_version = 1
struct.minor_version = 14
elif nr == kfd_ioctls.AMDKFD_IOC_ALLOC_MEMORY_OF_GPU:
if struct.gpu_id not in self.gpus: return -1
struct.handle = self._alloc_handle()

View File

@@ -113,7 +113,6 @@ class TestSchedule(unittest.TestCase):
self.assertListEqual(a.tolist(), [[15]])
@unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch")
@expect_rangeify_fails
def test_error_on_device_mismatch(self):
a = Tensor.empty(10)
b = Tensor.empty(10, device="CPU")
@@ -121,7 +120,6 @@ class TestSchedule(unittest.TestCase):
with self.assertRaisesRegex(RuntimeError, "all buffers must be on the same device"): check_schedule(c, 1)
@unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch")
@expect_rangeify_fails
def test_error_on_device_mismatch_alt(self):
a = Tensor.empty(10)
b = Tensor.empty((1,), device="CPU").expand(10).contiguous()
@@ -399,7 +397,6 @@ class TestSchedule(unittest.TestCase):
# a and b share the same underlying device memory
self.assertIs(a.uop.realized, b.uop.realized)
@expect_rangeify_fails
def test_clone_doesnt_dedup(self):
src = Tensor.ones(4).contiguous().realize()
a = src.clone()
@@ -407,7 +404,7 @@ class TestSchedule(unittest.TestCase):
sched = check_schedule([a, b], 2, filter_sink=False)
run_schedule(sched)
# a and b are assigned to the same device Buffer
self.assertIsNot(a.uop.realized, b.uop.realized)
self.assertIsNot(a.uop.base.realized, b.uop.base.realized)
# EMPTY is assigned to a unique device Buffer
@@ -2468,23 +2465,24 @@ class TestUOpBecome(unittest.TestCase):
self.assertEqual(add.uop.shape, (8, 2))
assert add.uop is not add.uop.base
@expect_rangeify_fails
def test_new_flat_buffer(self):
a = Tensor.empty(4,)
b = Tensor.empty(4,)
add = a+b
check_schedule(add, 1)
# BUFFER already has a shape (4,), this tensor just becomes a contiguous BUFFER
assert UPat(Ops.BUFFER).match(add.uop, {})
assert UPat(Ops.BUFFER).match(add.uop.base, {})
# sometimes we prefer to perform an op before movement ops, in this case we should stack the mops on top of the new buffer
# NOTE: this expand is not reordered because there's before it to fuse
@expect_rangeify_fails
def test_reorder_expand(self):
a = Tensor.empty(4, 1)
b = a.expand(4, 4).reciprocal()
check_schedule(b, 1)
if RANGEIFY:
self.assertEqual(b.uop.base.buffer.size, 4)
self.assertEqual(b.uop.shape, (4, 4))
return
self.assertEqual(b.uop.base.buffer.size, 16)
self.assertEqual(b.uop.st, ShapeTracker.from_shape((4, 4)))
@@ -2501,7 +2499,6 @@ class TestUOpBecome(unittest.TestCase):
b = a*1
assert UPat(Ops.MUL).match(b.uop, {}) # before scheduling it's a mul
check_schedule(b, 0)
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.uop, {}) # scheduling merges all MovementOps into a single VIEW
self.assertIs(a.uop.base.buffer, b.uop.base.buffer)
def test_become_buf_with_mops(self):
@@ -2523,17 +2520,6 @@ class TestUOpBecome(unittest.TestCase):
check_schedule(b, 0)
assert UPat(Ops.CONST, arg=0).match(b.uop.base, {}) # scheduling replaces the tensor uop with a VIEW(BUFFER)
@expect_rangeify_fails
def test_become_const_in_view(self):
# if we shrink the base down to a size 0, only the VIEW becomes CONST, base is unchanged.
add = Tensor.empty(2, 2)+Tensor.empty(2, 2)
b = add.shrink(((0, 1), (0, 0)))
check_schedule(b, 0)
assert UPat(Ops.CONST, arg=0).match(b.uop, {})
self.assertEqual(b.shape, (1, 0))
# the base is untouched.
assert UPat(Ops.ADD).match(add.uop, {})
def test_become_const_from_const(self):
const_add = Tensor(1)+Tensor(2)
assert UPat(Ops.ADD).match(const_add.uop, {})
@@ -2585,14 +2571,17 @@ class TestUOpBecome(unittest.TestCase):
assert b.uop is c.uop
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(c.uop, {})
@expect_rangeify_fails
def test_setitem_becomes_subbuffer(self):
a = Tensor.full((4,), 2.).contiguous().realize()
b = a.shrink(((0, 2),)).assign(Tensor.full((2,), 1.0))
b.realize()
assert a.uop.is_realized
assert a.uop.buffer._base is None
# b is a subbuffer of a
# b is a subbuffer of a (buffer_view in non rangeify, rangeify just makes a shrink)
if RANGEIFY:
assert b.uop.op_in_backward_slice_with_self(Ops.SHRINK)
assert b.uop.base is a.uop.base
return
assert b.uop.op is Ops.BUFFER_VIEW
assert b.uop.src[0] is a.uop

View File

@@ -544,6 +544,11 @@ class TestUopsObject(unittest.TestCase):
with Timing("create 10k uops:"): ret = [UOp(Ops.CONST, dtypes.int, arg=10000000+i) for i in range(10000)]
assert len(ret) == 10000
def test_nested(self):
a = UOp.new_buffer(Device.DEFAULT, 1, dtypes.char)
for _ in range(10_000): a = a+a
self.assertEqual(a.device, Device.DEFAULT)
class TestUOpRender(unittest.TestCase):
def test_render_vectorize_same(self):
u = UOp(Ops.VECTORIZE, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0)))

View File

@@ -379,9 +379,9 @@ class TestVizProfiler(unittest.TestCase):
j = load_profile(prof)
tracks = list(j['layout'])
self.assertEqual(tracks[0], 'NV Graph')
self.assertEqual(tracks[1], 'NV')
self.assertEqual(tracks[2], 'NV:1')
self.assertEqual(tracks[0], 'NV')
self.assertEqual(tracks[1], 'NV:1')
self.assertEqual(tracks[2], 'NV Graph')
nv_events = j['layout']['NV']['events']
self.assertEqual(nv_events[0]['name'], 'E_25_4n2')

View File

@@ -557,7 +557,9 @@ class KFDIface:
for i in FileIOInterface(f'{ip_base}/{hw}').listdir()} for ip,hw in ip_hw }
self.drm_fd = FileIOInterface(f"/dev/dri/renderD{self.props['drm_render_minor']}", os.O_RDWR)
self.kfd_ver = ((ver_st:=kfd.AMDKFD_IOC_GET_VERSION(KFDIface.kfd)).major_version, ver_st.minor_version)
kfd.AMDKFD_IOC_ACQUIRE_VM(KFDIface.kfd, drm_fd=self.drm_fd.fd, gpu_id=self.gpu_id)
if self.kfd_ver >= (1,14): kfd.AMDKFD_IOC_RUNTIME_ENABLE(KFDIface.kfd, mode_mask=0)
# Set these for our device.
if KFDIface.event_page is None:

View File

@@ -450,7 +450,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
# this is the ranges replaced
# NOTE: if buf src is a const, we don't replace it
replaces = flatten([(k,v) for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST])
return UOp(Ops.SUBSTITUTE, src=(src, UOp(Ops.NOOP, src=tuple(replaces[0::2])), UOp(Ops.NOOP, src=tuple(replaces[1::2]))))
return UOp(Ops.SUBSTITUTE, dtype=src.dtype, src=(src, UOp(Ops.NOOP, src=tuple(replaces[0::2])), UOp(Ops.NOOP, src=tuple(replaces[1::2]))))
def pre_bufferize(b:UOp, x:UOp, copy:UOp):
nb = b.replace(src=(b.src[0].contiguous(),)+b.src[1:])
@@ -692,6 +692,8 @@ def split_store(ctx:list[UOp], x:UOp):
if ret.src[1].op not in {Ops.COPY, Ops.BUFFER_VIEW} else ret.src[1]
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1])
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}")
return x.as_buf().assign(kernel)
split_kernels = PatternMatcher([
@@ -731,7 +733,7 @@ def do_sub_recurse(s:UOp):
if x.op is Ops.SUBSTITUTE:
sub_k = UOp(Ops.SUBSTITUTE, src=(x.src[1],)+s.src[1:])
sub_v = UOp(Ops.SUBSTITUTE, src=(x.src[2],)+s.src[1:])
return UOp(Ops.SUBSTITUTE, src=(x.src[0], sub_k, sub_v))
return UOp(Ops.SUBSTITUTE, dtype=x.dtype, src=(x.src[0], sub_k, sub_v))
# here we actually do the SUBSTITUTE
if x in keys: return values[keys.index(x)]
# we filter any keys that aren't in the backward slice. this keeps the algorithm O(output graph size)
@@ -741,7 +743,7 @@ def do_sub_recurse(s:UOp):
if len(new_kv) == 0: return x
# then we add SUBSTITUTE to all parents
uop_keys, uop_values = UOp(Ops.NOOP, src=tuple(new_kv.keys())), UOp(Ops.NOOP, src=tuple(new_kv.values()))
return x.replace(src=tuple([UOp(Ops.SUBSTITUTE, src=(y,uop_keys,uop_values)) for y in x.src]))
return x.replace(src=tuple([UOp(Ops.SUBSTITUTE, dtype=y.dtype, src=(y,uop_keys,uop_values)) for y in x.src]))
pm_substitute_recurse = PatternMatcher([(UPat(Ops.SUBSTITUTE, src=(UPat(), UPat(Ops.NOOP), UPat(Ops.NOOP)), name="s"), do_sub_recurse)])
# *** fast rangeify ***

View File

@@ -24,15 +24,15 @@ from tinygrad.schedule.kernelize import get_kernelize_map
all_tensors: dict[weakref.ref[Tensor], None] = {}
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str|None=None) -> None:
fixed_tensors = [t for tref in all_tensors if (t:=tref()) is not None and
scope_tensors = [t for tref in tuple(all_tensors) if (t:=tref()) is not None and
(t.uop in applied_map or len(applied_map.keys() & t.uop.backward_slice.keys()))]
# get all Tensors and apply the map
sink = UOp.sink(*[t.uop for t in fixed_tensors])
sink = UOp.sink(*[t.uop for t in scope_tensors])
new_sink = sink.substitute(applied_map, name=name)
# set the relevant uop to the realized UOps
for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
for t,s,ns in zip(scope_tensors, sink.src, new_sink.src):
if s is ns: continue
t.uop = ns

View File

@@ -488,7 +488,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def new_buffer(device:str|tuple[str, ...], size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp.unique(), UOp(Ops.DEVICE, arg=device)), size)
@property
def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
@functools.cached_property
@recursive_property
def _device(self) -> str|tuple[str, ...]|None:
if self.op is Ops.DEVICE: return self.arg
if self.op is Ops.BUFFERIZE: return self.arg.device

View File

@@ -261,6 +261,9 @@ full_spec = PatternMatcher([
# SENTINEL should never be in the graph
(UPat(Ops.SENTINEL), lambda: False),
# allow any SUBSTITUTE
(UPat(Ops.SUBSTITUTE), lambda: True),
# Invalid must have type Index
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
# where on index in rhs position is fine
@@ -277,7 +280,7 @@ full_spec = PatternMatcher([
# rangeify: buffer view with index or load is okay
(UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.INDEX, Ops.LOAD)),)), lambda: True),
# bufferize (must be on ranges)
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.op is Ops.RANGE for y in x.src[1:])),
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.op in {Ops.RANGE, Ops.CONST} for y in x.src[1:])),
# realize with one src is fine
(UPat(Ops.REALIZE, src=(UPat(),)), lambda: True),
# intermediate index

View File

@@ -199,7 +199,8 @@ def get_profile(profile:list[ProfileEvent]) -> bytes|None:
v.sort(key=lambda e:e[0])
layout[k] = timeline_layout(v, start_ts, scache)
layout[f"{k} Memory"] = mem_layout(v, start_ts, unwrap(end_ts), peaks, dtype_size, scache)
ret = [b"".join([struct.pack("<B", len(k)), k.encode(), v]) for k,v in layout.items() if v is not None]
groups = sorted(layout.items(), key=lambda x: '' if len(ss:=x[0].split(" ")) == 1 else ss[1])
ret = [b"".join([struct.pack("<B", len(k)), k.encode(), v]) for k,v in groups if v is not None]
index = json.dumps({"strings":list(scache), "dtypeSize":dtype_size, "markers":[{"ts":int(e.ts-start_ts), **e.arg} for e in markers]}).encode()
return struct.pack("<IQII", unwrap(end_ts)-start_ts, max(peaks,default=0), len(index), len(ret))+index+b"".join(ret)