diff --git a/tinygrad/codegen/gpu.py b/tinygrad/codegen/gpu.py index 7de2823a04..d1d355d8d8 100644 --- a/tinygrad/codegen/gpu.py +++ b/tinygrad/codegen/gpu.py @@ -341,4 +341,4 @@ class GPUCodegen(ASTKernel): return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), self.bufs_to_delete, list(self.output_shape[::-1]) if len(self.output_shape) > 0 else [1], (self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None, - op_estimate=self.info.flops, mem_estimate=sum(4*prod(x._base_shape) for x in self.bufs if x is not None)) + op_estimate=self.info.flops, mem_estimate=sum(x.dtype.itemsize*prod(x._base_shape) for x in self.bufs if x is not None)) diff --git a/tinygrad/codegen/llvm.py b/tinygrad/codegen/llvm.py index 420c473d19..db79b008af 100644 --- a/tinygrad/codegen/llvm.py +++ b/tinygrad/codegen/llvm.py @@ -212,4 +212,4 @@ class LLVMCodegen(ASTKernel): loop_entry[-1].branch(loop_exit[-1]._block) loop_exit[0].ret_void() - return ASTRunner('exec', str(module), op_estimate=self.info.flops, mem_estimate=sum(4*prod(x._base_shape) for x in self.bufs)) + return ASTRunner('exec', str(module), op_estimate=self.info.flops, mem_estimate=sum(x.dtype.itemsize*prod(x._base_shape) for x in self.bufs))