mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 23:54:58 -05:00
replace raise Exception with specific errors (#3874)
This commit is contained in:
@@ -384,5 +384,5 @@ class UOpGraph:
|
||||
if u.arg.startswith("__metal_wmma"): flops += 2*(8*8*8)//32 * mults
|
||||
elif u.arg == "__hip_wmma_f16_f16" or u.arg == "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32": flops += 2*(16*16*16)//32 * mults
|
||||
elif u.arg == "__cuda_mma_m16n8k16_f16_f32": flops += 2*(8*16*16)//32 * mults
|
||||
else: raise Exception("not implemented")
|
||||
else: raise NotImplementedError(f"not implemented wmma {u.arg=}")
|
||||
return flops, mem
|
||||
|
||||
@@ -126,7 +126,7 @@ class CUDAAllocator(LRUAllocator):
|
||||
if options.host:
|
||||
return init_c_var(ctypes.c_void_p(), lambda x: check(cuda.cuMemHostAlloc(ctypes.byref(x), size, 0)))
|
||||
else:
|
||||
raise Exception("no options")
|
||||
raise ValueError("no options")
|
||||
def _free(self, opaque): check(cuda.cuMemFree_v2(opaque))
|
||||
def copyin(self, dest, src:memoryview):
|
||||
host_mem = self._alloc_with_options(len(src), BufferOptions(host=True))
|
||||
|
||||
@@ -115,7 +115,7 @@ class HSAAllocator(LRUAllocator):
|
||||
check(hsa.hsa_amd_memory_pool_allocate(HSADevice.cpu_mempool, size, 0, ctypes.byref(mem := ctypes.c_void_p())))
|
||||
check(hsa.hsa_amd_agents_allow_access(2, (hsa.hsa_agent_t*2)(HSADevice.cpu_agent, self.device.agent), None, mem))
|
||||
return mem.value
|
||||
else: raise Exception("no options")
|
||||
else: raise ValueError("no options")
|
||||
|
||||
def _free(self, opaque:T):
|
||||
HSADevice.synchronize_system()
|
||||
|
||||
@@ -177,7 +177,7 @@ class PythonProgram:
|
||||
def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8) # (i, j), C, D (4 elements on 32 threads)
|
||||
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
|
||||
else:
|
||||
raise Exception(f"unimplemented tensor core {arg}")
|
||||
raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
elif uop is UOps.ALU:
|
||||
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}"
|
||||
assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPEQ, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"
|
||||
|
||||
Reference in New Issue
Block a user