fix mem_estimate for dtype.itemsize

This commit is contained in:
George Hotz
2023-03-11 09:20:05 -08:00
parent fe8c05b96f
commit fd65edf595
2 changed files with 2 additions and 2 deletions

View File

@@ -341,4 +341,4 @@ class GPUCodegen(ASTKernel):
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), self.bufs_to_delete,
list(self.output_shape[::-1]) if len(self.output_shape) > 0 else [1],
(self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None,
op_estimate=self.info.flops, mem_estimate=sum(4*prod(x._base_shape) for x in self.bufs if x is not None))
op_estimate=self.info.flops, mem_estimate=sum(x.dtype.itemsize*prod(x._base_shape) for x in self.bufs if x is not None))

View File

@@ -212,4 +212,4 @@ class LLVMCodegen(ASTKernel):
loop_entry[-1].branch(loop_exit[-1]._block)
loop_exit[0].ret_void()
return ASTRunner('exec', str(module), op_estimate=self.info.flops, mem_estimate=sum(4*prod(x._base_shape) for x in self.bufs))
return ASTRunner('exec', str(module), op_estimate=self.info.flops, mem_estimate=sum(x.dtype.itemsize*prod(x._base_shape) for x in self.bufs))