diff --git a/README.md b/README.md index db6cf38bd1..fecabced02 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ Try a matmul. See how, despite the style, it is fused into one kernel with the p ```sh DEBUG=3 python3 -c "from tinygrad import Tensor; N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N); -c = (a.reshape(N, 1, N) * b.permute(1,0).reshape(1, N, N)).sum(axis=2); +c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2); print((c.numpy() - (a.numpy() @ b.numpy())).mean())" ``` diff --git a/extra/multitensor.py b/extra/multitensor.py new file mode 100644 index 0000000000..dadcf1ef18 --- /dev/null +++ b/extra/multitensor.py @@ -0,0 +1,59 @@ +import numpy as np +from tinygrad import Tensor, Device, GlobalCounters +from tinygrad.helpers import Timing + +d0, d1 = f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2" +N = 256 +FLOPS = N*N*N*2 + +# LazyBuffer should make three fields lists: self.st (all must have the same shape), self.realized, and self.device + +def explicit_shard_W_axis_1(X, W): + Xs = [X.to(d0), X.to(d1)] + Ws = [W[:, :N//2].to(d0), W[:, N//2:].to(d1)] # TODO: these shouldn't make copies on the original device + # pad them to form the correct size + Ws = [Ws[0].pad((None, (0,N//2))), Ws[1].pad((None, (N//2,0)))] + for x in Xs: assert x.shape == X.shape + for w in Ws: assert w.shape == W.shape + + # TODO: it shouldn't be faster with these realize + for x in Xs+Ws: x.realize() + def lm(x:Tensor, w:Tensor): + # these are movement ops on the local device + x = x.reshape(N, 1, N).expand(N, N, N) + w = w.T.reshape(1, N, N).expand(N, N, N) + m = x*w + assert m.lazydata.st.views[0].mask is not None + ret = m.sum(2) + return ret + #Os = [lm(Xs[0], Ws[0]), lm(Xs[1], Ws[1])] + Os = [Xs[0] @ Ws[0], Xs[1] @ Ws[1]] + for x in Os: x.realize() + return Os[0].to(Device.DEFAULT) + Os[1].to(Device.DEFAULT) + + #return Tensor.cat(*[x.to(Device.DEFAULT) for x in Os], dim=1) # TODO: someday we can remove this copy too + +def matmul(X, W): + return explicit_shard_W_axis_1(X, W) + #return X@W + +if __name__ == "__main__": + with Timing("init devices: "): + Device[d0], Device[d1] + + with Timing("create tensors: "): + X = Tensor.kaiming_uniform(N, N).realize() + W = Tensor.kaiming_uniform(N, N).realize() + + #with Timing("warmup: "): + # O = matmul(X, W).numpy() + + GlobalCounters.reset() + print("******** multiply start") + with Timing("******** multiply done: ", lambda x: f" {FLOPS/x:.2f} GFLOPS"): + O = matmul(X, W).realize() + Device[Device.DEFAULT].synchronize() + + with Timing("testing: "): + val = X.numpy() @ W.numpy() + np.testing.assert_allclose(val, O.numpy(), atol=1e-5) diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index a24ea38a6a..a80a9e247f 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -166,7 +166,7 @@ class TestOpt(unittest.TestCase): with CLCache(allowed=1): d = a * b + c d.realize() - np.testing.assert_allclose(d.numpy(), na*nb+nc, rtol=1e-5) + np.testing.assert_allclose(d.numpy(), na*nb+nc, rtol=1e-5, atol=1e-7) def test_fold_reduce_elementwise(self): img = Tensor.ones(32) diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index ac5d006beb..003654cc37 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -17,11 +17,12 @@ from examples.stable_diffusion import UNetModel def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jitted=False): tms = [] for _ in range(4): + early_gen = [x.realize() if isinstance(x, Tensor) else x for x in gen()] GlobalCounters.reset() GlobalCounters.mem_used = 0 Device[Device.DEFAULT].synchronize() st = time.perf_counter_ns() - train(*gen()) + train(*early_gen) Device[Device.DEFAULT].synchronize() tms.append(time.perf_counter_ns() - st) diff --git a/test/test_masked_st.py b/test/test_masked_st.py index 4630f13e99..c518d5b20e 100644 --- a/test/test_masked_st.py +++ b/test/test_masked_st.py @@ -3,19 +3,27 @@ from tinygrad.tensor import Tensor class TestMaskedShapeTracker(unittest.TestCase): def test_mul_masked(self): - a = Tensor([1,1,1,1]) - b = Tensor([1,1]).pad(((0,2),)) + a = Tensor([1,1,1,1,1]) + b = Tensor([1,1]).pad(((0,3),)) c = a*b - # TODO: make this true + assert c.shape == a.shape #assert c.lazydata.st.views[0].mask is not None ret = c.data() - assert ret.tolist() == [1.0, 1.0, 0.0, 0.0] + assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0] + + def test_mul_both_masked(self): + a = Tensor([1,1]).pad(((0,3),)) + b = Tensor([1,1]).pad(((0,3),)) + c = a*b + assert c.shape == a.shape + #assert c.lazydata.st.views[0].mask is not None + ret = c.data() + assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0] def test_add_masked(self): a = Tensor([1,1]).pad(((0,2),)) b = Tensor([1,1]).pad(((0,2),)) c = a+b - # TODO: make this true #assert c.lazydata.st.views[0].mask is not None ret = c.data() assert ret.tolist() == [2.0, 2.0, 0.0, 0.0] diff --git a/tinygrad/device.py b/tinygrad/device.py index b6b40d78c9..1405af0975 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -48,7 +48,7 @@ class JITRunner: def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: raise NotImplementedError("override this") -def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, jit=False, num_kernels=1, lra: Optional[Dict]=None): +def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count:int, jit=False, num_kernels=1, lra: Optional[Dict]=None, device:str=""): if var_vals is None: var_vals = {} op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals) GlobalCounters.kernel_count += num_kernels @@ -56,7 +56,7 @@ def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Option GlobalCounters.global_mem += mem_estimate if et is not None: GlobalCounters.time_sum_s += et if DEBUG >= 2: - print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} {str(lra.get('local_size', '') if lra else ''):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + + print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} dev {device:7s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) # **************** Buffer / Allocator **************** @@ -89,32 +89,41 @@ class Buffer: if self.size > 0: self.allocator.copyout(flat_mv(ret.data), self._buf) return ret +def _internal_buffer_copy(dest, src): + if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator): + # fast path, used on HIP between GPUs + # NOTE: it's important we use the dest device here to ensure the transfer is ready + Device[src.device].synchronize() # TODO: async this + dest.allocator.transfer(dest._buf, src._buf, dest.size*dest.dtype.itemsize) + return + if getenv("FROM_BUFFER") and hasattr(dest.allocator, 'from_buffer') and hasattr(dest.allocator, 'transfer') and hasattr(src.allocator, 'as_buffer'): + # fast path, used on Metal in OS X Sonoma + # NOTE: this is *only* faster if the pages from disk are already loaded into memory + fb = dest.allocator.from_buffer(src.allocator.as_buffer(src._buf)) + if fb: + dest.allocator.transfer(dest._buf, fb, dest.size*dest.dtype.itemsize) + return + if hasattr(dest.allocator, 'as_buffer'): + # fast(ish) path, uses readinto in diskbuffers + src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf) + elif hasattr(src.allocator, 'as_buffer'): + dest.allocator.copyin(dest._buf, src.allocator.as_buffer(src._buf)) + else: + # slow path, allocates a CPU buffer + dest.copyin(src.toCPU().data) + class _BufferCopy(JITRunner): # TODO: make wait work def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): dest, src = rawbufs assert dest.size == src.size and dest.dtype == src.dtype, "buffer copy size/dtype mismatch" - if DEBUG >= 2: print(f"*** copy {dest.device} <- {src.device} size {dest.size:<16d} dtype {dest.dtype}") - if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator): - # fast path, used on HIP between GPUs - # NOTE: it's important we use the dest device here to ensure the transfer is ready - dest.allocator.transfer(dest._buf, src._buf, dest.size*dest.dtype.itemsize) - return - if getenv("FROM_BUFFER") and hasattr(dest.allocator, 'from_buffer') and hasattr(dest.allocator, 'transfer') and hasattr(src.allocator, 'as_buffer'): - # fast path, used on Metal in OS X Sonoma - # NOTE: this is *only* faster if the pages from disk are already loaded into memory - fb = dest.allocator.from_buffer(src.allocator.as_buffer(src._buf)) - if fb: - dest.allocator.transfer(dest._buf, fb, dest.size*dest.dtype.itemsize) - return - if hasattr(dest.allocator, 'as_buffer'): - # fast(ish) path, uses readinto in diskbuffers - src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf) - elif hasattr(src.allocator, 'as_buffer'): - dest.allocator.copyin(dest._buf, src.allocator.as_buffer(src._buf)) - else: - # slow path, allocates a CPU buffer - dest.copyin(src.toCPU().data) + st = time.perf_counter() + _internal_buffer_copy(dest, src) + et = None + if wait or DEBUG >= 2: + Device[dest.device].synchronize() + et = time.perf_counter() - st + update_stats(colored(f"copy {dest.device:7s} <- {src.device:7s}", "yellow"), 0, dest.size*dest.dtype.itemsize, {}, et, 2, jit, lra={"global_size": dest.size}, device=dest.device) BufferCopy = _BufferCopy() # TODO: size, dest, src are the same type. can we enforce this? @@ -167,7 +176,7 @@ class InterpretedASTRunner(JITRunner): st = time.perf_counter() rawbufs[0]._buf = self.fxn([x._buf for x in rawbufs], var_vals) et = time.perf_counter() - st - update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit) + update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, device=rawbufs[0].device) return et class Interpreted: @@ -257,7 +266,7 @@ class CompiledASTRunner(JITRunner): if global_size: lra['global_size'] = global_size if local_size and 'local_size' not in lra: lra['local_size'] = local_size et = self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait or DEBUG>=2) - update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra) + update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra, device=rawbufs[0].device) return et class Compiled: diff --git a/tinygrad/graph.py b/tinygrad/graph.py index 3963c72402..805c9ec8fe 100644 --- a/tinygrad/graph.py +++ b/tinygrad/graph.py @@ -2,6 +2,7 @@ import os, atexit, functools from collections import defaultdict from typing import Dict, List from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp +from tinygrad.device import Device from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv, dedup from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.shape.shapetracker import ShapeTracker @@ -21,7 +22,7 @@ if GRAPH: G = nx.DiGraph() def save_graph_exit(): for k,v in cnts.items(): print(k, v) - print("saving", G) + print("saving", G, f"to {GRAPHPATH}.svg") nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot') # -Gnslimit=100 can make it finish, but you won't like results os.system(f'dot -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg') @@ -68,7 +69,7 @@ def log_schedule_item(si: ScheduleItem): cnts[optype] += 1 if GRAPH: assert si.out.base == si.out, "all outputs based" - top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#FF8080'} + top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'} # get inputs for shapetrackers input_to_st = defaultdict(list) @@ -88,7 +89,7 @@ def log_schedule_item(si: ScheduleItem): if nm(si.out) not in G.nodes: G.add_node(nm(si.out)) - G.nodes[nm(si.out)]['label'] = (str(set(x.shape for x in si.inputs))+"\n"+str(si.out.shape) if optype == ReduceOps else str(si.out.shape))+str_dtype(si.out.dtype)+(f"\n{si.ast.op}" if si.ast.op in LoadOps else "") + G.nodes[nm(si.out)]['label'] = '"' + (str(set(x.shape for x in si.inputs))+"\n"+str(si.out.shape) if optype == ReduceOps else str(si.out.shape))+str_dtype(si.out.dtype)+(f"\n{si.ast.op}" if si.ast.op in LoadOps or optype is BufferOps else "")+(f"\n{si.out.device}" if si.out.device != Device.DEFAULT else "") + '"' G.nodes[nm(si.out)]['fillcolor'] = top_colors[optype] G.nodes[nm(si.out)]['color'] = 'black' G.nodes[nm(si.out)]['style'] = 'filled' diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index 87272dbb1c..685ded4fe0 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -55,6 +55,7 @@ class HIPAllocator(LRUAllocator): check(hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost)) def transfer(self, dest:T, src:T, sz:int): check(hip.hipSetDevice(self.device)) + # TODO: hipMemcpyAsync, but you have to track the "src" buffer to not free it check(hip.hipMemcpy(dest, src, sz, hip.hipMemcpyDeviceToDevice)) class HIPDevice(Compiled): @@ -65,4 +66,6 @@ class HIPDevice(Compiled): from tinygrad.features.graph.hip import HIPGraph super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self.device), LinearizerOptions(device="HIP"), HIPRenderer, compile_hip, functools.partial(HIPProgram, self.device), HIPGraph) - def synchronize(self): hip.hipDeviceSynchronize() \ No newline at end of file + def synchronize(self): + check(hip.hipSetDevice(self.device)) + check(hip.hipDeviceSynchronize()) \ No newline at end of file