mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
touchups, print GB/s
This commit is contained in:
@@ -94,7 +94,7 @@ class InterpretedBuffer(DeviceBuffer): # pylint: disable=abstract-method
|
||||
if context is None: context = dict()
|
||||
if ast in context: return context[ast]
|
||||
srcs = [cls.exec_ast(x, context=context) if isinstance(x, LazyOp) else x for x in ast.src]
|
||||
if DEBUG >= 4: print("exec_ast", ast.op, [x.shape for x in srcs], ast.arg)
|
||||
if DEBUG >= 4 or (not isinstance(srcs[0]._buf, GenericShape) and DEBUG >= 3): print("exec_ast", ast.op, [x.shape for x in srcs], ast.arg)
|
||||
if ast.op in BinaryOps: assert srcs[0].shape == srcs[1].shape, f"BinaryOps shape mismatch {srcs[0].shape} != {srcs[1].shape}"
|
||||
if ast.op in ReduceOps: assert all(r == n or n == 1 for r,n in zip(srcs[0].shape, ast.arg)), f"ReduceOps can't reduce {srcs[0].shape} -> {ast.arg}"
|
||||
if ast.op in MovementOps: ret = srcs[0].movement_op(ast.op, ast.arg)
|
||||
@@ -126,8 +126,8 @@ class ASTRunner:
|
||||
def __call__(self, rawbufs:List[RawBuffer]) -> Optional[float]:
|
||||
if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=DEBUG>=2): GlobalCounters.time_sum_s += et
|
||||
if DEBUG >= 1:
|
||||
print(f"**** {GlobalCounters.kernel_count:4d} {self.name:20s} args {len(rawbufs):5d} kernels {str(self.global_size):18s} {str(self.local_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS)"))
|
||||
print(f"*** {GlobalCounters.kernel_count:4d} {self.name:20s} ars {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS, {self.mem_estimate/(et*1e9):6.2f} GB/s)"))
|
||||
GlobalCounters.log_kernel(self.op_estimate, self.mem_estimate)
|
||||
return et
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ class RawMallocBuffer(RawBufferCopyIn):
|
||||
def __init__(self, size):
|
||||
super().__init__(size)
|
||||
self._buf = (ctypes.c_float * (size//4))()
|
||||
def _buffer(self): return self._buf
|
||||
def copyin(self, x:np.ndarray): ctypes.memmove(self._buf, x.ctypes.data, x.size*4)
|
||||
def toCPU(self): return np.ctypeslib.as_array(self._buf)
|
||||
|
||||
@@ -29,7 +30,7 @@ class ClangProgram:
|
||||
if wait: return time.monotonic()-st
|
||||
|
||||
class ClangCodegen(GPUCodegen):
|
||||
lang = GPULanguage(buffer_suffix="restrict")
|
||||
lang = GPULanguage(buffer_suffix=" restrict")
|
||||
|
||||
class ClangBuffer(CompiledBuffer):
|
||||
raw_buffer_type, codegen_type, runtime_type = RawMallocBuffer, ClangCodegen, ClangProgram
|
||||
|
||||
@@ -8,7 +8,7 @@ torch_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(),
|
||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).float(),
|
||||
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
|
||||
FusedOps.MULACC: einsum_mulacc(torch.einsum, lambda x: x.stride(), lambda x,s: x.expand(s))
|
||||
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(a.dtype), lambda x: x.stride(), lambda x,s: x.expand(s))
|
||||
}}
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
|
||||
@@ -70,6 +70,7 @@ def view_from_shape(shape:Tuple[int, ...]) -> View:
|
||||
assert all(isinstance(x, int) for x in shape) and len(shape) != 0
|
||||
return View(tuple(shape), strides_for_shape(shape))
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
||||
new_strides, new_offset = [], vm2.expr_node(Variable.num(vm1.offset))
|
||||
assert isinstance(new_offset, NumNode), "new_offset wasn't a number?!?"
|
||||
|
||||
Reference in New Issue
Block a user