From 08ca871d77ae52ca46c588357c93e4240a8c8992 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Mon, 20 Jan 2025 18:05:22 +0300 Subject: [PATCH 01/18] am: remove pm block (#8688) * am: remove pm block * hm * oops --- test/external/external_test_am.py | 3 ++ tinygrad/runtime/support/am/amdev.py | 38 ++++++-------- tinygrad/runtime/support/am/ip.py | 76 ++++++++++++++-------------- 3 files changed, 58 insertions(+), 59 deletions(-) diff --git a/test/external/external_test_am.py b/test/external/external_test_am.py index ad5159d4ef..985554623e 100644 --- a/test/external/external_test_am.py +++ b/test/external/external_test_am.py @@ -1,5 +1,6 @@ import unittest from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableTraverseContext +from tinygrad.helpers import mv_address class FakeGMC: def __init__(self): self.vm_base = 0x0 @@ -19,6 +20,8 @@ class FakeAM: self.gmc = FakeGMC() self.mm = AMMemoryManager(self, vram_size=4 << 30) self.is_booting = False + def paddr2cpu(self, paddr:int) -> int: return paddr + mv_address(self.vram) + def paddr2mc(self, paddr:int) -> int: return paddr # * PTE format: # * 63:59 reserved diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index b87ff02666..2397ae200f 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -98,20 +98,14 @@ class AMFirmware: def desc(self, typ:int, blob:memoryview, offset:int, size:int) -> tuple[int, memoryview]: return (typ, blob[offset:offset+size]) -class AMPhysicalMemoryBlock: - def __init__(self, adev:AMDev, paddr:int, size:int): self.adev, self.paddr, self.size = adev, paddr, size - def mc_addr(self): return self.adev.gmc.mc_base + self.paddr - def cpu_addr(self): return mv_address(self.adev.vram) + self.paddr - def cpu_view(self): return to_mv(self.cpu_addr(), self.size) - @dataclasses.dataclass(frozen=True) class AMMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702 class AMPageTableEntry: - def __init__(self, pm, lv): self.pm, self.view, self.lv = pm, pm.cpu_view().cast('Q'), lv + def __init__(self, adev, paddr, lv): self.paddr, self.view, self.lv = paddr, to_mv(adev.paddr2cpu(paddr), 0x1000).cast('Q'), lv def set_table(self, entry_id, pte:AMPageTableEntry, valid=True): - self.view[entry_id] = (pte.pm.paddr & 0x0000FFFFFFFFF000) | (am.AMDGPU_PTE_VALID if valid else 0) + self.view[entry_id] = (pte.paddr & 0x0000FFFFFFFFF000) | (am.AMDGPU_PTE_VALID if valid else 0) def set_page(self, entry_id, paddr, uncached=False, system=False, snooped=False, frag=0, valid=True): f = (am.AMDGPU_PTE_VALID if valid else 0) | am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE \ @@ -133,11 +127,11 @@ class AMPageTableTraverseContext: def level_down(self): pt, pte_idx, _ = self.pt_stack[-1] if (entry:=pt.get_entry(pte_idx)) & am.AMDGPU_PTE_VALID: - assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.pm.paddr:#x}, {pte_idx=} {entry=:#x}" - child_page_table = AMPageTableEntry(AMPhysicalMemoryBlock(pt.pm.adev, entry & 0x0000FFFFFFFFF000, 0x1000), lv=pt.lv+1) + assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}" + child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1) else: assert self.create_pts, "Not allowed to create new page table" - pt.set_table(pte_idx, child_page_table:=AMPageTableEntry(self.adev.mm.palloc(0x1000, zero=True), lv=pt.lv+1)) + pt.set_table(pte_idx, child_page_table:=AMPageTableEntry(self.adev, self.adev.mm.palloc(0x1000, zero=True), lv=pt.lv+1)) self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table))) return self.pt_stack[-1] @@ -145,7 +139,7 @@ class AMPageTableTraverseContext: def _try_free_pt(self) -> bool: pt, _, _ = self.pt_stack[-1] if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.get_entry(i) & am.AMDGPU_PTE_VALID == 0 for i in range(512)): - self.adev.mm.pfree(AMPhysicalMemoryBlock(self.adev, pt.pm.paddr, 0x1000)) + self.adev.mm.pfree(pt.paddr) parent_pt, parent_pte_idx, _ = self.pt_stack[-2] parent_pt.set_page(parent_pte_idx, 0x0, valid=False) return True @@ -179,7 +173,7 @@ class AMMemoryManager: self.adev, self.vram_size = adev, vram_size self.boot_allocator = TLSFAllocator(32 << 20, base=vram_size - (64 << 20)) # per device self.pa_allocator = TLSFAllocator(vram_size - (64 << 20)) # per device - self.root_page_table = AMPageTableEntry(self.palloc(0x1000, zero=True, boot=True), lv=am.AMDGPU_VM_PDB1) + self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=True, boot=True), lv=am.AMDGPU_VM_PDB1) def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False) -> AMMapping: assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}" @@ -213,12 +207,12 @@ class AMMemoryManager: # Alloc physical memory and map it to the virtual address va = self.alloc_vaddr(size, align) - if contigous: paddrs = [(self.palloc(size, zero=True).paddr, size)] + if contigous: paddrs = [(self.palloc(size, zero=True), size)] else: paddrs = [] try: ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, va, create_pts=True) - for _, _, _, seg_cnt, seg_size in ctx.next(size): paddrs += [(self.palloc(seg_size, zero=False).paddr, seg_size) for _ in range(seg_cnt)] + for _, _, _, seg_cnt, seg_size in ctx.next(size): paddrs += [(self.palloc(seg_size, zero=False), seg_size) for _ in range(seg_cnt)] except MemoryError: for paddr, _ in paddrs: self.pa_allocator.free(paddr) raise @@ -230,13 +224,13 @@ class AMMemoryManager: self.va_allocator.free(vm.va_addr) for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr) - def palloc(self, size, align=0x1000, zero=True, boot=False) -> AMPhysicalMemoryBlock: + def palloc(self, size:int, align:int=0x1000, zero=True, boot=False) -> int: assert self.adev.is_booting == boot, "During booting, only boot memory can be allocated" - pm = AMPhysicalMemoryBlock(self.adev, (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align), size) - if zero: ctypes.memset(pm.cpu_addr(), 0, pm.size) - return pm + paddr = (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align) + if zero: ctypes.memset(self.adev.paddr2cpu(paddr), 0, size) + return paddr - def pfree(self, pm:AMPhysicalMemoryBlock): self.pa_allocator.free(pm.paddr) + def pfree(self, paddr:int): self.pa_allocator.free(paddr) class AMDev: def __init__(self, pcidev, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview): @@ -309,6 +303,7 @@ class AMDev: for ip in [self.sdma, self.gfx]: ip.fini() def paddr2cpu(self, paddr:int) -> int: return mv_address(self.vram) + paddr + def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr def ip_base(self, ip:str, inst:int, seg:int) -> int: return self.regs_offset[am.__dict__[f"{ip}_HWIP"]][inst][seg] @@ -348,9 +343,8 @@ class AMDev: # The table is located at the end of VRAM - 64KB and is 10KB in size. mmRCC_CONFIG_MEMSIZE = 0xde3 self.vram_size = self.rreg(mmRCC_CONFIG_MEMSIZE) << 20 - self.discovery_pm = AMPhysicalMemoryBlock(self, self.vram_size - (64 << 10), 10 << 10) - bhdr = am.struct_binary_header.from_address(self.discovery_pm.cpu_addr()) + bhdr = am.struct_binary_header.from_address(self.paddr2cpu(self.vram_size - (64 << 10))) ihdr = am.struct_ip_discovery_header.from_address(ctypes.addressof(bhdr) + bhdr.table_list[am.IP_DISCOVERY].offset) assert ihdr.signature == am.DISCOVERY_TABLE_SIGNATURE and not ihdr.base_addr_64_bit, f"0x{ihdr.signature:X} != 0x{am.DISCOVERY_TABLE_SIGNATURE:X}" diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 3c831511b5..79c4de0a24 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -25,8 +25,8 @@ class AM_GMC(AM_IP): self.vm_base = self.adev.mm.va_allocator.base self.vm_end = self.vm_base + self.adev.mm.va_allocator.size - 1 - self.memscratch_pm = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True) - self.dummy_page_pm = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True) + self.memscratch_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True) + self.dummy_page_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True) self.hub_initted = {"MM": False, "GC": False} def init(self): self.init_hub("MM") @@ -55,7 +55,7 @@ class AM_GMC(AM_IP): def enable_vm_addressing(self, page_table, ip:Literal["MM", "GC"], vmid): self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_START_ADDR", "_LO32", "_HI32", self.vm_base >> 12) self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_END_ADDR", "_LO32", "_HI32", self.vm_end >> 12) - self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_BASE_ADDR", "_LO32", "_HI32", page_table.pm.paddr | 1) + self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_BASE_ADDR", "_LO32", "_HI32", page_table.paddr | 1) self.adev.reg(f"reg{ip}VM_CONTEXT{vmid}_CNTL").write(0x1fffe00, enable_context=1, page_table_depth=(3 - page_table.lv)) def init_hub(self, ip:Literal["MM", "GC"]): @@ -66,8 +66,8 @@ class AM_GMC(AM_IP): self.adev.reg(f"reg{ip}MC_VM_SYSTEM_APERTURE_LOW_ADDR").write(self.mc_base >> 18) self.adev.reg(f"reg{ip}MC_VM_SYSTEM_APERTURE_HIGH_ADDR").write(self.mc_end >> 18) - self.adev.wreg_pair(f"reg{ip}MC_VM_SYSTEM_APERTURE_DEFAULT_ADDR", "_LSB", "_MSB", self.memscratch_pm.paddr >> 12) - self.adev.wreg_pair(f"reg{ip}VM_L2_PROTECTION_FAULT_DEFAULT_ADDR", "_LO32", "_HI32", self.dummy_page_pm.paddr >> 12) + self.adev.wreg_pair(f"reg{ip}MC_VM_SYSTEM_APERTURE_DEFAULT_ADDR", "_LSB", "_MSB", self.memscratch_paddr >> 12) + self.adev.wreg_pair(f"reg{ip}VM_L2_PROTECTION_FAULT_DEFAULT_ADDR", "_LO32", "_HI32", self.dummy_page_paddr >> 12) self.adev.reg(f"reg{ip}VM_L2_PROTECTION_FAULT_CNTL2").update(active_page_migration_pte_read_retry=1) @@ -232,27 +232,28 @@ class AM_GFX(AM_IP): class AM_IH(AM_IP): def __init__(self, adev): super().__init__(adev) - self.rings = [(self.adev.mm.palloc(512 << 10, boot=True), self.adev.mm.palloc(0x1000, boot=True), "", 0), - (self.adev.mm.palloc(512 << 10, boot=True), self.adev.mm.palloc(0x1000, boot=True), "_RING1", 1)] + self.ring_size = 512 << 10 + self.rings = [(self.adev.mm.palloc(self.ring_size, boot=True), self.adev.mm.palloc(0x1000, boot=True), "", 0), + (self.adev.mm.palloc(self.ring_size, boot=True), self.adev.mm.palloc(0x1000, boot=True), "_RING1", 1)] def interrupt_handler(self): - ring_vm, rwptr_vm, suf, _ = self.rings[0] - wptr = to_mv(rwptr_vm.cpu_addr(), 8).cast('Q')[0] + _, rwptr_vm, suf, _ = self.rings[0] + wptr = to_mv(self.adev.paddr2cpu(rwptr_vm), 8).cast('Q')[0] if self.adev.reg(f"regIH_RB_WPTR{suf}").read(rb_overflow=1): self.adev.reg(f"regIH_RB_WPTR{suf}").update(rb_overflow=0) self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=1) self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=0) - self.adev.regIH_RB_RPTR.write(wptr % ring_vm.size) + self.adev.regIH_RB_RPTR.write(wptr % self.ring_size) def init(self): for ring_vm, rwptr_vm, suf, ring_id in self.rings: - self.adev.wreg_pair("regIH_RB_BASE", suf, f"_HI{suf}", ring_vm.mc_addr() >> 8) + self.adev.wreg_pair("regIH_RB_BASE", suf, f"_HI{suf}", self.adev.paddr2mc(ring_vm) >> 8) - self.adev.reg(f"regIH_RB_CNTL{suf}").write(mc_space=4, wptr_overflow_clear=1, rb_size=(ring_vm.size//4).bit_length(), + self.adev.reg(f"regIH_RB_CNTL{suf}").write(mc_space=4, wptr_overflow_clear=1, rb_size=(self.ring_size//4).bit_length(), mc_snoop=1, mc_ro=0, mc_vmid=0, **({'wptr_overflow_enable': 1, 'rptr_rearm': 1} if ring_id == 0 else {'rb_full_drain_enable': 1})) - if ring_id == 0: self.adev.wreg_pair("regIH_RB_WPTR_ADDR", "_LO", "_HI", rwptr_vm.mc_addr()) + if ring_id == 0: self.adev.wreg_pair("regIH_RB_WPTR_ADDR", "_LO", "_HI", self.adev.paddr2mc(rwptr_vm)) self.adev.reg(f"regIH_RB_WPTR{suf}").write(0) self.adev.reg(f"regIH_RB_RPTR{suf}").write(0) @@ -303,10 +304,12 @@ class AM_PSP(AM_IP): def __init__(self, adev): super().__init__(adev) - self.msg1_pm = self.adev.mm.palloc(am.PSP_1_MEG, align=am.PSP_1_MEG, zero=not self.adev.partial_boot, boot=True) - self.cmd_pm = self.adev.mm.palloc(am.PSP_CMD_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True) - self.fence_pm = self.adev.mm.palloc(am.PSP_FENCE_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True) - self.ring_pm = self.adev.mm.palloc(0x10000, zero=not self.adev.partial_boot, boot=True) + self.msg1_paddr = self.adev.mm.palloc(am.PSP_1_MEG, align=am.PSP_1_MEG, zero=not self.adev.partial_boot, boot=True) + self.cmd_paddr = self.adev.mm.palloc(am.PSP_CMD_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True) + self.fence_paddr = self.adev.mm.palloc(am.PSP_FENCE_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True) + + self.ring_size = 0x10000 + self.ring_paddr = self.adev.mm.palloc(self.ring_size, zero=not self.adev.partial_boot, boot=True) def is_sos_alive(self): return self.adev.regMP0_SMN_C2PMSG_81.read() != 0x0 def init(self): @@ -332,8 +335,8 @@ class AM_PSP(AM_IP): def _wait_for_bootloader(self): self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_35, mask=0xFFFFFFFF, value=0x80000000) def _prep_msg1(self, data): - ctypes.memset(self.msg1_pm.cpu_addr(), 0, self.msg1_pm.size) - self.msg1_pm.cpu_view()[:len(data)] = data + ctypes.memset(cpu_addr:=self.adev.paddr2cpu(self.msg1_paddr), 0, am.PSP_1_MEG) + to_mv(cpu_addr, len(data))[:] = data self.adev.gmc.flush_hdp() def _bootloader_load_component(self, fw, compid): @@ -342,7 +345,7 @@ class AM_PSP(AM_IP): self._wait_for_bootloader() self._prep_msg1(self.adev.fw.sos_fw[fw]) - self.adev.regMP0_SMN_C2PMSG_36.write(self.msg1_pm.mc_addr() >> 20) + self.adev.regMP0_SMN_C2PMSG_36.write(self.adev.paddr2mc(self.msg1_paddr) >> 20) self.adev.regMP0_SMN_C2PMSG_35.write(compid) return self._wait_for_bootloader() @@ -350,16 +353,15 @@ class AM_PSP(AM_IP): def _tmr_init(self): # Load TOC and calculate TMR size self._prep_msg1(fwm:=self.adev.fw.sos_fw[am.PSP_FW_TYPE_PSP_TOC]) - resp = self._load_toc_cmd(len(fwm)) - - self.tmr_pm = self.adev.mm.palloc(resp.resp.tmr_size, align=am.PSP_TMR_ALIGNMENT, boot=True) + self.tmr_size = self._load_toc_cmd(len(fwm)).resp.tmr_size + self.tmr_paddr = self.adev.mm.palloc(self.tmr_size, align=am.PSP_TMR_ALIGNMENT, boot=True) def _ring_create(self): # Wait until the sOS is ready self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_64, mask=0x80000000, value=0x80000000) - self.adev.wreg_pair("regMP0_SMN_C2PMSG", "_69", "_70", self.ring_pm.mc_addr()) - self.adev.regMP0_SMN_C2PMSG_71.write(self.ring_pm.size) + self.adev.wreg_pair("regMP0_SMN_C2PMSG", "_69", "_70", self.adev.paddr2mc(self.ring_paddr)) + self.adev.regMP0_SMN_C2PMSG_71.write(self.ring_size) self.adev.regMP0_SMN_C2PMSG_64.write(am.PSP_RING_TYPE__KM << 16) # There might be handshake issue with hardware which needs delay @@ -369,28 +371,28 @@ class AM_PSP(AM_IP): def _ring_submit(self): prev_wptr = self.adev.regMP0_SMN_C2PMSG_67.read() - ring_entry_addr = self.ring_pm.cpu_addr() + prev_wptr * 4 + ring_entry_addr = self.adev.paddr2cpu(self.ring_paddr) + prev_wptr * 4 ctypes.memset(ring_entry_addr, 0, ctypes.sizeof(am.struct_psp_gfx_rb_frame)) write_loc = am.struct_psp_gfx_rb_frame.from_address(ring_entry_addr) - write_loc.cmd_buf_addr_hi, write_loc.cmd_buf_addr_lo = data64(self.cmd_pm.mc_addr()) - write_loc.fence_addr_hi, write_loc.fence_addr_lo = data64(self.fence_pm.mc_addr()) + write_loc.cmd_buf_addr_hi, write_loc.cmd_buf_addr_lo = data64(self.adev.paddr2mc(self.cmd_paddr)) + write_loc.fence_addr_hi, write_loc.fence_addr_lo = data64(self.adev.paddr2mc(self.fence_paddr)) write_loc.fence_value = prev_wptr # Move the wptr self.adev.regMP0_SMN_C2PMSG_67.write(prev_wptr + ctypes.sizeof(am.struct_psp_gfx_rb_frame) // 4) - while self.fence_pm.cpu_view().cast('I')[0] != prev_wptr: pass + while to_mv(self.adev.paddr2cpu(self.fence_paddr), 4).cast('I')[0] != prev_wptr: pass time.sleep(0.005) - resp = am.struct_psp_gfx_cmd_resp.from_address(self.cmd_pm.cpu_addr()) + resp = am.struct_psp_gfx_cmd_resp.from_address(self.adev.paddr2cpu(self.cmd_paddr)) if resp.resp.status != 0: raise RuntimeError(f"PSP command failed {resp.cmd_id} {resp.resp.status}") return resp def _prep_ring_cmd(self, hdr): - ctypes.memset(self.cmd_pm.cpu_addr(), 0, 0x1000) - cmd = am.struct_psp_gfx_cmd_resp.from_address(self.cmd_pm.cpu_addr()) + ctypes.memset(self.adev.paddr2cpu(self.cmd_paddr), 0, 0x1000) + cmd = am.struct_psp_gfx_cmd_resp.from_address(self.adev.paddr2cpu(self.cmd_paddr)) cmd.cmd_id = hdr return cmd @@ -400,22 +402,22 @@ class AM_PSP(AM_IP): self._prep_msg1(fw_bytes) cmd = self._prep_ring_cmd(am.GFX_CMD_ID_LOAD_IP_FW) - cmd.cmd.cmd_load_ip_fw.fw_phy_addr_hi, cmd.cmd.cmd_load_ip_fw.fw_phy_addr_lo = data64(self.msg1_pm.mc_addr()) + cmd.cmd.cmd_load_ip_fw.fw_phy_addr_hi, cmd.cmd.cmd_load_ip_fw.fw_phy_addr_lo = data64(self.adev.paddr2mc(self.msg1_paddr)) cmd.cmd.cmd_load_ip_fw.fw_size = len(fw_bytes) cmd.cmd.cmd_load_ip_fw.fw_type = fw_type return self._ring_submit() def _tmr_load_cmd(self): cmd = self._prep_ring_cmd(am.GFX_CMD_ID_SETUP_TMR) - cmd.cmd.cmd_setup_tmr.buf_phy_addr_hi, cmd.cmd.cmd_setup_tmr.buf_phy_addr_lo = data64(self.tmr_pm.mc_addr()) - cmd.cmd.cmd_setup_tmr.system_phy_addr_hi, cmd.cmd.cmd_setup_tmr.system_phy_addr_lo = data64(self.tmr_pm.paddr) + cmd.cmd.cmd_setup_tmr.buf_phy_addr_hi, cmd.cmd.cmd_setup_tmr.buf_phy_addr_lo = data64(self.adev.paddr2mc(self.tmr_paddr)) + cmd.cmd.cmd_setup_tmr.system_phy_addr_hi, cmd.cmd.cmd_setup_tmr.system_phy_addr_lo = data64(self.tmr_paddr) cmd.cmd.cmd_setup_tmr.bitfield.virt_phy_addr = 1 - cmd.cmd.cmd_setup_tmr.buf_size = self.tmr_pm.size + cmd.cmd.cmd_setup_tmr.buf_size = self.tmr_size return self._ring_submit() def _load_toc_cmd(self, toc_size): cmd = self._prep_ring_cmd(am.GFX_CMD_ID_LOAD_TOC) - cmd.cmd.cmd_load_toc.toc_phy_addr_hi, cmd.cmd.cmd_load_toc.toc_phy_addr_lo = data64(self.msg1_pm.mc_addr()) + cmd.cmd.cmd_load_toc.toc_phy_addr_hi, cmd.cmd.cmd_load_toc.toc_phy_addr_lo = data64(self.adev.paddr2mc(self.msg1_paddr)) cmd.cmd.cmd_load_toc.toc_size = toc_size return self._ring_submit() From 679b1ad0589a5bdb6e2823113c4146dcc1f0db23 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 20 Jan 2025 12:16:32 -0500 Subject: [PATCH 02/18] move softmax upcast to after subtracting max (#8684) * move softmax upcast to after subtracting max max can always be done in the same dtype without any numerical loss, so this is better when explicitly upcasting in softmax * skipUnless half --- test/test_schedule.py | 20 ++++++++++++++++++++ tinygrad/tensor.py | 4 ++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index f4c9240e04..79fc2516ac 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -972,6 +972,26 @@ class TestSchedule(unittest.TestCase): expected = (x_exp:=np.exp(x.numpy()-x.numpy().max(-1, keepdims=True)))/x_exp.sum(-1, keepdims=True) np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4) + @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") + def test_softmax_upcast(self): + # input half, softmax in float + Tensor.manual_seed(0) + x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.half).realize() + out = x.softmax(dtype=dtypes.float) + sched = out.schedule() + self.assertEqual(len(sched), 3) + self.assertEqual(len(sched[0].outputs), 1) + self.assertEqual(sched[0].outputs[0].dtype, dtypes.half) + + # input float, softmax in float + Tensor.manual_seed(0) + x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.float).realize() + out = x.softmax(dtype=dtypes.float) + sched = out.schedule() + self.assertEqual(len(sched), 3) + self.assertEqual(len(sched[0].outputs), 1) + self.assertEqual(sched[0].outputs[0].dtype, dtypes.float) + def test_softmax_backward(self): Tensor.manual_seed(0) x = Tensor.randn(4, 12, 64, 64, requires_grad=True).realize() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 5fbdbfd018..45a9b23749 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1856,8 +1856,8 @@ class Tensor(SimpleMathTrait): return self.std(axis, keepdim, correction), self.mean(axis, keepdim) def _softmax(self, axis, dtype:Optional[DTypeLike]=None): - x = self.cast(dtype) if dtype is not None else self - m = x - x.max(axis=axis, keepdim=True).detach() + m = self - self.max(axis=axis, keepdim=True).detach() + if dtype is not None: m = m.cast(dtype) e = m.exp() return m, e, e.sum(axis=axis, keepdim=True) From 46a8c5e1e58fe82da08d4b580140d9d9b6e7e4f1 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 20 Jan 2025 09:40:36 -0800 Subject: [PATCH 03/18] delete forced_realize (#8615) * delete forced_realize * put that back * expectedFailures * cleaner create_subbuffer * more comments --------- Co-authored-by: qazal Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com> --- .github/workflows/test.yml | 2 +- test/test_arange.py | 2 +- test/test_const_folding.py | 4 +++- test/test_jit.py | 1 + test/test_linearizer.py | 7 ++++++- test/test_multitensor.py | 2 +- test/test_setitem.py | 3 ++- test/unit/test_gradient.py | 3 ++- tinygrad/engine/schedule.py | 7 +++---- tinygrad/ops.py | 9 +-------- 10 files changed, 21 insertions(+), 19 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 47cffd0090..be82930b62 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -298,7 +298,7 @@ jobs: - if: ${{ matrix.task == 'optimage' }} name: Test openpilot model kernel count and gate usage run: | - PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2104 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx + PYTHONPATH="." ALLOWED_KERNEL_COUNT=209 ALLOWED_READ_IMAGE=2105 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx - if: ${{ matrix.task == 'optimage' }} name: Test openpilot alt model correctness (float32) run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx diff --git a/test/test_arange.py b/test/test_arange.py index a5c8b535bb..07512ae1b6 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -166,7 +166,7 @@ class TestIndexing(unittest.TestCase): GlobalCounters.reset() z = emb(x).realize() self.assertLessEqual(GlobalCounters.global_ops, op_limit) - self.assertEqual(GlobalCounters.kernel_count, 2) + self.assertEqual(GlobalCounters.kernel_count, 3) if getenv("CHECK", 1): import torch with torch.no_grad(): diff --git a/test/test_const_folding.py b/test/test_const_folding.py index 4ca2359912..dfffca8989 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -220,7 +220,9 @@ class TestMultiConstFolding(unittest.TestCase): t = Tensor.arange(16).float().realize().to(ds) # non const folding case creates one ast on each shard - _check_ast_count(4, t + 1) + # NOTE: there's extra contiguous kernels here since it's realizing both the CONTIGUOUS and its parent COPY + # why does multi call contiguous on a COPY? + _check_ast_count(7, t + 1) _check_ast_count(4, 1 + t) _check_ast_count(4, t * 2) _check_ast_count(4, 2 * t) diff --git a/test/test_jit.py b/test/test_jit.py index 382c83a52a..7abb13100f 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -318,6 +318,7 @@ class TestJit(unittest.TestCase): assert len(res3) == 10, "All values should be different, rand works in jit." assert res3 != res2, "Jit rand is diff with diff seeds" + @unittest.expectedFailure # requires contiguous folding def test_jit_random_after_unrealized_random(self): @TinyJit def f(): return Tensor.rand() diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 58bf3d4e13..da2995d37d 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -63,7 +63,11 @@ def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, d class TestLinearizer(unittest.TestCase): def test_arg_dedup(self): - a, b = Tensor.randn(4), Tensor.randn(4) + # NOTE: this realize exists because Tensor.numpy calls .contiguous() internally + # without contiguous folding, rand.to("CLANG") and rand.contiguous().to("CLANG") are different UOps. + # this test asserts they are the identical Buffer + # having different buffers is fine for correctness, because the outputs match. + a, b = Tensor.randn(4).realize(), Tensor.randn(4).realize() np_a, np_b = a.numpy(), b.numpy() c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),)))) lowered = list(lower_schedule(c.schedule())) @@ -1690,6 +1694,7 @@ class TestHandCodedOpts(unittest.TestCase): # should upcast the two Tensor.stacks assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2 + @unittest.expectedFailure # requires contiguous folding def test_masked_upcast_wino_full(self): with Context(WINO=1): x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize() diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 34a3480c0d..b34baced75 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -734,7 +734,7 @@ class TestMultiTensor(unittest.TestCase): zeros = Tensor.zeros(3).realize() b = a.to(devices_2)*zeros.to(devices_2) sched = b.schedule() - self.assertEqual(len(sched), 6) + self.assertEqual(len(sched), 8) # notably, only two copies (for the arange) - vs 4 copies if we didn't fold the const copy self.assertEqual(len([x for x in sched if any(u.op is Ops.COPY for u in x.ast.toposort)]), 2) # all these kernels are just because multi calls contiguous on every single shard diff --git a/test/test_setitem.py b/test/test_setitem.py index f1bb595ef2..5c7c14fb57 100644 --- a/test/test_setitem.py +++ b/test/test_setitem.py @@ -69,7 +69,8 @@ class TestSetitem(unittest.TestCase): t[1] ^= 5 np.testing.assert_allclose(t.numpy(), [[0, 1], [7, 6]]) - @unittest.expectedFailure + #@unittest.expectedFailure + # update: passing after delete_forced_realize def test_setitem_consecutive_inplace_operator(self): t = Tensor.arange(4).reshape(2, 2).contiguous() t[1] += 2 diff --git a/test/unit/test_gradient.py b/test/unit/test_gradient.py index b36b81f243..a9a41eace0 100644 --- a/test/unit/test_gradient.py +++ b/test/unit/test_gradient.py @@ -104,7 +104,8 @@ class TestRealizeMeansRealize(unittest.TestCase): x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize() self.assertEqual(x.lazydata.op, Ops.VIEW) - @unittest.expectedFailure + #@unittest.expectedFailure + # update: passing after delete_forced_realize def test_uniform_realizes(self): x = Tensor.uniform(16, 3, 3, 3, requires_grad=True).realize() print(x.lazydata) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 2375408dd2..9cbc371158 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -109,7 +109,7 @@ def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp: # track the underlying tensor uop for this buffer ctx.tensor_uops[buf_uop] = [buf] # (early) bufferize - cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st) + cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st) return ret # **** AST graph rewrite @@ -329,7 +329,7 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]: # maybe fuse arange with its children for rbuf in reduce_of_const: group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf} - if any(luop.forced_realize for tr in group for luop in ctx.tensor_uops[tr]): continue + if any(luop.op is Ops.CONTIGUOUS for tr in group for luop in ctx.tensor_uops[tr]): continue kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}} if len(kernel_children) == 0: continue for tr in group: del ctx.realizes[tr] @@ -448,8 +448,7 @@ def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) return x.view(unwrap(view.st)) def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp): - if not root.device.startswith("DISK"): return None - if x.op is not Ops.VIEW: x = x.src[-1] # TODO: remove this once forced_realize is gone + if not b.device.startswith("DISK"): return None buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize) return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW))) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 29fe063540..639d16819f 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -233,7 +233,6 @@ class UOpMetaClass(type): # some uops map to other stuff buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary() -forced_realize:weakref.WeakSet[UOp] = weakref.WeakSet() # NOTE: this should be frozen, but frozen is slower @dataclass(eq=False, slots=True) @@ -409,11 +408,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}") return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x)) - def contiguous(self): - if not unwrap(self.st).contiguous or self.size != self.base.size or self.base.op is Ops.CONST: - return self.alu(Ops.CONTIGUOUS) - forced_realize.add(self.base) - return self + def contiguous(self): return self.alu(Ops.CONTIGUOUS) # *** from LazyBuffer *** @@ -443,8 +438,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def lbs(self): return [self] @property def metadata(self): return all_metadata.get(self, None) - @property - def forced_realize(self): return self in forced_realize # *** uop movement ops *** From dd82b4c913ac9c89e8c0a2eeb66f8c0e9d92a964 Mon Sep 17 00:00:00 2001 From: geohotstan <135171913+geohotstan@users.noreply.github.com> Date: Tue, 21 Jan 2025 02:11:05 +0800 Subject: [PATCH 04/18] make onnx runner a class (#8647) * this * clean up * more clean ups and improve debug msg * more correct training toggler * remove manual training toggling * change some variable names * actually just add the training toggle for LIMIT envvar too * more refinement * __call__ and OnnxRunner * fix half pylint, other half is importing from onnx while this file is onnx.py, figure out later * ahhhh found another mistake * remove limit from __call__ --------- Co-authored-by: chenyu --- examples/benchmark_onnx.py | 6 +- examples/compile_tensorflow.py | 4 +- examples/openpilot/compile3.py | 7 +- examples/yolov8-onnx.py | 4 +- extra/onnx.py | 290 +++++++++--------- extra/onnx_ops.py | 7 +- test/external/external_benchmark_openpilot.py | 7 +- test/external/external_model_benchmark.py | 6 +- test/external/external_test_onnx_backend.py | 6 +- test/models/test_onnx.py | 8 +- 10 files changed, 172 insertions(+), 173 deletions(-) diff --git a/examples/benchmark_onnx.py b/examples/benchmark_onnx.py index 092a03cd13..498e626aa6 100644 --- a/examples/benchmark_onnx.py +++ b/examples/benchmark_onnx.py @@ -1,14 +1,12 @@ import sys, onnx, time from tinygrad import Tensor, TinyJit, Device, GlobalCounters, fetch from tinygrad.tensor import _from_np_dtype -from extra.onnx import get_run_onnx +from extra.onnx import OnnxRunner def load_onnx_model(fn): onnx_file = fetch(fn) onnx_model = onnx.load(onnx_file) - Tensor.no_grad = True - Tensor.training = False - run_onnx = get_run_onnx(onnx_model) + run_onnx = OnnxRunner(onnx_model) # find preinitted tensors and ignore them initted_tensors = {inp.name:None for inp in onnx_model.graph.initializer} diff --git a/examples/compile_tensorflow.py b/examples/compile_tensorflow.py index e3def3bbdb..7733934880 100644 --- a/examples/compile_tensorflow.py +++ b/examples/compile_tensorflow.py @@ -8,7 +8,7 @@ import numpy as np import subprocess import tensorflow as tf import tf2onnx -from extra.onnx import get_run_onnx +from extra.onnx import OnnxRunner from tinygrad.tensor import Tensor from extra.export_model import export_model_clang, compile_net, jit_model @@ -25,7 +25,7 @@ class TinyOnnx: def __init__(self, keras_model): input_signature = [tf.TensorSpec([1,32], tf.float32, name='x')] onnx_model, _ = tf2onnx.convert.from_keras(keras_model, input_signature, opset=13) - self.run_onnx = get_run_onnx(onnx_model) + self.run_onnx = OnnxRunner(onnx_model) def forward(self, x): return self.run_onnx({"x": x}, debug=False)['predictions'] diff --git a/examples/openpilot/compile3.py b/examples/openpilot/compile3.py index f168a36b45..3a0f4c8628 100644 --- a/examples/openpilot/compile3.py +++ b/examples/openpilot/compile3.py @@ -12,17 +12,14 @@ from tinygrad.engine.realize import CompiledRunner import onnx from onnx.helper import tensor_dtype_to_np_dtype -from extra.onnx import get_run_onnx # TODO: port to main tinygrad +from extra.onnx import OnnxRunner # TODO: port to main tinygrad OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx" OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl" def compile(onnx_file): onnx_model = onnx.load(onnx_file) - Tensor.no_grad = True - Tensor.training = False - - run_onnx = get_run_onnx(onnx_model) + run_onnx = OnnxRunner(onnx_model) print("loaded model") input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input} diff --git a/examples/yolov8-onnx.py b/examples/yolov8-onnx.py index f75b5cb333..9c440b5afa 100644 --- a/examples/yolov8-onnx.py +++ b/examples/yolov8-onnx.py @@ -3,7 +3,7 @@ import os from ultralytics import YOLO import onnx from pathlib import Path -from extra.onnx import get_run_onnx +from extra.onnx import OnnxRunner from tinygrad.tensor import Tensor os.chdir("/tmp") @@ -14,5 +14,5 @@ onnx_model = onnx.load(open("yolov8n-seg.onnx", "rb")) # TODO: move get example inputs to onnx input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input} print(input_shapes) -run_onnx = get_run_onnx(onnx_model) +run_onnx = OnnxRunner(onnx_model) run_onnx({"images": Tensor.zeros(1,3,480,640)}, debug=True) diff --git a/extra/onnx.py b/extra/onnx.py index 539151fb6e..fbf8e69904 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -1,53 +1,42 @@ from typing import Callable, Any, Sequence -import importlib, functools -import numpy as np -from tinygrad import Tensor, dtypes +import importlib, functools, dataclasses +from tinygrad.tensor import Tensor from tinygrad.helpers import getenv, DEBUG, all_same -from tinygrad.dtype import DType, ConstType +from tinygrad.dtype import DType, ConstType, dtypes from tinygrad.device import is_dtype_supported -from onnx import AttributeProto, ModelProto, TensorProto, ValueInfoProto, helper -from google.protobuf.json_format import MessageToDict -cache_misses = 0 -@functools.lru_cache(None) -def _cached_to_python_const(t:Tensor): - if t.dtype is dtypes.uint8: return t.data().tobytes() - if 0 in t.shape: return [] - return t.tolist() +# ***** protobuf parsing ****** +from onnx import AttributeProto, ModelProto, TensorProto, TypeProto, helper +import numpy as np -# Tensor -> python value cache for parameters -def to_python_const(t) -> list[ConstType]|ConstType|bytes: - if not isinstance(t, Tensor): return t - global cache_misses - ret = _cached_to_python_const(t) - if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3: - print(f"Cache miss for {t}") - cache_misses = info.misses - return ret - -# TODO: use real float16 -# src: onnx/mapping.py -DTYPE_MAP: dict[int, DType] = { - TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8, - TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64, - TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32, - TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16, -} def dtype_parse(onnx_dtype: int) -> DType: - if onnx_dtype not in DTYPE_MAP: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported") - return DTYPE_MAP[onnx_dtype] if is_dtype_supported(DTYPE_MAP[onnx_dtype]) else dtypes.float + supported: dict[int, DType] = { + TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8, + TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64, + TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32, + TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16, + } + unsupported = { + TensorProto.UNDEFINED, TensorProto.STRING, TensorProto.COMPLEX64, TensorProto.COMPLEX128, TensorProto.FLOAT8E4M3FN, TensorProto.FLOAT8E4M3FNUZ, + TensorProto.FLOAT8E5M2, TensorProto.FLOAT8E5M2FNUZ, TensorProto.UINT4, TensorProto.INT4 + } + if onnx_dtype in unsupported: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported") + return supported[onnx_dtype] if is_dtype_supported(supported[onnx_dtype]) else dtypes.float -# src: onnx/onnx_ml_pb2.pyi -ATTRIBUTE_MAP: dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = { - AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i), - AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t), - AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints), - AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings) -} def attribute_parse(onnx_attribute: AttributeProto): - if onnx_attribute.type not in ATTRIBUTE_MAP: + supported: dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = { + AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i), + AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t), + AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints), + AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings) + } + unsupported = { + AttributeProto.UNDEFINED, AttributeProto.GRAPH, AttributeProto.SPARSE_TENSOR, AttributeProto.TYPE_PROTO, AttributeProto.TENSORS, + AttributeProto.GRAPHS, AttributeProto.SPARSE_TENSORS, AttributeProto.TYPE_PROTOS + } + if onnx_attribute.type in unsupported: raise NotImplementedError(f"attribute with type {AttributeProto.AttributeType.Name(onnx_attribute.type)} is not supported") - return ATTRIBUTE_MAP[onnx_attribute.type](onnx_attribute) + return supported[onnx_attribute.type](onnx_attribute) def buffer_parse(onnx_tensor: TensorProto) -> Tensor: if onnx_tensor.string_data: raise NotImplementedError("Parsing for buffer with string data is not implemented.") @@ -62,116 +51,137 @@ def buffer_parse(onnx_tensor: TensorProto) -> Tensor: return Tensor(np_buffer, dtype=dtype) return Tensor(None) -onnx_ops = importlib.import_module('extra.onnx_ops') -ONNXLIMIT = getenv("ONNXLIMIT", -1) -def get_run_onnx(onnx_model: ModelProto): - # model initialization data - model_tensors = {inp.name:buffer_parse(inp) for inp in onnx_model.graph.initializer} - model_expected_inputs = {inp.name:inp for inp in onnx_model.graph.input if inp.name not in model_tensors} - model_attributes = {num:{x.name:attribute_parse(x) for x in n.attribute} for num,n in enumerate(onnx_model.graph.node)} +def type_parse(onnx_type: TypeProto): + elem_type = onnx_type + if elem_type.HasField("map_type") or elem_type.HasField("sparse_tensor_type") or elem_type.HasField("opaque_type"): + raise NotImplementedError("parsing for map_type, sparse_tensor_type and opaque_type are not implemented") + if is_optional := elem_type.HasField("optional_type"): elem_type = elem_type.optional_type.elem_type + if is_sequence := elem_type.HasField("sequence_type"): elem_type = elem_type.sequence_type.elem_type + if elem_type.HasField("tensor_type"): + shape = tuple(d.dim_param or d.dim_value for d in elem_type.tensor_type.shape.dim) + dtype = dtype_parse(elem_type.tensor_type.elem_type) + return OnnxValue(shape, dtype, is_optional, is_sequence) + raise RuntimeError(f"TypeProto was not parsed properly: {onnx_type=}") - # model descriptions - # TODO: need a better way of controlling training vs non-training - is_onnx_preview_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in onnx_model.graph.node) - onnx_model_version = onnx_model.opset_import[0].version +# ***** onnx spec ***** +@dataclasses.dataclass(frozen=True) +class OnnxValue: + shape: tuple[str|int] + dtype: DType + is_optional: bool + is_sequence: bool - # used to check validity of user_input according to their dimension variables - variable_dims = {} +@dataclasses.dataclass(frozen=True) +class OnnxNode: + num: int + op: str + inputs: tuple[str] + outputs: tuple[str] + opts: dict[str, Any] - # mapping from onnx ops to tensor.py ops - tensor_methods = { - op:op.lower() for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan", - "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", - "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod") - } +# ***** python const ***** +required_input_python_consts: dict[str, tuple[int, ...]] = { + "Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,), + "CumSum": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,), + "ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4), + **{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")}, + **{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")} +} - # these values are expected to be python consts - required_input_python_consts: dict[str, tuple[int, ...]] = { - "Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,), - "CumSum": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,), - "ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4), - **{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")}, - **{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")} - } +cache_misses = 0 +@functools.lru_cache(None) +def _cached_to_python_const(t:Tensor): + if t.dtype is dtypes.uint8: return t.data().tobytes() + if 0 in t.shape: return [] + return t.tolist() - # src: https://onnx.ai/onnx/repo-docs/IR.html#input-output-data-types - # parses and validates inputs based on their shape and dtype specified by model - def prepare_input(user_input:Any, model_input:ValueInfoProto): - type_proto = model_input.type - if type_proto.HasField("optional_type"): - if user_input is None: return None - type_proto = type_proto.optional_type.elem_type - if type_proto.HasField("sequence_type"): - if not isinstance(user_input, Sequence): raise RuntimeError(f"{model_input.name} received {user_input}, expected sequence type") - dtype = dtype_parse(type_proto.sequence_type.elem_type.tensor_type.elem_type) - sequence = [Tensor(i, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(i, Tensor) else i for i in user_input] - if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"shapes for {model_input.name} must be homogeneous") - # TODO: need true float16 for dtype checking - # if not all(t.dtype is dtype for t in sequence): - # raise RuntimeError(f"{model_input.name} has dtype mismatch for sequence type. Expected {dtype}, received {tensor.dtype}.") +# Tensor -> python value cache for parameters +def to_python_const(t:Any, op:str, idx:int) -> list[ConstType]|ConstType|bytes: + if idx not in required_input_python_consts.get(op, ()) or not isinstance(t, Tensor): return t + global cache_misses + ret = _cached_to_python_const(t) + if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3: + print(f"Cache miss for {t}") + cache_misses = info.misses + return ret + +# ***** runner ****** +debug = int(getenv("DEBUGONNX", "0")) +limit = int(getenv("ONNXLIMIT", "-1")) +class OnnxRunner: + def __init__(self, model: ModelProto): + # parse model protobuf + self.is_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in model.graph.node) + self.old_training, self.old_no_grad = Tensor.training, Tensor.no_grad + Tensor.training = True if self.is_training else False + Tensor.no_grad = False if self.is_training else True + self.graph_values = {x.name:buffer_parse(x) for x in model.graph.initializer} + self.graph_inputs = {x.name:type_parse(x.type) for x in model.graph.input if x.name not in self.graph_values} + self.graph_outputs = {x.name:type_parse(x.type) for x in model.graph.output} + self.graph_nodes = tuple(OnnxNode(num, n.op_type, tuple(n.input), tuple(n.output), {x.name:attribute_parse(x) for x in n.attribute}) + for num,n in enumerate(model.graph.node)) + self.opset_version = model.opset_import[0].version + self.variable_dims: dict[str, int] = {} + + # TODO: move extra.onnx_ops here so we don't have to deal with annoying circular import + # TODO: clean up opset stuff after moving extra.onnx_ops here + self.onnx_ops_module = importlib.import_module('extra.onnx_ops') + self.onnx_ops = { + **{op: getattr(Tensor, op.lower()) for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", + "Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", + "Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")}, + } + + def _parse_input(self, name: str, value: Any, spec: OnnxValue): + if spec.is_optional and value is None: return None + # TODO: need true float16 for dtype checking + if spec.is_sequence: + if not isinstance(value, Sequence): raise RuntimeError(f"{name} received {value}, expected a sequence type") + sequence = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value] + if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"Shapes for {name} sequence must be homogeneous") return sequence - if type_proto.HasField("tensor_type"): - dtype = dtype_parse(type_proto.tensor_type.elem_type) - tensor = Tensor(user_input, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(user_input, Tensor) else user_input - # TODO: need true float16 for dtype checking - # if dtype is not tensor.dtype: raise RuntimeError(f"{model_input.name} has mismatch for dtype. Expected {dtype}, received {tensor.dtype}.") - for dim, onnx_dim in enumerate(type_proto.tensor_type.shape.dim): - dim_param, dim_value = onnx_dim.dim_param, onnx_dim.dim_value - user_dim_input = tensor.shape[dim] - if dim_param: dim_value = variable_dims[dim_param] if dim_param in variable_dims else variable_dims.setdefault(dim_param, user_dim_input) - if user_dim_input != dim_value: - raise RuntimeError(f"{model_input.name} has mismatch for dim={dim_param or dim}. Expected {dim_value}, received {user_dim_input}.") - return tensor - type_field_names = [field.name for field,_ in type_proto.ListFields()] - raise NotImplementedError(f"{model_input.name} with {type_field_names=} is not supported") + tensor = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(value, Tensor) else value + for dim, (onnx_dim, user_dim_input) in enumerate(zip(spec.shape, tensor.shape, strict=True)): + if isinstance(onnx_dim, str): + onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input)) + if user_dim_input != onnx_dim: raise RuntimeError(f"{name} has mismatch on {dim=}. Expected {onnx_dim}, received {user_dim_input}.") + return tensor - def run_onnx(inputs={}, debug=0): - debug = getenv("DEBUGONNX") or debug - if debug >= 3: print("Model initialization data:\n" + "\n".join(f"\t{i.name} - {model_tensors[i.name]}" for i in onnx_model.graph.initializer)) + def _dispatch_op(self, op, inps, opts): + if op in self.onnx_ops: return self.onnx_ops[op](*inps, **opts) + if hasattr(self.onnx_ops_module, op): + fxn = getattr(self.onnx_ops_module, op) + if isinstance(fxn, dict): + for k in sorted(fxn.keys()): + if k <= self.opset_version: + real_fxn = fxn[k] + else: real_fxn = fxn + return real_fxn(*inps, **opts) + raise NotImplementedError(f"{op=} not supported") - if debug >= 1: print("Model input:") - for name, value_info in model_expected_inputs.items(): + def __call__(self, inputs:dict[str, Any], debug=debug): + for name, input_spec in self.graph_inputs.items(): if name not in inputs: raise RuntimeError(f"Please provide input data for {name}") - model_tensors[name] = prepare_input(inputs[name], value_info) - if debug >= 1: print(f"\t{name} - {model_tensors[name]}") - if debug >= 2: print(f"\t\t{MessageToDict(value_info.type)}") + self.graph_values[name] = self._parse_input(name, inputs[name], input_spec) - for num,n in enumerate(onnx_model.graph.node): - inp_tensors = [model_tensors.get(x) for x in n.input] - required_consts = required_input_python_consts.get(n.op_type, ()) - inp = [to_python_const(t) if i in required_consts else t for i,t in enumerate(inp_tensors)] - opt = model_attributes[num] + for node in self.graph_nodes: + inps = [to_python_const(self.graph_values.get(name), node.op, i) for i,name in enumerate(node.inputs)] + opts = node.opts - if debug >= 1: print(f"{num}: op \"{n.op_type}\" input shapes {[x.shape if isinstance(x, Tensor) else x for x in inp_tensors]} opt {opt}") - if debug >= 3: - print("\tinputs:") - print("\n".join(f"\t\t{x} - {t!r}" + (" (to_python_const)" if i in required_consts else "") for i,(x,t) in enumerate(zip(n.input, inp)))) + # provide additional opts + if node.op == "Split" and 'num_outputs' not in opts: opts['num_outputs'] = len(node.outputs) + if node.op == "Gradient": opts['intermediate_tensors'] = self.graph_values - # provide additional arguments - if n.op_type == "Split" and 'num_outputs' not in opt: opt['num_outputs'] = len(n.output) - if n.op_type == "Gradient": opt['intermediate_tensors'] = model_tensors + if debug >= 1: print(f"{node.num}: op '{node.op}' opt {opts}") + if debug >= 2 and node.inputs: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {i!r}" for x,i in zip(node.inputs, inps))) + ret = self._dispatch_op(node.op, inps, opts) + ret = ret if isinstance(ret, tuple) else (ret,) + if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{x} - {o!r}" for x,o in zip(node.outputs, ret))) - # run op - if n.op_type in tensor_methods: ret = getattr(Tensor, tensor_methods[n.op_type])(*inp, **opt) - elif hasattr(onnx_ops, n.op_type): - fxn = getattr(onnx_ops, n.op_type) - if isinstance(fxn, dict): - for k in sorted(fxn.keys()): - if k <= onnx_model_version: - real_fxn = fxn[k] - else: - real_fxn = fxn - ret = real_fxn(*inp, **opt) - else: - print("UNSUPPORTED", n.op_type, n.input, n.output) - raise NotImplementedError(f"op_type {n.op_type} not supported") + self.graph_values.update(dict(zip(node.outputs, ret[:len(node.outputs)], strict=True))) - # finalization after running the op - if not isinstance(ret, tuple): ret = (ret, ) - if len(n.output) > len(ret): raise RuntimeError(f"expected output size must be less than {len(ret)}, it's {n.output}") - for i in range(len(n.output)): model_tensors[n.output[i]] = ret[i] - if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{n.output[i]} - {ret[i]}" for i in range(len(n.output)))) - - if num == ONNXLIMIT: return {name:model_tensors[name] for name in n.output} - return {x.name:model_tensors[x.name] for x in onnx_model.graph.output} - return run_onnx + if node.num == limit: + Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad + return {name:self.graph_values[name] for name in node.outputs} + Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad + return {name:self.graph_values[name] for name in self.graph_outputs} \ No newline at end of file diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 4f01b3cb03..4f745e680b 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -3,7 +3,7 @@ from typing import cast, Literal from tinygrad.tensor import Tensor, _broadcast_shape, ConstType, ReductionStr from tinygrad.dtype import ImageDType, dtypes from tinygrad.helpers import prod, flatten, make_tuple -from extra.onnx import dtype_parse, to_python_const +from extra.onnx import dtype_parse, _cached_to_python_const import numpy as np # **************** Free Ops **************** @@ -282,7 +282,7 @@ def Gather(x:Tensor, indices:Tensor, axis:int=0): x_sh = list(x.shape) ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:] if indices.ndim > 1: indices = indices.flatten() - indices = [to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in to_python_const(indices)] # type: ignore + indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)] args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot @@ -575,12 +575,9 @@ from tinygrad.nn.optim import SGD def onnx_training(input_group_size): def _decorator(func): def __wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs): - old_training = Tensor.training - Tensor.training = True R = R.detach() groups = len(inputs) // input_group_size ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))] - Tensor.training = old_training return tuple(flatten(zip(*ret))) return __wrapper return _decorator diff --git a/test/external/external_benchmark_openpilot.py b/test/external/external_benchmark_openpilot.py index 2811ec891c..780539b1e7 100644 --- a/test/external/external_benchmark_openpilot.py +++ b/test/external/external_benchmark_openpilot.py @@ -2,7 +2,7 @@ import time, sys, hashlib from pathlib import Path import onnx from onnx.helper import tensor_dtype_to_np_dtype -from extra.onnx import get_run_onnx +from extra.onnx import OnnxRunner from tinygrad import Tensor, dtypes, TinyJit from tinygrad.helpers import IMAGE, GlobalCounters, fetch, colored, getenv, trange from tinygrad.tensor import _from_np_dtype @@ -11,11 +11,8 @@ import numpy as np OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx" if __name__ == "__main__": - Tensor.no_grad = True - Tensor.training = False - onnx_model = onnx.load(onnx_path := fetch(OPENPILOT_MODEL)) - run_onnx = get_run_onnx(onnx_model) + run_onnx = OnnxRunner(onnx_model) Tensor.manual_seed(100) input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input} diff --git a/test/external/external_model_benchmark.py b/test/external/external_model_benchmark.py index 3ed7b82746..4c0b720df2 100644 --- a/test/external/external_model_benchmark.py +++ b/test/external/external_model_benchmark.py @@ -6,7 +6,7 @@ import onnx from onnx.helper import tensor_dtype_to_np_dtype import onnxruntime as ort from onnx2torch import convert -from extra.onnx import get_run_onnx +from extra.onnx import OnnxRunner from tinygrad.helpers import OSX, DEBUG, fetch from tinygrad import Tensor, Device from tinygrad.device import CompileError @@ -65,7 +65,7 @@ def benchmark_model(m, devices, validate_outs=False): try: Device.DEFAULT = device inputs = {k:Tensor(inp) for k,inp in np_inputs.items()} - tinygrad_model = get_run_onnx(onnx_model) + tinygrad_model = OnnxRunner(onnx_model) benchmark(m, f"tinygrad_{device.lower()}_jitless", lambda: {k:v.numpy() for k,v in tinygrad_model(inputs).items()}) from tinygrad.engine.jit import TinyJit @@ -115,7 +115,7 @@ def benchmark_model(m, devices, validate_outs=False): rtol, atol = 2e-3, 2e-3 # tolerance for fp16 models Device.DEFAULT = device inputs = {k:Tensor(inp) for k,inp in np_inputs.items()} - tinygrad_model = get_run_onnx(onnx_model) + tinygrad_model = OnnxRunner(onnx_model) tinygrad_out = tinygrad_model(inputs) ort_sess = ort.InferenceSession(str(fn), ort_options, ["CPUExecutionProvider"]) diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index 7e93a3984e..b9a61b40f1 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -10,7 +10,7 @@ from tinygrad.device import is_dtype_supported # pip3 install tabulate pytest_plugins = 'onnx.backend.test.report', -from extra.onnx import get_run_onnx +from extra.onnx import OnnxRunner class TinygradModel(BackendRep): def __init__(self, run_onnx, input_names): @@ -20,7 +20,7 @@ class TinygradModel(BackendRep): def run(self, inputs: Any, **kwargs: Any) -> Tuple[Any, ...]: real_inputs = dict(zip(self.input_names, inputs)) - ret = self.fxn(real_inputs, debug=True) + ret = self.fxn(real_inputs, debug=2) return tuple(x.numpy() if isinstance(x, Tensor) else [i.numpy() for i in x] if isinstance(x, list) else np.array(x) for x in ret.values()) class TinygradBackend(Backend): @@ -30,7 +30,7 @@ class TinygradBackend(Backend): input_initializer = [x.name for x in model.graph.initializer] net_feed_input = [x for x in input_all if x not in input_initializer] print("prepare", cls, device, net_feed_input) - run_onnx = get_run_onnx(model) + run_onnx = OnnxRunner(model) return TinygradModel(run_onnx, net_feed_input) @classmethod diff --git a/test/models/test_onnx.py b/test/models/test_onnx.py index 13252040e0..e3d1868aed 100644 --- a/test/models/test_onnx.py +++ b/test/models/test_onnx.py @@ -7,7 +7,7 @@ try: import onnx except ModuleNotFoundError: raise unittest.SkipTest("onnx not installed, skipping onnx test") -from extra.onnx import get_run_onnx +from extra.onnx import OnnxRunner from tinygrad.tensor import Tensor from tinygrad.helpers import CI, fetch, temp @@ -26,7 +26,7 @@ np.random.seed(1337) class TestOnnxModel(unittest.TestCase): def test_benchmark_openpilot_model(self): onnx_model = onnx.load(fetch(OPENPILOT_MODEL)) - run_onnx = get_run_onnx(onnx_model) + run_onnx = OnnxRunner(onnx_model) def get_inputs(): np_inputs = { "input_imgs": np.random.randn(*(1, 12, 128, 256)), @@ -70,7 +70,7 @@ class TestOnnxModel(unittest.TestCase): def test_openpilot_model(self): onnx_model = onnx.load(fetch(OPENPILOT_MODEL)) - run_onnx = get_run_onnx(onnx_model) + run_onnx = OnnxRunner(onnx_model) print("got run_onnx") inputs = { "input_imgs": np.random.randn(*(1, 12, 128, 256)), @@ -124,7 +124,7 @@ class TestOnnxModel(unittest.TestCase): onnx_model = onnx.load(fn) print("onnx loaded") from test.models.test_efficientnet import chicken_img, car_img, preprocess, _LABELS - run_onnx = get_run_onnx(onnx_model) + run_onnx = OnnxRunner(onnx_model) def run(img): inputs = {input_name: preprocess(img, new=input_new)} From 1a15c0e89df6558f57feb709bef985d2ad9463b5 Mon Sep 17 00:00:00 2001 From: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com> Date: Mon, 20 Jan 2025 20:56:27 +0100 Subject: [PATCH 05/18] Move define_acc down an unrolled add chain (#8404) * Move define_acc down an unrolled add chain * Prevent possible infinite recursion * Add test * Fix typo in test * Move mulacc_unrolled to devoctorize + load_store_indexing pass * Add test for mulacc_unrolled by itself * undo formatter * import from ops, not rewriter * Add a const version --------- Co-authored-by: chenyu --- test/test_uops.py | 12 +++++++++++ test/unit/test_graph_rewrite.py | 38 ++++++++++++++++++++++++++++++++- tinygrad/codegen/rewriter.py | 8 +++++-- 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/test/test_uops.py b/test/test_uops.py index 99c75e68a3..b7392e46aa 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -14,6 +14,7 @@ from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_ker from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.rewriter import full_graph_rewrite, sym from tinygrad.device import is_dtype_supported +from tinygrad.codegen.kernel import Kernel, Opt, OptOps def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check) @@ -365,6 +366,17 @@ class TestAssembly(unittest.TestCase): self.assertIn(Ops.SHR, ops) self.assertIn(Ops.IDIV, ops) + def test_mulacc_unrolled(self): + # test that acc = acc + a0*b0 + a1*b1 + a2*b2 + a3*b3 + # is not acc = acc + (a0*b0 + a1*b1 + a2*b2 + a3*b3) + a = Tensor.empty(1024) + b = Tensor.empty(1024) + c = (a*b).sum() + k = Kernel(c.schedule()[-1].ast) + k.apply_opt(Opt(OptOps.UNROLL, 0, 4)) + uops = k.linearize().uops + self.assertEqual(len([x.op for x in uops if x.op is Ops.MULACC]), 4) + class TestUOpMethod(unittest.TestCase): @unittest.skip("uops lt no longer ordered") def test_compare_alu_same_src_different_arg(self): diff --git a/test/unit/test_graph_rewrite.py b/test/unit/test_graph_rewrite.py index 669d2e3319..86364e87ec 100644 --- a/test/unit/test_graph_rewrite.py +++ b/test/unit/test_graph_rewrite.py @@ -2,7 +2,7 @@ import unittest, math from tinygrad import dtypes from tinygrad.helpers import all_same from tinygrad.ops import GroupOp, UOp, Ops, exec_alu -from tinygrad.codegen.rewriter import full_graph_rewrite +from tinygrad.codegen.rewriter import full_graph_rewrite, mulacc_unrolled # Helper function to apply the graph rewrite def apply_rewrite(expr): @@ -274,5 +274,41 @@ class TestSubstitute(unittest.TestCase): ret = substitute(ret, {a.sin():a.sqrt(), n1.sin():n1.sqrt()}) self.assertIs(ret, a.sqrt().sqrt()) +class TestMulaccUnrolledAcc(unittest.TestCase): + def test_unrolled2(self): + acc_range = (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)) + acc = UOp(Ops.DEFINE_ACC, dtypes.int, (UOp.const(dtypes.int, 0),) + acc_range, (0,)) + a = UOp.variable('a', 0, 10) + b = UOp.variable('b', 0, 10) + expr = acc.assign(acc + (a*2 + b*3)) + expr_with_mulacc = graph_rewrite(expr, mulacc_unrolled) + self.assertIs(expr_with_mulacc, acc.assign(acc + a*2 + b*3)) + + def test_unrolled4_float(self): + acc_range = (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 3)) + acc = UOp(Ops.DEFINE_ACC, dtypes.float32, (UOp.const(dtypes.int, 0),)+acc_range, (0,)) + + a = [UOp.variable(f'a{i}', float("-inf"), float("inf"), dtype=dtypes.float32) for i in range(4)] + b = [UOp.variable(f'b{i}', float("-inf"), float("inf"), dtype=dtypes.float32) for i in range(4)] + + expr = acc.assign(acc + (a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3])) + expr_with_mulacc = graph_rewrite(expr, mulacc_unrolled) + + # Verify it unrolls into individual multiply-accumulate operations + expected = acc.assign(acc + a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3]) + self.assertIs(expr_with_mulacc, expected) + + def test_unrolled4_float_const(self): + acc_range = (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 3)) + acc = UOp(Ops.DEFINE_ACC, dtypes.float32, (UOp.const(dtypes.int, 0),)+acc_range, (0,)) + + a = [UOp.variable(f'a{i}', float("-inf"), float("inf"), dtype=dtypes.float32) for i in range(4)] + expr = acc.assign(acc + (a[0]*3.0 + a[1]*4.0 + a[2]*5.0 + a[3]*6.0)) + expr_with_mulacc = graph_rewrite(expr, mulacc_unrolled) + + # Verify it unrolls into individual multiply-accumulate operations + expected = acc.assign(acc + a[0]*3.0 + a[1]*4.0 + a[2]*5.0 + a[3]*6.0) + self.assertIs(expr_with_mulacc, expected) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/rewriter.py b/tinygrad/codegen/rewriter.py index 3d45187cc8..b61de63329 100644 --- a/tinygrad/codegen/rewriter.py +++ b/tinygrad/codegen/rewriter.py @@ -239,6 +239,9 @@ index_load = UPat.var("buf").index(rng_aug).load(name="ld") arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug)) arange_m = ((arange_augrng UOp: # expand sink = graph_rewrite(sink, sym+expander) - # devectorize + load_store_indexing - sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing) + # devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse + sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing+ + mulacc_unrolled) # final rules for the renderer (without sym) sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher) From b14c9848cce127a7c2b3f88fcf53d6dd1b9a0fb3 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 20 Jan 2025 15:25:59 -0500 Subject: [PATCH 06/18] small changes to make the tensor_map_simple diff cleaner [pr] (#8691) --- tinygrad/engine/schedule.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 9cbc371158..9db897446f 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -522,10 +522,10 @@ remove_movement_ops = PatternMatcher([ @track_rewrites(named=True) def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: + if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec) # if using VIZ, do a graph rewrite to vizualize the Tensor graph if getenv("VIZ"): graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext()) - if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec) - # to_uop is removing (many) of the movement ops + # add BUFFER uops sink = add_buffers(big_sink, ctx:=ScheduleContext(), cache={}) # const folding and fusion sink = graph_rewrite(sink, remove_movement_ops+ops_folding+do_realize, ctx) @@ -538,8 +538,10 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu # preschedule realize groups prescheduled: list[ScheduleItem] = [] for store_uops in store_groups: - if len(stores:=[ctx.realizes[u] for u in store_uops if ctx.realizes[u].op is Ops.STORE]) == 0: continue - prescheduled.append(schedule_uop(UOp.sink(*stores), ctx)) + small_sink = UOp.sink(*[ctx.realizes[u] for u in store_uops]) + # TODO: this still exists because symbolic folding is happening after bufferization + if not all(x.op is Ops.STORE for x in small_sink.src): continue + prescheduled.append(schedule_uop(small_sink, ctx)) # can only schedule once for buf_uop in store_uops: for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st)) From 02ad450e22311d065715445412c7085c0c999b32 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 20 Jan 2025 15:50:09 -0500 Subject: [PATCH 07/18] add failing assert for gradient realization [pr] (#8692) --- test/test_image_dtype.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/test_image_dtype.py b/test/test_image_dtype.py index 03ce8dac44..88a4c929c4 100644 --- a/test/test_image_dtype.py +++ b/test/test_image_dtype.py @@ -113,6 +113,7 @@ class TestImageDType(unittest.TestCase): assert it.lazydata.base.realized._buf != b1 # issue caused by: don't realize image to image casts. this is part of a larger problem + @unittest.expectedFailure def test_lil_model(self): with Context(IMAGE=2): x = Tensor.zeros(1, 1) @@ -121,7 +122,10 @@ class TestImageDType(unittest.TestCase): loss = x.image_dot(w1).image_dot(w2).float().max() loss.backward() sched = unwrap(w1.grad).schedule() - self.assertEqual(len(sched), 9) + # NOTE: the w1 grad must realize to a seperate kernel + assert w1.grad.lazydata.is_realized, f"never realized {w1.grad}" + self.assertEqual(w1.grad.lazydata.base.buffer.dtype, dtypes.float32) + self.assertEqual(len(sched), 10) for s,ei in zip(sched, lower_schedule(sched[:])): ei.run() if s.outputs[0].dtype == dtypes.float: From 08eb1f1f56cc71279deed76abcbba9ed239fafbd Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 20 Jan 2025 16:42:42 -0500 Subject: [PATCH 08/18] simplify tensors before scheduling [pr] (#8580) * delete forced_realize * put that back * work * remove forced_realize * expectedFailures * contiguous(buffer) * multi * expectedFailures * cleaner create_subbuffer * more comments * remove that * note * realizes * work * one upat and image is back * remove * cleaner * fix test_complex_backward for now --------- Co-authored-by: George Hotz --- test/test_image_dtype.py | 3 +- test/test_schedule.py | 6 +-- tinygrad/engine/schedule.py | 77 ++++++++++++------------------------- 3 files changed, 30 insertions(+), 56 deletions(-) diff --git a/test/test_image_dtype.py b/test/test_image_dtype.py index 88a4c929c4..62fcb4a443 100644 --- a/test/test_image_dtype.py +++ b/test/test_image_dtype.py @@ -113,7 +113,8 @@ class TestImageDType(unittest.TestCase): assert it.lazydata.base.realized._buf != b1 # issue caused by: don't realize image to image casts. this is part of a larger problem - @unittest.expectedFailure + #@unittest.expectedFailure + # update: passing after tensor_map def test_lil_model(self): with Context(IMAGE=2): x = Tensor.zeros(1, 1) diff --git a/test/test_schedule.py b/test/test_schedule.py index 79fc2516ac..b0426630de 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -220,7 +220,7 @@ class TestSchedule(unittest.TestCase): GlobalCounters.reset() expr = (a*b)/b expr.realize() - self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertEqual(GlobalCounters.kernel_count, 0) # the scheduler can fold divs now! self.assertEqual(GlobalCounters.global_ops, 0) np.testing.assert_allclose(expr.numpy(), np.full((4,), 4.0)) @@ -229,7 +229,7 @@ class TestSchedule(unittest.TestCase): GlobalCounters.reset() expr = a/a expr.realize() - self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertEqual(GlobalCounters.kernel_count, 0) self.assertEqual(GlobalCounters.global_ops, 0) np.testing.assert_allclose(expr.numpy(), np.full((4,), 1.0)) @@ -2204,7 +2204,7 @@ class TestConst(unittest.TestCase): sched = add.schedule() self.assertEqual(len(sched), 0) # b+0 and b share the same underlying device memory - self.assertIs(add.lazydata.realized, b.lazydata.realized) + self.assertIs(add.lazydata.buffer, b.lazydata.buffer) self.assertListEqual(add.tolist(), [2, 2, 2, 2]) def test_src_masked_const_folding(self): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 9db897446f..d5acb3cbe3 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -2,7 +2,7 @@ import sys, atexit, functools, pickle from collections import defaultdict, deque from dataclasses import dataclass, field from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views -from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify +from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify, graph_rewrite_map from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, ContextVar from tinygrad.dtype import DType, ImageDType, dtypes @@ -88,15 +88,15 @@ class ScheduleContext: # wrap tensor uops around a VIEW(BUFFER, ) # this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it. -def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp: +def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r # SINK is passthrough - if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, ctx, cache) for x in buf.src)) + if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src)) # skip creating buffers for CONST/BIND/DEVICE/BUFFER if buf.base.is_realized or buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf # VIEW is passthrough if buf is not buf.base: - cache[buf] = ret = add_buffers(buf.base, ctx, cache).view(unwrap(buf.st)) + cache[buf] = ret = add_buffers(buf.base, tensor_map, ctx, cache).view(unwrap(buf.st)) return ret # make things that can't be images not images dtype = buf.dtype @@ -105,9 +105,9 @@ def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp: dtype = buf.dtype.base # ASSIGN already has a target buffer, otherwise we create a new one buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype) - op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, ctx, cache) for x in buf.src)) + op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src)) # track the underlying tensor uop for this buffer - ctx.tensor_uops[buf_uop] = [buf] + ctx.tensor_uops[buf_uop] = tensor_map[buf] # (early) bufferize cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st) return ret @@ -358,10 +358,8 @@ def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None: case _: return None return reduce.const_like(ret) -def found_contiguous(ctx:ScheduleContext, contig:UOp, base:UOp, b:UOp): - if contig.src[0].op is Ops.VIEW and len(contig.src[0].src): - old_base = contig.src[0].src[0] - if old_base.op is Ops.VIEW and (sti:=unwrap(contig.src[0].st).invert(old_base.shape)) is not None: ctx.contiguous[old_base] = base.view(sti) +def found_contiguous(ctx:ScheduleContext, contig:UOp, src:UOp): + if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx.contiguous[src.base] = contig.view(sti) def replace_contiguous(ctx:ScheduleContext, alu:UOp): new_src = list(alu.src) for i,s in enumerate(alu.src): @@ -372,8 +370,6 @@ ops_folding = symbolic_simple+PatternMatcher([ # op with size 0 is zero (UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \ and not (root.base.op is Ops.CONST and root.base.arg == 0) else None), - # if the uop folded to a CONST we can delete the BUFFER - (UPatScheduled(Ops.CONST, name="const"), lambda b,base,const: base.const_like(const.const_arg)), # DETACH is a NOOP here (UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]), # reduce of size 0 is the identity element @@ -386,13 +382,16 @@ ops_folding = symbolic_simple+PatternMatcher([ # no COPY to same device, except clone (arg is True) (UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"), lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None), + # remove cast to image when it's already a contiguous image + (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)), + lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None), # remove contiguous if we can just view the buffer (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)), lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None), # double contiguous is one contiguous (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.CONTIGUOUS),)), lambda root: root.src[0]), # support for using a contiguous permuted view instead of the parent view if one exists - (UPatScheduled(Ops.CONTIGUOUS, name="contig"), found_contiguous), + (UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous), (UPat(GroupOp.ALU, name="alu"), replace_contiguous), # remove CONST/BIND/BUFFER/VIEW from SINK (UPat(Ops.SINK, name="root"), @@ -400,34 +399,6 @@ ops_folding = symbolic_simple+PatternMatcher([ if (new_src:=tuple(x.base for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), ]) -# ** buffer merging - -def merge(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp) -> UOp: - assert v1.st is not None and v2.st is not None and v1.st == v2.st, f"implicit movementop {v1.st} {v2.st}" - # if b2 is realized also realize b1 - if b2 in ctx.realizes: - ctx.realizes[b1] = b1 - del ctx.realizes[b2] - # ops referring to b2 now ref to b1 - ctx.tensor_uops[b1] += ctx.tensor_uops[b2] - del ctx.tensor_uops[b2] - # merge - return v1 - -def merge_realized(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp): - # early become - for luop in ctx.tensor_uops.get(b1, [])+ctx.tensor_uops.get(b2, []): ctx.becomes_map[luop] = b1.view(unwrap(luop.st)) - return v1 - -merge_bufs = PatternMatcher([ - # merge base - (UPat(Ops.VIEW, name="v2", src=(UPat(Ops.BUFFER, name="b2"), UPat(Ops.VIEW, name="v1", src=(UPat(Ops.BUFFER, name="b1"), UPat())))), merge), - (UPat(Ops.VIEW, name="v2", src=(UPat(Ops.BUFFER, name="b2"), UPat(Ops.VIEW, name="v1", src=(UPat(Ops.BUFFER, name="b1"),)))), merge_realized), - # collapse view - (UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat())).view(name="mv"))), lambda mv:mv), - (UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).view(name="mv"))), lambda mv:mv), -]) - # ** this decides which ops get realized def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store @@ -481,7 +452,7 @@ def load_realized(ctx:ScheduleContext, b:UOp, st:UOp): return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop())) def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp): - if (m:=ctx.tensor_uops[b][0].metadata) is not None: ctx.ops_metadata[x] = m + if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m if b not in ctx.realizes: return x # collapse BUFFER ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x) return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop())) @@ -523,15 +494,13 @@ remove_movement_ops = PatternMatcher([ @track_rewrites(named=True) def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec) - # if using VIZ, do a graph rewrite to vizualize the Tensor graph - if getenv("VIZ"): graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext()) + tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+ops_folding, ctx:=ScheduleContext()) + rev_tensor_map: dict[UOp, list[UOp]] = {} + for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k) # add BUFFER uops - sink = add_buffers(big_sink, ctx:=ScheduleContext(), cache={}) - # const folding and fusion - sink = graph_rewrite(sink, remove_movement_ops+ops_folding+do_realize, ctx) - sink = graph_rewrite(sink, merge_bufs, ctx) - # create the scheduler context - graph_rewrite(sink, create_ctx, ctx) + sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx, cache={}) + # add realizes + sink = graph_rewrite(sink, do_realize+create_ctx, ctx) # group realizes into kernels store_groups = group_realizes(ctx) graph_rewrite(sink, break_sched, ctx) @@ -539,13 +508,17 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu prescheduled: list[ScheduleItem] = [] for store_uops in store_groups: small_sink = UOp.sink(*[ctx.realizes[u] for u in store_uops]) - # TODO: this still exists because symbolic folding is happening after bufferization - if not all(x.op is Ops.STORE for x in small_sink.src): continue + if not all(x.op is Ops.STORE for x in small_sink.src): raise RuntimeError(f"expected all realized BUFFERs to get a STORE {sink}") prescheduled.append(schedule_uop(small_sink, ctx)) # can only schedule once for buf_uop in store_uops: for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st)) + # tensors can become an existing buffer, no ScheduleItem needed + for k,v in tensor_map.items(): + # NOTE: we only add base tensors to becomes_map + if k is not v and v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st)) + # add kernel children schedule_targets = {out:si for si in prescheduled for out in si.outputs} graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list) From 66ac0087e839052de251bd26604cfc767dcf992c Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 20 Jan 2025 18:52:58 -0500 Subject: [PATCH 09/18] more high level contiguous tests + scheduler deletions [pr] (#8695) * delete those * move the upat too * rename ops_folding to just sym * keep that --- test/test_schedule.py | 45 +++++++++++++++++-------------------- tinygrad/engine/schedule.py | 16 ++++++------- 2 files changed, 28 insertions(+), 33 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index b0426630de..9d29bebb34 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -16,7 +16,7 @@ from tinygrad.shape.view import View from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same from tinygrad.codegen.kernel import verify_ast -from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, ops_folding +from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule from extra.models.llama import precompute_freqs_cis @@ -67,7 +67,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs): np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2) @track_rewrites(named=True) -def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext()) +def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, ScheduleContext()) class TestSchedule(unittest.TestCase): def test_basic_binop_fusion(self): @@ -1824,7 +1824,7 @@ def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops. # these pattern matchers should move to engine/schedule.py -sym = symbolic_simple+PatternMatcher([ +ops_folding = symbolic_simple+PatternMatcher([ (UPat(Ops.DETACH, name="x"), lambda x:x.src[0]), ]) @@ -1842,8 +1842,8 @@ def run_tensor_ast(r:Tensor): output = UOp.new_buffer(r.device, r.lazydata.size, r.dtype) glbl = UOp(Ops.DEFINE_GLOBAL, output.dtype.ptr(size=output.size), (), 0) sink = UOp(Ops.STORE, src=(glbl, ShapeTracker.from_shape(r.lazydata.base.shape).to_uop(), r.lazydata.base)).sink() - sink = graph_rewrite(sink, remove_movement_ops+sym+load_buffers+view_left, bufs:=[output]) - sink = graph_rewrite(sink, remove_movement_ops+sym+view_right) + sink = graph_rewrite(sink, remove_movement_ops+ops_folding+load_buffers+view_left, bufs:=[output]) + sink = graph_rewrite(sink, remove_movement_ops+ops_folding+view_right) si = ScheduleItem(sink, tuple(x.buffer for x in bufs), (), ()) run_schedule([si]) return output.realized.as_buffer().cast(output.dtype.fmt, r.shape).tolist() @@ -2336,34 +2336,29 @@ class TestBufferUOp(unittest.TestCase): class TestContiguous(unittest.TestCase): def test_contiguous_buffer(self): - a = Tensor.empty(4).lazydata - b = a.alu(Ops.CONTIGUOUS) - b = schedule_graph_rewrite(b) - self.assertIs(b, a) + a = Tensor.empty(4) + b = a.contiguous() + check_schedule(b, 0) def test_contiguous_buffer_view(self): - a = Tensor.empty(4).lazydata - b = a.reshape((2, 2)).alu(Ops.CONTIGUOUS) - b = schedule_graph_rewrite(b) - self.assertIs(b, a.buf_uop.view(unwrap(b.st))) + a = Tensor.empty(4) + b = a.reshape((2, 2)).contiguous() + check_schedule(b, 0) def test_non_contiguous_buffer_view(self): - a = Tensor.empty(4, 1).lazydata - b = a.expand((4, 4)).alu(Ops.CONTIGUOUS) - b = schedule_graph_rewrite(b) - assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {}) + a = Tensor.empty(4, 1) + b = a.expand((4, 4)).contiguous() + check_schedule(b, 1) def test_size_change_buffer_view(self): - a = Tensor.empty(4).lazydata - b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).alu(Ops.CONTIGUOUS) - b = schedule_graph_rewrite(b) - assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {}) + a = Tensor.empty(4) + b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).contiguous() + check_schedule(b, 1) def test_double_contiguous_realizes_once(self): - a = Tensor.empty(4, 1).lazydata - b = a.expand((4, 4)).alu(Ops.CONTIGUOUS).alu(Ops.CONTIGUOUS) - b = schedule_graph_rewrite(b) - assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {}) + a = Tensor.empty(4, 1) + b = a.expand((4, 4)).contiguous().contiguous() + check_schedule(b, 1) if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index d5acb3cbe3..497cb5970a 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -242,7 +242,7 @@ if CAPTURE_PROCESS_REPLAY: def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER def uval(u:UOp) -> UOp: assert is_scheduled(u), f"must be a scheduled op {u}" - return r.src[0] if (r:=u.src[1]).op is Ops.CONTIGUOUS and not (r.src[0].base.op is Ops.VIEW and len(r.src[0].base.src) == 2) else r + return u.src[1] def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp], realizes:dict[UOp, UOp], reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None: @@ -340,10 +340,6 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]: # **** Schedule creation and BFS toposort -class UPatScheduled(UPat): - def __init__(self, *args, **kwargs): - super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs}))) - # ** this is schedule level const folding def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None: @@ -366,8 +362,8 @@ def replace_contiguous(ctx:ScheduleContext, alu:UOp): if (replace_src:=ctx.contiguous.get(s, None)) is not None: new_src[i] = replace_src if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src)) -ops_folding = symbolic_simple+PatternMatcher([ - # op with size 0 is zero +sym = symbolic_simple+PatternMatcher([ + # UOp with size 0 is zero (UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \ and not (root.base.op is Ops.CONST and root.base.arg == 0) else None), # DETACH is a NOOP here @@ -401,6 +397,10 @@ ops_folding = symbolic_simple+PatternMatcher([ # ** this decides which ops get realized +class UPatScheduled(UPat): + def __init__(self, *args, **kwargs): + super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs}))) + def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None: @@ -494,7 +494,7 @@ remove_movement_ops = PatternMatcher([ @track_rewrites(named=True) def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec) - tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+ops_folding, ctx:=ScheduleContext()) + tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx:=ScheduleContext()) rev_tensor_map: dict[UOp, list[UOp]] = {} for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k) # add BUFFER uops From 2b239db5d28bb1972cddb511904d475b113e1580 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Tue, 21 Jan 2025 12:26:43 +0300 Subject: [PATCH 10/18] temp() with usernames (#8697) --- tinygrad/device.py | 4 ++-- tinygrad/helpers.py | 3 ++- tinygrad/ops.py | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tinygrad/device.py b/tinygrad/device.py index d4782120c1..e1777aa02f 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -310,7 +310,7 @@ if PROFILE: for dev in devs: dev.synchronize() for dev in devs: dev._at_profile_finalize() - with open(temp("profile.pkl"), "wb") as f: pickle.dump(Compiled.profile_events, f) + with open(fn:=temp("profile.pkl", append_user=True), "wb") as f: pickle.dump(Compiled.profile_events, f) from tinygrad.ops import launch_viz - launch_viz("PROFILE", temp("profile.pkl")) + launch_viz("PROFILE", fn) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index d47f62d7d7..39d50fed8a 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -78,7 +78,8 @@ def polyN(x:T, p:list[float]) -> T: return functools.reduce(lambda acc,c: acc*x+ def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)]) @functools.lru_cache(maxsize=None) def getenv(key:str, default=0): return type(default)(os.getenv(key, default)) -def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix() +def temp(x:str, append_user:bool=False) -> str: + return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{os.getenv('USERNAME', os.getlogin())}" if append_user else x)).as_posix() class Context(contextlib.ContextDecorator): def __init__(self, **kwargs): self.kwargs = kwargs diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 639d16819f..74892dbed8 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -815,10 +815,10 @@ if TRACK_MATCH_STATS: @atexit.register def print_match_stats(): if TRACK_MATCH_STATS >= 2: - with open(fn:=temp("rewrites.pkl"), "wb") as f: + with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f: print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}") with Context(PICKLE_BUFFERS=0): pickle.dump((tracked_keys, tracked_ctxs), f) - if getenv("VIZ"): launch_viz("VIZ", temp("rewrites.pkl")) + if getenv("VIZ"): launch_viz("VIZ", temp("rewrites.pkl", append_user=True)) if getenv("PRINT_MATCH_STATS", 1): ret = [0,0,0.0,0.0] for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]): From e2008c98c39499143e13b194b4540bcfa94e8c88 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 21 Jan 2025 05:01:25 -0500 Subject: [PATCH 11/18] allow symbolic shape in tensor const parents [pr] (#8699) --- test/test_schedule.py | 6 ++++++ tinygrad/engine/schedule.py | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 9d29bebb34..0f71dae331 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2273,6 +2273,12 @@ class TestTensorUOpSpec(unittest.TestCase): t = graph_rewrite(a.lazydata.sink(), remove_movement_ops+merge_views) create_schedule_with_vars(t) + def test_symbolic_shape_ok(self): + a = Tensor.ones(4) + vi = UOp.variable("i", 1, 10).bind(4) + t = graph_rewrite(a.reshape(vi).sum().lazydata, remove_movement_ops+merge_views) + create_schedule_with_vars(t) + class TestBufferUOp(unittest.TestCase): # BUFFER has a ShapeTracker of shape=(n,) and stride=(1,) def test_buffer_has_buffer(self): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 497cb5970a..1c5931962e 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -31,9 +31,9 @@ tensor_uop_spec = PatternMatcher([ (UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True), (UPat(Ops.DEFINE_VAR, src=(UPat(Ops.VIEW, arg=ShapeTracker.from_shape(()))), arg=None), lambda: True), - # Tensor const has an unmasked ShapeTracker of stride 0 and a device + # Tensor const has a device and an unmasked ShapeTracker of stride 0 or a ShapeTracker with symbolic shape (UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)), - lambda st: len(st.st.views) == 1 and all(s == 0 for s in st.st.views[0].strides) and st.st.views[0].mask is None), + lambda st: st.st.views[0].mask is None and ((len(st.st.views) == 1 and all(s == 0 for s in st.st.views[0].strides)) or not all_int(st.shape))), # DETACH and CONTIGUOUS change how we interpret the source UOp # CONTIGUOUS ensures the source UOp realizes From f0d424ecdf764df06e79a2efe348c4beef8487dd Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 21 Jan 2025 05:33:19 -0500 Subject: [PATCH 12/18] Tensor UOps can become a buffer or const after scheduling (#8698) * spec * work * update test_viewed_consts_do_not_realize * remove --- test/test_schedule.py | 51 +++++++++++++++++++++ test/unit/test_tensor_uop_representation.py | 4 +- tinygrad/engine/schedule.py | 10 ++-- 3 files changed, 60 insertions(+), 5 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 0f71dae331..20ca1ccd2c 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2366,5 +2366,56 @@ class TestContiguous(unittest.TestCase): b = a.expand((4, 4)).contiguous().contiguous() check_schedule(b, 1) + +class TestUOpBecome(unittest.TestCase): + # the simplest case, if we create a new BUFFER for this UOp + def test_new_buffer(self): + a = Tensor.empty(4, 4) + b = Tensor.empty(4, 4) + add = a+b + check_schedule(add, 1) + assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {}) + + def test_new_buffer_view(self): + a = Tensor.empty(4, 4) + b = Tensor.empty(4, 4) + add = (a+b).reshape(8, 2) + check_schedule(add, 1) + assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {}) + # VIEW is preserverd after the becomes rewrite. + self.assertEqual(add.lazydata.shape, (8, 2)) + assert add.lazydata is not add.lazydata.base + + def test_become_existing_buffer(self): + a = Tensor.empty(4, 4) + b = a*1 + assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul + check_schedule(b, 0) + assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER) + self.assertIs(a.lazydata.base.buffer, b.lazydata.base.buffer) + + def test_become_const_in_base(self): + a = Tensor.empty(4) + b = a*0 + assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul + check_schedule(b, 0) + assert UPat(Ops.CONST, arg=0).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER) + + def test_become_const_in_view(self): + # if we shrink the base down to a size 0, only the VIEW becomes CONST, base is unchanged. + add = Tensor.empty(2, 2)+Tensor.empty(2, 2) + b = add.shrink(((0, 1), (0, 0))) + check_schedule(b, 0) + assert UPat(Ops.CONST, arg=0).match(b.lazydata, {}) + self.assertEqual(b.shape, (1, 0)) + # the base is untouched. + assert UPat(Ops.ADD).match(add.lazydata, {}) + + def test_become_const_from_const(self): + const_add = Tensor(1)+Tensor(2) + assert UPat(Ops.ADD).match(const_add.lazydata, {}) + check_schedule(const_add, 0) + assert UPat(Ops.CONST, arg=3).match(const_add.lazydata.base, {}) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/test/unit/test_tensor_uop_representation.py b/test/unit/test_tensor_uop_representation.py index dc8d0b64aa..6ea1b011a4 100644 --- a/test/unit/test_tensor_uop_representation.py +++ b/test/unit/test_tensor_uop_representation.py @@ -71,9 +71,9 @@ class TestTensorUopRepresentation(unittest.TestCase): def test_viewed_consts_do_not_realize(self): a = Tensor.ones(10, 10) print(a.lazydata) - pre_realize = a.lazydata a.realize() - assert a.lazydata is pre_realize + is_pattern(a, const_pattern) + self.assertEqual(a.lazydata.shape, (10, 10)) # currently, CONSTs have a "fake" BUFFER. this should be fixed # current: diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 1c5931962e..71cc12492c 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -514,10 +514,14 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu for buf_uop in store_uops: for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st)) - # tensors can become an existing buffer, no ScheduleItem needed + # tensors can become an existing buffer or simplify to a const, no ScheduleItem needed for k,v in tensor_map.items(): - # NOTE: we only add base tensors to becomes_map - if k is not v and v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st)) + # NOOP + if k.base is v.base: continue + # NOTE: only the base tensors get a BUFFER UOp + if v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st)) + # otherwise if it simplified to a CONST the UOp just becomes that CONST + elif v.op is Ops.CONST: ctx.becomes_map[k] = v # add kernel children schedule_targets = {out:si for si in prescheduled for out in si.outputs} From 6733a3a96b6fa1ec869348e86ad2bd2c0c208596 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Tue, 21 Jan 2025 14:35:15 +0300 Subject: [PATCH 13/18] am: fix typo (#8700) --- docs/developer/am.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/developer/am.md b/docs/developer/am.md index 9435bc9537..67699422fb 100644 --- a/docs/developer/am.md +++ b/docs/developer/am.md @@ -27,7 +27,7 @@ AM binds compute queues directly to MEC (bypassing MES). Tinygrad uses only one The GPU being passed can be in one of several states: 1. Not initialized -2. Initialized by AMDGPU +2. Initialized by amdgpu 3. Initialized by AM The first and second states require a full GPU setup since their states are unknown. The second state also requires a mode1 reset to reinitialize all components. @@ -36,4 +36,4 @@ The third state can be set up partially to optimize boot time. In this case, onl ### VM Management -Each AM device sets up only a single `VMID=0` and one page directory. The page directory used is 3-level and thus supports up to 512TB of virtual addresses. All AM devices are located in one virtual address space. \ No newline at end of file +Each AM device sets up only a single `VMID=0` and one page directory. The page directory used is 3-level and thus supports up to 512GB of virtual addresses. All AM devices are located in one virtual address space. \ No newline at end of file From 3628f899292e311e01dead74e666929c9d0fbfd9 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Tue, 21 Jan 2025 16:34:19 +0300 Subject: [PATCH 14/18] fix deallocate for subbuffers (#8701) * fix deallocate for subbuffers * forgot this * rm name * hmm --- test/test_subbuffer.py | 18 ++++++++++++++++++ tinygrad/device.py | 6 +++--- tinygrad/helpers.py | 2 +- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/test/test_subbuffer.py b/test/test_subbuffer.py index 8b6e2043f4..40fb7ad3a3 100644 --- a/test/test_subbuffer.py +++ b/test/test_subbuffer.py @@ -2,6 +2,7 @@ import unittest from tinygrad import Device, dtypes, Tensor from tinygrad.device import Buffer from tinygrad.ops import view_supported_devices +from tinygrad.helpers import Context @unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported") class TestSubBuffer(unittest.TestCase): @@ -47,5 +48,22 @@ class TestSubBuffer(unittest.TestCase): out = vt.to(f"{Device.DEFAULT}:1").realize().tolist() assert out == [2, 3, 4] + def test_subbuffer_deallocate(self): + with Context(LRU=0): + vbuf = self.buf.view(2, dtypes.uint8, offset=3).ensure_allocated() + self.buf.deallocate() + vbuf.deallocate() + + # Allocate a fake one on the same place + _ = Buffer(Device.DEFAULT, 10, dtypes.uint8).ensure_allocated() + + self.buf.ensure_allocated() + self.buf.copyin(memoryview(bytearray(range(10, 20)))) + + vbuf.ensure_allocated() + + tst = vbuf.as_buffer().tolist() + assert tst == [13, 14] + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/device.py b/tinygrad/device.py index e1777aa02f..2a20992f80 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -4,7 +4,7 @@ from collections import defaultdict from typing import Optional, Any, Iterator, Generator import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE -from tinygrad.helpers import CI, OSX, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \ +from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \ cpu_time_execution from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes from tinygrad.renderer import Renderer @@ -129,7 +129,7 @@ class Buffer: if self._base is None and (self.options is None or self.options.external_ptr is None): if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes self.allocator.free(self._buf, self.nbytes, self.options) - del self._buf + del self._buf def __reduce__(self): buf = None if self._base is not None: @@ -202,7 +202,7 @@ class LRUAllocator(Allocator): for opaque in opaques: super().free(opaque, sz, options) opaques.clear() def free(self, opaque:Any, size:int, options:Optional[BufferSpec]=None): - if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque) + if LRU and (options is None or not options.nolru): self.cache[(size, options)].append(opaque) else: super().free(opaque, size, options) class _MallocAllocator(LRUAllocator): diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 39d50fed8a..090b9178cc 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -108,7 +108,7 @@ WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1) FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0) SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1) -PICKLE_BUFFERS, PROFILE = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")) +PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1) @dataclass(frozen=True) class Metadata: From d6bf1feaab0d62ffb33d71663faa9088f47d45fa Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 21 Jan 2025 10:09:33 -0500 Subject: [PATCH 15/18] remove the "no copy" line from copy_to_device (#8702) * delete the no copy one * add tests --- test/test_schedule.py | 11 +++++++++++ tinygrad/ops.py | 2 -- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 20ca1ccd2c..ccf423d0a5 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2258,6 +2258,17 @@ class TestCopyFolding(unittest.TestCase): add = schedule_graph_rewrite(add) assert all_same([x.device for x in add.src]), f"ALU has different devices! {[x.device for x in add.src]}" + def test_copy_to_same_device(self): + a = Tensor.empty(4).lazydata + b = a.copy_to_device(a.device) + check_schedule(b, 0, filter_sink=False) + b = schedule_graph_rewrite(b) + self.assertIs(b, a) + + def test_clone(self): + a = Tensor.empty(4).lazydata + check_schedule(a.clone(), 1, filter_sink=False) + class TestTensorUOpSpec(unittest.TestCase): def test_const_must_be_unmasked(self): a = Tensor.ones((4, 4)).pad((2, 2)) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 74892dbed8..cbf0038fd2 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -427,8 +427,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # otherwise it's just a VIEW(BUFFER) return UOp(Ops.VIEW, dtype, (UOp.new_buffer(device, (st:=ShapeTracker.from_shape(shape)).size, dtype),), st) def copy_to_device(self, device:str, clone:bool=False) -> UOp: - # no COPY - if self.device == device and not clone: return self # if it's a shrink, do the shrink before the copy with CONTIGUOUS if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device) # COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st) From 018edd934bc60085c0dab8442144de6ca1492932 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 21 Jan 2025 09:57:47 -0800 Subject: [PATCH 16/18] don't use view in copy [pr] (#8704) * don't use view in copy [pr] * oh, remove double contig * fix reps --- test/unit/test_disk_tensor.py | 1 + test/unit/test_tensor_uop_representation.py | 6 ++++-- tinygrad/ops.py | 9 ++++++++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index 7078a994f3..a76c194076 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -164,6 +164,7 @@ class TestSafetensors(unittest.TestCase): def test_save_all_dtypes(self): for dtype in dtypes.fields().values(): if dtype in [dtypes.bfloat16]: continue # not supported in numpy + if dtype in [dtypes.double] and Device.DEFAULT == "METAL": continue # not supported on METAL path = temp(f"ones.{dtype}.safetensors") ones = Tensor(np.random.rand(10,10), dtype=dtype) safe_save(get_state_dict(ones), path) diff --git a/test/unit/test_tensor_uop_representation.py b/test/unit/test_tensor_uop_representation.py index 6ea1b011a4..b4d391c2af 100644 --- a/test/unit/test_tensor_uop_representation.py +++ b/test/unit/test_tensor_uop_representation.py @@ -53,7 +53,8 @@ class TestTensorUopRepresentation(unittest.TestCase): b = Tensor([4.,5,6]).realize() c = a+b print(c.lazydata) - is_pattern(c, UPat(Ops.ADD, src=(UPat(Ops.VIEW, src=(realized_pattern,)), UPat(Ops.VIEW, src=(realized_pattern,))))) + is_pattern(c, UPat(Ops.ADD, src=(realized_pattern, realized_pattern))) + #is_pattern(c, UPat(Ops.ADD, src=(UPat(Ops.VIEW, src=(realized_pattern,)), UPat(Ops.VIEW, src=(realized_pattern,))))) def test_const_pattern(self): a = Tensor(1) @@ -111,7 +112,8 @@ class TestTensorUopRepresentation(unittest.TestCase): c = a.to("TEST") # NOTE: this isn't checked print(c.lazydata) # TODO: COPY on a Tensor becomes a VIEW(COPY), this should be done in the scheduler not in ops - is_pattern(c, UPat(Ops.VIEW, src=(UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)),))) + is_pattern(c, UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,))) + #is_pattern(c, UPat(Ops.VIEW, src=(UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)),))) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/ops.py b/tinygrad/ops.py index cbf0038fd2..0225603cc4 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -430,7 +430,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # if it's a shrink, do the shrink before the copy with CONTIGUOUS if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device) # COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st) - return UOp(Ops.COPY, self.base.dtype, (UOp(Ops.DEVICE, arg=device), self.base), clone).view(unwrap(self.st)) + ret = UOp(Ops.COPY, self.base.dtype, (UOp(Ops.DEVICE, arg=device), self.base), clone) + op_arg = [] + mop = self + while mop is not self.base: + op_arg.append((mop.op, mop.arg)) + mop = mop.src[0] + for op,arg in reversed(op_arg): ret = UOp(op, ret.dtype, (ret,), arg) + return ret def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True) @property def lbs(self): return [self] From 1e283c33d375323b4f60a8063475894d8d76b50f Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 21 Jan 2025 14:11:03 -0500 Subject: [PATCH 17/18] remove realize in bert model init [pr] (#8707) --- examples/mlperf/model_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 19fa1a8e07..ef92c98a69 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -687,9 +687,9 @@ def train_bert(): model = get_mlperf_bert_model(init_ckpt if RUNMLPERF else None) - for _, x in get_state_dict(model).items(): - x.realize().to_(GPUS) parameters = get_parameters(model) + for p in parameters: + p.to_(GPUS) # ** Log run config ** for key, value in config.items(): print(f'HParam: "{key}": {value}') From c5e46c5eee3fe84e06b26593ffee9ee699f273a6 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Tue, 21 Jan 2025 22:22:23 +0300 Subject: [PATCH 18/18] am: recover from any boot interrupt (#8703) * am: recover from any load interrupt * add fuzzer * nu --- test/external/external_fuzz_am_interrupts.py | 39 ++++++++++++++++++++ tinygrad/runtime/support/am/amdev.py | 17 ++++----- tinygrad/runtime/support/am/ip.py | 26 +++++++++---- 3 files changed, 65 insertions(+), 17 deletions(-) create mode 100644 test/external/external_fuzz_am_interrupts.py diff --git a/test/external/external_fuzz_am_interrupts.py b/test/external/external_fuzz_am_interrupts.py new file mode 100644 index 0000000000..2ed5724288 --- /dev/null +++ b/test/external/external_fuzz_am_interrupts.py @@ -0,0 +1,39 @@ +import subprocess +import random +import time +from concurrent.futures import ThreadPoolExecutor, as_completed + +def run_test(i, full_run=False): + print(f"\rRunning iteration {i}...", end=" ", flush=True) + + p = subprocess.Popen(['python3', 'test/test_tiny.py', 'TestTiny.test_plus'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + if not full_run: + time.sleep(random.uniform(0, 1200) / 1000) + p.kill() + _, stderr = p.communicate() + else: + _, stderr = p.communicate() + + if full_run: + stderr_text = stderr.decode() + print(stderr_text) + assert "Ran 1 test in" in stderr_text and "OK" in stderr_text + +max_workers = 4 +with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for i in range(1000000): + if i % 100 == 0: + for future in as_completed(futures): + try: future.result() + except Exception as e: + print(f"\nError in iteration: {e}") + futures = [] + + run_test(i, True) + else: + future = executor.submit(run_test, i, False) + futures.append(future) + + if len(futures) > max_workers * 2: futures = [f for f in futures if not f.done()] \ No newline at end of file diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index 2397ae200f..37b43d9c7b 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -1,5 +1,5 @@ from __future__ import annotations -import ctypes, collections, time, dataclasses, pathlib, fcntl, os, signal +import ctypes, collections, time, dataclasses, pathlib, fcntl, os from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp from tinygrad.runtime.autogen.am import am, mp_11_0, mp_13_0_0, nbio_4_3_0, mmhub_3_0_0, gc_11_0_0, osssys_6_0_0 from tinygrad.runtime.support.allocator import TLSFAllocator @@ -279,13 +279,10 @@ class AMDev: self.partial_boot = False if not self.partial_boot: - try: # do not interrupt the boot process - signal.signal(signal.SIGINT, signal.SIG_IGN) - if self.psp.is_sos_alive(): self.smu.mode1_reset() - for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]: - ip.init() - if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized") - finally: signal.signal(signal.SIGINT, signal.default_int_handler) + if self.psp.is_sos_alive() and self.smu.is_smu_alive(): self.smu.mode1_reset() + for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]: + ip.init() + if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized") # Booting done self.is_booting = False @@ -332,8 +329,8 @@ class AMDev: self.reg("regBIF_BX_PF0_RSMU_INDEX").write(reg) self.reg("regBIF_BX_PF0_RSMU_DATA").write(val) - def wait_reg(self, reg:AMRegister, value:int, mask=0xffffffff) -> int: - for _ in range(10000): + def wait_reg(self, reg:AMRegister, value:int, mask=0xffffffff, timeout=10000) -> int: + for _ in range(timeout): if ((rval:=reg.read()) & mask) == value: return rval time.sleep(0.001) raise RuntimeError(f'wait_reg timeout reg=0x{reg.reg_off:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}') diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 79c4de0a24..2b08798825 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -1,4 +1,4 @@ -import ctypes, time +import ctypes, time, contextlib from typing import Literal from tinygrad.runtime.autogen.am import am, smu_v13_0_0 from tinygrad.helpers import to_mv, data64, lo32, hi32, DEBUG @@ -106,22 +106,26 @@ class AM_SMU(AM_IP): self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetSoftMinByFreq, clck, poll=True) self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetSoftMaxByFreq, clck, poll=True) + def is_smu_alive(self): + with contextlib.suppress(RuntimeError): self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_GetSmuVersion, 0, timeout=100) + return self.adev.mmMP1_SMN_C2PMSG_90.read() != 0 + def mode1_reset(self): if DEBUG >= 2: print(f"am {self.adev.devfmt}: mode1 reset") self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_Mode1Reset, 0, poll=True) time.sleep(0.5) # 500ms - def _smu_cmn_poll_stat(self): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1) + def _smu_cmn_poll_stat(self, timeout=10000): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1, timeout=timeout) def _smu_cmn_send_msg(self, msg, param=0): self.adev.mmMP1_SMN_C2PMSG_90.write(0) # resp reg self.adev.mmMP1_SMN_C2PMSG_82.write(param) self.adev.mmMP1_SMN_C2PMSG_66.write(msg) - def _smu_cmn_send_smc_msg_with_param(self, msg, param, poll=True, read_back_arg=False): - if poll: self._smu_cmn_poll_stat() + def _smu_cmn_send_smc_msg_with_param(self, msg, param, poll=True, read_back_arg=False, timeout=10000): # 10s + if poll: self._smu_cmn_poll_stat(timeout=timeout) self._smu_cmn_send_msg(msg, param) - self._smu_cmn_poll_stat() + self._smu_cmn_poll_stat(timeout=timeout) return self.adev.rreg(self.adev.mmMP1_SMN_C2PMSG_82) if read_back_arg else None class AM_GFX(AM_IP): @@ -319,8 +323,9 @@ class AM_PSP(AM_IP): (am.PSP_FW_TYPE_PSP_INTF_DRV, am.PSP_BL__LOAD_INTFDRV), (am.PSP_FW_TYPE_PSP_DBG_DRV, am.PSP_BL__LOAD_DBGDRV), (am.PSP_FW_TYPE_PSP_RAS_DRV, am.PSP_BL__LOAD_RASDRV), (am.PSP_FW_TYPE_PSP_SOS, am.PSP_BL__LOAD_SOSDRV)] - for fw, compid in sos_components_load_order: self._bootloader_load_component(fw, compid) - while not self.is_sos_alive(): time.sleep(0.01) + if not self.is_sos_alive(): + for fw, compid in sos_components_load_order: self._bootloader_load_component(fw, compid) + while not self.is_sos_alive(): time.sleep(0.01) self._ring_create() self._tmr_init() @@ -357,6 +362,13 @@ class AM_PSP(AM_IP): self.tmr_paddr = self.adev.mm.palloc(self.tmr_size, align=am.PSP_TMR_ALIGNMENT, boot=True) def _ring_create(self): + # If the ring is already created, destroy it + if self.adev.regMP0_SMN_C2PMSG_71.read() != 0: + self.adev.regMP0_SMN_C2PMSG_64.write(am.GFX_CTRL_CMD_ID_DESTROY_RINGS) + + # There might be handshake issue with hardware which needs delay + time.sleep(0.02) + # Wait until the sOS is ready self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_64, mask=0x80000000, value=0x80000000)