replace raise Exception with specific errors (#3874)

This commit is contained in:
chenyu
2024-03-22 12:32:21 -04:00
committed by GitHub
parent 8ef5490ec8
commit 1c51d586ea
4 changed files with 4 additions and 4 deletions

View File

@@ -384,5 +384,5 @@ class UOpGraph:
if u.arg.startswith("__metal_wmma"): flops += 2*(8*8*8)//32 * mults
elif u.arg == "__hip_wmma_f16_f16" or u.arg == "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32": flops += 2*(16*16*16)//32 * mults
elif u.arg == "__cuda_mma_m16n8k16_f16_f32": flops += 2*(8*16*16)//32 * mults
else: raise Exception("not implemented")
else: raise NotImplementedError(f"not implemented wmma {u.arg=}")
return flops, mem

View File

@@ -126,7 +126,7 @@ class CUDAAllocator(LRUAllocator):
if options.host:
return init_c_var(ctypes.c_void_p(), lambda x: check(cuda.cuMemHostAlloc(ctypes.byref(x), size, 0)))
else:
raise Exception("no options")
raise ValueError("no options")
def _free(self, opaque): check(cuda.cuMemFree_v2(opaque))
def copyin(self, dest, src:memoryview):
host_mem = self._alloc_with_options(len(src), BufferOptions(host=True))

View File

@@ -115,7 +115,7 @@ class HSAAllocator(LRUAllocator):
check(hsa.hsa_amd_memory_pool_allocate(HSADevice.cpu_mempool, size, 0, ctypes.byref(mem := ctypes.c_void_p())))
check(hsa.hsa_amd_agents_allow_access(2, (hsa.hsa_agent_t*2)(HSADevice.cpu_agent, self.device.agent), None, mem))
return mem.value
else: raise Exception("no options")
else: raise ValueError("no options")
def _free(self, opaque:T):
HSADevice.synchronize_system()

View File

@@ -177,7 +177,7 @@ class PythonProgram:
def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8) # (i, j), C, D (4 elements on 32 threads)
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
else:
raise Exception(f"unimplemented tensor core {arg}")
raise NotImplementedError(f"unimplemented tensor core {arg}")
elif uop is UOps.ALU:
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}"
assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPEQ, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"