mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 22:38:16 -05:00
fix mem_estimate for dtype.itemsize
This commit is contained in:
@@ -341,4 +341,4 @@ class GPUCodegen(ASTKernel):
|
||||
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), self.bufs_to_delete,
|
||||
list(self.output_shape[::-1]) if len(self.output_shape) > 0 else [1],
|
||||
(self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None,
|
||||
op_estimate=self.info.flops, mem_estimate=sum(4*prod(x._base_shape) for x in self.bufs if x is not None))
|
||||
op_estimate=self.info.flops, mem_estimate=sum(x.dtype.itemsize*prod(x._base_shape) for x in self.bufs if x is not None))
|
||||
|
||||
@@ -212,4 +212,4 @@ class LLVMCodegen(ASTKernel):
|
||||
loop_entry[-1].branch(loop_exit[-1]._block)
|
||||
loop_exit[0].ret_void()
|
||||
|
||||
return ASTRunner('exec', str(module), op_estimate=self.info.flops, mem_estimate=sum(4*prod(x._base_shape) for x in self.bufs))
|
||||
return ASTRunner('exec', str(module), op_estimate=self.info.flops, mem_estimate=sum(x.dtype.itemsize*prod(x._base_shape) for x in self.bufs))
|
||||
|
||||
Reference in New Issue
Block a user