From a7cb80bfab2d37e10e6fa71124e33d4b42f3558c Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 8 Oct 2025 06:15:05 +0300 Subject: [PATCH 1/6] use recursive_property in UOp device (#12477) * simple failing test with RecursionError * switch to @recursive_property * merge 2 * diff --- test/test_uops.py | 5 +++++ tinygrad/uop/ops.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/test/test_uops.py b/test/test_uops.py index 6571facc04..14315b040d 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -544,6 +544,11 @@ class TestUopsObject(unittest.TestCase): with Timing("create 10k uops:"): ret = [UOp(Ops.CONST, dtypes.int, arg=10000000+i) for i in range(10000)] assert len(ret) == 10000 + def test_nested(self): + a = UOp.new_buffer(Device.DEFAULT, 1, dtypes.char) + for _ in range(10_000): a = a+a + self.assertEqual(a.device, Device.DEFAULT) + class TestUOpRender(unittest.TestCase): def test_render_vectorize_same(self): u = UOp(Ops.VECTORIZE, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0))) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 70d1aa3732..194176bba2 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -476,7 +476,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def new_buffer(device:str|tuple[str, ...], size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp.unique(), UOp(Ops.DEVICE, arg=device)), size) @property def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device)) - @functools.cached_property + @recursive_property def _device(self) -> str|tuple[str, ...]|None: if self.op is Ops.DEVICE: return self.arg if self.op is Ops.BUFFERIZE: return self.arg.device From d06226b575938c32b9f15cf746acb0cd21f823f4 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 8 Oct 2025 11:18:17 +0800 Subject: [PATCH 2/6] fix SPEC and all_tensors iterator (#12496) --- tinygrad/schedule/rangeify.py | 6 +++--- tinygrad/tensor.py | 6 +++--- tinygrad/uop/spec.py | 5 ++++- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index a967a0eec1..545b610180 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -450,7 +450,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp): # this is the ranges replaced # NOTE: if buf src is a const, we don't replace it replaces = flatten([(k,v) for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST]) - return UOp(Ops.SUBSTITUTE, src=(src, UOp(Ops.NOOP, src=tuple(replaces[0::2])), UOp(Ops.NOOP, src=tuple(replaces[1::2])))) + return UOp(Ops.SUBSTITUTE, dtype=src.dtype, src=(src, UOp(Ops.NOOP, src=tuple(replaces[0::2])), UOp(Ops.NOOP, src=tuple(replaces[1::2])))) def pre_bufferize(b:UOp, x:UOp, copy:UOp): nb = b.replace(src=(b.src[0].contiguous(),)+b.src[1:]) @@ -731,7 +731,7 @@ def do_sub_recurse(s:UOp): if x.op is Ops.SUBSTITUTE: sub_k = UOp(Ops.SUBSTITUTE, src=(x.src[1],)+s.src[1:]) sub_v = UOp(Ops.SUBSTITUTE, src=(x.src[2],)+s.src[1:]) - return UOp(Ops.SUBSTITUTE, src=(x.src[0], sub_k, sub_v)) + return UOp(Ops.SUBSTITUTE, dtype=x.dtype, src=(x.src[0], sub_k, sub_v)) # here we actually do the SUBSTITUTE if x in keys: return values[keys.index(x)] # we filter any keys that aren't in the backward slice. this keeps the algorithm O(output graph size) @@ -741,7 +741,7 @@ def do_sub_recurse(s:UOp): if len(new_kv) == 0: return x # then we add SUBSTITUTE to all parents uop_keys, uop_values = UOp(Ops.NOOP, src=tuple(new_kv.keys())), UOp(Ops.NOOP, src=tuple(new_kv.values())) - return x.replace(src=tuple([UOp(Ops.SUBSTITUTE, src=(y,uop_keys,uop_values)) for y in x.src])) + return x.replace(src=tuple([UOp(Ops.SUBSTITUTE, dtype=y.dtype, src=(y,uop_keys,uop_values)) for y in x.src])) pm_substitute_recurse = PatternMatcher([(UPat(Ops.SUBSTITUTE, src=(UPat(), UPat(Ops.NOOP), UPat(Ops.NOOP)), name="s"), do_sub_recurse)]) @track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2d2ada0506..5be941b0e6 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -24,15 +24,15 @@ from tinygrad.schedule.kernelize import get_kernelize_map all_tensors: dict[weakref.ref[Tensor], None] = {} def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str|None=None) -> None: - fixed_tensors = [t for tref in all_tensors if (t:=tref()) is not None and + scope_tensors = [t for tref in tuple(all_tensors) if (t:=tref()) is not None and (t.uop in applied_map or len(applied_map.keys() & t.uop.backward_slice.keys()))] # get all Tensors and apply the map - sink = UOp.sink(*[t.uop for t in fixed_tensors]) + sink = UOp.sink(*[t.uop for t in scope_tensors]) new_sink = sink.substitute(applied_map, name=name) # set the relevant uop to the realized UOps - for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src): + for t,s,ns in zip(scope_tensors, sink.src, new_sink.src): if s is ns: continue t.uop = ns diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 3b13e94d23..3f69dbe4a2 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -261,6 +261,9 @@ full_spec = PatternMatcher([ # SENTINEL should never be in the graph (UPat(Ops.SENTINEL), lambda: False), + # allow any SUBSTITUTE + (UPat(Ops.SUBSTITUTE), lambda: True), + # Invalid must have type Index (UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index), # where on index in rhs position is fine @@ -277,7 +280,7 @@ full_spec = PatternMatcher([ # rangeify: buffer view with index or load is okay (UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.INDEX, Ops.LOAD)),)), lambda: True), # bufferize (must be on ranges) - (UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.op is Ops.RANGE for y in x.src[1:])), + (UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.op in {Ops.RANGE, Ops.CONST} for y in x.src[1:])), # realize with one src is fine (UPat(Ops.REALIZE, src=(UPat(),)), lambda: True), # intermediate index From 2e19354c1c2de496c87982a80626ac24aee07844 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 8 Oct 2025 07:10:23 +0300 Subject: [PATCH 3/6] viz: reorder timeline graphs (#12498) * viz: reorder timeline graphs * update test_viz with the new order --- test/unit/test_viz.py | 6 +++--- tinygrad/viz/serve.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index 7212b56aff..fbfc37e76f 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -379,9 +379,9 @@ class TestVizProfiler(unittest.TestCase): j = load_profile(prof) tracks = list(j['layout']) - self.assertEqual(tracks[0], 'NV Graph') - self.assertEqual(tracks[1], 'NV') - self.assertEqual(tracks[2], 'NV:1') + self.assertEqual(tracks[0], 'NV') + self.assertEqual(tracks[1], 'NV:1') + self.assertEqual(tracks[2], 'NV Graph') nv_events = j['layout']['NV']['events'] self.assertEqual(nv_events[0]['name'], 'E_25_4n2') diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 22941527c4..e5945fbfcb 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -198,7 +198,8 @@ def get_profile(profile:list[ProfileEvent]) -> bytes|None: v.sort(key=lambda e:e[0]) layout[k] = timeline_layout(v, start_ts, scache) layout[f"{k} Memory"] = mem_layout(v, start_ts, unwrap(end_ts), peaks, dtype_size, scache) - ret = [b"".join([struct.pack(" Date: Wed, 8 Oct 2025 07:19:36 +0300 Subject: [PATCH 4/6] early assert for device mistmatched asts in rangeify (#12499) * early assert for device mistmatched asts in rangeify * alt also passes --- test/test_schedule.py | 2 -- tinygrad/schedule/rangeify.py | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 3e7055c6a3..01d390bb46 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -113,7 +113,6 @@ class TestSchedule(unittest.TestCase): self.assertListEqual(a.tolist(), [[15]]) @unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch") - @expect_rangeify_fails def test_error_on_device_mismatch(self): a = Tensor.empty(10) b = Tensor.empty(10, device="CPU") @@ -121,7 +120,6 @@ class TestSchedule(unittest.TestCase): with self.assertRaisesRegex(RuntimeError, "all buffers must be on the same device"): check_schedule(c, 1) @unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch") - @expect_rangeify_fails def test_error_on_device_mismatch_alt(self): a = Tensor.empty(10) b = Tensor.empty((1,), device="CPU").expand(10).contiguous() diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 545b610180..53b8c03cfb 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -692,6 +692,8 @@ def split_store(ctx:list[UOp], x:UOp): if ret.src[1].op not in {Ops.COPY, Ops.BUFFER_VIEW} else ret.src[1] kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]) kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg) + if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]): + raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}") return x.as_buf().assign(kernel) split_kernels = PatternMatcher([ From 60b6dca5badd2a3577ded1f54c5d1f02a473a666 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 8 Oct 2025 07:42:31 +0300 Subject: [PATCH 5/6] update some tests instead of expect_rangeify_fails (#12500) * update test_clone_doesnt_dedup to use base * new_flat_buffer passes * fix test_reorder_expand * remove the view stuff * remove that test, we don't want this view const behavior * test_setitem_becomes_subbuffer is good --- test/test_schedule.py | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 01d390bb46..83e312a691 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -397,7 +397,6 @@ class TestSchedule(unittest.TestCase): # a and b share the same underlying device memory self.assertIs(a.uop.realized, b.uop.realized) - @expect_rangeify_fails def test_clone_doesnt_dedup(self): src = Tensor.ones(4).contiguous().realize() a = src.clone() @@ -405,7 +404,7 @@ class TestSchedule(unittest.TestCase): sched = check_schedule([a, b], 2, filter_sink=False) run_schedule(sched) # a and b are assigned to the same device Buffer - self.assertIsNot(a.uop.realized, b.uop.realized) + self.assertIsNot(a.uop.base.realized, b.uop.base.realized) # EMPTY is assigned to a unique device Buffer @@ -2466,23 +2465,24 @@ class TestUOpBecome(unittest.TestCase): self.assertEqual(add.uop.shape, (8, 2)) assert add.uop is not add.uop.base - @expect_rangeify_fails def test_new_flat_buffer(self): a = Tensor.empty(4,) b = Tensor.empty(4,) add = a+b check_schedule(add, 1) # BUFFER already has a shape (4,), this tensor just becomes a contiguous BUFFER - assert UPat(Ops.BUFFER).match(add.uop, {}) + assert UPat(Ops.BUFFER).match(add.uop.base, {}) # sometimes we prefer to perform an op before movement ops, in this case we should stack the mops on top of the new buffer - # NOTE: this expand is not reordered because there's before it to fuse - @expect_rangeify_fails def test_reorder_expand(self): a = Tensor.empty(4, 1) b = a.expand(4, 4).reciprocal() check_schedule(b, 1) + if RANGEIFY: + self.assertEqual(b.uop.base.buffer.size, 4) + self.assertEqual(b.uop.shape, (4, 4)) + return self.assertEqual(b.uop.base.buffer.size, 16) self.assertEqual(b.uop.st, ShapeTracker.from_shape((4, 4))) @@ -2499,7 +2499,6 @@ class TestUOpBecome(unittest.TestCase): b = a*1 assert UPat(Ops.MUL).match(b.uop, {}) # before scheduling it's a mul check_schedule(b, 0) - assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.uop, {}) # scheduling merges all MovementOps into a single VIEW self.assertIs(a.uop.base.buffer, b.uop.base.buffer) def test_become_buf_with_mops(self): @@ -2521,17 +2520,6 @@ class TestUOpBecome(unittest.TestCase): check_schedule(b, 0) assert UPat(Ops.CONST, arg=0).match(b.uop.base, {}) # scheduling replaces the tensor uop with a VIEW(BUFFER) - @expect_rangeify_fails - def test_become_const_in_view(self): - # if we shrink the base down to a size 0, only the VIEW becomes CONST, base is unchanged. - add = Tensor.empty(2, 2)+Tensor.empty(2, 2) - b = add.shrink(((0, 1), (0, 0))) - check_schedule(b, 0) - assert UPat(Ops.CONST, arg=0).match(b.uop, {}) - self.assertEqual(b.shape, (1, 0)) - # the base is untouched. - assert UPat(Ops.ADD).match(add.uop, {}) - def test_become_const_from_const(self): const_add = Tensor(1)+Tensor(2) assert UPat(Ops.ADD).match(const_add.uop, {}) @@ -2583,14 +2571,17 @@ class TestUOpBecome(unittest.TestCase): assert b.uop is c.uop assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(c.uop, {}) - @expect_rangeify_fails def test_setitem_becomes_subbuffer(self): a = Tensor.full((4,), 2.).contiguous().realize() b = a.shrink(((0, 2),)).assign(Tensor.full((2,), 1.0)) b.realize() assert a.uop.is_realized assert a.uop.buffer._base is None - # b is a subbuffer of a + # b is a subbuffer of a (buffer_view in non rangeify, rangeify just makes a shrink) + if RANGEIFY: + assert b.uop.op_in_backward_slice_with_self(Ops.SHRINK) + assert b.uop.base is a.uop.base + return assert b.uop.op is Ops.BUFFER_VIEW assert b.uop.src[0] is a.uop From 4a756a37d8f6b9c9612c4edbfcc167076e6a5e86 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Wed, 8 Oct 2025 14:30:39 +0800 Subject: [PATCH 6/6] amd: support rocm7 (#12502) * amd: support rocm7 * mock --- extra/hip_gpu_driver/hip_ioctl.py | 2 +- extra/hip_gpu_driver/kfd_ioctl.h | 1157 ++++++++++++++++++++++++++++- test/mockgpu/amd/amddriver.py | 6 +- tinygrad/runtime/ops_amd.py | 2 + 4 files changed, 1157 insertions(+), 10 deletions(-) diff --git a/extra/hip_gpu_driver/hip_ioctl.py b/extra/hip_gpu_driver/hip_ioctl.py index 20c4f3248a..fcb3a9f2da 100644 --- a/extra/hip_gpu_driver/hip_ioctl.py +++ b/extra/hip_gpu_driver/hip_ioctl.py @@ -50,7 +50,7 @@ def ioctls_from_header(): hdr = (pathlib.Path(__file__).parent / "kfd_ioctl.h").read_text().replace("\\\n", "") pattern = r'#define\s+(AMDKFD_IOC_[A-Z0-9_]+)\s+AMDKFD_IOW?R?\((0x[0-9a-fA-F]+),\s+struct\s([A-Za-z0-9_]+)\)' matches = re.findall(pattern, hdr, re.MULTILINE) - return {int(nr, 0x10):(name, getattr(kfd_ioctl, "struct_"+sname)) for name, nr, sname in matches} + return {int(nr, 0x10):(name, getattr(kfd_ioctl, "struct_"+sname, None)) for name, nr, sname in matches} nrs = ioctls_from_header() @ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int, ctypes.c_ulong, ctypes.c_void_p) diff --git a/extra/hip_gpu_driver/kfd_ioctl.h b/extra/hip_gpu_driver/kfd_ioctl.h index af96af174d..4b2c16ad22 100644 --- a/extra/hip_gpu_driver/kfd_ioctl.h +++ b/extra/hip_gpu_driver/kfd_ioctl.h @@ -32,9 +32,20 @@ * - 1.4 - Indicate new SRAM EDC bit in device properties * - 1.5 - Add SVM API * - 1.6 - Query clear flags in SVM get_attr API + * - 1.7 - Checkpoint Restore (CRIU) API + * - 1.8 - CRIU - Support for SDMA transfers with GTT BOs + * - 1.9 - Add available memory ioctl + * - 1.10 - Add SMI profiler event log + * - 1.11 - Add unified memory for ctx save/restore area + * - 1.12 - Add DMA buf export ioctl + * - 1.13 - Add debugger API + * - 1.14 - Update kfd_event_data + * - 1.15 - Enable managing mappings in compute VMs with GEM_VA ioctl + * - 1.16 - Add contiguous VRAM allocation flag + * - 1.17 - Add SDMA queue creation with target SDMA engine ID */ #define KFD_IOCTL_MAJOR_VERSION 1 -#define KFD_IOCTL_MINOR_VERSION 6 +#define KFD_IOCTL_MINOR_VERSION 17 struct kfd_ioctl_get_version_args { __u32 major_version; /* from KFD */ @@ -46,6 +57,7 @@ struct kfd_ioctl_get_version_args { #define KFD_IOC_QUEUE_TYPE_SDMA 0x1 #define KFD_IOC_QUEUE_TYPE_COMPUTE_AQL 0x2 #define KFD_IOC_QUEUE_TYPE_SDMA_XGMI 0x3 +#define KFD_IOC_QUEUE_TYPE_SDMA_BY_ENG_ID 0x4 #define KFD_MAX_QUEUE_PERCENTAGE 100 #define KFD_MAX_QUEUE_PRIORITY 15 @@ -68,6 +80,8 @@ struct kfd_ioctl_create_queue_args { __u64 ctx_save_restore_address; /* to KFD */ __u32 ctx_save_restore_size; /* to KFD */ __u32 ctl_stack_size; /* to KFD */ + __u32 sdma_engine_id; /* to KFD */ + __u32 pad; }; struct kfd_ioctl_destroy_queue_args { @@ -98,6 +112,38 @@ struct kfd_ioctl_get_queue_wave_state_args { __u32 pad; }; +struct kfd_ioctl_get_available_memory_args { + __u64 available; /* from KFD */ + __u32 gpu_id; /* to KFD */ + __u32 pad; +}; + +struct kfd_dbg_device_info_entry { + __u64 exception_status; + __u64 lds_base; + __u64 lds_limit; + __u64 scratch_base; + __u64 scratch_limit; + __u64 gpuvm_base; + __u64 gpuvm_limit; + __u32 gpu_id; + __u32 location_id; + __u32 vendor_id; + __u32 device_id; + __u32 revision_id; + __u32 subsystem_vendor_id; + __u32 subsystem_device_id; + __u32 fw_version; + __u32 gfx_target_version; + __u32 simd_count; + __u32 max_waves_per_simd; + __u32 array_count; + __u32 simd_arrays_per_engine; + __u32 num_xcc; + __u32 capability; + __u32 debug_prop; +}; + /* For kfd_ioctl_set_memory_policy_args.default_policy and alternate_policy */ #define KFD_IOC_CACHE_POLICY_COHERENT 0 #define KFD_IOC_CACHE_POLICY_NONCOHERENT 1 @@ -194,6 +240,19 @@ struct kfd_ioctl_dbg_wave_control_args { __u32 buf_size_in_bytes; /*including gpu_id and buf_size */ }; +#define KFD_INVALID_FD 0xffffffff + +struct kfd_ioctl_dbg_trap_args_deprecated { + __u64 exception_mask; /* to KFD */ + __u64 ptr; /* to KFD -- used for pointer arguments: queue arrays */ + __u32 pid; /* to KFD */ + __u32 op; /* to KFD */ + __u32 data1; /* to KFD */ + __u32 data2; /* to KFD */ + __u32 data3; /* to KFD */ + __u32 data4; /* to KFD */ +}; + /* Matching HSA_EVENTTYPE */ #define KFD_IOC_EVENT_SIGNAL 0 #define KFD_IOC_EVENT_NODECHANGE 1 @@ -279,12 +338,20 @@ struct kfd_hsa_hw_exception_data { __u32 gpu_id; }; +/* hsa signal event data */ +struct kfd_hsa_signal_event_data { + __u64 last_event_age; /* to and from KFD */ +}; + /* Event data */ struct kfd_event_data { union { + /* From KFD */ struct kfd_hsa_memory_exception_data memory_exception_data; struct kfd_hsa_hw_exception_data hw_exception_data; - }; /* From KFD */ + /* To and From KFD */ + struct kfd_hsa_signal_event_data signal_event_data; + }; __u64 kfd_event_data_ext; /* pointer to an extension structure for future exception types */ __u32 event_id; /* to KFD */ @@ -355,6 +422,8 @@ struct kfd_ioctl_acquire_vm_args { #define KFD_IOC_ALLOC_MEM_FLAGS_AQL_QUEUE_MEM (1 << 27) #define KFD_IOC_ALLOC_MEM_FLAGS_COHERENT (1 << 26) #define KFD_IOC_ALLOC_MEM_FLAGS_UNCACHED (1 << 25) +#define KFD_IOC_ALLOC_MEM_FLAGS_EXT_COHERENT (1 << 24) +#define KFD_IOC_ALLOC_MEM_FLAGS_CONTIGUOUS (1 << 23) /* Allocate memory for later SVM (shared virtual memory) mapping. * @@ -450,6 +519,12 @@ struct kfd_ioctl_import_dmabuf_args { __u32 dmabuf_fd; /* to KFD */ }; +struct kfd_ioctl_export_dmabuf_args { + __u64 handle; /* to KFD */ + __u32 flags; /* to KFD */ + __u32 dmabuf_fd; /* from KFD */ +}; + /* * KFD SMI(System Management Interface) events */ @@ -459,15 +534,277 @@ enum kfd_smi_event { KFD_SMI_EVENT_THERMAL_THROTTLE = 2, KFD_SMI_EVENT_GPU_PRE_RESET = 3, KFD_SMI_EVENT_GPU_POST_RESET = 4, + KFD_SMI_EVENT_MIGRATE_START = 5, + KFD_SMI_EVENT_MIGRATE_END = 6, + KFD_SMI_EVENT_PAGE_FAULT_START = 7, + KFD_SMI_EVENT_PAGE_FAULT_END = 8, + KFD_SMI_EVENT_QUEUE_EVICTION = 9, + KFD_SMI_EVENT_QUEUE_RESTORE = 10, + KFD_SMI_EVENT_UNMAP_FROM_GPU = 11, + + /* + * max event number, as a flag bit to get events from all processes, + * this requires super user permission, otherwise will not be able to + * receive event from any process. Without this flag to receive events + * from same process. + */ + KFD_SMI_EVENT_ALL_PROCESS = 64 +}; + +/* The reason of the page migration event */ +enum KFD_MIGRATE_TRIGGERS { + KFD_MIGRATE_TRIGGER_PREFETCH, /* Prefetch to GPU VRAM or system memory */ + KFD_MIGRATE_TRIGGER_PAGEFAULT_GPU, /* GPU page fault recover */ + KFD_MIGRATE_TRIGGER_PAGEFAULT_CPU, /* CPU page fault recover */ + KFD_MIGRATE_TRIGGER_TTM_EVICTION /* TTM eviction */ +}; + +/* The reason of user queue evition event */ +enum KFD_QUEUE_EVICTION_TRIGGERS { + KFD_QUEUE_EVICTION_TRIGGER_SVM, /* SVM buffer migration */ + KFD_QUEUE_EVICTION_TRIGGER_USERPTR, /* userptr movement */ + KFD_QUEUE_EVICTION_TRIGGER_TTM, /* TTM move buffer */ + KFD_QUEUE_EVICTION_TRIGGER_SUSPEND, /* GPU suspend */ + KFD_QUEUE_EVICTION_CRIU_CHECKPOINT, /* CRIU checkpoint */ + KFD_QUEUE_EVICTION_CRIU_RESTORE /* CRIU restore */ +}; + +/* The reason of unmap buffer from GPU event */ +enum KFD_SVM_UNMAP_TRIGGERS { + KFD_SVM_UNMAP_TRIGGER_MMU_NOTIFY, /* MMU notifier CPU buffer movement */ + KFD_SVM_UNMAP_TRIGGER_MMU_NOTIFY_MIGRATE,/* MMU notifier page migration */ + KFD_SVM_UNMAP_TRIGGER_UNMAP_FROM_CPU /* Unmap to free the buffer */ }; #define KFD_SMI_EVENT_MASK_FROM_INDEX(i) (1ULL << ((i) - 1)) +#define KFD_SMI_EVENT_MSG_SIZE 96 struct kfd_ioctl_smi_events_args { __u32 gpuid; /* to KFD */ __u32 anon_fd; /* from KFD */ }; +/** + * kfd_ioctl_spm_op - SPM ioctl operations + * + * @KFD_IOCTL_SPM_OP_ACQUIRE: acquire exclusive access to SPM + * @KFD_IOCTL_SPM_OP_RELEASE: release exclusive access to SPM + * @KFD_IOCTL_SPM_OP_SET_DEST_BUF: set or unset destination buffer for SPM streaming + */ +enum kfd_ioctl_spm_op { + KFD_IOCTL_SPM_OP_ACQUIRE, + KFD_IOCTL_SPM_OP_RELEASE, + KFD_IOCTL_SPM_OP_SET_DEST_BUF +}; + +/** + * kfd_ioctl_spm_args - Arguments for SPM ioctl + * + * @op[in]: specifies the operation to perform + * @gpu_id[in]: GPU ID of the GPU to profile + * @dst_buf[in]: used for the address of the destination buffer + * in @KFD_IOCTL_SPM_SET_DEST_BUFFER + * @buf_size[in]: size of the destination buffer + * @timeout[in/out]: [in]: timeout in milliseconds, [out]: amount of time left + * `in the timeout window + * @bytes_copied[out]: amount of data that was copied to the previous dest_buf + * @has_data_loss: boolean indicating whether data was lost + * (e.g. due to a ring-buffer overflow) + * + * This ioctl performs different functions depending on the @op parameter. + * + * KFD_IOCTL_SPM_OP_ACQUIRE + * ------------------------ + * + * Acquires exclusive access of SPM on the specified @gpu_id for the calling process. + * This must be called before using KFD_IOCTL_SPM_OP_SET_DEST_BUF. + * + * KFD_IOCTL_SPM_OP_RELEASE + * ------------------------ + * + * Releases exclusive access of SPM on the specified @gpu_id for the calling process, + * which allows another process to acquire it in the future. + * + * KFD_IOCTL_SPM_OP_SET_DEST_BUF + * ----------------------------- + * + * If @dst_buf is NULL, the destination buffer address is unset and copying of counters + * is stopped. + * + * If @dst_buf is not NULL, it specifies the pointer to a new destination buffer. + * @buf_size specifies the size of the buffer. + * + * If @timeout is non-0, the call will wait for up to @timeout ms for the previous + * buffer to be filled. If previous buffer to be filled before timeout, the @timeout + * will be updated value with the time remaining. If the timeout is exceeded, the function + * copies any partial data available into the previous user buffer and returns success. + * The amount of valid data in the previous user buffer is indicated by @bytes_copied. + * + * If @timeout is 0, the function immediately replaces the previous destination buffer + * without waiting for the previous buffer to be filled. That means the previous buffer + * may only be partially filled, and @bytes_copied will indicate how much data has been + * copied to it. + * + * If data was lost, e.g. due to a ring buffer overflow, @has_data_loss will be non-0. + * + * Returns negative error code on failure, 0 on success. + */ +struct kfd_ioctl_spm_args { + __u64 dest_buf; + __u32 buf_size; + __u32 op; + __u32 timeout; + __u32 gpu_id; + __u32 bytes_copied; + __u32 has_data_loss; +}; + +/* + * SVM event tracing via SMI system management interface + * + * Open event file descriptor + * use ioctl AMDKFD_IOC_SMI_EVENTS, pass in gpuid and return a anonymous file + * descriptor to receive SMI events. + * If calling with sudo permission, then file descriptor can be used to receive + * SVM events from all processes, otherwise, to only receive SVM events of same + * process. + * + * To enable the SVM event + * Write event file descriptor with KFD_SMI_EVENT_MASK_FROM_INDEX(event) bitmap + * mask to start record the event to the kfifo, use bitmap mask combination + * for multiple events. New event mask will overwrite the previous event mask. + * KFD_SMI_EVENT_MASK_FROM_INDEX(KFD_SMI_EVENT_ALL_PROCESS) bit requires sudo + * permisson to receive SVM events from all process. + * + * To receive the event + * Application can poll file descriptor to wait for the events, then read event + * from the file into a buffer. Each event is one line string message, starting + * with the event id, then the event specific information. + * + * To decode event information + * The following event format string macro can be used with sscanf to decode + * the specific event information. + * event triggers: the reason to generate the event, defined as enum for unmap, + * eviction and migrate events. + * node, from, to, prefetch_loc, preferred_loc: GPU ID, or 0 for system memory. + * addr: user mode address, in pages + * size: in pages + * pid: the process ID to generate the event + * ns: timestamp in nanosecond-resolution, starts at system boot time but + * stops during suspend + * migrate_update: GPU page fault is recovered by 'M' for migrate, 'U' for update + * rw: 'W' for write page fault, 'R' for read page fault + * rescheduled: 'R' if the queue restore failed and rescheduled to try again + */ +#define KFD_EVENT_FMT_UPDATE_GPU_RESET(reset_seq_num, reset_cause)\ + "%x %s\n", (reset_seq_num), (reset_cause) + +#define KFD_EVENT_FMT_THERMAL_THROTTLING(bitmask, counter)\ + "%llx:%llx\n", (bitmask), (counter) + +#define KFD_EVENT_FMT_VMFAULT(pid, task_name)\ + "%x:%s\n", (pid), (task_name) + +#define KFD_EVENT_FMT_PAGEFAULT_START(ns, pid, addr, node, rw)\ + "%lld -%d @%lx(%x) %c\n", (ns), (pid), (addr), (node), (rw) + +#define KFD_EVENT_FMT_PAGEFAULT_END(ns, pid, addr, node, migrate_update)\ + "%lld -%d @%lx(%x) %c\n", (ns), (pid), (addr), (node), (migrate_update) + +#define KFD_EVENT_FMT_MIGRATE_START(ns, pid, start, size, from, to, prefetch_loc,\ + preferred_loc, migrate_trigger)\ + "%lld -%d @%lx(%lx) %x->%x %x:%x %d\n", (ns), (pid), (start), (size),\ + (from), (to), (prefetch_loc), (preferred_loc), (migrate_trigger) + +#define KFD_EVENT_FMT_MIGRATE_END(ns, pid, start, size, from, to, migrate_trigger)\ + "%lld -%d @%lx(%lx) %x->%x %d\n", (ns), (pid), (start), (size),\ + (from), (to), (migrate_trigger) + +#define KFD_EVENT_FMT_QUEUE_EVICTION(ns, pid, node, evict_trigger)\ + "%lld -%d %x %d\n", (ns), (pid), (node), (evict_trigger) + +#define KFD_EVENT_FMT_QUEUE_RESTORE(ns, pid, node, rescheduled)\ + "%lld -%d %x %c\n", (ns), (pid), (node), (rescheduled) + +#define KFD_EVENT_FMT_UNMAP_FROM_GPU(ns, pid, addr, size, node, unmap_trigger)\ + "%lld -%d @%lx(%lx) %x %d\n", (ns), (pid), (addr), (size),\ + (node), (unmap_trigger) + +/************************************************************************************************** + * CRIU IOCTLs (Checkpoint Restore In Userspace) + * + * When checkpointing a process, the userspace application will perform: + * 1. PROCESS_INFO op to determine current process information. This pauses execution and evicts + * all the queues. + * 2. CHECKPOINT op to checkpoint process contents (BOs, queues, events, svm-ranges) + * 3. UNPAUSE op to un-evict all the queues + * + * When restoring a process, the CRIU userspace application will perform: + * + * 1. RESTORE op to restore process contents + * 2. RESUME op to start the process + * + * Note: Queues are forced into an evicted state after a successful PROCESS_INFO. User + * application needs to perform an UNPAUSE operation after calling PROCESS_INFO. + */ + +enum kfd_criu_op { + KFD_CRIU_OP_PROCESS_INFO, + KFD_CRIU_OP_CHECKPOINT, + KFD_CRIU_OP_UNPAUSE, + KFD_CRIU_OP_RESTORE, + KFD_CRIU_OP_RESUME, +}; + +/** + * kfd_ioctl_criu_args - Arguments perform CRIU operation + * @devices: [in/out] User pointer to memory location for devices information. + * This is an array of type kfd_criu_device_bucket. + * @bos: [in/out] User pointer to memory location for BOs information + * This is an array of type kfd_criu_bo_bucket. + * @priv_data: [in/out] User pointer to memory location for private data + * @priv_data_size: [in/out] Size of priv_data in bytes + * @num_devices: [in/out] Number of GPUs used by process. Size of @devices array. + * @num_bos [in/out] Number of BOs used by process. Size of @bos array. + * @num_objects: [in/out] Number of objects used by process. Objects are opaque to + * user application. + * @pid: [in/out] PID of the process being checkpointed + * @op [in] Type of operation (kfd_criu_op) + * + * Return: 0 on success, -errno on failure + */ +struct kfd_ioctl_criu_args { + __u64 devices; /* Used during ops: CHECKPOINT, RESTORE */ + __u64 bos; /* Used during ops: CHECKPOINT, RESTORE */ + __u64 priv_data; /* Used during ops: CHECKPOINT, RESTORE */ + __u64 priv_data_size; /* Used during ops: PROCESS_INFO, RESTORE */ + __u32 num_devices; /* Used during ops: PROCESS_INFO, RESTORE */ + __u32 num_bos; /* Used during ops: PROCESS_INFO, RESTORE */ + __u32 num_objects; /* Used during ops: PROCESS_INFO, RESTORE */ + __u32 pid; /* Used during ops: PROCESS_INFO, RESUME */ + __u32 op; +}; + +struct kfd_criu_device_bucket { + __u32 user_gpu_id; + __u32 actual_gpu_id; + __u32 drm_fd; + __u32 pad; +}; + +struct kfd_criu_bo_bucket { + __u64 addr; + __u64 size; + __u64 offset; + __u64 restored_offset; /* During restore, updated offset for BO */ + __u32 gpu_id; /* This is the user_gpu_id */ + __u32 alloc_flags; + __u32 dmabuf_fd; + __u32 pad; +}; + +/* CRIU IOCTLs - END */ +/**************************************************************************************************/ /* Register offset inside the remapped mmio page */ enum kfd_mmio_remap { @@ -475,6 +812,39 @@ enum kfd_mmio_remap { KFD_MMIO_REMAP_HDP_REG_FLUSH_CNTL = 4, }; +struct kfd_ioctl_ipc_export_handle_args { + __u64 handle; /* to KFD */ + __u32 share_handle[4]; /* from KFD */ + __u32 gpu_id; /* to KFD */ + __u32 flags; /* to KFD */ +}; + +struct kfd_ioctl_ipc_import_handle_args { + __u64 handle; /* from KFD */ + __u64 va_addr; /* to KFD */ + __u64 mmap_offset; /* from KFD */ + __u32 share_handle[4]; /* to KFD */ + __u32 gpu_id; /* to KFD */ + __u32 flags; /* from KFD */ +}; + +struct kfd_ioctl_cross_memory_copy_deprecated_args { + /* to KFD: Process ID of the remote process */ + __u32 pid; + /* to KFD: See above definition */ + __u32 flags; + /* to KFD: Source GPU VM range */ + __u64 src_mem_range_array; + /* to KFD: Size of above array */ + __u64 src_mem_array_size; + /* to KFD: Destination GPU VM range */ + __u64 dst_mem_range_array; + /* to KFD: Size of above array */ + __u64 dst_mem_array_size; + /* from KFD: Total amount of bytes copied */ + __u64 bytes_copied; +}; + /* Guarantee host access to memory */ #define KFD_IOCTL_SVM_FLAG_HOST_ACCESS 0x00000001 /* Fine grained coherency between all devices with access */ @@ -487,6 +857,10 @@ enum kfd_mmio_remap { #define KFD_IOCTL_SVM_FLAG_GPU_EXEC 0x00000010 /* GPUs mostly read, may allow similar optimizations as RO, but writes fault */ #define KFD_IOCTL_SVM_FLAG_GPU_READ_MOSTLY 0x00000020 +/* Keep GPU memory mapping always valid as if XNACK is disable */ +#define KFD_IOCTL_SVM_FLAG_GPU_ALWAYS_MAPPED 0x00000040 +/* Fine grained coherency between all devices using device-scope atomics */ +#define KFD_IOCTL_SVM_FLAG_EXT_COHERENT 0x00000080 /** * kfd_ioctl_svm_op - SVM ioctl operations @@ -596,7 +970,7 @@ struct kfd_ioctl_svm_args { __u32 op; __u32 nattr; /* Variable length array of attributes */ - struct kfd_ioctl_svm_attribute attrs[0]; + struct kfd_ioctl_svm_attribute attrs[]; }; /** @@ -637,6 +1011,733 @@ struct kfd_ioctl_set_xnack_mode_args { __s32 xnack_enabled; }; +/* Wave launch override modes */ +enum kfd_dbg_trap_override_mode { + KFD_DBG_TRAP_OVERRIDE_OR = 0, + KFD_DBG_TRAP_OVERRIDE_REPLACE = 1 +}; + +/* Wave launch overrides */ +enum kfd_dbg_trap_mask { + KFD_DBG_TRAP_MASK_FP_INVALID = 1, + KFD_DBG_TRAP_MASK_FP_INPUT_DENORMAL = 2, + KFD_DBG_TRAP_MASK_FP_DIVIDE_BY_ZERO = 4, + KFD_DBG_TRAP_MASK_FP_OVERFLOW = 8, + KFD_DBG_TRAP_MASK_FP_UNDERFLOW = 16, + KFD_DBG_TRAP_MASK_FP_INEXACT = 32, + KFD_DBG_TRAP_MASK_INT_DIVIDE_BY_ZERO = 64, + KFD_DBG_TRAP_MASK_DBG_ADDRESS_WATCH = 128, + KFD_DBG_TRAP_MASK_DBG_MEMORY_VIOLATION = 256, + KFD_DBG_TRAP_MASK_TRAP_ON_WAVE_START = (1 << 30), + KFD_DBG_TRAP_MASK_TRAP_ON_WAVE_END = (1 << 31) +}; + +/* Wave launch modes */ +enum kfd_dbg_trap_wave_launch_mode { + KFD_DBG_TRAP_WAVE_LAUNCH_MODE_NORMAL = 0, + KFD_DBG_TRAP_WAVE_LAUNCH_MODE_HALT = 1, + KFD_DBG_TRAP_WAVE_LAUNCH_MODE_DEBUG = 3 +}; + +/* Address watch modes */ +enum kfd_dbg_trap_address_watch_mode { + KFD_DBG_TRAP_ADDRESS_WATCH_MODE_READ = 0, + KFD_DBG_TRAP_ADDRESS_WATCH_MODE_NONREAD = 1, + KFD_DBG_TRAP_ADDRESS_WATCH_MODE_ATOMIC = 2, + KFD_DBG_TRAP_ADDRESS_WATCH_MODE_ALL = 3 +}; + +/* Additional wave settings */ +enum kfd_dbg_trap_flags { + KFD_DBG_TRAP_FLAG_SINGLE_MEM_OP = 1, + KFD_DBG_TRAP_FLAG_SINGLE_ALU_OP = 2, +}; + +/* Trap exceptions */ +enum kfd_dbg_trap_exception_code { + EC_NONE = 0, + /* per queue */ + EC_QUEUE_WAVE_ABORT = 1, + EC_QUEUE_WAVE_TRAP = 2, + EC_QUEUE_WAVE_MATH_ERROR = 3, + EC_QUEUE_WAVE_ILLEGAL_INSTRUCTION = 4, + EC_QUEUE_WAVE_MEMORY_VIOLATION = 5, + EC_QUEUE_WAVE_APERTURE_VIOLATION = 6, + EC_QUEUE_PACKET_DISPATCH_DIM_INVALID = 16, + EC_QUEUE_PACKET_DISPATCH_GROUP_SEGMENT_SIZE_INVALID = 17, + EC_QUEUE_PACKET_DISPATCH_CODE_INVALID = 18, + EC_QUEUE_PACKET_RESERVED = 19, + EC_QUEUE_PACKET_UNSUPPORTED = 20, + EC_QUEUE_PACKET_DISPATCH_WORK_GROUP_SIZE_INVALID = 21, + EC_QUEUE_PACKET_DISPATCH_REGISTER_INVALID = 22, + EC_QUEUE_PACKET_VENDOR_UNSUPPORTED = 23, + EC_QUEUE_PREEMPTION_ERROR = 30, + EC_QUEUE_NEW = 31, + /* per device */ + EC_DEVICE_QUEUE_DELETE = 32, + EC_DEVICE_MEMORY_VIOLATION = 33, + EC_DEVICE_RAS_ERROR = 34, + EC_DEVICE_FATAL_HALT = 35, + EC_DEVICE_NEW = 36, + /* per process */ + EC_PROCESS_RUNTIME = 48, + EC_PROCESS_DEVICE_REMOVE = 49, + EC_MAX +}; + +/* Mask generated by ecode in kfd_dbg_trap_exception_code */ +#define KFD_EC_MASK(ecode) (1ULL << (ecode - 1)) + +/* Masks for exception code type checks below */ +#define KFD_EC_MASK_QUEUE (KFD_EC_MASK(EC_QUEUE_WAVE_ABORT) | \ + KFD_EC_MASK(EC_QUEUE_WAVE_TRAP) | \ + KFD_EC_MASK(EC_QUEUE_WAVE_MATH_ERROR) | \ + KFD_EC_MASK(EC_QUEUE_WAVE_ILLEGAL_INSTRUCTION) | \ + KFD_EC_MASK(EC_QUEUE_WAVE_MEMORY_VIOLATION) | \ + KFD_EC_MASK(EC_QUEUE_WAVE_APERTURE_VIOLATION) | \ + KFD_EC_MASK(EC_QUEUE_PACKET_DISPATCH_DIM_INVALID) | \ + KFD_EC_MASK(EC_QUEUE_PACKET_DISPATCH_GROUP_SEGMENT_SIZE_INVALID) | \ + KFD_EC_MASK(EC_QUEUE_PACKET_DISPATCH_CODE_INVALID) | \ + KFD_EC_MASK(EC_QUEUE_PACKET_RESERVED) | \ + KFD_EC_MASK(EC_QUEUE_PACKET_UNSUPPORTED) | \ + KFD_EC_MASK(EC_QUEUE_PACKET_DISPATCH_WORK_GROUP_SIZE_INVALID) | \ + KFD_EC_MASK(EC_QUEUE_PACKET_DISPATCH_REGISTER_INVALID) | \ + KFD_EC_MASK(EC_QUEUE_PACKET_VENDOR_UNSUPPORTED) | \ + KFD_EC_MASK(EC_QUEUE_PREEMPTION_ERROR) | \ + KFD_EC_MASK(EC_QUEUE_NEW)) +#define KFD_EC_MASK_DEVICE (KFD_EC_MASK(EC_DEVICE_QUEUE_DELETE) | \ + KFD_EC_MASK(EC_DEVICE_RAS_ERROR) | \ + KFD_EC_MASK(EC_DEVICE_FATAL_HALT) | \ + KFD_EC_MASK(EC_DEVICE_MEMORY_VIOLATION) | \ + KFD_EC_MASK(EC_DEVICE_NEW)) +#define KFD_EC_MASK_PROCESS (KFD_EC_MASK(EC_PROCESS_RUNTIME) | \ + KFD_EC_MASK(EC_PROCESS_DEVICE_REMOVE)) +#define KFD_EC_MASK_PACKET (KFD_EC_MASK(EC_QUEUE_PACKET_DISPATCH_DIM_INVALID) | \ + KFD_EC_MASK(EC_QUEUE_PACKET_DISPATCH_GROUP_SEGMENT_SIZE_INVALID) | \ + KFD_EC_MASK(EC_QUEUE_PACKET_DISPATCH_CODE_INVALID) | \ + KFD_EC_MASK(EC_QUEUE_PACKET_RESERVED) | \ + KFD_EC_MASK(EC_QUEUE_PACKET_UNSUPPORTED) | \ + KFD_EC_MASK(EC_QUEUE_PACKET_DISPATCH_WORK_GROUP_SIZE_INVALID) | \ + KFD_EC_MASK(EC_QUEUE_PACKET_DISPATCH_REGISTER_INVALID) | \ + KFD_EC_MASK(EC_QUEUE_PACKET_VENDOR_UNSUPPORTED)) + +/* Checks for exception code types for KFD search */ +#define KFD_DBG_EC_IS_VALID(ecode) (ecode > EC_NONE && ecode < EC_MAX) +#define KFD_DBG_EC_TYPE_IS_QUEUE(ecode) \ + (KFD_DBG_EC_IS_VALID(ecode) && !!(KFD_EC_MASK(ecode) & KFD_EC_MASK_QUEUE)) +#define KFD_DBG_EC_TYPE_IS_DEVICE(ecode) \ + (KFD_DBG_EC_IS_VALID(ecode) && !!(KFD_EC_MASK(ecode) & KFD_EC_MASK_DEVICE)) +#define KFD_DBG_EC_TYPE_IS_PROCESS(ecode) \ + (KFD_DBG_EC_IS_VALID(ecode) && !!(KFD_EC_MASK(ecode) & KFD_EC_MASK_PROCESS)) +#define KFD_DBG_EC_TYPE_IS_PACKET(ecode) \ + (KFD_DBG_EC_IS_VALID(ecode) && !!(KFD_EC_MASK(ecode) & KFD_EC_MASK_PACKET)) + + +/* Runtime enable states */ +enum kfd_dbg_runtime_state { + DEBUG_RUNTIME_STATE_DISABLED = 0, + DEBUG_RUNTIME_STATE_ENABLED = 1, + DEBUG_RUNTIME_STATE_ENABLED_BUSY = 2, + DEBUG_RUNTIME_STATE_ENABLED_ERROR = 3 +}; + +/* Runtime enable status */ +struct kfd_runtime_info { + __u64 r_debug; + __u32 runtime_state; + __u32 ttmp_setup; +}; + +/* Enable modes for runtime enable */ +#define KFD_RUNTIME_ENABLE_MODE_ENABLE_MASK 1 +#define KFD_RUNTIME_ENABLE_MODE_TTMP_SAVE_MASK 2 + +/** + * kfd_ioctl_runtime_enable_args - Arguments for runtime enable + * + * Coordinates debug exception signalling and debug device enablement with runtime. + * + * @r_debug - pointer to user struct for sharing information between ROCr and the debuggger + * @mode_mask - mask to set mode + * KFD_RUNTIME_ENABLE_MODE_ENABLE_MASK - enable runtime for debugging, otherwise disable + * KFD_RUNTIME_ENABLE_MODE_TTMP_SAVE_MASK - enable trap temporary setup (ignore on disable) + * @capabilities_mask - mask to notify runtime on what KFD supports + * + * Return - 0 on SUCCESS. + * - EBUSY if runtime enable call already pending. + * - EEXIST if user queues already active prior to call. + * If process is debug enabled, runtime enable will enable debug devices and + * wait for debugger process to send runtime exception EC_PROCESS_RUNTIME + * to unblock - see kfd_ioctl_dbg_trap_args. + * + */ +struct kfd_ioctl_runtime_enable_args { + __u64 r_debug; + __u32 mode_mask; + __u32 capabilities_mask; +}; + +/* Queue information */ +struct kfd_queue_snapshot_entry { + __u64 exception_status; + __u64 ring_base_address; + __u64 write_pointer_address; + __u64 read_pointer_address; + __u64 ctx_save_restore_address; + __u32 queue_id; + __u32 gpu_id; + __u32 ring_size; + __u32 queue_type; + __u32 ctx_save_restore_area_size; + __u32 reserved; +}; + +/* Queue status return for suspend/resume */ +#define KFD_DBG_QUEUE_ERROR_BIT 30 +#define KFD_DBG_QUEUE_INVALID_BIT 31 +#define KFD_DBG_QUEUE_ERROR_MASK (1 << KFD_DBG_QUEUE_ERROR_BIT) +#define KFD_DBG_QUEUE_INVALID_MASK (1 << KFD_DBG_QUEUE_INVALID_BIT) + +/* Context save area header information */ +struct kfd_context_save_area_header { + struct { + __u32 control_stack_offset; + __u32 control_stack_size; + __u32 wave_state_offset; + __u32 wave_state_size; + } wave_state; + __u32 debug_offset; + __u32 debug_size; + __u64 err_payload_addr; + __u32 err_event_id; + __u32 reserved1; +}; + +/* + * Debug operations + * + * For specifics on usage and return values, see documentation per operation + * below. Otherwise, generic error returns apply: + * - ESRCH if the process to debug does not exist. + * + * - EINVAL (with KFD_IOC_DBG_TRAP_ENABLE exempt) if operation + * KFD_IOC_DBG_TRAP_ENABLE has not succeeded prior. + * Also returns this error if GPU hardware scheduling is not supported. + * + * - EPERM (with KFD_IOC_DBG_TRAP_DISABLE exempt) if target process is not + * PTRACE_ATTACHED. KFD_IOC_DBG_TRAP_DISABLE is exempt to allow + * clean up of debug mode as long as process is debug enabled. + * + * - EACCES if any DBG_HW_OP (debug hardware operation) is requested when + * AMDKFD_IOC_RUNTIME_ENABLE has not succeeded prior. + * + * - ENODEV if any GPU does not support debugging on a DBG_HW_OP call. + * + * - Other errors may be returned when a DBG_HW_OP occurs while the GPU + * is in a fatal state. + * + */ +enum kfd_dbg_trap_operations { + KFD_IOC_DBG_TRAP_ENABLE = 0, + KFD_IOC_DBG_TRAP_DISABLE = 1, + KFD_IOC_DBG_TRAP_SEND_RUNTIME_EVENT = 2, + KFD_IOC_DBG_TRAP_SET_EXCEPTIONS_ENABLED = 3, + KFD_IOC_DBG_TRAP_SET_WAVE_LAUNCH_OVERRIDE = 4, /* DBG_HW_OP */ + KFD_IOC_DBG_TRAP_SET_WAVE_LAUNCH_MODE = 5, /* DBG_HW_OP */ + KFD_IOC_DBG_TRAP_SUSPEND_QUEUES = 6, /* DBG_HW_OP */ + KFD_IOC_DBG_TRAP_RESUME_QUEUES = 7, /* DBG_HW_OP */ + KFD_IOC_DBG_TRAP_SET_NODE_ADDRESS_WATCH = 8, /* DBG_HW_OP */ + KFD_IOC_DBG_TRAP_CLEAR_NODE_ADDRESS_WATCH = 9, /* DBG_HW_OP */ + KFD_IOC_DBG_TRAP_SET_FLAGS = 10, + KFD_IOC_DBG_TRAP_QUERY_DEBUG_EVENT = 11, + KFD_IOC_DBG_TRAP_QUERY_EXCEPTION_INFO = 12, + KFD_IOC_DBG_TRAP_GET_QUEUE_SNAPSHOT = 13, + KFD_IOC_DBG_TRAP_GET_DEVICE_SNAPSHOT = 14 +}; + +/** + * kfd_ioctl_dbg_trap_enable_args + * + * Arguments for KFD_IOC_DBG_TRAP_ENABLE. + * + * Enables debug session for target process. Call @op KFD_IOC_DBG_TRAP_DISABLE in + * kfd_ioctl_dbg_trap_args to disable debug session. + * + * @exception_mask (IN) - exceptions to raise to the debugger + * @rinfo_ptr (IN) - pointer to runtime info buffer (see kfd_runtime_info) + * @rinfo_size (IN/OUT) - size of runtime info buffer in bytes + * @dbg_fd (IN) - fd the KFD will nofify the debugger with of raised + * exceptions set in exception_mask. + * + * Generic errors apply (see kfd_dbg_trap_operations). + * Return - 0 on SUCCESS. + * Copies KFD saved kfd_runtime_info to @rinfo_ptr on enable. + * Size of kfd_runtime saved by the KFD returned to @rinfo_size. + * - EBADF if KFD cannot get a reference to dbg_fd. + * - EFAULT if KFD cannot copy runtime info to rinfo_ptr. + * - EINVAL if target process is already debug enabled. + * + */ +struct kfd_ioctl_dbg_trap_enable_args { + __u64 exception_mask; + __u64 rinfo_ptr; + __u32 rinfo_size; + __u32 dbg_fd; +}; + +/** + * kfd_ioctl_dbg_trap_send_runtime_event_args + * + * + * Arguments for KFD_IOC_DBG_TRAP_SEND_RUNTIME_EVENT. + * Raises exceptions to runtime. + * + * @exception_mask (IN) - exceptions to raise to runtime + * @gpu_id (IN) - target device id + * @queue_id (IN) - target queue id + * + * Generic errors apply (see kfd_dbg_trap_operations). + * Return - 0 on SUCCESS. + * - ENODEV if gpu_id not found. + * If exception_mask contains EC_PROCESS_RUNTIME, unblocks pending + * AMDKFD_IOC_RUNTIME_ENABLE call - see kfd_ioctl_runtime_enable_args. + * All other exceptions are raised to runtime through err_payload_addr. + * See kfd_context_save_area_header. + */ +struct kfd_ioctl_dbg_trap_send_runtime_event_args { + __u64 exception_mask; + __u32 gpu_id; + __u32 queue_id; +}; + +/** + * kfd_ioctl_dbg_trap_set_exceptions_enabled_args + * + * Arguments for KFD_IOC_SET_EXCEPTIONS_ENABLED + * Set new exceptions to be raised to the debugger. + * + * @exception_mask (IN) - new exceptions to raise the debugger + * + * Generic errors apply (see kfd_dbg_trap_operations). + * Return - 0 on SUCCESS. + */ +struct kfd_ioctl_dbg_trap_set_exceptions_enabled_args { + __u64 exception_mask; +}; + +/** + * kfd_ioctl_dbg_trap_set_wave_launch_override_args + * + * Arguments for KFD_IOC_DBG_TRAP_SET_WAVE_LAUNCH_OVERRIDE + * Enable HW exceptions to raise trap. + * + * @override_mode (IN) - see kfd_dbg_trap_override_mode + * @enable_mask (IN/OUT) - reference kfd_dbg_trap_mask. + * IN is the override modes requested to be enabled. + * OUT is referenced in Return below. + * @support_request_mask (IN/OUT) - reference kfd_dbg_trap_mask. + * IN is the override modes requested for support check. + * OUT is referenced in Return below. + * + * Generic errors apply (see kfd_dbg_trap_operations). + * Return - 0 on SUCCESS. + * Previous enablement is returned in @enable_mask. + * Actual override support is returned in @support_request_mask. + * - EINVAL if override mode is not supported. + * - EACCES if trap support requested is not actually supported. + * i.e. enable_mask (IN) is not a subset of support_request_mask (OUT). + * Otherwise it is considered a generic error (see kfd_dbg_trap_operations). + */ +struct kfd_ioctl_dbg_trap_set_wave_launch_override_args { + __u32 override_mode; + __u32 enable_mask; + __u32 support_request_mask; + __u32 pad; +}; + +/** + * kfd_ioctl_dbg_trap_set_wave_launch_mode_args + * + * Arguments for KFD_IOC_DBG_TRAP_SET_WAVE_LAUNCH_MODE + * Set wave launch mode. + * + * @mode (IN) - see kfd_dbg_trap_wave_launch_mode + * + * Generic errors apply (see kfd_dbg_trap_operations). + * Return - 0 on SUCCESS. + */ +struct kfd_ioctl_dbg_trap_set_wave_launch_mode_args { + __u32 launch_mode; + __u32 pad; +}; + +/** + * kfd_ioctl_dbg_trap_suspend_queues_ags + * + * Arguments for KFD_IOC_DBG_TRAP_SUSPEND_QUEUES + * Suspend queues. + * + * @exception_mask (IN) - raised exceptions to clear + * @queue_array_ptr (IN) - pointer to array of queue ids (u32 per queue id) + * to suspend + * @num_queues (IN) - number of queues to suspend in @queue_array_ptr + * @grace_period (IN) - wave time allowance before preemption + * per 1K GPU clock cycle unit + * + * Generic errors apply (see kfd_dbg_trap_operations). + * Destruction of a suspended queue is blocked until the queue is + * resumed. This allows the debugger to access queue information and + * the its context save area without running into a race condition on + * queue destruction. + * Automatically copies per queue context save area header information + * into the save area base + * (see kfd_queue_snapshot_entry and kfd_context_save_area_header). + * + * Return - Number of queues suspended on SUCCESS. + * . KFD_DBG_QUEUE_ERROR_MASK and KFD_DBG_QUEUE_INVALID_MASK masked + * for each queue id in @queue_array_ptr array reports unsuccessful + * suspend reason. + * KFD_DBG_QUEUE_ERROR_MASK = HW failure. + * KFD_DBG_QUEUE_INVALID_MASK = queue does not exist, is new or + * is being destroyed. + */ +struct kfd_ioctl_dbg_trap_suspend_queues_args { + __u64 exception_mask; + __u64 queue_array_ptr; + __u32 num_queues; + __u32 grace_period; +}; + +/** + * kfd_ioctl_dbg_trap_resume_queues_args + * + * Arguments for KFD_IOC_DBG_TRAP_RESUME_QUEUES + * Resume queues. + * + * @queue_array_ptr (IN) - pointer to array of queue ids (u32 per queue id) + * to resume + * @num_queues (IN) - number of queues to resume in @queue_array_ptr + * + * Generic errors apply (see kfd_dbg_trap_operations). + * Return - Number of queues resumed on SUCCESS. + * KFD_DBG_QUEUE_ERROR_MASK and KFD_DBG_QUEUE_INVALID_MASK mask + * for each queue id in @queue_array_ptr array reports unsuccessful + * resume reason. + * KFD_DBG_QUEUE_ERROR_MASK = HW failure. + * KFD_DBG_QUEUE_INVALID_MASK = queue does not exist. + */ +struct kfd_ioctl_dbg_trap_resume_queues_args { + __u64 queue_array_ptr; + __u32 num_queues; + __u32 pad; +}; + +/** + * kfd_ioctl_dbg_trap_set_node_address_watch_args + * + * Arguments for KFD_IOC_DBG_TRAP_SET_NODE_ADDRESS_WATCH + * Sets address watch for device. + * + * @address (IN) - watch address to set + * @mode (IN) - see kfd_dbg_trap_address_watch_mode + * @mask (IN) - watch address mask + * @gpu_id (IN) - target gpu to set watch point + * @id (OUT) - watch id allocated + * + * Generic errors apply (see kfd_dbg_trap_operations). + * Return - 0 on SUCCESS. + * Allocated watch ID returned to @id. + * - ENODEV if gpu_id not found. + * - ENOMEM if watch IDs can be allocated + */ +struct kfd_ioctl_dbg_trap_set_node_address_watch_args { + __u64 address; + __u32 mode; + __u32 mask; + __u32 gpu_id; + __u32 id; +}; + +/** + * kfd_ioctl_dbg_trap_clear_node_address_watch_args + * + * Arguments for KFD_IOC_DBG_TRAP_CLEAR_NODE_ADDRESS_WATCH + * Clear address watch for device. + * + * @gpu_id (IN) - target device to clear watch point + * @id (IN) - allocated watch id to clear + * + * Generic errors apply (see kfd_dbg_trap_operations). + * Return - 0 on SUCCESS. + * - ENODEV if gpu_id not found. + * - EINVAL if watch ID has not been allocated. + */ +struct kfd_ioctl_dbg_trap_clear_node_address_watch_args { + __u32 gpu_id; + __u32 id; +}; + +/** + * kfd_ioctl_dbg_trap_set_flags_args + * + * Arguments for KFD_IOC_DBG_TRAP_SET_FLAGS + * Sets flags for wave behaviour. + * + * @flags (IN/OUT) - IN = flags to enable, OUT = flags previously enabled + * + * Generic errors apply (see kfd_dbg_trap_operations). + * Return - 0 on SUCCESS. + * - EACCESS if any debug device does not allow flag options. + */ +struct kfd_ioctl_dbg_trap_set_flags_args { + __u32 flags; + __u32 pad; +}; + +/** + * kfd_ioctl_dbg_trap_query_debug_event_args + * + * Arguments for KFD_IOC_DBG_TRAP_QUERY_DEBUG_EVENT + * + * Find one or more raised exceptions. This function can return multiple + * exceptions from a single queue or a single device with one call. To find + * all raised exceptions, this function must be called repeatedly until it + * returns -EAGAIN. Returned exceptions can optionally be cleared by + * setting the corresponding bit in the @exception_mask input parameter. + * However, clearing an exception prevents retrieving further information + * about it with KFD_IOC_DBG_TRAP_QUERY_EXCEPTION_INFO. + * + * @exception_mask (IN/OUT) - exception to clear (IN) and raised (OUT) + * @gpu_id (OUT) - gpu id of exceptions raised + * @queue_id (OUT) - queue id of exceptions raised + * + * Generic errors apply (see kfd_dbg_trap_operations). + * Return - 0 on raised exception found + * Raised exceptions found are returned in @exception mask + * with reported source id returned in @gpu_id or @queue_id. + * - EAGAIN if no raised exception has been found + */ +struct kfd_ioctl_dbg_trap_query_debug_event_args { + __u64 exception_mask; + __u32 gpu_id; + __u32 queue_id; +}; + +/** + * kfd_ioctl_dbg_trap_query_exception_info_args + * + * Arguments KFD_IOC_DBG_TRAP_QUERY_EXCEPTION_INFO + * Get additional info on raised exception. + * + * @info_ptr (IN) - pointer to exception info buffer to copy to + * @info_size (IN/OUT) - exception info buffer size (bytes) + * @source_id (IN) - target gpu or queue id + * @exception_code (IN) - target exception + * @clear_exception (IN) - clear raised @exception_code exception + * (0 = false, 1 = true) + * + * Generic errors apply (see kfd_dbg_trap_operations). + * Return - 0 on SUCCESS. + * If @exception_code is EC_DEVICE_MEMORY_VIOLATION, copy @info_size(OUT) + * bytes of memory exception data to @info_ptr. + * If @exception_code is EC_PROCESS_RUNTIME, copy saved + * kfd_runtime_info to @info_ptr. + * Actual required @info_ptr size (bytes) is returned in @info_size. + */ +struct kfd_ioctl_dbg_trap_query_exception_info_args { + __u64 info_ptr; + __u32 info_size; + __u32 source_id; + __u32 exception_code; + __u32 clear_exception; +}; + +/** + * kfd_ioctl_dbg_trap_get_queue_snapshot_args + * + * Arguments KFD_IOC_DBG_TRAP_GET_QUEUE_SNAPSHOT + * Get queue information. + * + * @exception_mask (IN) - exceptions raised to clear + * @snapshot_buf_ptr (IN) - queue snapshot entry buffer (see kfd_queue_snapshot_entry) + * @num_queues (IN/OUT) - number of queue snapshot entries + * The debugger specifies the size of the array allocated in @num_queues. + * KFD returns the number of queues that actually existed. If this is + * larger than the size specified by the debugger, KFD will not overflow + * the array allocated by the debugger. + * + * @entry_size (IN/OUT) - size per entry in bytes + * The debugger specifies sizeof(struct kfd_queue_snapshot_entry) in + * @entry_size. KFD returns the number of bytes actually populated per + * entry. The debugger should use the KFD_IOCTL_MINOR_VERSION to determine, + * which fields in struct kfd_queue_snapshot_entry are valid. This allows + * growing the ABI in a backwards compatible manner. + * Note that entry_size(IN) should still be used to stride the snapshot buffer in the + * event that it's larger than actual kfd_queue_snapshot_entry. + * + * Generic errors apply (see kfd_dbg_trap_operations). + * Return - 0 on SUCCESS. + * Copies @num_queues(IN) queue snapshot entries of size @entry_size(IN) + * into @snapshot_buf_ptr if @num_queues(IN) > 0. + * Otherwise return @num_queues(OUT) queue snapshot entries that exist. + */ +struct kfd_ioctl_dbg_trap_queue_snapshot_args { + __u64 exception_mask; + __u64 snapshot_buf_ptr; + __u32 num_queues; + __u32 entry_size; +}; + +/** + * kfd_ioctl_dbg_trap_get_device_snapshot_args + * + * Arguments for KFD_IOC_DBG_TRAP_GET_DEVICE_SNAPSHOT + * Get device information. + * + * @exception_mask (IN) - exceptions raised to clear + * @snapshot_buf_ptr (IN) - pointer to snapshot buffer (see kfd_dbg_device_info_entry) + * @num_devices (IN/OUT) - number of debug devices to snapshot + * The debugger specifies the size of the array allocated in @num_devices. + * KFD returns the number of devices that actually existed. If this is + * larger than the size specified by the debugger, KFD will not overflow + * the array allocated by the debugger. + * + * @entry_size (IN/OUT) - size per entry in bytes + * The debugger specifies sizeof(struct kfd_dbg_device_info_entry) in + * @entry_size. KFD returns the number of bytes actually populated. The + * debugger should use KFD_IOCTL_MINOR_VERSION to determine, which fields + * in struct kfd_dbg_device_info_entry are valid. This allows growing the + * ABI in a backwards compatible manner. + * Note that entry_size(IN) should still be used to stride the snapshot buffer in the + * event that it's larger than actual kfd_dbg_device_info_entry. + * + * Generic errors apply (see kfd_dbg_trap_operations). + * Return - 0 on SUCCESS. + * Copies @num_devices(IN) device snapshot entries of size @entry_size(IN) + * into @snapshot_buf_ptr if @num_devices(IN) > 0. + * Otherwise return @num_devices(OUT) queue snapshot entries that exist. + */ +struct kfd_ioctl_dbg_trap_device_snapshot_args { + __u64 exception_mask; + __u64 snapshot_buf_ptr; + __u32 num_devices; + __u32 entry_size; +}; + +/** + * kfd_ioctl_dbg_trap_args + * + * Arguments to debug target process. + * + * @pid - target process to debug + * @op - debug operation (see kfd_dbg_trap_operations) + * + * @op determines which union struct args to use. + * Refer to kern docs for each kfd_ioctl_dbg_trap_*_args struct. + */ +struct kfd_ioctl_dbg_trap_args { + __u32 pid; + __u32 op; + + union { + struct kfd_ioctl_dbg_trap_enable_args enable; + struct kfd_ioctl_dbg_trap_send_runtime_event_args send_runtime_event; + struct kfd_ioctl_dbg_trap_set_exceptions_enabled_args set_exceptions_enabled; + struct kfd_ioctl_dbg_trap_set_wave_launch_override_args launch_override; + struct kfd_ioctl_dbg_trap_set_wave_launch_mode_args launch_mode; + struct kfd_ioctl_dbg_trap_suspend_queues_args suspend_queues; + struct kfd_ioctl_dbg_trap_resume_queues_args resume_queues; + struct kfd_ioctl_dbg_trap_set_node_address_watch_args set_node_address_watch; + struct kfd_ioctl_dbg_trap_clear_node_address_watch_args clear_node_address_watch; + struct kfd_ioctl_dbg_trap_set_flags_args set_flags; + struct kfd_ioctl_dbg_trap_query_debug_event_args query_debug_event; + struct kfd_ioctl_dbg_trap_query_exception_info_args query_exception_info; + struct kfd_ioctl_dbg_trap_queue_snapshot_args queue_snapshot; + struct kfd_ioctl_dbg_trap_device_snapshot_args device_snapshot; + }; +}; + +/** + * kfd_ioctl_pc_sample_op - PC Sampling ioctl operations + * + * @KFD_IOCTL_PCS_OP_QUERY_CAPABILITIES: Query device PC Sampling capabilities + * @KFD_IOCTL_PCS_OP_CREATE: Register this process with a per-device PC sampler instance + * @KFD_IOCTL_PCS_OP_DESTROY: Unregister from a previously registered PC sampler instance + * @KFD_IOCTL_PCS_OP_START: Process begins taking samples from a previously registered PC sampler instance + * @KFD_IOCTL_PCS_OP_STOP: Process stops taking samples from a previously registered PC sampler instance + */ +enum kfd_ioctl_pc_sample_op { + KFD_IOCTL_PCS_OP_QUERY_CAPABILITIES, + KFD_IOCTL_PCS_OP_CREATE, + KFD_IOCTL_PCS_OP_DESTROY, + KFD_IOCTL_PCS_OP_START, + KFD_IOCTL_PCS_OP_STOP, +}; + +/* Values have to be a power of 2*/ +#define KFD_IOCTL_PCS_FLAG_POWER_OF_2 0x00000001 + +enum kfd_ioctl_pc_sample_method { + KFD_IOCTL_PCS_METHOD_HOSTTRAP = 1, + KFD_IOCTL_PCS_METHOD_STOCHASTIC, +}; + +enum kfd_ioctl_pc_sample_type { + KFD_IOCTL_PCS_TYPE_TIME_US, + KFD_IOCTL_PCS_TYPE_CLOCK_CYCLES, + KFD_IOCTL_PCS_TYPE_INSTRUCTIONS +}; + +struct kfd_pc_sample_info { + __u64 interval; /* [IN] if PCS_TYPE_INTERVAL_US: sample interval in us + * if PCS_TYPE_CLOCK_CYCLES: sample interval in graphics core clk cycles + * if PCS_TYPE_INSTRUCTIONS: sample interval in instructions issued by + * graphics compute units + */ + __u64 interval_min; /* [OUT] */ + __u64 interval_max; /* [OUT] */ + __u64 flags; /* [OUT] indicate potential restrictions e.g FLAG_POWER_OF_2 */ + __u32 method; /* [IN/OUT] kfd_ioctl_pc_sample_method */ + __u32 type; /* [IN/OUT] kfd_ioctl_pc_sample_type */ +}; + +#define KFD_IOCTL_PCS_QUERY_TYPE_FULL (1 << 0) /* If not set, return current */ + +struct kfd_ioctl_pc_sample_args { + __u64 sample_info_ptr; /* array of kfd_pc_sample_info */ + __u32 num_sample_info; + __u32 op; /* kfd_ioctl_pc_sample_op */ + __u32 gpu_id; + __u32 trace_id; + __u32 flags; /* kfd_ioctl_pcs_query flags */ + __u32 version; +}; + +#define KFD_IOC_PROFILER_VERSION_NUM 1 +enum kfd_profiler_ops { + KFD_IOC_PROFILER_PMC = 0, + KFD_IOC_PROFILER_PC_SAMPLE = 1, + KFD_IOC_PROFILER_VERSION = 2, +}; + +/** + * Enables/Disables GPU Specific profiler settings + */ +struct kfd_ioctl_pmc_settings { + __u32 gpu_id; /* This is the user_gpu_id */ + __u32 lock; /* Lock GPU for Profiling */ + __u32 perfcount_enable; /* Force Perfcount Enable for queues on GPU */ +}; + +struct kfd_ioctl_profiler_args { + __u32 op; /* kfd_profiler_op */ + union { + struct kfd_ioctl_pc_sample_args pc_sample; + struct kfd_ioctl_pmc_settings pmc; + __u32 version; /* KFD_IOC_PROFILER_VERSION_NUM */ + }; +}; + #define AMDKFD_IOCTL_BASE 'K' #define AMDKFD_IO(nr) _IO(AMDKFD_IOCTL_BASE, nr) #define AMDKFD_IOR(nr, type) _IOR(AMDKFD_IOCTL_BASE, nr, type) @@ -679,16 +1780,16 @@ struct kfd_ioctl_set_xnack_mode_args { #define AMDKFD_IOC_WAIT_EVENTS \ AMDKFD_IOWR(0x0C, struct kfd_ioctl_wait_events_args) -#define AMDKFD_IOC_DBG_REGISTER \ +#define AMDKFD_IOC_DBG_REGISTER_DEPRECATED \ AMDKFD_IOW(0x0D, struct kfd_ioctl_dbg_register_args) -#define AMDKFD_IOC_DBG_UNREGISTER \ +#define AMDKFD_IOC_DBG_UNREGISTER_DEPRECATED \ AMDKFD_IOW(0x0E, struct kfd_ioctl_dbg_unregister_args) -#define AMDKFD_IOC_DBG_ADDRESS_WATCH \ +#define AMDKFD_IOC_DBG_ADDRESS_WATCH_DEPRECATED \ AMDKFD_IOW(0x0F, struct kfd_ioctl_dbg_address_watch_args) -#define AMDKFD_IOC_DBG_WAVE_CONTROL \ +#define AMDKFD_IOC_DBG_WAVE_CONTROL_DEPRECATED \ AMDKFD_IOW(0x10, struct kfd_ioctl_dbg_wave_control_args) #define AMDKFD_IOC_SET_SCRATCH_BACKING_VA \ @@ -742,7 +1843,47 @@ struct kfd_ioctl_set_xnack_mode_args { #define AMDKFD_IOC_SET_XNACK_MODE \ AMDKFD_IOWR(0x21, struct kfd_ioctl_set_xnack_mode_args) +#define AMDKFD_IOC_CRIU_OP \ + AMDKFD_IOWR(0x22, struct kfd_ioctl_criu_args) + +#define AMDKFD_IOC_AVAILABLE_MEMORY \ + AMDKFD_IOWR(0x23, struct kfd_ioctl_get_available_memory_args) + +#define AMDKFD_IOC_EXPORT_DMABUF \ + AMDKFD_IOWR(0x24, struct kfd_ioctl_export_dmabuf_args) + +#define AMDKFD_IOC_RUNTIME_ENABLE \ + AMDKFD_IOWR(0x25, struct kfd_ioctl_runtime_enable_args) + +#define AMDKFD_IOC_DBG_TRAP \ + AMDKFD_IOWR(0x26, struct kfd_ioctl_dbg_trap_args) + #define AMDKFD_COMMAND_START 0x01 -#define AMDKFD_COMMAND_END 0x22 +#define AMDKFD_COMMAND_END 0x27 + +/* non-upstream ioctls */ +#define AMDKFD_IOC_IPC_IMPORT_HANDLE \ + AMDKFD_IOWR(0x80, struct kfd_ioctl_ipc_import_handle_args) + +#define AMDKFD_IOC_IPC_EXPORT_HANDLE \ + AMDKFD_IOWR(0x81, struct kfd_ioctl_ipc_export_handle_args) + +#define AMDKFD_IOC_DBG_TRAP_DEPRECATED \ + AMDKFD_IOWR(0x82, struct kfd_ioctl_dbg_trap_args_deprecated) + +#define AMDKFD_IOC_CROSS_MEMORY_COPY_DEPRECATED \ + AMDKFD_IOWR(0x83, struct kfd_ioctl_cross_memory_copy_deprecated_args) + +#define AMDKFD_IOC_RLC_SPM \ + AMDKFD_IOWR(0x84, struct kfd_ioctl_spm_args) + +#define AMDKFD_IOC_PC_SAMPLE \ + AMDKFD_IOWR(0x85, struct kfd_ioctl_pc_sample_args) + +#define AMDKFD_IOC_PROFILER \ + AMDKFD_IOWR(0x86, struct kfd_ioctl_profiler_args) + +#define AMDKFD_COMMAND_START_2 0x80 +#define AMDKFD_COMMAND_END_2 0x87 #endif diff --git a/test/mockgpu/amd/amddriver.py b/test/mockgpu/amd/amddriver.py index 9933fbba24..69dcd01bd1 100644 --- a/test/mockgpu/amd/amddriver.py +++ b/test/mockgpu/amd/amddriver.py @@ -17,7 +17,7 @@ def ioctls_from_header(): pattern = r'#define\s+(AMDKFD_IOC_[A-Z0-9_]+)\s+AMDKFD_(IOW?R?)\((0x[0-9a-fA-F]+),\s+struct\s([A-Za-z0-9_]+)\)' matches = re.findall(pattern, hdr, re.MULTILINE) return type("KFD_IOCTLS", (object, ), {name: int(nr, 0x10) for name, _, nr, _ in matches}), \ - {int(nr, 0x10): getattr(kfd, "struct_"+sname) for name, idir, nr, sname in matches} + {int(nr, 0x10): getattr(kfd, "struct_"+sname, None) for name, idir, nr, sname in matches} kfd_ioctls, kfd_headers = ioctls_from_header() class KFDFileDesc(VirtFileDesc): @@ -115,6 +115,10 @@ class AMDDriver(VirtDriver): struct = kfd_headers[nr].from_address(argp) if nr == kfd_ioctls.AMDKFD_IOC_ACQUIRE_VM: pass + elif nr == kfd_ioctls.AMDKFD_IOC_RUNTIME_ENABLE: pass + elif nr == kfd_ioctls.AMDKFD_IOC_GET_VERSION: + struct.major_version = 1 + struct.minor_version = 14 elif nr == kfd_ioctls.AMDKFD_IOC_ALLOC_MEMORY_OF_GPU: if struct.gpu_id not in self.gpus: return -1 struct.handle = self._alloc_handle() diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index b6b1776730..41756849c2 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -557,7 +557,9 @@ class KFDIface: for i in FileIOInterface(f'{ip_base}/{hw}').listdir()} for ip,hw in ip_hw } self.drm_fd = FileIOInterface(f"/dev/dri/renderD{self.props['drm_render_minor']}", os.O_RDWR) + self.kfd_ver = ((ver_st:=kfd.AMDKFD_IOC_GET_VERSION(KFDIface.kfd)).major_version, ver_st.minor_version) kfd.AMDKFD_IOC_ACQUIRE_VM(KFDIface.kfd, drm_fd=self.drm_fd.fd, gpu_id=self.gpu_id) + if self.kfd_ver >= (1,14): kfd.AMDKFD_IOC_RUNTIME_ENABLE(KFDIface.kfd, mode_mask=0) # Set these for our device. if KFDIface.event_page is None: