mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
nv: generic mmu impl (#11179)
This commit is contained in:
@@ -36,23 +36,23 @@ class NVPageTableEntry:
|
||||
|
||||
def set_entry(self, entry_id:int, paddr:int, table=False, uncached=False, system=False, snooped=False, frag=0, valid=True):
|
||||
if not table:
|
||||
x = self.nvdev.NV_MMU_VER2_PTE.encode(valid=valid, address_sys=paddr >> 12, aperture=2 if system else 0, vol=uncached, kind=6)
|
||||
elif self._is_dual_pde(): # Dual PDE
|
||||
x = self.nvdev.NV_MMU_VER2_DUAL_PDE.encode(is_pte=False, address_small_sys=paddr >> 12, aperture_small=1 if valid else 0, vol_small=0, no_ats=1)
|
||||
x = self.nvdev.pte_t.encode(valid=valid, address_sys=paddr >> 12, aperture=2 if system else 0, kind=6,
|
||||
**({'pcf': int(uncached)} if self.nvdev.mmu_ver == 3 else {'vol': uncached}))
|
||||
else:
|
||||
x = self.nvdev.NV_MMU_VER2_PDE.encode(is_pte=False, address_sys=paddr >> 12, aperture=1 if valid else 0, vol=0, no_ats=1)
|
||||
pde = self.nvdev.dual_pde_t if self._is_dual_pde() else self.nvdev.pde_t
|
||||
small, sys = ("_small" if self._is_dual_pde() else ""), "" if self.nvdev.mmu_ver == 3 else "_sys"
|
||||
x = pde.encode(is_pte=False, **{f'aperture{small}': 1 if valid else 0, f'address{small}{sys}': paddr >> 12},
|
||||
**({f'pcf{small}': 0b10} if self.nvdev.mmu_ver == 3 else {'no_ats': 1}))
|
||||
|
||||
if self._is_dual_pde():
|
||||
self.entries[2*entry_id] = x & 0xffffffffffffffff
|
||||
self.entries[2*entry_id+1] = x >> 64
|
||||
if self._is_dual_pde(): self.entries[2*entry_id], self.entries[2*entry_id+1] = x & 0xffffffffffffffff, x >> 64
|
||||
else: self.entries[entry_id] = x
|
||||
|
||||
def entry(self, entry_id:int) -> int:
|
||||
return (self.entries[2*entry_id+1]<<64) | self.entries[2*entry_id] if self._is_dual_pde() else self.entries[entry_id]
|
||||
|
||||
def read_fields(self, entry_id:int) -> dict:
|
||||
if self.is_huge_page(entry_id): return self.nvdev.NV_MMU_VER2_PTE.decode(self.entry(entry_id))
|
||||
return (self.nvdev.NV_MMU_VER2_DUAL_PDE if self._is_dual_pde() else self.nvdev.NV_MMU_VER2_PDE).decode(self.entry(entry_id))
|
||||
if self.is_huge_page(entry_id): return self.nvdev.pte_t.decode(self.entry(entry_id))
|
||||
return (self.nvdev.dual_pde_t if self._is_dual_pde() else self.nvdev.pde_t).decode(self.entry(entry_id))
|
||||
|
||||
def is_huge_page(self, entry_id) -> bool: return (self.entry(entry_id) & 1 == 1) if self.lv < self.nvdev.mm.level_cnt - 1 else True
|
||||
def supports_huge_page(self, paddr:int): return self.lv >= self.nvdev.mm.level_cnt - 3 and paddr % self.nvdev.mm.pte_covers[self.lv] == 0
|
||||
@@ -61,7 +61,9 @@ class NVPageTableEntry:
|
||||
if self.is_huge_page(entry_id): return self.read_fields(entry_id)['valid']
|
||||
return self.read_fields(entry_id)['aperture_small' if self._is_dual_pde() else 'aperture'] != 0
|
||||
|
||||
def address(self, entry_id:int) -> int: return self.read_fields(entry_id)['address_small_sys' if self._is_dual_pde() else 'address_sys'] << 12
|
||||
def address(self, entry_id:int) -> int:
|
||||
small, sys = ("_small" if self._is_dual_pde() else ""), "_sys" if self.nvdev.mmu_ver == 2 or self.lv == self.nvdev.mm.level_cnt - 1 else ""
|
||||
return self.read_fields(entry_id)[f'address{small}{sys}'] << 12
|
||||
|
||||
class NVMemoryManager(MemoryManager):
|
||||
va_allocator = TLSFAllocator((1 << 44), base=1 << 30) # global for all devices.
|
||||
@@ -108,7 +110,8 @@ class NVDev(PCIDevImplBase):
|
||||
self.include("src/common/inc/swref/published/nv_ref.h")
|
||||
self.chip_id = self.reg("NV_PMC_BOOT_0").read()
|
||||
self.chip_details = self.reg("NV_PMC_BOOT_42").read_bitfields()
|
||||
self.chip_name = {0x17: "GA", 0x19: "AD"}[self.chip_details['architecture']] + str(100+self.chip_details['implementation'])
|
||||
self.chip_name = {0x17: "GA1", 0x19: "AD1", 0x1b: "GB2"}[self.chip_details['architecture']] + f"{self.chip_details['implementation']:02d}"
|
||||
self.mmu_ver = 3 if self.chip_details['architecture'] >= 0x1a else 2
|
||||
|
||||
self.include("src/common/inc/swref/published/turing/tu102/dev_fb.h")
|
||||
if self.reg("NV_PFB_PRI_MMU_WPR2_ADDR_HI").read() != 0:
|
||||
@@ -121,9 +124,10 @@ class NVDev(PCIDevImplBase):
|
||||
self.include("src/common/inc/swref/published/ampere/ga102/dev_gc6_island_addendum.h")
|
||||
|
||||
# MMU Init
|
||||
self.reg_names.update(['NV_MMU_VER2_PTE', 'NV_MMU_VER2_PDE', 'NV_MMU_VER2_DUAL_PDE'])
|
||||
for name in ['NV_MMU_VER2_PTE', 'NV_MMU_VER2_PDE', 'NV_MMU_VER2_DUAL_PDE']: self.__dict__[name] = NVReg(self, None, None, fields={})
|
||||
self.include("kernel-open/nvidia-uvm/hwref/turing/tu102/dev_mmu.h")
|
||||
self.reg_names.update(mmu_pd_names:=[f'NV_MMU_VER{self.mmu_ver}_PTE', f'NV_MMU_VER{self.mmu_ver}_PDE', f'NV_MMU_VER{self.mmu_ver}_DUAL_PDE'])
|
||||
for name in mmu_pd_names: self.__dict__[name] = NVReg(self, None, None, fields={})
|
||||
self.include(f"kernel-open/nvidia-uvm/hwref/{'hopper/gh100' if self.mmu_ver == 3 else 'turing/tu102'}/dev_mmu.h")
|
||||
self.pte_t, self.pde_t, self.dual_pde_t = tuple([self.__dict__[name] for name in mmu_pd_names])
|
||||
|
||||
self.vram_size = self.reg("NV_PGC6_AON_SECURE_SCRATCH_GROUP_42").read() << 20
|
||||
|
||||
|
||||
Reference in New Issue
Block a user