From 32e9949052a51c904dad5bbbd966e0fd65835718 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 8 Jun 2025 08:42:22 -0700 Subject: [PATCH] rename lazydata to uop (#10698) --- docs/abstractions3.py | 2 +- docs/developer/kernelize.md | 4 +- docs/ramp.py | 30 +-- examples/llama.py | 3 +- examples/llama3.py | 2 +- examples/llm.c/export.py | 6 +- examples/mlperf/dataloader.py | 16 +- examples/openpilot/compile4.py | 6 +- examples/qwq.py | 2 +- examples/tinychat/tinychat-browser/compile.py | 6 +- examples/webgpu/stable_diffusion/compile.py | 2 +- extra/export_model.py | 6 +- extra/hip_gpu_driver/test_pm4.py | 6 +- extra/multitensor.py | 2 +- extra/remu/test/hwtest.py | 2 +- extra/torch_backend/backend.py | 12 +- extra/torch_backend/wrapped_tensor.cpp | 2 +- extra/torch_hook/hook_torch.py | 2 +- .../external_benchmark_bert_matmuls.py | 2 +- test/external/external_multi_gpu.py | 4 +- test/external/external_test_amd.py | 10 +- test/external/external_test_hcq.py | 52 ++--- test/external/external_test_hip_compile.py | 2 +- test/external/external_test_nv.py | 10 +- test/external/fuzz_graph.py | 2 +- test/external/fuzz_linearizer.py | 2 +- test/imported/test_indexing.py | 16 +- test/test_arange.py | 2 +- test/test_assign.py | 22 +- test/test_const_folding.py | 2 +- test/test_dtype.py | 2 +- test/test_gc.py | 10 +- test/test_graph.py | 2 +- test/test_hcq.py | 26 +-- test/test_image_dtype.py | 32 +-- test/test_linearizer.py | 96 ++++---- test/test_masked_st.py | 6 +- test/test_multitensor.py | 46 ++-- test/test_nn.py | 12 +- test/test_pickle.py | 10 +- test/test_profiler.py | 6 +- test/test_renderer_failures.py | 2 +- test/test_rewrite_tracked_childen.py | 8 +- test/test_schedule.py | 212 +++++++++--------- test/test_search.py | 2 +- test/test_setitem.py | 2 +- test/test_symbolic_shapetracker.py | 20 +- test/test_tensor.py | 36 +-- test/test_tensor_uop.py | 6 +- test/test_uops.py | 26 +-- test/test_zero_copy.py | 2 +- test/unit/test_gradient.py | 10 +- test/unit/test_shapetracker.py | 28 +-- test/unit/test_tensor_uop_representation.py | 54 ++--- tinygrad/engine/jit.py | 6 +- tinygrad/nn/state.py | 2 +- tinygrad/tensor.py | 70 +++--- 57 files changed, 485 insertions(+), 486 deletions(-) diff --git a/docs/abstractions3.py b/docs/abstractions3.py index b69905f490..c34a399bba 100644 --- a/docs/abstractions3.py +++ b/docs/abstractions3.py @@ -36,7 +36,7 @@ optim.schedule_step() # this will step the optimizer without running realize # 3. Create a schedule. # The weight Tensors have been assigned to, but not yet realized. Everything is still lazy at this point -# l1.lazydata and l2.lazydata define a computation graph +# l1.uop and l2.uop define a computation graph from tinygrad.engine.schedule import ScheduleItem schedule: List[ScheduleItem] = Tensor.schedule(l1, l2) diff --git a/docs/developer/kernelize.md b/docs/developer/kernelize.md index 9731464504..b38db222b4 100644 --- a/docs/developer/kernelize.md +++ b/docs/developer/kernelize.md @@ -34,7 +34,7 @@ print(out) # , None)> on METAL with The multiply Tensor stays the same because it is fused. The output Tensor's UOp becomes a new ASSIGN UOp: ```py -print(out.lazydata) +print(out.uop) ``` The first source is the output BUFFER: @@ -72,7 +72,7 @@ Once a Tensor is kernelized, all children will LOAD its BUFFER, instead of fusin ```py child = out+2 child.kernelize() -print(child.lazydata.src[1].arg.ast) +print(child.uop.src[1].arg.ast) ``` ``` diff --git a/docs/ramp.py b/docs/ramp.py index db0afdf20e..ab4323cb1d 100644 --- a/docs/ramp.py +++ b/docs/ramp.py @@ -39,8 +39,8 @@ assert t.shape == (4,) print(t) # , None)> on METAL with grad None> -# the ".lazydata" property on Tensor contains the specification of how to compute it -print(t.lazydata) +# the ".uop" property on Tensor contains the specification of how to compute it +print(t.uop) """ UOp(Ops.COPY, dtypes.int, arg=None, src=( UOp(Ops.BUFFER, dtypes.int, arg=4, src=( @@ -57,21 +57,21 @@ UOp(Ops.COPY, dtypes.int, arg=None, src=( t.realize() # if we want to "realize" a tensor, we can with the "realize" method -# now when we look at the lazydata, it's changed -print(t.lazydata) +# now when we look at the uop, it's changed +print(t.uop) """ UOp(Ops.BUFFER, dtypes.int, arg=4, src=( UOp(Ops.UNIQUE, dtypes.void, arg=1, src=()), UOp(Ops.DEVICE, dtypes.void, arg='METAL', src=()),)) """ -# the copy was actually run, and now the "lazydata" of the Tensor is just a BUFFER +# the copy was actually run, and now the "uop" of the Tensor is just a BUFFER # if you run this script with DEBUG=2 in the environment, you can see the copy happen # *** METAL 1 copy 16, METAL <- PYTHON ... # now let's do some compute -# we look at the lazydata to see the specification of the compute +# we look at the uop to see the specification of the compute t_times_2 = t * 2 -print(t_times_2.lazydata) +print(t_times_2.uop) """ UOp(Ops.MUL, dtypes.int, arg=None, src=( UOp(Ops.BUFFER, dtypes.int, arg=4, src=( @@ -90,24 +90,24 @@ UOp(Ops.MUL, dtypes.int, arg=None, src=( assert t_times_2.tolist() == [2, 4, 6, 8] # UOps are both immutable and globally unique -# if i multiply the Tensor by 4 twice, these result Tensors will have the same lazydata specification +# if i multiply the Tensor by 4 twice, these result Tensors will have the same uop specification t_times_4_try_1 = t * 4 t_times_4_try_2 = t * 4 -assert t_times_4_try_1.lazydata is t_times_4_try_2.lazydata +assert t_times_4_try_1.uop is t_times_4_try_2.uop # the specification isn't just the same, it's the exact same Python object assert t_times_4_try_1 is not t_times_4_try_2 # the Tensor is a different Python object # if we realize `t_times_4_try_1` ... t_times_4_try_1.realize() -print(t_times_4_try_2.lazydata) +print(t_times_4_try_2.uop) """ UOp(Ops.BUFFER, dtypes.int, arg=4, src=( UOp(Ops.UNIQUE, dtypes.void, arg=4, src=()), UOp(Ops.DEVICE, dtypes.void, arg='METAL', src=()),)) """ # ... `t_times_4_try_2` also becomes the same BUFFER -assert t_times_4_try_1.lazydata is t_times_4_try_2.lazydata +assert t_times_4_try_1.uop is t_times_4_try_2.uop # so this print doesn't require any computation, just a copy back to the CPU so we can print it print("** only the copy start") print(t_times_4_try_2.tolist()) # [4, 8, 12, 16] @@ -120,7 +120,7 @@ t_float = Tensor([3.0]) t_log = t_float.log() t_log_grad, = t_log.sum().gradient(t_float) # due to how log is implemented, this gradient contains a lot of UOps -print(t_log_grad.lazydata) +print(t_log_grad.uop) # ...not shown here... # but if you run with DEBUG=4 (CPU=1 used here for simpler code), you can see the generated code """ @@ -144,7 +144,7 @@ t = Tensor([1,2,3,4]) # NOTE: the APIs here are subject to change t_plus_3_plus_4 = t + 3 + 4 -print(t_plus_3_plus_4.lazydata) +print(t_plus_3_plus_4.uop) """ UOp(Ops.ADD, dtypes.int, arg=None, src=( UOp(Ops.ADD, dtypes.int, arg=None, src=( @@ -166,7 +166,7 @@ UOp(Ops.ADD, dtypes.int, arg=None, src=( # but by the time we are actually running the code, it's adding 7 # `kernelize` will simplify and group the operations in the graph into kernels t_plus_3_plus_4.kernelize() -print(t_plus_3_plus_4.lazydata) +print(t_plus_3_plus_4.uop) """ UOp(Ops.ASSIGN, dtypes.int, arg=None, src=( x0:=UOp(Ops.BUFFER, dtypes.int, arg=4, src=( @@ -181,7 +181,7 @@ UOp(Ops.ASSIGN, dtypes.int, arg=None, src=( # ASSIGN has two srcs, src[0] is the BUFFER that's assigned to, and src[1] is the thing to assign # src[1] is the GPU Kernel that's going to be run # we can get the ast of the Kernel as follows -kernel_ast = t_plus_3_plus_4.lazydata.src[1].arg.ast +kernel_ast = t_plus_3_plus_4.uop.src[1].arg.ast # almost everything in tinygrad functions as a rewrite of the UOps # the codegen rewrites the ast to a simplified form ready for "rendering" diff --git a/examples/llama.py b/examples/llama.py index c79e0e3060..42f9b6e57b 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -240,7 +240,6 @@ class LLaMa: #elif k.endswith('.weight'): v.shard_(device, axis=-1) #elif 'norm.' in k: v.shard_(device, axis=-1) else: v.shard_(device, axis=None) - #print(k, v.shape, v.lazydata.axis) # replace weights in model load_state_dict(model, weights, strict=False, consume=True) @@ -446,7 +445,7 @@ After you are done speaking, output [EOS]. You are not Chad. print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model") device = tuple(f"{Device.DEFAULT}:{i}" for i in range(args.shard)) if args.shard > 1 else Device.DEFAULT llama = LLaMa.build(MODEL_PATH, TOKENIZER_PATH, model_gen=args.gen, model_size=args.size, quantize=args.quantize, device=device) - param_bytes = sum(x.lazydata.size * x.dtype.itemsize for x in get_parameters(llama.model)) + param_bytes = sum(x.uop.size * x.dtype.itemsize for x in get_parameters(llama.model)) outputted = pre_prompt if chatbot else args.prompt start_pos, toks = 0, [llama.tokenizer.bos_id()] + llama.tokenizer.encode(outputted) diff --git a/examples/llama3.py b/examples/llama3.py index 9240077dcf..b02fb9d997 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -284,7 +284,7 @@ if __name__ == "__main__": device = tuple(f"{Device.DEFAULT}:{i}" for i in range(args.shard)) if args.shard > 1 else Device.DEFAULT model = build_transformer(args.model, model_size=args.size, quantize=args.quantize, device=device) - param_bytes = sum(x.lazydata.size * x.dtype.itemsize for x in get_parameters(model)) + param_bytes = sum(x.uop.size * x.dtype.itemsize for x in get_parameters(model)) if not args.no_api and not args.benchmark: from bottle import Bottle, request, response, HTTPResponse, abort, static_file diff --git a/examples/llm.c/export.py b/examples/llm.c/export.py index bc13a09fbc..9612f7e96f 100755 --- a/examples/llm.c/export.py +++ b/examples/llm.c/export.py @@ -16,7 +16,7 @@ if __name__ == "__main__": #model.load_pretrained() for p in nn.state.get_parameters(model): p.replace(Tensor.empty(p.shape, dtype=p.dtype)) # fake load pretrained - #early_sched = create_schedule([x.lazydata for x in nn.state.get_parameters(model)]) + #early_sched = create_schedule([x.uop for x in nn.state.get_parameters(model)]) #print(f"built model {len(early_sched)}") #B, T = Variable("B", 1, 128).bind(4), 64 #Variable("T", 1, 1024).bind(64) @@ -56,7 +56,7 @@ if __name__ == "__main__": state_dict.update({'X': X, 'Y': Y, 'loss': loss}) grad_state_dict = {} for k,v in state_dict.items(): - if v.lazydata.base.buffer not in used_buffers: print(f"UNUSED: {k}") + if v.uop.base.buffer not in used_buffers: print(f"UNUSED: {k}") if v.grad is not None: grad_state_dict['grad_'+k] = v.grad state_dict.update(grad_state_dict) state_dict.update({'adam_b1_t': optimizer.b1_t, 'adam_b2_t': optimizer.b2_t, 'adam_lr': optimizer.lr}) @@ -65,7 +65,7 @@ if __name__ == "__main__": nm = inverse_state_dict[p] state_dict["adam_m_"+nm] = m state_dict["adam_v_"+nm] = v - named_buffers = {v.lazydata.base.buffer:k.replace(".", "_") for k,v in state_dict.items()} + named_buffers = {v.uop.base.buffer:k.replace(".", "_") for k,v in state_dict.items()} c_code = ["#include ", "#include ", "#include "] if TIMING: c_code += ["#include ", "#include "] diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index 0942e83615..c01ab48a56 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -71,7 +71,7 @@ def loader_process(q_in, q_out, X:Tensor, seed): #storage_tensor._copyin(img_tensor.numpy()) # faster - X[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes() + X[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes() # ideal #X[idx].assign(img.tobytes()) # NOTE: this is slow! @@ -262,8 +262,8 @@ def load_unet3d_data(preprocessed_dataset_dir, seed, queue_in, queue_out, X:Tens x = random_brightness_augmentation(x) x = gaussian_noise(x) - X[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = x.tobytes() - Y[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = y.tobytes() + X[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = x.tobytes() + Y[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = y.tobytes() queue_out.put(idx) queue_out.put(None) @@ -377,12 +377,12 @@ def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue clipped_match_idxs = np.clip(match_idxs, 0, None) clipped_boxes, clipped_labels = tgt["boxes"][clipped_match_idxs], tgt["labels"][clipped_match_idxs] - boxes[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = clipped_boxes.tobytes() - labels[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = clipped_labels.tobytes() - matches[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = match_idxs.tobytes() - anchors[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = anchor.tobytes() + boxes[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = clipped_boxes.tobytes() + labels[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = clipped_labels.tobytes() + matches[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = match_idxs.tobytes() + anchors[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = anchor.tobytes() - imgs[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes() + imgs[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes() queue_out.put(idx) queue_out.put(None) diff --git a/examples/openpilot/compile4.py b/examples/openpilot/compile4.py index a2bac699f7..4d7e1eab18 100644 --- a/examples/openpilot/compile4.py +++ b/examples/openpilot/compile4.py @@ -19,8 +19,8 @@ if __name__ == "__main__": inputs = run_onnx.get_empty_input_data("npy") out: Tensor = next(iter(run_onnx({k:v.to(None) for k,v in inputs.items()}).values())).to('cpu') - root = out.lazydata - targets = [x.lazydata for x in inputs.values()] + root = out.uop + targets = [x.uop for x in inputs.values()] print(targets) # TODO: abstract this from gradient? @@ -42,7 +42,7 @@ if __name__ == "__main__": print("**** real ****") GlobalCounters.reset() - out.lazydata = root.substitute(kernelized).substitute(becomes_map) + out.uop = root.substitute(kernelized).substitute(becomes_map) out.kernelize() # realize diff --git a/examples/qwq.py b/examples/qwq.py index f668e81ed0..fad87695bd 100644 --- a/examples/qwq.py +++ b/examples/qwq.py @@ -66,7 +66,7 @@ if __name__ == "__main__": model_path = Path(args.weights) if args.weights else download_weights(model_info["total_num_weights"]) transformer = load_model(model_path, model_info["model_params"]) tokenizer = AutoTokenizer.from_pretrained(model_info["tokenizer"]) - param_bytes = sum(x.lazydata.size * x.dtype.itemsize for x in get_parameters(transformer)) + param_bytes = sum(x.uop.size * x.dtype.itemsize for x in get_parameters(transformer)) outputted = args.prompt start_pos, toks = 0, tokenizer(outputted)["input_ids"] diff --git a/examples/tinychat/tinychat-browser/compile.py b/examples/tinychat/tinychat-browser/compile.py index 90c215dbc6..d1a1e64c35 100644 --- a/examples/tinychat/tinychat-browser/compile.py +++ b/examples/tinychat/tinychat-browser/compile.py @@ -13,8 +13,8 @@ def prepare_browser_chunks(model): chunk_size = 16 * 1024 * 1024 # small chunks based on iphone browser constraints metadata = {} # We won't export cache_kv bytes (because we start inference on client at start_pos=0), but we will tell the client how big cache_kv needs to be - t_infos = [(v.lazydata.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" not in k] - empty_t_infos = [(v.lazydata.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" in k] + t_infos = [(v.uop.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" not in k] + empty_t_infos = [(v.uop.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" in k] split_t_infos = [] for size, name, dtype in t_infos: @@ -48,7 +48,7 @@ def prepare_browser_chunks(model): weight_metadata = metadata.get(name, default) weight_metadata["parts"][part_num] = {"file": i, "file_start_pos": cursor, "size": size} metadata[name] = weight_metadata - data = bytes(state_dict[name].lazydata.base.realized.as_buffer()) + data = bytes(state_dict[name].uop.base.realized.as_buffer()) data = data if not offsets else data[offsets[0]:offsets[1]] writer.write(data) cursor += size diff --git a/examples/webgpu/stable_diffusion/compile.py b/examples/webgpu/stable_diffusion/compile.py index e13f34f1c9..6f47a5b3c6 100644 --- a/examples/webgpu/stable_diffusion/compile.py +++ b/examples/webgpu/stable_diffusion/compile.py @@ -114,7 +114,7 @@ if __name__ == "__main__": run, special_names = jit_model(step, *step.input) functions, statements, bufs, _ = compile_net(run, special_names) state = get_state_dict(model) - weights = {id(x.lazydata.base.realized): name for name, x in state.items()} + weights = {id(x.uop.base.realized): name for name, x in state.items()} kernel_code = '\n\n'.join([f"const {key} = `{fixup_code(code, key)}`;" for key, code in functions.items()]) kernel_names = ', '.join([name for (name, _, _, _) in statements]) input_names = [name for _,name in special_names.items() if "input" in name] diff --git a/extra/export_model.py b/extra/export_model.py index 63b4098270..2b2aa1dc62 100644 --- a/extra/export_model.py +++ b/extra/export_model.py @@ -48,13 +48,13 @@ def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]: # hack to put the inputs back for (j,i),idx in run.input_replace.items(): - realized_input = args[idx].lazydata.base.realized + realized_input = args[idx].uop.base.realized run.jit_cache[j].bufs[i] = realized_input special_names[id(realized_input)] = f'input{idx}' # TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret) for i, output in enumerate(the_output): - special_names[id(output.lazydata.base.realized)] = f'output{i}' + special_names[id(output.uop.base.realized)] = f'output{i}' return run, special_names def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]], @@ -242,7 +242,7 @@ def export_model(model, target:str, *inputs, model_name: Optional[str] = "model" with Context(JIT=2): run,special_names = jit_model(model, *inputs) functions, statements, bufs, bufs_to_save = compile_net(run, special_names) state = get_state_dict(model) - weight_names = {id(x.lazydata.base.realized): name for name, x in state.items()} + weight_names = {id(x.uop.base.realized): name for name, x in state.items()} input_names = [name for _,name in special_names.items() if "input" in name] output_names = [name for _,name in special_names.items() if "output" in name] diff --git a/extra/hip_gpu_driver/test_pm4.py b/extra/hip_gpu_driver/test_pm4.py index b5a4541217..d0cf92002d 100644 --- a/extra/hip_gpu_driver/test_pm4.py +++ b/extra/hip_gpu_driver/test_pm4.py @@ -47,7 +47,7 @@ if __name__ == "__main__": a = Tensor([0.,1.,2.], device="KFD").realize() b = a + 7 - b.lazydata.buffer.allocate() + b.uop.buffer.allocate() si = b.schedule()[-1] runner = dev.get_runner(*si.ast) prg: AMDProgram = runner.clprg @@ -69,8 +69,8 @@ if __name__ == "__main__": #scratch = dev._gpu_alloc(0x10000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM) ka = to_mv(dev.kernargs_ptr, 0x10).cast("Q") - ka[0] = b.lazydata.buffer._buf.va_addr - ka[1] = a.lazydata.buffer._buf.va_addr + ka[0] = b.uop.buffer._buf.va_addr + ka[1] = a.uop.buffer._buf.va_addr compute_read_pointer = to_mv(compute_queue.read_pointer_address, 8).cast("Q") compute_write_pointer = to_mv(compute_queue.write_pointer_address, 8).cast("Q") diff --git a/extra/multitensor.py b/extra/multitensor.py index dadcf1ef18..af7ddc30a7 100644 --- a/extra/multitensor.py +++ b/extra/multitensor.py @@ -23,7 +23,7 @@ def explicit_shard_W_axis_1(X, W): x = x.reshape(N, 1, N).expand(N, N, N) w = w.T.reshape(1, N, N).expand(N, N, N) m = x*w - assert m.lazydata.st.views[0].mask is not None + assert m.uop.st.views[0].mask is not None ret = m.sum(2) return ret #Os = [lm(Xs[0], Ws[0]), lm(Xs[1], Ws[1])] diff --git a/extra/remu/test/hwtest.py b/extra/remu/test/hwtest.py index 0cc2ba0fc2..76bd2f6e69 100644 --- a/extra/remu/test/hwtest.py +++ b/extra/remu/test/hwtest.py @@ -89,7 +89,7 @@ def get_output(s:str, n_threads:int=1): "s_waitcnt 0", "global_store_b32 v0, v1, s[0:1]", "s_nop 0", "s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)", "s_endpgm"]) - test = Tensor.zeros((n_threads,), dtype=dtypes.uint32).contiguous().realize().lazydata.buffer + test = Tensor.zeros((n_threads,), dtype=dtypes.uint32).contiguous().realize().uop.buffer prg = get_prg(code, 32, 32) prg(test._buf, global_size=(1, 1, 1), local_size=(n_threads, 1, 1), wait=True) return test.numpy() diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index c7f082cd2c..17c8083f83 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -65,12 +65,12 @@ for k,v in view_ops.items(): torch.library.impl(k.replace("aten.", "aten::"), "p # in place operations with views def realize_with_views(self: Tensor, views: Tensor): - if not self.lazydata.st.contiguous: self.replace(self.contiguous()) + if not self.uop.st.contiguous: self.replace(self.contiguous()) self.replace(self.clone().realize()) for v in views: - if v.lazydata.base.op is Ops.BUFFER_VIEW: continue # skip subbuffer, we just use the real buffer view + if v.uop.base.op is Ops.BUFFER_VIEW: continue # skip subbuffer, we just use the real buffer view ret = self - st = ShapeTracker(self.lazydata.st.views + v.lazydata.st.views) # TODO: is this right? + st = ShapeTracker(self.uop.st.views + v.uop.st.views) # TODO: is this right? for mo in cached_to_movement_ops(self.shape, st): ret = apply_mop(ret, mo) v.replace(ret) def maybe_realize_storage(self: Tensor) -> bool: @@ -178,7 +178,7 @@ def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None): # multiple as_strided do not compound base = canonical_base(tensor) # TODO: this is heavyweight - st = ShapeTracker(base.lazydata.st.views + (View.create(tuple(size), tuple(stride), storage_offset),)) + st = ShapeTracker(base.uop.st.views + (View.create(tuple(size), tuple(stride), storage_offset),)) ret = base if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st) if prod(size) == 1: return ret.flatten()[storage_offset].reshape(size) @@ -315,7 +315,7 @@ def _copy_from(src: torch.Tensor, dest, non_blocking=False): to_device = _from_torch_device(dest.device) src,dest = unwrap(src),unwrap(dest) # TODO we need to properly match dest shape and strides, not blindly assign - if dest.lazydata.st.contiguous or dest.lazydata.is_realized: src = src.contiguous() # this only solves some cases + if dest.uop.st.contiguous or dest.uop.is_realized: src = src.contiguous() # this only solves some cases dest.assign(src.cast(cast_dtype).to(to_device)) if realize: Tensor.realize(dest) elif src.is_tiny and dest.is_cpu: @@ -493,7 +493,7 @@ def wrap_out(f): assert out.shape == assigned.shape, f"shape mismatch: {assigned.shape} -> {out.shape}" assert out.device == assigned.device, f"device mismatch: {assigned.device} -> {out.device}" assert out.dtype == assigned.dtype, f"dtype mismatch: {assigned.dtype} -> {out.dtype}" - if out.lazydata.is_realized: assigned = assigned.contiguous() # TODO: how does this map to torch's semantics + if out.uop.is_realized: assigned = assigned.contiguous() # TODO: how does this map to torch's semantics return out.assign(assigned) return _wrap_out diff --git a/extra/torch_backend/wrapped_tensor.cpp b/extra/torch_backend/wrapped_tensor.cpp index 0df27d02da..3e5f19fcbf 100644 --- a/extra/torch_backend/wrapped_tensor.cpp +++ b/extra/torch_backend/wrapped_tensor.cpp @@ -116,7 +116,7 @@ at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype, c10::DeviceInd // TODO: we have to get the dtype and the shape from the tinygrad Tensor std::vector sizes = py_obj.attr("shape").cast>(); - py::list views = py_obj.attr("lazydata").attr("st").attr("views"); + py::list views = py_obj.attr("uop").attr("st").attr("views"); std::vector strides = views[views.size() - 1].attr("strides").cast>(); int64_t storage_offset = 0; for (auto& v: views) { diff --git a/extra/torch_hook/hook_torch.py b/extra/torch_hook/hook_torch.py index f3b8113b4a..8f718076a9 100644 --- a/extra/torch_hook/hook_torch.py +++ b/extra/torch_hook/hook_torch.py @@ -113,7 +113,7 @@ class DispatchLog(TorchDispatchMode): _ = tiny_x.cpu().numpy() if torch.is_tensor(tiny_x) and tiny_x.device.type == "tiny": tt = tiny_torch.unwrap(tiny_x) - try: out_addr = tt.lazydata.buffer._buf.value + try: out_addr = tt.uop.buffer._buf.value except Exception: pass tiny_events = hook_cuda.collect_events(clear=True) print_events(tiny_events, colored("tiny", "magenta"), out_addr) diff --git a/test/external/external_benchmark_bert_matmuls.py b/test/external/external_benchmark_bert_matmuls.py index b83e6d4809..4f64629b54 100644 --- a/test/external/external_benchmark_bert_matmuls.py +++ b/test/external/external_benchmark_bert_matmuls.py @@ -13,6 +13,6 @@ if __name__ == "__main__": (Tensor.empty(BS, 16, 512, 512), Tensor.empty(BS, 512, 16, 64).permute(0,2,1,3)), # qk@v ] for t0, t1 in tensors: - print(f"{t0.shape=}, {t0.lazydata.st.real_strides()=}, {t1.shape=}, {t1.lazydata.st.real_strides()=}") + print(f"{t0.shape=}, {t0.uop.st.real_strides()=}, {t1.shape=}, {t1.uop.st.real_strides()=}") for _ in range(5): t0.dot(t1, dtype=acc_dtype).realize() diff --git a/test/external/external_multi_gpu.py b/test/external/external_multi_gpu.py index e5dd836c88..32d107df7d 100644 --- a/test/external/external_multi_gpu.py +++ b/test/external/external_multi_gpu.py @@ -21,8 +21,8 @@ if __name__ == "__main__": with Timing("CPU creation: ", on_exit=lambda x: f", {(sz*4*2)/x:.2f} GB/sec"): c0 = (Tensor.ones(sz, device="CPU")/2).realize() c1 = (Tensor.ones(sz, device="CPU")/4).realize() - print(c0.lazydata.base.realized) - print(c1.lazydata.base.realized) + print(c0.uop.base.realized) + print(c1.uop.base.realized) with Timing("CPU -> 0: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): a0 = c0.to(d0).realize() diff --git a/test/external/external_test_amd.py b/test/external/external_test_amd.py index aabff98968..dd13353a0f 100644 --- a/test/external/external_test_amd.py +++ b/test/external/external_test_amd.py @@ -9,18 +9,18 @@ class TestAMD(unittest.TestCase): TestAMD.d0: AMDDevice = Device["AMD"] TestAMD.a = Tensor([0.,1.], device="AMD").realize() TestAMD.b = self.a + 1 - si = create_schedule([self.b.lazydata])[-1] + si = create_schedule([self.b.uop])[-1] TestAMD.d0_runner = TestAMD.d0.get_runner(*si.ast) - TestAMD.b.lazydata.buffer.allocate() + TestAMD.b.uop.buffer.allocate() def test_amd_ring_64bit_doorbell(self): TestAMD.d0.pm4_write_pointer[0] = TestAMD.d0.pm4_write_pointer[0] + (2 << 32) - TestAMD.d0.pm4_ring.size // 4 for _ in range(2000): - TestAMD.d0_runner.clprg(TestAMD.b.lazydata.buffer._buf, TestAMD.a.lazydata.buffer._buf, + TestAMD.d0_runner.clprg(TestAMD.b.uop.buffer._buf, TestAMD.a.uop.buffer._buf, global_size=TestAMD.d0_runner.global_size, local_size=TestAMD.d0_runner.local_size) - TestAMD.d0_runner.clprg(TestAMD.a.lazydata.buffer._buf, TestAMD.b.lazydata.buffer._buf, + TestAMD.d0_runner.clprg(TestAMD.a.uop.buffer._buf, TestAMD.b.uop.buffer._buf, global_size=TestAMD.d0_runner.global_size, local_size=TestAMD.d0_runner.local_size) - val = TestAMD.a.lazydata.buffer.as_buffer().cast("f")[0] + val = TestAMD.a.uop.buffer.as_buffer().cast("f")[0] assert val == 4000.0, f"got val {val}" if __name__ == "__main__": diff --git a/test/external/external_test_hcq.py b/test/external/external_test_hcq.py index 0303948bd1..2ae0371fe1 100644 --- a/test/external/external_test_hcq.py +++ b/test/external/external_test_hcq.py @@ -22,10 +22,10 @@ class TestHCQ(unittest.TestCase): TestHCQ.b = self.a + 1 si = self.b.schedule()[-1] TestHCQ.runner = get_runner(TestHCQ.d0.device, si.ast) - TestHCQ.b.lazydata.buffer.allocate() + TestHCQ.b.uop.buffer.allocate() # wow that's a lot of abstraction layers - TestHCQ.addr = struct.pack("QQ", TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr) - TestHCQ.addr2 = struct.pack("QQ", TestHCQ.a.lazydata.buffer._buf.va_addr, TestHCQ.b.lazydata.buffer._buf.va_addr) + TestHCQ.addr = struct.pack("QQ", TestHCQ.b.uop.buffer._buf.va_addr, TestHCQ.a.uop.buffer._buf.va_addr) + TestHCQ.addr2 = struct.pack("QQ", TestHCQ.a.uop.buffer._buf.va_addr, TestHCQ.b.uop.buffer._buf.va_addr) TestHCQ.kernargs_off = TestHCQ.runner._prg.kernargs_offset TestHCQ.kernargs_size = TestHCQ.runner._prg.kernargs_alloc_size ctypes.memmove(TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_off, TestHCQ.addr, len(TestHCQ.addr)) @@ -45,8 +45,8 @@ class TestHCQ(unittest.TestCase): def setUp(self): TestHCQ.d0.synchronize() - TestHCQ.a.lazydata.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 1)))) - TestHCQ.b.lazydata.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 0)))) + TestHCQ.a.uop.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 1)))) + TestHCQ.b.uop.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 0)))) TestHCQ.d0.synchronize() # wait for copyins to complete def test_run_1000_times_one_submit(self): @@ -65,7 +65,7 @@ class TestHCQ(unittest.TestCase): q.submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0] + val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0] assert val == 2000.0, f"got val {val}" def test_run_1000_times(self): @@ -81,7 +81,7 @@ class TestHCQ(unittest.TestCase): TestHCQ.compute_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0] + val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0] assert val == 2000.0, f"got val {val}" def test_run_to_3(self): @@ -95,7 +95,7 @@ class TestHCQ(unittest.TestCase): q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0] + val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0] assert val == 3.0, f"got val {val}" def test_update_exec(self): @@ -106,9 +106,9 @@ class TestHCQ(unittest.TestCase): q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0] + val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0] assert val == 1.0, f"got val {val}" - val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1] + val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1] assert val == 0.0, f"got val {val}, should not be updated" @unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind") @@ -126,7 +126,7 @@ class TestHCQ(unittest.TestCase): TestHCQ.compute_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0] + val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0] assert val == 2000.0, f"got val {val}" @unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind") @@ -141,9 +141,9 @@ class TestHCQ(unittest.TestCase): q.submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0] + val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0] assert val == 1.0, f"got val {val}" - val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1] + val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1] assert val == 0.0, f"got val {val}, should not be updated" @unittest.skipIf(CI, "Can't handle async update on CPU") @@ -174,7 +174,7 @@ class TestHCQ(unittest.TestCase): q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0] + val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0] assert val == 1.0, f"got val {val}" def test_submit_empty_queues(self): @@ -206,13 +206,13 @@ class TestHCQ(unittest.TestCase): q.submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0] + val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0] assert val == 1.0, f"got val {val}" def test_copy_1000_times(self): q = TestHCQ.copy_queue() - q.copy(TestHCQ.a.lazydata.buffer._buf.va_addr, TestHCQ.b.lazydata.buffer._buf.va_addr, 8) - q.copy(TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr, 8) + q.copy(TestHCQ.a.uop.buffer._buf.va_addr, TestHCQ.b.uop.buffer._buf.va_addr, 8) + q.copy(TestHCQ.b.uop.buffer._buf.va_addr, TestHCQ.a.uop.buffer._buf.va_addr, 8) for _ in range(1000): q.submit(TestHCQ.d0) TestHCQ.copy_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) @@ -221,24 +221,24 @@ class TestHCQ(unittest.TestCase): # confirm the signal didn't exceed the put value with self.assertRaises(RuntimeError): TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value + 1, timeout=50) - val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1] + val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1] assert val == 0.0, f"got val {val}" def test_copy(self): q = TestHCQ.copy_queue() - q.copy(TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr, 8) + q.copy(TestHCQ.b.uop.buffer._buf.va_addr, TestHCQ.a.uop.buffer._buf.va_addr, 8) q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) q.submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1] + val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1] assert val == 1.0, f"got val {val}" @unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind") def test_bind_copy(self): q = TestHCQ.copy_queue() - q.copy(TestHCQ.a.lazydata.buffer._buf.va_addr, TestHCQ.b.lazydata.buffer._buf.va_addr, 8) - q.copy(TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr, 8) + q.copy(TestHCQ.a.uop.buffer._buf.va_addr, TestHCQ.b.uop.buffer._buf.va_addr, 8) + q.copy(TestHCQ.b.uop.buffer._buf.va_addr, TestHCQ.a.uop.buffer._buf.va_addr, 8) q.bind(TestHCQ.d0) for _ in range(1000): q.submit(TestHCQ.d0) @@ -248,7 +248,7 @@ class TestHCQ(unittest.TestCase): # confirm the signal didn't exceed the put value with self.assertRaises(RuntimeError): TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value + 1, timeout=50) - val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1] + val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1] assert val == 0.0, f"got val {val}" def test_copy_bandwidth(self): @@ -281,14 +281,14 @@ class TestHCQ(unittest.TestCase): q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) # b = [1, 2] q.signal(sig:=TestHCQ.d0._alloc_signal(value=0), value=1) qc.wait(sig, value=1) - qc.copy(TestHCQ.a.lazydata.buffer._buf.va_addr, TestHCQ.b.lazydata.buffer._buf.va_addr, 8) + qc.copy(TestHCQ.a.uop.buffer._buf.va_addr, TestHCQ.b.uop.buffer._buf.va_addr, 8) qc.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) qc.submit(TestHCQ.d0) time.sleep(0.02) # give it time for the wait to fail q.submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0] + val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0] assert val == 1.0, f"got val {val}" def test_cross_device_signal(self): @@ -319,7 +319,7 @@ class TestHCQ(unittest.TestCase): q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0] + val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0] assert val == 1.0, f"got val {val}" if __name__ == "__main__": diff --git a/test/external/external_test_hip_compile.py b/test/external/external_test_hip_compile.py index 8f8292d4ce..0edef8730b 100644 --- a/test/external/external_test_hip_compile.py +++ b/test/external/external_test_hip_compile.py @@ -10,7 +10,7 @@ class TestHIPCompileSpeed(unittest.TestCase): def test_hip_compile(self): a, b = Tensor([1,2,3,4,5]), Tensor([1,2,3,4,5]) out = a + b - lin = Kernel(create_schedule([out.lazydata])[-1].ast[0]) + lin = Kernel(create_schedule([out.uop])[-1].ast[0]) lin.linearize() reference = """ diff --git a/test/external/external_test_nv.py b/test/external/external_test_nv.py index b488e86725..f061975e44 100644 --- a/test/external/external_test_nv.py +++ b/test/external/external_test_nv.py @@ -21,8 +21,8 @@ class TestNV(unittest.TestCase): TestNV.b = self.a + 1 si = self.b.schedule()[-1] TestNV.d0_runner = get_runner(TestNV.d0.device, si.ast) - TestNV.b.lazydata.buffer.allocate() - TestNV.addr = struct.pack("QQ", TestNV.b.lazydata.buffer._buf.va_addr, TestNV.a.lazydata.buffer._buf.va_addr) + TestNV.b.uop.buffer.allocate() + TestNV.addr = struct.pack("QQ", TestNV.b.uop.buffer._buf.va_addr, TestNV.a.uop.buffer._buf.va_addr) def test_oor_kernels(self): ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501 @@ -44,8 +44,8 @@ class TestNV(unittest.TestCase): TestNV.along = Tensor([105615], device="NV").realize() ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.SIN, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501 temp_runner = get_runner(TestNV.d0.device, (ast,)) - temp_runner([TestNV.b.lazydata.buffer, TestNV.along.lazydata.buffer], var_vals={}) - val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0] + temp_runner([TestNV.b.uop.buffer, TestNV.along.uop.buffer], var_vals={}) + val = TestNV.b.uop.buffer.as_buffer().cast("f")[0] assert abs(val - 0.80647) < 0.001, f"got val {val}" def test_kernargs_no_oob_access(self): @@ -59,7 +59,7 @@ class TestNV(unittest.TestCase): q.signal(TestNV.d0.timeline_signal, TestNV.d0.timeline_value).submit(TestNV.d0) TestNV.d0._wait_signal(TestNV.d0.timeline_signal, TestNV.d0.timeline_value) TestNV.d0.timeline_value += 1 - val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0] + val = TestNV.b.uop.buffer.as_buffer().cast("f")[0] assert val == 1.0, f"got val {val}" if __name__ == "__main__": diff --git a/test/external/fuzz_graph.py b/test/external/fuzz_graph.py index 7992e21778..4e1d492eae 100644 --- a/test/external/fuzz_graph.py +++ b/test/external/fuzz_graph.py @@ -28,7 +28,7 @@ def alloc_rawbuffer(device, fill=False): if fill: with Context(DEBUG=0): data = np.random.randint(-10000, 10000, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype)) - rawbuf.copyin(Tensor(data).realize().lazydata.base.realized.as_buffer()) + rawbuf.copyin(Tensor(data).realize().uop.base.realized.as_buffer()) return rawbuf def gen_kernel_ji(device, deps): diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 8a03176306..1bf269ce56 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -73,7 +73,7 @@ def get_fuzz_rawbufs(lin): data = np.random.uniform(-1, 1, size=rawbuf.size).astype(dtype=_to_np_dtype(rawbuf.dtype)) else: data = np.random.uniform(-10, 10, size=rawbuf.size).astype(dtype=_to_np_dtype(rawbuf.dtype)) - rawbuf.copyin(Tensor(data, device=lin.opts.device).realize().lazydata.base.realized.as_buffer()) + rawbuf.copyin(Tensor(data, device=lin.opts.device).realize().uop.base.realized.as_buffer()) return rawbufs def get_fuzz_rawbuf_like(old_rawbuf, zero=False, copy=False, size=None, force_device=None): diff --git a/test/imported/test_indexing.py b/test/imported/test_indexing.py index 271aeb586b..da5d61944c 100644 --- a/test/imported/test_indexing.py +++ b/test/imported/test_indexing.py @@ -21,18 +21,18 @@ def consec(shape, start=1): # creates strided tensor with base set to reference tensor's base, equivalent to torch.set_() def set_(reference: Tensor, shape, strides, offset): - raise NotImplementedError("need to implement without calling lazydata.view") - if reference.lazydata.base.realized is None: reference.realize() - assert reference.lazydata.base.realized, "base has to be realized before setting it to strided's base" - strided = Tensor(reference.lazydata.view(ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),)))) - assert strided.lazydata.st.real_strides() == strides, "real_strides should equal strides for strided" + raise NotImplementedError("need to implement without calling uop.view") + if reference.uop.base.realized is None: reference.realize() + assert reference.uop.base.realized, "base has to be realized before setting it to strided's base" + strided = Tensor(reference.uop.view(ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),)))) + assert strided.uop.st.real_strides() == strides, "real_strides should equal strides for strided" return strided def clone(original:Tensor): return original.clone() def copy_(src:Tensor, other:Tensor) -> Tensor: return src.clone() # this is fine for tested usecases since as geohotstan understands, # data_ptr is used to compare if operations needed between tensors is the same -def data_ptr(tensor:Tensor): return tensor.lazydata +def data_ptr(tensor:Tensor): return tensor.uop # https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html def index_put_(tensor:Tensor, indices, values, accumulate) -> Tensor: @@ -971,9 +971,9 @@ class TestIndexing(unittest.TestCase): numpy_testing_assert_equal_helper((2, 0, 4), z.shape) # this isn't technically necessary, but matches NumPy stride calculations. # NOTE: this is empty and shouldn't have strides - #numpy_testing_assert_equal_helper((60, 20, 5), z.lazydata.st.real_strides()) + #numpy_testing_assert_equal_helper((60, 20, 5), z.uop.st.real_strides()) # NOTE tinygrad's int slicing implementation makes this not contiguous - # self.assertTrue(z.lazydata.st.contiguous) + # self.assertTrue(z.uop.st.contiguous) @unittest.skip("bool indexing not supported") def test_index_getitem_copy_bools_slices(self): diff --git a/test/test_arange.py b/test/test_arange.py index 085a54d828..e9bc9983b2 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -20,7 +20,7 @@ class TestArange(unittest.TestCase): p = k.to_program() print(p.name) #print(p.src) - ExecItem(CompiledRunner(p), [tt.lazydata.buffer]).run() + ExecItem(CompiledRunner(p), [tt.uop.buffer]).run() np.testing.assert_equal(tt.numpy(), np.arange(N)) return p.estimates.ops diff --git a/test/test_assign.py b/test/test_assign.py index 6f8a7b8039..837e4141e1 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -13,11 +13,11 @@ class TestAssign(unittest.TestCase): b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) a.realize() b.realize() - ba1 = a.lazydata.base.realized - bb1 = b.lazydata.base.realized + ba1 = a.uop.base.realized + bb1 = b.uop.base.realized a += b a.realize() - ba2 = a.lazydata.base.realized + ba2 = a.uop.base.realized assert ba1 == ba2 and ba1 != bb1 np.testing.assert_allclose(a.numpy(), (np.arange(N*N)*2).reshape((N,N))) @@ -259,13 +259,13 @@ class TestAssign(unittest.TestCase): b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) a.realize() b.realize() - ba1 = a.lazydata.base.realized - bb1 = b.lazydata.base.realized + ba1 = a.uop.base.realized + bb1 = b.uop.base.realized with self.assertRaises((RuntimeError, AssertionError)): a = a.permute(1,0) a += b a.realize() - ba2 = a.lazydata.base.realized + ba2 = a.uop.base.realized assert ba1 != ba2 and ba1 != bb1 np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0)) @@ -275,12 +275,12 @@ class TestAssign(unittest.TestCase): a.realize() b.realize() #GlobalCounters.cache = [] - ba1 = a.lazydata.base.realized # noqa: F841 - bb1 = b.lazydata.base.realized # noqa: F841 + ba1 = a.uop.base.realized # noqa: F841 + bb1 = b.uop.base.realized # noqa: F841 with self.assertRaisesRegex(RuntimeError, "contiguous"): a.assign(a.permute(1,0) + b) # this should not work! a.realize() - ba2 = a.lazydata.base.realized # noqa: F841 + ba2 = a.uop.base.realized # noqa: F841 # NOTE: don't test that it's assigned #assert ba1 == ba2 and ba1 != bb1 np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0)) @@ -383,10 +383,10 @@ class TestAssign(unittest.TestCase): def test_cast_assignment(self): a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) a.realize() - oba1 = a.lazydata.base.output_buffer + oba1 = a.uop.base.output_buffer a.assign(a.cast(dtypes.int32).realize()) a.realize() - oba2 = a.lazydata.base.output_buffer + oba2 = a.uop.base.output_buffer assert oba1 is None and oba2 is None np.testing.assert_allclose(a.numpy(), np.arange(N*N,dtype=np.int32).reshape((N,N))) diff --git a/test/test_const_folding.py b/test/test_const_folding.py index 86d33abd0a..187a4b5142 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -221,7 +221,7 @@ class TestReduceOpsConstFolding(unittest.TestCase): # contiguous folded const can still schedule a = Tensor.empty(1, 0).sum().contiguous() _check_ast_count(2, a+2) - self.assertIs(a.lazydata.base.op, Ops.BUFFER) + self.assertIs(a.uop.base.op, Ops.BUFFER) np.testing.assert_equal((Tensor.empty(1, 0).sum().contiguous()+2).numpy(), 2) # otherwise we just fuse it _check_ast_count(1, (Tensor.empty(1, 0).sum()+2).contiguous()) diff --git a/test/test_dtype.py b/test/test_dtype.py index b41a9663a3..0c5f9b41ec 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -32,7 +32,7 @@ def get_available_cast_dtypes(dtype: DType) -> List[DType]: def _test_to_np(a:Tensor, np_dtype, target): if DEBUG >= 2: print(a) na = a.numpy() - if DEBUG >= 2: print(na, na.dtype, a.lazydata.base.realized) + if DEBUG >= 2: print(na, na.dtype, a.uop.base.realized) try: assert na.dtype == np_dtype np.testing.assert_allclose(na, target) diff --git a/test/test_gc.py b/test/test_gc.py index e32053c8c8..78827a0fdc 100644 --- a/test/test_gc.py +++ b/test/test_gc.py @@ -73,7 +73,7 @@ class TestGC(unittest.TestCase): x = Tensor.ones(4,4).contiguous().realize()+1 self.assertEqual(bufs_allocated()-init, 1) # try commenting this part out, it's green! - x.lazydata.toposort() + x.uop.toposort() del x if bufs_allocated()-init != 0: print(inspect.getclosurevars(UOp.toposort().fget)) @@ -84,11 +84,11 @@ class TestGC(unittest.TestCase): a = Tensor.empty(10) self.assertEqual(bufs_allocated()-init, 0) a.realize() - real_buf = a.lazydata.buffer + real_buf = a.uop.buffer # after the Tensor UOp is deleted there shouldn't be any references on the Buffer self.assertEqual(real_buf.lb_refcount, 1) self.assertEqual(bufs_allocated()-init, 1) - del a.lazydata + del a.uop self.assertEqual(real_buf.lb_refcount, 0) self.assertEqual(bufs_allocated()-init, 1) # keep the buffer alive del real_buf @@ -98,10 +98,10 @@ class TestGC(unittest.TestCase): init = bufs_allocated() a = Tensor.full((4,), 1.).contiguous() a.realize() - real_buf = a.lazydata.buffer + real_buf = a.uop.buffer self.assertEqual(real_buf.lb_refcount, 1) a.assign(Tensor.full((4,), 2.)) - self.assertIs(a.lazydata.src[0].buffer, real_buf) + self.assertIs(a.uop.src[0].buffer, real_buf) # NOTE: this is still 1, we don't count the ASSIGN self.assertEqual(real_buf.lb_refcount, 1) a.realize() diff --git a/test/test_graph.py b/test/test_graph.py index a47756a23f..19dd6f9fe1 100644 --- a/test/test_graph.py +++ b/test/test_graph.py @@ -36,7 +36,7 @@ def helper_alloc_rawbuffer(device, fill=False): if fill: with Context(DEBUG=0): data = np.random.randint(-10000, 10000, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype)) - rawbuf.copyin(Tensor(data).realize().lazydata.base.realized.as_buffer()) + rawbuf.copyin(Tensor(data).realize().uop.base.realized.as_buffer()) return rawbuf def helper_create_offset_rawbuffer(base, offset=0): diff --git a/test/test_hcq.py b/test/test_hcq.py index b2ae098a6a..884c5fc21a 100644 --- a/test/test_hcq.py +++ b/test/test_hcq.py @@ -20,15 +20,15 @@ class TestHCQ(unittest.TestCase): si = self.b.schedule()[-1] TestHCQ.runner = get_runner(TestHCQ.d0.device, si.ast) - TestHCQ.b.lazydata.buffer.allocate() + TestHCQ.b.uop.buffer.allocate() - TestHCQ.kernargs_ba_ptr = TestHCQ.runner._prg.fill_kernargs([TestHCQ.b.lazydata.buffer._buf, TestHCQ.a.lazydata.buffer._buf]) - TestHCQ.kernargs_ab_ptr = TestHCQ.runner._prg.fill_kernargs([TestHCQ.a.lazydata.buffer._buf, TestHCQ.b.lazydata.buffer._buf]) + TestHCQ.kernargs_ba_ptr = TestHCQ.runner._prg.fill_kernargs([TestHCQ.b.uop.buffer._buf, TestHCQ.a.uop.buffer._buf]) + TestHCQ.kernargs_ab_ptr = TestHCQ.runner._prg.fill_kernargs([TestHCQ.a.uop.buffer._buf, TestHCQ.b.uop.buffer._buf]) def setUp(self): TestHCQ.d0.synchronize() - TestHCQ.a.lazydata.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 1)))) - TestHCQ.b.lazydata.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 0)))) + TestHCQ.a.uop.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 1)))) + TestHCQ.b.uop.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 0)))) TestHCQ.d0.synchronize() # wait for copyins to complete # Test signals @@ -117,7 +117,7 @@ class TestHCQ(unittest.TestCase): TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0] + val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0] assert val == 1.0, f"got val {val}" def test_exec_2_kernels_100_times(self): @@ -133,7 +133,7 @@ class TestHCQ(unittest.TestCase): q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value}) TestHCQ.d0.timeline_value += 1 - val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0] + val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0] assert val == 200.0, f"got val {val}" def test_exec_update(self): @@ -148,9 +148,9 @@ class TestHCQ(unittest.TestCase): TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0] + val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0] assert val == 1.0, f"got val {val}" - val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1] + val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1] assert val == 0.0, f"got val {val}, should not be updated" def test_exec_update_fuzz(self): @@ -192,13 +192,13 @@ class TestHCQ(unittest.TestCase): if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue") TestHCQ.d0.hw_copy_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \ - .copy(TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr, 8) \ + .copy(TestHCQ.b.uop.buffer._buf.va_addr, TestHCQ.a.uop.buffer._buf.va_addr, 8) \ .signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1] + val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1] assert val == 1.0, f"got val {val}" def test_copy_long(self): @@ -252,12 +252,12 @@ class TestHCQ(unittest.TestCase): .copy(virt_dest_addr, virt_src_addr, 8) \ .signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) - q.submit(TestHCQ.d0, {virt_src_addr: TestHCQ.a.lazydata.buffer._buf.va_addr, virt_dest_addr: TestHCQ.b.lazydata.buffer._buf.va_addr}) + q.submit(TestHCQ.d0, {virt_src_addr: TestHCQ.a.uop.buffer._buf.va_addr, virt_dest_addr: TestHCQ.b.uop.buffer._buf.va_addr}) TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1] + val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1] assert val == 1.0, f"got val {val}" def test_update_copy_long(self): diff --git a/test/test_image_dtype.py b/test/test_image_dtype.py index 41b32de81f..08d2c04c32 100644 --- a/test/test_image_dtype.py +++ b/test/test_image_dtype.py @@ -13,7 +13,7 @@ IMAGE_SUPPORTED_DEVICES = ("QCOM", "GPU") class TestImageCopy(unittest.TestCase): def test_image_copyout_1x1(self, img_type=dtypes.imagef): it = Tensor.arange(4).cast(img_type((1,1,4))).realize() - buf = it.lazydata.buffer + buf = it.uop.buffer out = buf.as_buffer() np.testing.assert_equal(out.cast(it.dtype.fmt).tolist(), np.arange(4)) @@ -27,18 +27,18 @@ class TestImageCopy(unittest.TestCase): def test_image_copyout_2x3(self): it = Tensor.arange(2*3*4).cast(dtypes.imagef((2,3,4))).realize() - buf = it.lazydata.buffer + buf = it.uop.buffer out = buf.as_buffer() np.testing.assert_equal(out.cast('f').tolist(), np.arange(2*3*4)) def test_image_roundtrip(self): sz = (4,2,4) it = Tensor.rand(prod(sz)).cast(dtypes.imagef(sz)).realize() - buf = it.lazydata.buffer + buf = it.uop.buffer out = buf.as_buffer() it2 = Tensor.rand(prod(sz)).cast(dtypes.imagef(sz)).realize() - buf2 = it2.lazydata.buffer + buf2 = it2.uop.buffer buf2.copyin(out) assert (it == it2).sum().item() == prod(sz) @@ -49,7 +49,7 @@ class TestImageDType(unittest.TestCase): data = Tensor.randn(9*27*4).realize() tst = data.numpy() it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() - assert isinstance(it.lazydata.base.realized.dtype, ImageDType) + assert isinstance(it.uop.base.realized.dtype, ImageDType) np.testing.assert_equal(tst, it.numpy()) @unittest.expectedFailure # this isn't supported anymore, CAST to ImageDType stays ImageDType @@ -58,14 +58,14 @@ class TestImageDType(unittest.TestCase): tst = data.numpy() it = data.cast(dtypes.imagef((9,27,4))).realize() # the underlying UOp is identical - self.assertIs(it.lazydata.base.realized, data.lazydata.base.realized) + self.assertIs(it.uop.base.realized, data.uop.base.realized) np.testing.assert_equal(tst, it.numpy()) def test_image_and_back_wrong_shape(self): data = Tensor.randn(9*27*4).realize() tst = data.numpy() it = data.cast(dtypes.imagef((9,12,4))).realize() - assert not isinstance(it.lazydata.base.realized.dtype, ImageDType) + assert not isinstance(it.uop.base.realized.dtype, ImageDType) np.testing.assert_equal(tst, it.numpy()) def test_shrink_load_float(self): @@ -77,7 +77,7 @@ class TestImageDType(unittest.TestCase): # NOTE: contiguous is needed otherwise this folds it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).contiguous().realize() out = (it*2).realize() - assert isinstance(out.lazydata.base.realized.dtype, ImageDType) + assert isinstance(out.uop.base.realized.dtype, ImageDType) def test_sum(self): it = Tensor.rand(8).cast(dtypes.imagef((1,2,4))).realize() @@ -98,26 +98,26 @@ class TestImageDType(unittest.TestCase): def test_lru_alloc(self): data = Tensor.randn(9*27*4).realize() it = data.cast(dtypes.imagef((9,27,4))).realize() - b1 = it.lazydata.base.realized._buf + b1 = it.uop.base.realized._buf del it it = data.cast(dtypes.imagef((9,27,4))).realize() - assert it.lazydata.base.realized._buf == b1 + assert it.uop.base.realized._buf == b1 def test_no_lru_alloc(self): data = Tensor.randn(9*27*4).realize() it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() - b1 = it.lazydata.base.realized._buf + b1 = it.uop.base.realized._buf del it it = data.cast(dtypes.imagef((10,27,4))).contiguous().realize() - assert it.lazydata.base.realized._buf != b1 + assert it.uop.base.realized._buf != b1 def test_no_lru_alloc_dtype(self): data = Tensor.randn(9*27*4).realize() it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() - b1 = it.lazydata.base.realized._buf + b1 = it.uop.base.realized._buf del it it = data.cast(dtypes.imageh((9,27,4))).realize() - assert it.lazydata.base.realized._buf != b1 + assert it.uop.base.realized._buf != b1 # issue caused by: don't realize image to image casts. this is part of a larger problem #@unittest.expectedFailure @@ -137,8 +137,8 @@ class TestImageDType(unittest.TestCase): print(lst) assert not np.any(np.isnan(lst)) # NOTE: the w1 grad must realize to a seperate kernel - assert w1.grad.lazydata.is_realized, f"never realized {w1.grad}" - self.assertEqual(w1.grad.lazydata.base.buffer.dtype, dtypes.float32) + assert w1.grad.uop.is_realized, f"never realized {w1.grad}" + self.assertEqual(w1.grad.uop.base.buffer.dtype, dtypes.float32) self.assertEqual(len(sched), 10) @unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported") diff --git a/test/test_linearizer.py b/test/test_linearizer.py index cc23517167..6a9dfe2049 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -75,7 +75,7 @@ class TestLinearizer(unittest.TestCase): lowered = [x[1] for x in lower_schedule(c.schedule())] for ei in lowered: ei.run() rawbufs = lowered[-1].bufs - assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.base.realized, b.lazydata.base.realized} + assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.uop.base.realized, b.uop.base.realized} np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:]) np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4) @@ -140,7 +140,7 @@ class TestLinearizer(unittest.TestCase): def test_multireduce(self): Tensor.manual_seed(0) x = Tensor.randn(32, dtype=dtypes.float).realize() - st_x = x.lazydata.st + st_x = x.uop.st g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(st_x.reshape((1, 32)).expand((32, 32))),)) first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (1,))) @@ -172,7 +172,7 @@ class TestLinearizer(unittest.TestCase): def test_mid_dim_multireduce(self): Tensor.manual_seed(0) x = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() - st_x = x.lazydata.st + st_x = x.uop.st g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(st_x.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5))),)) first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) @@ -232,12 +232,12 @@ class TestLinearizer(unittest.TestCase): x1 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() x2 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() g0, g1, g2, g3 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(4)] - first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5))),)) + first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x0.uop.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5))),)) first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,))) - second_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5))),)) + second_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x1.uop.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5))),)) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 32, 32, 1, 5))) second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (2,))) - third_x = UOp(Ops.LOAD, dtypes.float, (g3.view(x2.lazydata.st.reshape((27, 32, 1, 1, 5))),)) + third_x = UOp(Ops.LOAD, dtypes.float, (g3.view(x2.uop.st.reshape((27, 32, 1, 1, 5))),)) mul = (third_x*second_reduce) third_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (mul,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 1, 5))), third_reduce)) @@ -253,7 +253,7 @@ class TestLinearizer(unittest.TestCase): def test_double_reduce_multireduce(self): Tensor.manual_seed(0) x = Tensor.randn(8, 32, 8, 16, dtype=dtypes.float).realize() - st = x.lazydata.st + st = x.uop.st g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, st.reshape((8, 1, 32, 8, 1, 16)).expand((8, 32, 32, 8, 16, 16)).to_uop())) first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2, 5))) @@ -302,9 +302,9 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) x = Tensor.randn(27, 15, 5, dtype=dtypes.float).softmax(1).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),)) + first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.uop.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),)) first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) - second_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.lazydata.st.reshape((27, 15, 1, 5))),)) + second_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.uop.st.reshape((27, 15, 1, 5))),)) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5))) second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce)) @@ -329,11 +329,11 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(4, 32, dtype=dtypes.float).realize() x_p = Tensor.randn(4, 32, dtype=dtypes.float).realize() g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] - first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32))),)) - first_x_p = UOp(Ops.LOAD, dtypes.float, (g2.view(x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32))),)) + first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.uop.st.reshape((4, 1, 32)).expand((4, 32, 32))),)) + first_x_p = UOp(Ops.LOAD, dtypes.float, (g2.view(x_p.uop.st.reshape((4, 1, 32)).expand((4, 32, 32))),)) first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) first_reduce_p = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x_p.alu(Ops.EXP2),), (Ops.ADD, (2,))) - second_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.lazydata.st.reshape((4, 32, 1))),)) + second_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.uop.st.reshape((4, 32, 1))),)) diff = (second_x+(first_reduce + first_reduce_p)*ast_const(dtypes.float, -1, (4, 32, 1))) second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((4, 1, 1))), second_reduce)) @@ -361,9 +361,9 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize() g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] - first_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),)) + first_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x.uop.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),)) first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) - second_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x.lazydata.st.reshape((27, 15, 1, 5))),)) + second_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x.uop.st.reshape((27, 15, 1, 5))),)) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5))) second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store0 = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce)) @@ -383,9 +383,9 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize() g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] - first_x = UOp(Ops.LOAD, dtypes.float, src=(g2.view(x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),)) + first_x = UOp(Ops.LOAD, dtypes.float, src=(g2.view(x.uop.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),)) first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) - second_x = UOp(Ops.LOAD, dtypes.float, src=(g2.view(x.lazydata.st.reshape((27, 15, 1, 5))),)) + second_x = UOp(Ops.LOAD, dtypes.float, src=(g2.view(x.uop.st.reshape((27, 15, 1, 5))),)) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5))) second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store0 = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce)) @@ -402,9 +402,9 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5))),)) + first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5))),)) first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) - second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((27, 3, 1, 5))),)) + second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((27, 3, 1, 5))),)) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5))) second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce)) @@ -418,9 +418,9 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5))),)) + first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5))),)) first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) - second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((27, 3, 1, 5))),)) + second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((27, 3, 1, 5))),)) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5))) second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce)) @@ -437,9 +437,9 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) x = Tensor.randn(27, 12, 5, dtype=dtypes.float).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop())) + first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.uop.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop())) first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) - second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 12, 1, 5)).to_uop())) + second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.uop.st.reshape((27, 12, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 12, 1, 5))) second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) @@ -453,10 +453,10 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35))),)) + first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35))),)) first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,))) neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1)) - second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((15, 25, 35, 1))),)) + second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((15, 25, 35, 1))),)) squares = (second_x+neg_mean)*(second_x+neg_mean) squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,))) variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1)) @@ -471,10 +471,10 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35))),)) + first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35))),)) first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) neg_mean = first_reduce * ast_const(dtypes.float, -0.04, (15, 25, 1, 35)) - second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((15, 25, 1, 35))),)) + second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((15, 25, 1, 35))),)) squares = (second_x+neg_mean)*(second_x+neg_mean) squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (1,))) variance = squares_sum * ast_const(dtypes.float, 0.04, (15, 1, 1, 35)) @@ -491,10 +491,10 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] - first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop())) + first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.uop.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop())) first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,))) neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1)) - second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop())) + second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.uop.st.reshape((15, 25, 35, 1)).to_uop())) squares = (second_x+neg_mean)*(second_x+neg_mean) squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,))) variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1)) @@ -514,12 +514,12 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(3, 27, 32, dtype=dtypes.float).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] # push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD - first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32))),)) + first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32))),)) first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,))) neg_mean = first_reduce * ast_const(dtypes.float, -0.03125, (3, 27, 32, 1)) # store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 32, 1)).to_uop(), mean)) # verify_lazyop(store) - second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((3, 27, 32, 1))),)) + second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((3, 27, 32, 1))),)) squares = (second_x+neg_mean)*(second_x+neg_mean) squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,))) variance = squares_sum * ast_const(dtypes.float, 0.03125, (3, 27, 1, 1)) @@ -535,9 +535,9 @@ class TestLinearizer(unittest.TestCase): def test_softmax_multireduce(self): x = Tensor.rand(4, 32).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32))),)) + first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((4, 1, 32,)).expand((4, 32, 32))),)) max_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.MAX, (2,))) - second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((4, 32, 1,))),)) + second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((4, 32, 1,))),)) centered_x = second_x+max_x*ast_const(dtypes.float, -1, (4, 32, 1)) exp_x = centered_x.alu(Ops.EXP2) sum_exp_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (exp_x,), (Ops.ADD, (1,))) @@ -553,7 +553,7 @@ class TestLinearizer(unittest.TestCase): dataset = Tensor.rand(16384, 256).realize() idxs = Tensor([0,3,5,6]).realize() with Context(FUSE_ARANGE=1): - sink = dataset[idxs].contiguous().kernelize().lazydata.base.src[1].arg.ast + sink = dataset[idxs].contiguous().kernelize().uop.base.src[1].arg.ast real_index = dataset.numpy()[idxs.numpy()].reshape(4, 1, 256, 1) helper_linearizer_ast(sink, [dataset, idxs], wanna_output=[real_index]) @@ -656,16 +656,16 @@ class TestLinearizer(unittest.TestCase): ] g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((1, N, N)).expand((N,N,N))),)) - x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, 1, N))),)) + x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((1, N, N)).expand((N,N,N))),)) + x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((N, 1, N))),)) r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.ADD, (1,))) r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),),(Ops.ADD, (0,))) store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((1,1,N))), r1)) sink = UOp(Ops.SINK, src=(store,)) helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=0, keepdims=True)).sum(axis=0).reshape(1,1,N)], opts=opts) - x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, 1, N)).expand((N,N,N))),)) - x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, N, 1))),)) + x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((N, 1, N)).expand((N,N,N))),)) + x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((N, N, 1))),)) r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.ADD, (2,))) r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((N,1,1))), r1)) @@ -683,16 +683,16 @@ class TestLinearizer(unittest.TestCase): ] g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((1, N, N)).expand((N,N,N))),)) - x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, 1, N))),)) + x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((1, N, N)).expand((N,N,N))),)) + x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((N, 1, N))),)) r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.MAX, (1,))) r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),), (Ops.MAX, (0,))) store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((1,1,N))), r1)) sink = UOp(Ops.SINK, src=(store,)) helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=0, keepdims=True)).max(axis=0).reshape(1,1,N)], opts=opts) - x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, 1, N)).expand((N,N,N))),)) - x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, N, 1))),)) + x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((N, 1, N)).expand((N,N,N))),)) + x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((N, N, 1))),)) r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.MAX, (2,))) r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (Ops.MAX, (1,))) store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((N,1,1))), r1)) @@ -711,8 +711,8 @@ class TestLinearizer(unittest.TestCase): opts = [[Opt(OptOps.PADTO, 0, 32)],[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],] wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=1,keepdims=True), a.numpy(), b.numpy())).sum(axis=1),0.0,1.0).reshape((N,1,1)) # noqa: E501 - ld0 = x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)) - ld1 = x.lazydata.st.reshape((N, N, 1)) + ld0 = x.uop.st.reshape((N, 1, N)).expand((N,N,N)) + ld1 = x.uop.st.reshape((N, N, 1)) ast = UOp(Ops.SINK, src=( UOp(Ops.STORE, src=( UOp(Ops.VIEW, dtypes.float.ptr(), arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501 @@ -742,8 +742,8 @@ class TestLinearizer(unittest.TestCase): ast_const(dtypes.float, 1.0, (N, 1, 1)),)),)),)) helper_linearizer_ast(ast, [x,a,b], opts=opts, wanna_output=[wanna_output]) - ld0 = x.lazydata.st.reshape((1, N, N)).expand((N,N,N)) - ld1 = x.lazydata.st.reshape((N, 1, N)) + ld0 = x.uop.st.reshape((1, N, N)).expand((N,N,N)) + ld1 = x.uop.st.reshape((N, 1, N)) wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=0,keepdims=True), a.numpy(), b.numpy())).sum(axis=0),0.0,1.0).reshape(1,1,N) # noqa: E501 ast = UOp(Ops.SINK, src=( UOp(Ops.STORE, src=( @@ -776,8 +776,8 @@ class TestLinearizer(unittest.TestCase): # pad reduce axis helper_linearizer_ast(ast, [x,a,b], opts=[[Opt(OptOps.PADTO, 1, 32)],], wanna_output=[wanna_output]) - ld0 = x.lazydata.st.reshape((1,1,N,N)).expand((N,N,N,N)) - ld1 = x.lazydata.st.reshape((N,N,1,1)) + ld0 = x.uop.st.reshape((1,1,N,N)).expand((N,N,N,N)) + ld1 = x.uop.st.reshape((N,N,1,1)) wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(keepdims=True), a.numpy(), b.numpy())).sum(keepdims=True),0.0,1.0).reshape((1,1,1,1))# noqa: E501 ast = UOp(Ops.SINK, src=( UOp(Ops.STORE, src=( @@ -1794,7 +1794,7 @@ class TestHandCodedOpts(unittest.TestCase): def helper_linearizer_ast(ast:UOp, inputs:list[Tensor], *args, **kwargs): assert isinstance(ast, UOp), "ast must be UOp" - inbufs = [x.lazydata.base.buffer for x in inputs] + inbufs = [x.uop.base.buffer for x in inputs] outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.st_arg.size, out.src[1].dtype).allocate() \ for out in ast.src] return _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs) diff --git a/test/test_masked_st.py b/test/test_masked_st.py index c518d5b20e..ce88a710a1 100644 --- a/test/test_masked_st.py +++ b/test/test_masked_st.py @@ -7,7 +7,7 @@ class TestMaskedShapeTracker(unittest.TestCase): b = Tensor([1,1]).pad(((0,3),)) c = a*b assert c.shape == a.shape - #assert c.lazydata.st.views[0].mask is not None + #assert c.uop.st.views[0].mask is not None ret = c.data() assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0] @@ -16,7 +16,7 @@ class TestMaskedShapeTracker(unittest.TestCase): b = Tensor([1,1]).pad(((0,3),)) c = a*b assert c.shape == a.shape - #assert c.lazydata.st.views[0].mask is not None + #assert c.uop.st.views[0].mask is not None ret = c.data() assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0] @@ -24,7 +24,7 @@ class TestMaskedShapeTracker(unittest.TestCase): a = Tensor([1,1]).pad(((0,2),)) b = Tensor([1,1]).pad(((0,2),)) c = a+b - #assert c.lazydata.st.views[0].mask is not None + #assert c.uop.st.views[0].mask is not None ret = c.data() assert ret.tolist() == [2.0, 2.0, 0.0, 0.0] diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 22d7fa3c26..f06dc866a6 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -29,7 +29,7 @@ N = 128 def _test_allreduce(t:Tensor): aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize() ts = t.shard(devices_4, 0).realize() - b = Tensor(UOp.allreduce(ts.lazydata, Ops.ADD, ts.device)) + b = Tensor(UOp.allreduce(ts.uop, Ops.ADD, ts.device)) b.realize() return aa, b @@ -50,7 +50,7 @@ class TestMultiTensor(unittest.TestCase): def test_shard(self): X = Tensor.ones(256).contiguous().realize() X.shard_(devices_2, 0) - for lb in X.lazydata.src: + for lb in X.uop.src: assert lb.shape == (128,) (X + X).realize() @@ -61,12 +61,12 @@ class TestMultiTensor(unittest.TestCase): def test_tensor_from_multi(self): X = Tensor([1, 2], dtype=dtypes.int).shard_(devices_2, 0) - Y = Tensor(X.lazydata) + Y = Tensor(X.uop) self.assertEqual(Y.device, Device.DEFAULT) np.testing.assert_equal(X.numpy(), Y.numpy()) with self.assertRaises(AssertionError): - _ = Tensor(X.lazydata, dtype=dtypes.float) + _ = Tensor(X.uop, dtype=dtypes.float) def test_sharded_arange(self): sharded_arange = Tensor.arange(1000).shard(devices_2, 0) @@ -247,9 +247,9 @@ class TestMultiTensor(unittest.TestCase): shape = tuple([(n if i == 0 else 1) * random.randint(1, 10) for i in range(random.randint(1, 4))]) t = Tensor.rand(shape).shard_(tuple([d0, d1, d2, d3][:n]), 0) with Context(RING=0): - a = Tensor(UOp.allreduce(t.lazydata, Ops.ADD, t.device)) + a = Tensor(UOp.allreduce(t.uop, Ops.ADD, t.device)) with Context(RING=2): - b = Tensor(UOp.allreduce(t.lazydata, Ops.ADD, t.device)) + b = Tensor(UOp.allreduce(t.uop, Ops.ADD, t.device)) diff = a - b mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy() max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy() @@ -590,21 +590,21 @@ class TestMultiTensor(unittest.TestCase): t4 = t2.reshape((26, 105,)) for t in [t0, t1, t2, t3, t4]: - assert t.lazydata.axis == 1 + assert t.uop.axis == 1 np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten()) # test shape-one axis t5 = t4.reshape((26, 1, 105)) - assert t5.lazydata.axis == 2 + assert t5.uop.axis == 2 np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten()) # test split and rejoin to the right and reshape to the left t5 = t0.reshape((2, 13, 3, 5, 7)) t6 = t0.reshape((13, 2, 3, 7, 5)) t7 = t0.reshape((1, 13, 2, 3, 1, 7, 5)) - assert t5.lazydata.axis == 2 - assert t6.lazydata.axis == 2 - assert t7.lazydata.axis == 3 + assert t5.uop.axis == 2 + assert t6.uop.axis == 2 + assert t7.uop.axis == 3 np.testing.assert_allclose(t5.numpy().flatten(), t0.numpy().flatten()) np.testing.assert_allclose(t6.numpy().flatten(), t0.numpy().flatten()) np.testing.assert_allclose(t7.numpy().flatten(), t0.numpy().flatten()) @@ -616,7 +616,7 @@ class TestMultiTensor(unittest.TestCase): @unittest.skip("no longer supports uneven shard") def test_reshape_on_axis_uneven(self): def reshape_helper(t0, t, t_axis): - assert t.lazydata.axis == t_axis + assert t.uop.axis == t_axis np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy()) t0 = Tensor.rand((4, 42, 15)).shard(devices_3, axis=1, splits=[14, 7, 21]) @@ -687,24 +687,24 @@ class TestMultiTensor(unittest.TestCase): self.assertEqual(t.shape, t2.shape) self.assertEqual(t.device, t2.device) self.assertEqual(t.dtype, t2.dtype) - self.assertEqual(t.lazydata.axis, t2.lazydata.axis) + self.assertEqual(t.uop.axis, t2.uop.axis) def test_rand_like_from_alu(self): a = Tensor.ones(4, 4).shard(devices_4, axis=0) aa = a + a self.assertEqual(aa.device, devices_4) - self.assertEqual(aa.lazydata.axis, 0) + self.assertEqual(aa.uop.axis, 0) raa = aa.rand_like() self.assertEqual(raa.device, devices_4) - self.assertEqual(raa.lazydata.axis, 0) + self.assertEqual(raa.uop.axis, 0) b = Tensor.empty(4, 4).shard(devices_4, axis=None) ab = a + b self.assertEqual(ab.device, devices_4) - self.assertEqual(ab.lazydata.axis, 0) + self.assertEqual(ab.uop.axis, 0) rab = ab.rand_like() self.assertEqual(rab.device, devices_4) - self.assertEqual(rab.lazydata.axis, 0) + self.assertEqual(rab.uop.axis, 0) @unittest.skip("no longer supports uneven shard") def test_rand_like_uneven_shard(self): @@ -713,8 +713,8 @@ class TestMultiTensor(unittest.TestCase): self.assertEqual(t.shape, t2.shape) self.assertEqual(t.device, t2.device) self.assertEqual(t.dtype, t2.dtype) - self.assertEqual(t.lazydata.axis, t2.lazydata.axis) - assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.lazydata.src, t2.lazydata.src)) + self.assertEqual(t.uop.axis, t2.uop.axis) + assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.uop.src, t2.uop.src)) def test_rand_like_none_shard(self): t = Tensor.empty((16, 16)).shard(devices_2) @@ -722,7 +722,7 @@ class TestMultiTensor(unittest.TestCase): self.assertEqual(t.shape, t2.shape) self.assertEqual(t.device, t2.device) self.assertEqual(t.dtype, t2.dtype) - self.assertEqual(t.lazydata.axis, t2.lazydata.axis) + self.assertEqual(t.uop.axis, t2.uop.axis) def test_rand_like_arg_dtype(self): t = Tensor.empty((16, 16), dtype=dtypes.int32).shard(devices_2, axis=1) @@ -771,7 +771,7 @@ class TestMultiTensor(unittest.TestCase): devices = (d0, d1, d2, d3) t = Tensor.zeros(16, 16).contiguous() t.shard_(devices, axis=0).realize() - assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.lazydata.src]) + assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.uop.src]) @unittest.skip("this is unreliable on OSX") def test_clone(self): @@ -928,7 +928,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): to_add.append((Tensor.ones(2, 8) * i).shard(devices)) added:list[Tensor] = [] - for bound, a in zip(x.lazydata.bounds, to_add): + for bound, a in zip(x.uop.bounds, to_add): added.append(x[bound[0]:bound[1]] + a) output = added[0].cat(*added[1:]) @@ -1043,7 +1043,7 @@ class TestBatchNorm(unittest.TestCase): bns.append(bn) bn_ts = [] - for bound, bn in zip(x.lazydata.bounds, bns): + for bound, bn in zip(x.uop.bounds, bns): bni = bn(x[bound[0]:bound[1]]) bn_ts.append(bni) diff --git a/test/test_nn.py b/test/test_nn.py index f99b2d2cdb..000e459164 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -596,9 +596,9 @@ class TestNN(unittest.TestCase): # sharded model shards the state_dict self.assertEqual(layer.weight.device, devices) - self.assertEqual(layer.weight.lazydata.axis, 3) + self.assertEqual(layer.weight.uop.axis, 3) self.assertEqual(layer.bias.device, devices) - self.assertEqual(layer.bias.lazydata.axis, None) + self.assertEqual(layer.bias.uop.axis, None) np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy()) np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy()) @@ -634,9 +634,9 @@ class TestNN(unittest.TestCase): load_state_dict(layer, state_dict) self.assertEqual(layer.weight.device, devices) - self.assertEqual(layer.weight.lazydata.axis, 3) + self.assertEqual(layer.weight.uop.axis, 3) self.assertEqual(layer.bias.device, devices) - self.assertEqual(layer.bias.lazydata.axis, None) + self.assertEqual(layer.bias.uop.axis, None) np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy()) np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy()) @@ -658,9 +658,9 @@ class TestNN(unittest.TestCase): # NOTE: model and state_dict shard differently, use the state_dict sharding # TODO: revisit this? self.assertEqual(layer.weight.device, devices) - self.assertEqual(layer.weight.lazydata.axis, None) + self.assertEqual(layer.weight.uop.axis, None) self.assertEqual(layer.bias.device, devices5) - self.assertEqual(layer.bias.lazydata.axis, 0) + self.assertEqual(layer.bias.uop.axis, 0) np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy()) np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy()) diff --git a/test/test_pickle.py b/test/test_pickle.py index e9d88e1817..93758eef68 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -53,7 +53,7 @@ class TestPickle(unittest.TestCase): def test_pickle_realized_tensor_alt2(self): print("** init") t = Tensor.rand(10, 10).to("CPU").realize() - tensor_uop = t.lazydata + tensor_uop = t.uop assert tensor_uop.is_realized, f"expected {tensor_uop} to be realized" t_values = t.numpy() # pickle @@ -63,13 +63,13 @@ class TestPickle(unittest.TestCase): del tensor_uop print("** post pickle") t2:Tensor = pickle.loads(st) - assert t2.lazydata.is_realized, f"expected {t2.lazydata} to be realized" + assert t2.uop.is_realized, f"expected {t2.uop} to be realized" np.testing.assert_equal(t_values, t2.numpy()) # NOTE: currently Buffer exists on the uop, not tensor def test_pickle_buffer_uop(self): t = Tensor.arange(4).realize() - a = t.lazydata + a = t.uop assert a.op is Ops.BUFFER self.assertIsNotNone(buffer:=a.realized) s = pickle.dumps(a) @@ -98,12 +98,12 @@ class TestPickle(unittest.TestCase): def test_pickle_buffer_view(self): t = Tensor.arange(10, device="CPU").contiguous().realize() vt = t[3:5].contiguous().realize() - assert hasattr(vt.lazydata.buffer, 'base') + assert hasattr(vt.uop.buffer, 'base') ref_value = vt.tolist() st = pickle.dumps(vt) del t, vt vt2 = pickle.loads(st) - assert hasattr(vt2.lazydata.buffer, 'base') + assert hasattr(vt2.uop.buffer, 'base') assert ref_value == vt2.tolist() def test_pickle_numpy(self): diff --git a/test/test_profiler.py b/test/test_profiler.py index 989f79f423..c64c271167 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -36,12 +36,12 @@ class TestProfiler(unittest.TestCase): si = self.b.schedule()[-1] TestProfiler.runner = get_runner(TestProfiler.d0.device, si.ast) - TestProfiler.b.lazydata.buffer.allocate() + TestProfiler.b.uop.buffer.allocate() def test_profile_kernel_run(self): runner_name = TestProfiler.runner._prg.name with helper_collect_profile(TestProfiler.d0) as profile: - TestProfiler.runner([TestProfiler.b.lazydata.buffer, TestProfiler.a.lazydata.buffer], var_vals={}) + TestProfiler.runner([TestProfiler.b.uop.buffer, TestProfiler.a.uop.buffer], var_vals={}) profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device) kernel_runs = [x for x in profile if isinstance(x, ProfileRangeEvent)] @@ -66,7 +66,7 @@ class TestProfiler(unittest.TestCase): with helper_collect_profile(TestProfiler.d0) as profile: buf1.copyin(memoryview(bytearray(struct.pack("ff", 0, 1)))) - TestProfiler.runner([buf1, TestProfiler.a.lazydata.buffer], var_vals={}) + TestProfiler.runner([buf1, TestProfiler.a.uop.buffer], var_vals={}) buf1.copyout(memoryview(bytearray(buf1.nbytes))) profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device) diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index d04aecd6d4..3dd795ead6 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -23,7 +23,7 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None): uops = dedup(flatten(_recursive_add(st) for st in stores)) outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \ initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE] - inbufs = [cast(UOp,x.lazydata).base.buffer for x in inputs] + inbufs = [cast(UOp,x.uop).base.buffer for x in inputs] src = Device[Device.DEFAULT].renderer.render(uops) ei = CompiledRunner(ProgramSpec("test", src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size)) ei.exec(outbufs+inbufs) diff --git a/test/test_rewrite_tracked_childen.py b/test/test_rewrite_tracked_childen.py index a1dea7955d..688e2efa4c 100644 --- a/test/test_rewrite_tracked_childen.py +++ b/test/test_rewrite_tracked_childen.py @@ -25,7 +25,7 @@ class TestRewriteTrackedChildren(unittest.TestCase): a = Tensor(2) b = Tensor(3) c = a + b - sink = c.lazydata.sink() + sink = c.uop.sink() sink = graph_rewrite(sink, rewrite, track_children=True) def test_simple_child(self): @@ -35,8 +35,8 @@ class TestRewriteTrackedChildren(unittest.TestCase): a = Tensor(2) b = Tensor(3) c = a + b - sink = c.lazydata - view_w_child = a.lazydata.src[0] + sink = c.uop + view_w_child = a.uop.src[0] print([x().arg for x in view_w_child.children]) print([x.arg for x in sink.get_children_map()[view_w_child]]) self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((2,3))) @@ -57,7 +57,7 @@ class TestRewriteTrackedChildren(unittest.TestCase): extra = PatternMatcher([(UPat(Ops.REDUCE_AXIS, name="r"), print_children)]) a = Tensor.empty(3, 3) r = (a+0).sum() - graph_rewrite(r.lazydata, merge_views+sym+extra, track_children=True) + graph_rewrite(r.uop, merge_views+sym+extra, track_children=True) if __name__ == '__main__': unittest.main() diff --git a/test/test_schedule.py b/test/test_schedule.py index 557ddc45dd..92d5728060 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -188,7 +188,7 @@ class TestSchedule(unittest.TestCase): # pre-realize shared seed Tensor._device_rng_counters[x.device].realize() # run custom kernelized kernel - sched_sink = graph_rewrite(x.lazydata, create_kernels, ctx={u:None for u in x.lazydata.toposort() if u.op is Ops.COPY}, bottom_up=True) + sched_sink = graph_rewrite(x.uop, create_kernels, ctx={u:None for u in x.uop.toposort() if u.op is Ops.COPY}, bottom_up=True) y = Tensor(graph_rewrite(sched_sink, create_ast, bottom_up=True)) run_schedule(check_schedule(y, 1)) # compare against reference @@ -198,15 +198,15 @@ class TestSchedule(unittest.TestCase): def test_empty_is_not_realized(self): a = Tensor.empty(10) child = a+2 - assert not a.lazydata.is_realized + assert not a.uop.is_realized child.realize() - assert a.lazydata.is_realized + assert a.uop.is_realized # NOTE: because empty does not have an ExecItem if realize is called on a childless empty, it never gets allocated. def test_childless_empty_never_allocates(self): a = Tensor.empty(10) a.realize() - assert not a.lazydata.is_realized + assert not a.uop.is_realized def test_simplify_padded_const(self): a = Tensor.empty(1022).cummax(axis=0) @@ -404,7 +404,7 @@ class TestSchedule(unittest.TestCase): sched = check_schedule([a, b], 1) run_schedule(sched) # a and b share the same underlying device memory - self.assertIs(a.lazydata.realized, b.lazydata.realized) + self.assertIs(a.uop.realized, b.uop.realized) def test_clone_doesnt_dedup(self): src = Tensor.ones(4).contiguous().realize() @@ -413,7 +413,7 @@ class TestSchedule(unittest.TestCase): sched = check_schedule([a, b], 2, filter_sink=False) run_schedule(sched) # a and b are assigned to the same device Buffer - self.assertIsNot(a.lazydata.realized, b.lazydata.realized) + self.assertIsNot(a.uop.realized, b.uop.realized) # EMPTY is assigned to a unique device Buffer @@ -422,7 +422,7 @@ class TestSchedule(unittest.TestCase): b = Tensor.empty((4,)) # NOTE: empty does not have any schedule check_schedule([a, b], 0, filter_sink=False) - self.assertIsNot(a.lazydata.buffer, b.lazydata.buffer) + self.assertIsNot(a.uop.buffer, b.uop.buffer) def test_dedup_outputs(self): a = Tensor.full((4, 4), 1.).contiguous().realize() @@ -712,13 +712,13 @@ class TestSchedule(unittest.TestCase): prev_a = (a+1).contiguous() a.assign(Tensor([2])) a.kernelize(prev_a) - assert prev_a.lazydata in a.lazydata.src, "contiguous usage must run before assign" + assert prev_a.uop in a.uop.src, "contiguous usage must run before assign" self.assertEqual((prev_a+a*3).item(), 1+2*3) def test_multioutput_ast(self): - a = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().lazydata - b = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().lazydata - c = Tensor.arange(4).realize().lazydata + a = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop + b = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop + c = Tensor.arange(4).realize().uop kernel = UOp(Ops.KERNEL, src=(a, b, c.base), arg=Kernel(UOp.sink(c.r(Ops.ADD, (0,))+1, c.r(Ops.ADD, (0,))*2))) assert all(s.op is Ops.BUFFER for s in kernel.src), f"views are not allowed here {kernel}" kernel = graph_rewrite(kernel, create_ast) @@ -1623,7 +1623,7 @@ class TestSchedule(unittest.TestCase): @unittest.skip("disabling subbuffer manually isn't supported anymore") def test_bitcast_disable_subbufer(self): - x = cast(UOp, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata) + x = cast(UOp, Tensor.empty(1, dtype=dtypes.float32).realize().uop) a = x.alu(Ops.EXP2).cast(dtypes.int32, True, allow_buffer_view=False) b = x.cast(dtypes.int32, True, allow_buffer_view=False) b = a.alu(Ops.ADD, b) @@ -1660,11 +1660,11 @@ class TestSchedule(unittest.TestCase): self.assertEqual(GlobalCounters.mem_used-base, 1024) def test_const_schedule(self): - constv = Tensor.empty(2, 2).lazydata.const_like(10) + constv = Tensor.empty(2, 2).uop.const_like(10) check_schedule(constv, 0) def test_const_schedule_contig(self): - constv = Tensor.empty(2, 2).lazydata.const_like(10).contiguous() + constv = Tensor.empty(2, 2).uop.const_like(10).contiguous() check_schedule(constv, 1) @unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU") @@ -1676,8 +1676,8 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 3)) np.testing.assert_allclose(out.numpy(), x.numpy()@y.numpy(), atol=1e-4, rtol=1e-4) self.assertIsInstance(out.dtype, ImageDType) - self.assertIsNotNone(out.lazydata.base.realized) - self.assertIsInstance(out.lazydata.base.realized.dtype, ImageDType) + self.assertIsNotNone(out.uop.base.realized) + self.assertIsInstance(out.uop.base.realized.dtype, ImageDType) def _test_fusion(self, shapes, f, cnt): with Context(DEBUG=0, TRACK_MATCH_STATS=0): args = [Tensor.randn(s).realize() for s in shapes] @@ -1707,9 +1707,9 @@ class TestSchedule(unittest.TestCase): a = Tensor.arange(4).reshape(1, 4) casted_view = a.pad(((0, 1), (0, 0))).cast(dtypes.float) casted_view.realize() - self.assertEqual(casted_view.lazydata.base.realized.size, 4) + self.assertEqual(casted_view.uop.base.realized.size, 4) realized_view = casted_view.contiguous().realize() - self.assertEqual(realized_view.lazydata.base.realized.size, 8) + self.assertEqual(realized_view.uop.base.realized.size, 8) self.assertListEqual(realized_view.tolist(), [[0.0, 1.0, 2.0, 3.0], [0.0, 0.0, 0.0, 0.0]]) # NOTE: we only reorder CAST if it's an EXPAND @@ -1717,16 +1717,16 @@ class TestSchedule(unittest.TestCase): a = Tensor.arange(4).reshape(1, 4) casted_view = a.shrink(((0, 1), (0, 2))).cast(dtypes.float) casted_view.realize() - self.assertEqual(casted_view.lazydata.base.realized.size, 2) + self.assertEqual(casted_view.uop.base.realized.size, 2) realized_view = casted_view.contiguous().realize() - self.assertEqual(realized_view.lazydata.base.realized.size, 2) + self.assertEqual(realized_view.uop.base.realized.size, 2) self.assertListEqual(realized_view.tolist(), [[0, 1]]) def test_cast_const_view(self): a = Tensor.ones((4, 4), dtype=dtypes.float32) casted_view = a.cast(dtypes.int32) run_schedule(check_schedule(casted_view, 0)) - self.assertIsNone(casted_view.lazydata.base.realized) + self.assertIsNone(casted_view.uop.base.realized) realized_const_view = casted_view.contiguous() run_schedule(check_schedule(realized_const_view, 1)) self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]) @@ -1998,7 +1998,7 @@ class TestIndexing(unittest.TestCase): def test_recursive_swizzle(self): a = Tensor([1,2,3,4]).realize() for _ in range(24): a = a + a - new_uop = swizzle_rewrite(a.lazydata.reshape((4, 1))) + new_uop = swizzle_rewrite(a.uop.reshape((4, 1))) self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1))) self.assertEqual(swizzle_cnt(new_uop), 0) @@ -2012,13 +2012,13 @@ class TestIndexing(unittest.TestCase): a = Tensor.empty(32, 32).sum(axis=1)+Tensor.empty(1,32) ast = a.schedule()[0].ast self.assertEqual(ast.shape, (32, 1)) - self.assertEqual(a.lazydata.shape, (1, 32)) + self.assertEqual(a.uop.shape, (1, 32)) def test_no_reshape_reduceop(self): a = Tensor.empty(32, 32).sum(axis=(1,)).contiguous() ast = a.schedule()[0].ast self.assertEqual(ast.shape, (32, 1)) - self.assertEqual(a.lazydata.shape, (32,)) + self.assertEqual(a.uop.shape, (32,)) @track_rewrites(named=True) def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_left), view_right) @@ -2161,7 +2161,7 @@ class TestView(unittest.TestCase): assert b.shape == (10, 10) sched = check_schedule(b.contiguous(), 1) self.assertEqual(store_val(sched[-1]).op, Ops.LOAD) - self.assertEqual(store_val(sched[-1]).st_arg, b.lazydata.st) + self.assertEqual(store_val(sched[-1]).st_arg, b.uop.st) run_schedule(sched) np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:]) @@ -2175,9 +2175,9 @@ class TestView(unittest.TestCase): late_mul = a*bv check_schedule(late_mul, 0) # the arange doesn't realize - self.assertIsNone(b.lazydata.base.realized) + self.assertIsNone(b.uop.base.realized) # mul doesn't realize - self.assertIsNone(late_mul.lazydata.base.realized) + self.assertIsNone(late_mul.uop.base.realized) self.assertEqual(late_mul.tolist(), [0, 0]) # SINK has two branches: @@ -2192,13 +2192,13 @@ class TestView(unittest.TestCase): other_child = b+2 s = check_schedule([late_mul, other_child], 2) # the arange becomes a BUFFER - self.assertIs(b.lazydata.base.op, Ops.BUFFER) + self.assertIs(b.uop.base.op, Ops.BUFFER) # mul still collapses - self.assertIs(late_mul.lazydata.base.op, Ops.CONST) + self.assertIs(late_mul.uop.base.op, Ops.CONST) run_schedule(s) self.assertEqual(other_child.tolist(), [2, 3, 4]) -def tensor_rewrite(t) -> UOp: return graph_rewrite(t.lazydata.base, merge_views+symbolic_simple) +def tensor_rewrite(t) -> UOp: return graph_rewrite(t.uop.base, merge_views+symbolic_simple) class TestSimplifier(unittest.TestCase): def test_sink_childless_const(self): x = Tensor(0) @@ -2222,8 +2222,8 @@ class TestSimplifier(unittest.TestCase): a = Tensor.empty(4, 4, dtype=dtypes.int) sink = tensor_rewrite(a*0) assert UPat(Ops.CONST, arg=0).match(sink, {}) - self.assertIs(tensor_rewrite(a*1).base, a.lazydata.base) - self.assertIs(tensor_rewrite(a+0).base, a.lazydata.base) + self.assertIs(tensor_rewrite(a*1).base, a.uop.base) + self.assertIs(tensor_rewrite(a+0).base, a.uop.base) def test_cast_folding(self): a = Tensor(1.0).cast(dtypes.int) @@ -2257,14 +2257,14 @@ class TestConst(unittest.TestCase): def test_tensor_const(self): a = Tensor(1) - print(a.lazydata) - self.assertTrue(tensor_const_pm.rewrite(a.lazydata)) + print(a.uop) + self.assertTrue(tensor_const_pm.rewrite(a.uop)) def test_tensor_variable(self): vv = UOp.variable("a", 0, 10).bind(1) a = Tensor(vv) - print(a.lazydata) - self.assertTrue(tensor_const_pm.rewrite(a.lazydata)) + print(a.uop) + self.assertTrue(tensor_const_pm.rewrite(a.uop)) def test_const_schedule(self): a = Tensor.ones((4, 4)) @@ -2311,7 +2311,7 @@ class TestConst(unittest.TestCase): sched = add.schedule() self.assertEqual(len(sched), 0) # b+0 and b share the same underlying device memory - self.assertIs(add.lazydata.buffer, b.lazydata.buffer) + self.assertIs(add.uop.buffer, b.uop.buffer) self.assertListEqual(add.tolist(), [2, 2, 2, 2]) def test_src_masked_const_folding(self): @@ -2324,7 +2324,7 @@ class TestConst(unittest.TestCase): self.assertEqual(len(sched), 1) run_schedule(sched) # add gets assigned to a new buffer - self.assertIsNot(add.lazydata.base.realized, b.lazydata.base.realized) + self.assertIsNot(add.uop.base.realized, b.uop.base.realized) self.assertListEqual(add.tolist(), [4, 2, 2, 2, 2, 4]) # ** part 3: Tensor variable bindings @@ -2359,15 +2359,15 @@ class TestCopyFolding(unittest.TestCase): self.assertListEqual(b.tolist(), [0, 0, 0]) def test_alu_after_copy(self): - a = Tensor.ones((4,)).to("CPU").lazydata - b = Tensor.empty(4, device="CPU").lazydata + a = Tensor.ones((4,)).to("CPU").uop + b = Tensor.empty(4, device="CPU").uop add = a+b add = schedule_graph_rewrite(add) assert all_same([x.device for x in add.src]), f"ALU has different devices! {[x.device for x in add.src]}" @unittest.skip("this is just clone now") def test_copy_to_same_device(self): - a = Tensor.empty(4).lazydata + a = Tensor.empty(4).uop b = a.copy_to_device(a.device) check_schedule(b, 0, filter_sink=False) b = schedule_graph_rewrite(b) @@ -2377,7 +2377,7 @@ class TestCopyFolding(unittest.TestCase): @unittest.skip("this is just clone now") def test_copy_to_same_device_alt(self): - a = Tensor.empty(4, 4).lazydata + a = Tensor.empty(4, 4).uop b = a.copy_to_device(a.device) check_schedule(b, 0, filter_sink=False) b = schedule_graph_rewrite(b) @@ -2394,8 +2394,8 @@ class TestCopyFolding(unittest.TestCase): b = view.clone() # NOTE: this was sort of a bug making this 2 run_schedule(check_schedule(b, 2, filter_sink=False)) - self.assertEqual(b.lazydata.base.buffer.size, 2) - self.assertEqual(b.lazydata.size, 2) + self.assertEqual(b.uop.base.buffer.size, 2) + self.assertEqual(b.uop.size, 2) self.assertListEqual(b.tolist(), [0, 1]) def test_expanded_copy(self): @@ -2403,8 +2403,8 @@ class TestCopyFolding(unittest.TestCase): view = a.reshape(2, 1).expand(2, 2) b = view.clone() run_schedule(check_schedule(b, 2, filter_sink=False)) - self.assertEqual(b.lazydata.base.buffer.size, 4) - self.assertEqual(b.lazydata.size, 4) + self.assertEqual(b.uop.base.buffer.size, 4) + self.assertEqual(b.uop.size, 4) self.assertListEqual(b.tolist(), [[0, 0], [1, 1]]) def test_permuted_copy(self): @@ -2414,7 +2414,7 @@ class TestCopyFolding(unittest.TestCase): self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) def test_permute_on_disk(self): - with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().lazydata.base.buffer.as_buffer()) + with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer()) a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}") b = a.reshape(2, 2).permute(1, 0).to("CPU") b.realize() @@ -2430,7 +2430,7 @@ class TestCopyFolding(unittest.TestCase): # TODO: this is wrong because of the permute @unittest.expectedFailure def test_permute_after_shrink_on_disk(self): - with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().lazydata.base.buffer.as_buffer()) + with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().uop.base.buffer.as_buffer()) a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}") b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CPU") b.realize() @@ -2442,13 +2442,13 @@ class TestTensorUOpSpec(unittest.TestCase): unsafe_push_views = PatternMatcher([ (UPat.cvar("root").view(name="view"), lambda root,view: root.replace(src=tuple(x.view(view.st) for x in root.src))), ]) - a.lazydata = graph_rewrite(a.lazydata.sink(), merge_views+merge_views+unsafe_push_views) + a.uop = graph_rewrite(a.uop.sink(), merge_views+merge_views+unsafe_push_views) with self.assertRaisesRegex(RuntimeError, "UOp verification failed"): a.schedule() def test_expanded_const_ok(self): a = Tensor.ones((4, 4)) - t = graph_rewrite(a.lazydata.sink(), merge_views+merge_views) + t = graph_rewrite(a.uop.sink(), merge_views+merge_views) create_schedule_with_vars(t) # NOTE: changing symbolic CONST VIEWs is not allowed @@ -2456,69 +2456,69 @@ class TestTensorUOpSpec(unittest.TestCase): def test_symbolic_shape_ok(self): a = Tensor.ones(4) vi = UOp.variable("i", 1, 10).bind(4) - a.lazydata = graph_rewrite(a.reshape(vi).sum().lazydata, merge_views+merge_views) + a.uop = graph_rewrite(a.reshape(vi).sum().uop, merge_views+merge_views) a.schedule() class TestBufferUOp(unittest.TestCase): # BUFFER has a ShapeTracker of shape=(n,) and stride=(1,) def test_buffer_has_buffer(self): buf = Tensor.empty(10) - self.assertIsNotNone(buf.lazydata.buffer) - self.assertEqual(buf.lazydata.st, ShapeTracker.from_shape((10,))) + self.assertIsNotNone(buf.uop.buffer) + self.assertEqual(buf.uop.st, ShapeTracker.from_shape((10,))) # the device Buffer remains unallocated until it's we run the schedule - self.assertFalse(buf.lazydata.buffer.is_allocated()) + self.assertFalse(buf.uop.buffer.is_allocated()) add = buf+1 sched = add.schedule() - self.assertFalse(buf.lazydata.buffer.is_allocated()) + self.assertFalse(buf.uop.buffer.is_allocated()) run_schedule(sched) - self.assertTrue(buf.lazydata.buffer.is_allocated()) + self.assertTrue(buf.uop.buffer.is_allocated()) def test_buffer_has_unique_buffer(self): buf = Tensor.empty(10) - buf1 = buf.lazydata.buffer - buf2 = buf.lazydata.buffer + buf1 = buf.uop.buffer + buf2 = buf.uop.buffer self.assertIs(buf1, buf2) # we also allow VIEW(BUFFER) to access the underlying device Buffer, as long as it's contiguous def test_buffer_view_allowed(self): add = Tensor.empty(1, 1)+Tensor.empty(1, 1) add.realize() - self.assertIsNotNone(add.lazydata.buffer) - self.assertEqual(add.lazydata.shape, (1, 1)) + self.assertIsNotNone(add.uop.buffer) + self.assertEqual(add.uop.shape, (1, 1)) def test_buffer_view_not_allowed(self): permuted_view = Tensor.empty(1, 2, 3).permute(0, 2, 1) - merged = graph_rewrite(permuted_view.lazydata, merge_views) + merged = graph_rewrite(permuted_view.uop, merge_views) with self.assertRaisesRegex(AssertionError, "VIEW only works here if it's contiguous"): merged.buffer # cannot access Buffer of a non contiguous VIEW def test_buffer_only_after_realize(self): a = Tensor([1])+Tensor([2]) # accessing realized will return None - self.assertIsNone(a.lazydata.realized) + self.assertIsNone(a.uop.realized) # accessing Buffer will assert with self.assertRaisesRegex(AssertionError, "must be BUFFER"): - a.lazydata.buffer # there is no BUFFER on an unrealized ADD + a.uop.buffer # there is no BUFFER on an unrealized ADD # Buffer only exists once we realize it a.realize() - self.assertIsNotNone(a.lazydata.buffer) + self.assertIsNotNone(a.uop.buffer) def test_const_does_not_realize(self): a = Tensor(1)+Tensor(2) run_schedule(check_schedule(a, 0)) - self.assertIsNone(a.lazydata.base.realized) + self.assertIsNone(a.uop.base.realized) def test_var_does_not_realize(self): a = Tensor(UOp.variable("a", 0, 10).bind(1)) run_schedule(check_schedule(a, 0)) - self.assertIsNone(a.lazydata.base.realized) + self.assertIsNone(a.uop.base.realized) def test_view_does_not_realize(self): a = Tensor.randn(1, 4).expand(4, 4) a.realize() - self.assertEqual(a.lazydata.base.realized.size, 4) + self.assertEqual(a.uop.base.realized.size, 4) a2 = a.contiguous().realize() - self.assertEqual(a2.lazydata.base.realized.size, 16) + self.assertEqual(a2.uop.base.realized.size, 16) class TestContiguous(unittest.TestCase): def test_contiguous_buffer(self): @@ -2550,13 +2550,13 @@ class TestContiguous(unittest.TestCase): a = Tensor.empty(4) b = a.expand((4, 4)) check_schedule(b, 0) - self.assertEqual(b.lazydata.base.buffer.size, 4) + self.assertEqual(b.uop.base.buffer.size, 4) def test_contiguous_view_realizes(self): a = Tensor.empty(4) b = a.expand((4, 4)).contiguous() check_schedule(b, 1) - self.assertEqual(b.lazydata.base.buffer.size, 16) + self.assertEqual(b.uop.base.buffer.size, 16) class TestUOpBecome(unittest.TestCase): # the simplest case, if we create a new BUFFER for this tensor UOp @@ -2566,21 +2566,21 @@ class TestUOpBecome(unittest.TestCase): add = a+b check_schedule(add, 1) # NOTE: realized base is always a flat buffer - assert UPat(Ops.BUFFER).match(add.lazydata.base, {}) + assert UPat(Ops.BUFFER).match(add.uop.base, {}) # the Tensor UOp can optionally stack a VIEW on top of the BUFFER, in this case to preserve the (4, 4) shape of the tensor - assert add.lazydata is not add.lazydata.base - self.assertEqual(add.lazydata.size, 16) - self.assertEqual(add.lazydata.shape, (4, 4)) + assert add.uop is not add.uop.base + self.assertEqual(add.uop.size, 16) + self.assertEqual(add.uop.shape, (4, 4)) def test_new_buffer_view(self): a = Tensor.empty(4, 4) b = Tensor.empty(4, 4) add = (a+b).reshape(8, 2) check_schedule(add, 1) - assert UPat(Ops.BUFFER).match(add.lazydata.base, {}) + assert UPat(Ops.BUFFER).match(add.uop.base, {}) # the shape is preserverd in the becomes_map. - self.assertEqual(add.lazydata.shape, (8, 2)) - assert add.lazydata is not add.lazydata.base + self.assertEqual(add.uop.shape, (8, 2)) + assert add.uop is not add.uop.base def test_new_flat_buffer(self): a = Tensor.empty(4,) @@ -2588,7 +2588,7 @@ class TestUOpBecome(unittest.TestCase): add = a+b check_schedule(add, 1) # BUFFER already has a shape (4,), this tensor just becomes a contiguous BUFFER - assert UPat(Ops.BUFFER).match(add.lazydata, {}) + assert UPat(Ops.BUFFER).match(add.uop, {}) # sometimes we prefer to perform an op before movement ops, in this case we should stack the mops on top of the new buffer @@ -2597,8 +2597,8 @@ class TestUOpBecome(unittest.TestCase): a = Tensor.empty(4, 1) b = a.expand(4, 4).reciprocal() check_schedule(b, 1) - self.assertEqual(b.lazydata.base.buffer.size, 16) - self.assertEqual(b.lazydata.st, ShapeTracker.from_shape((4, 4))) + self.assertEqual(b.uop.base.buffer.size, 16) + self.assertEqual(b.uop.st, ShapeTracker.from_shape((4, 4))) def test_reorder_expand_alt(self): x = Tensor.empty(4, 1) @@ -2610,95 +2610,95 @@ class TestUOpBecome(unittest.TestCase): def test_become_existing_buffer(self): a = Tensor.empty(4, 4) b = a*1 - assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul + assert UPat(Ops.MUL).match(b.uop, {}) # before scheduling it's a mul check_schedule(b, 0) - assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.lazydata, {}) # scheduling merges all MovementOps into a single VIEW - self.assertIs(a.lazydata.base.buffer, b.lazydata.base.buffer) + assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.uop, {}) # scheduling merges all MovementOps into a single VIEW + self.assertIs(a.uop.base.buffer, b.uop.base.buffer) def test_become_buf_with_mops(self): a = Tensor.empty(2, 4, 2) noop = a.shrink(((1, 2), (0, 4), (0, 2))).reshape(4, 2)*1+0 # before realizing, this tensor is base - assert noop.lazydata is noop.lazydata.base + assert noop.uop is noop.uop.base noop.realize() # it becomes a realized view after realize - assert noop.lazydata is not noop.lazydata.base - assert noop.lazydata.base.op is Ops.BUFFER + assert noop.uop is not noop.uop.base + assert noop.uop.base.op is Ops.BUFFER late_add = noop+2 late_add.realize() def test_become_const_in_base(self): a = Tensor.empty(4) b = a*0 - assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul + assert UPat(Ops.MUL).match(b.uop, {}) # before scheduling it's a mul check_schedule(b, 0) - assert UPat(Ops.CONST, arg=0).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER) + assert UPat(Ops.CONST, arg=0).match(b.uop.base, {}) # scheduling replaces the tensor uop with a VIEW(BUFFER) def test_become_const_in_view(self): # if we shrink the base down to a size 0, only the VIEW becomes CONST, base is unchanged. add = Tensor.empty(2, 2)+Tensor.empty(2, 2) b = add.shrink(((0, 1), (0, 0))) check_schedule(b, 0) - assert UPat(Ops.CONST, arg=0).match(b.lazydata, {}) + assert UPat(Ops.CONST, arg=0).match(b.uop, {}) self.assertEqual(b.shape, (1, 0)) # the base is untouched. - assert UPat(Ops.ADD).match(add.lazydata, {}) + assert UPat(Ops.ADD).match(add.uop, {}) def test_become_const_from_const(self): const_add = Tensor(1)+Tensor(2) - assert UPat(Ops.ADD).match(const_add.lazydata, {}) + assert UPat(Ops.ADD).match(const_add.uop, {}) check_schedule(const_add, 0) - assert UPat(Ops.CONST, arg=3).match(const_add.lazydata.base, {}) + assert UPat(Ops.CONST, arg=3).match(const_add.uop.base, {}) # tensors can become another realized tensor source def test_become_existing_buf_simple(self): a = Tensor.empty(4, 4) b = a+0 check_schedule(b, 0) - assert b.lazydata.base.op is Ops.BUFFER - self.assertIs(a.lazydata, b.lazydata) + assert b.uop.base.op is Ops.BUFFER + self.assertIs(a.uop, b.uop) # they can also chain other movement ops on top of the tensor source def test_become_existing_buf_view(self): a = Tensor.empty(4, 4) b = a.permute((1, 0))+0 check_schedule(b, 0) - self.assertEqual(b.lazydata.st, a.lazydata.permute((1, 0)).st) + self.assertEqual(b.uop.st, a.uop.permute((1, 0)).st) def test_become_existing_buf_view_alt(self): a = Tensor.empty(4, 4) b = a.permute((1, 0)).reshape((8, 2))+0 check_schedule(b, 0) - self.assertEqual(b.lazydata.st, a.lazydata.permute((1, 0)).reshape((8, 2)).st) + self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st) # they can also have other base parents that simplified, in that case we just backtrack to the chained mops def test_become_existing_buf_complex(self): a = Tensor.empty(4, 4) b = (a.permute((1, 0))+0).reshape((8, 2))+0 check_schedule(b, 0) - self.assertEqual(b.lazydata.st, a.lazydata.permute((1, 0)).reshape((8, 2)).st) - assert b.lazydata.base.op is Ops.BUFFER + self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st) + assert b.uop.base.op is Ops.BUFFER def test_become_multiple_choices(self): a = Tensor.empty(16) b = (a.reshape(1, 1, 4, 1, 4)+0).reshape(1, 1, 4, 4).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0 c = (a.reshape(1, 1, 4, 4)+0).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0 check_schedule([b, c], 0) - assert all_same([x.lazydata.base.realized for x in [a,b,c]]) + assert all_same([x.uop.base.realized for x in [a,b,c]]) # these movement ops result in the same ShapeTracker - assert b.lazydata.st == c.lazydata.st - assert b.lazydata is c.lazydata - assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(c.lazydata, {}) + assert b.uop.st == c.uop.st + assert b.uop is c.uop + assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(c.uop, {}) def test_setitem_becomes_subbuffer(self): a = Tensor.full((4,), 2.).contiguous().realize() b = a.shrink(((0, 2),)).assign(Tensor.full((2,), 1.0)) b.realize() - assert a.lazydata.is_realized - assert a.lazydata.buffer._base is None + assert a.uop.is_realized + assert a.uop.buffer._base is None # b is a subbuffer of a - assert b.lazydata.op is Ops.BUFFER_VIEW - assert b.lazydata.src[0] is a.lazydata + assert b.uop.op is Ops.BUFFER_VIEW + assert b.uop.src[0] is a.uop def test_setitem_offset(self): a = Tensor.full((16,), 0.).contiguous().realize() diff --git a/test/test_search.py b/test/test_search.py index 573290860e..cf430ef092 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -44,7 +44,7 @@ class TestTimeLinearizer(unittest.TestCase): # Ensure that the kernel count is not incremented by time_linearizer when clearing l2 def test_kernel_count(self): - ast = Tensor.zeros(16).contiguous().kernelize().lazydata.src[1].arg.ast + ast = Tensor.zeros(16).contiguous().kernelize().uop.src[1].arg.ast lin = Kernel(ast) bufs = bufs_from_lin(lin) diff --git a/test/test_setitem.py b/test/test_setitem.py index d3f066bad4..d6de6e93fb 100644 --- a/test/test_setitem.py +++ b/test/test_setitem.py @@ -50,7 +50,7 @@ class TestSetitem(unittest.TestCase): def test_setitem_into_noncontiguous(self): t = Tensor.ones(4) - self.assertFalse(t.lazydata.st.contiguous) + self.assertFalse(t.uop.st.contiguous) with self.assertRaises(RuntimeError): t[1] = 5 @unittest.skip("TODO: flaky") diff --git a/test/test_symbolic_shapetracker.py b/test/test_symbolic_shapetracker.py index 9c754e3746..4cd542feb4 100644 --- a/test/test_symbolic_shapetracker.py +++ b/test/test_symbolic_shapetracker.py @@ -49,11 +49,11 @@ class TestSymbolic(unittest.TestCase): j = Variable("j", 1, 5).bind(3) k = Variable("k", 1, 5).bind(3) t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0) - st = t.lazydata.st + st = t.uop.st self.assert_tuple_equal(st.shape, (i+j+k, 4)) assert st.real_strides() == (4, 1) t = Tensor.rand(3, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0) - st = t.lazydata.st + st = t.uop.st self.assert_tuple_equal(st.shape, (2*i+3, 3)) assert st.real_strides() == (3, 1) @@ -62,7 +62,7 @@ class TestSymbolic(unittest.TestCase): j = Variable("j", 1, 5).bind(4) k = Variable("k", 1, 5).bind(4) t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1) - st = t.lazydata.st + st = t.uop.st self.assert_tuple_equal(st.shape, (3, i+j+k)) self.assert_tuple_equal(st.real_strides(), (i+j+k, 1)) @@ -113,7 +113,7 @@ class TestShapeTrackerUnbind(unittest.TestCase): v = Variable("v", 1, 100) bv = Variable("v", 1, 100).bind(3) t = Tensor.rand(3, 4).reshape(bv, 4) - unbound_st, var_val = t.lazydata.st.unbind() + unbound_st, var_val = t.uop.st.unbind() assert unbound_st == ShapeTracker((View.create(shape=(v, 4)),)) assert var_val == {v: 3} @@ -121,7 +121,7 @@ class TestShapeTrackerUnbind(unittest.TestCase): v = Variable("v", 1, 100) bv = Variable("v", 1, 100).bind(2) t = Tensor.rand(3, 4).shrink(((bv, bv+1), (0, 4))) - unbound_st, var_val = t.lazydata.st.unbind() + unbound_st, var_val = t.uop.st.unbind() assert unbound_st == ShapeTracker((View.create(shape=(1, 4), offset=4*v),)) assert var_val == {v: 2} @@ -180,8 +180,8 @@ class TestSymbolicReshapeFromNonContiguous(unittest.TestCase): vi = Variable("i", 1, 5).bind(4) t = Tensor.ones(3, 4).reshape(3, vi) assert t.shape == (3, vi) - assert not t.lazydata.st.contiguous - assert len(t.lazydata.st.views) == 1 + assert not t.uop.st.contiguous + assert len(t.uop.st.views) == 1 def test_reshape_not_allowed(self): vi = Variable("i", 1, 5).bind(4) @@ -195,12 +195,12 @@ class TestSymbolicReshapeFromNonContiguous(unittest.TestCase): def test_reshape_from_padded(self): vi = Variable("i", 1, 5).bind(4) t = Tensor.ones(3, 4).contiguous().expand(2, 3, 4).pad(((1, 1), None, None)).shrink((None, None, (1, 3))) - st = t.lazydata.st + st = t.uop.st assert len(st.views) == 1 view = st.views[0] assert view.shape == (4, 3, 2) t = t.reshape(vi, 3, 2) - st2 = t.lazydata.st + st2 = t.uop.st assert len(st2.views) == 1 view2 = st2.views[0] # check only shape changed. strides, offset, mask, contiguous remained the same @@ -237,7 +237,7 @@ class TestSymbolicPad(unittest.TestCase): v = Variable("v", 1, 100).bind(5) t = Tensor.ones(5).reshape(v).pad(((4, 0),)).reshape(9) assert t.shape == (9,) - st = t.lazydata.st + st = t.uop.st print(st) if __name__ == '__main__': diff --git a/test/test_tensor.py b/test/test_tensor.py index 3e887ccf75..704f69740d 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -565,17 +565,17 @@ class TestZeroShapeTensor(unittest.TestCase): t = Tensor.empty(3, 2, 0) assert t.shape == (3, 2, 0) # numpy has stride 0, 0, 0; torch has stride 2, 1, 1 - assert t.lazydata.st.real_strides() == (0, 0, 0) + assert t.uop.st.real_strides() == (0, 0, 0) t = Tensor.empty(3, 0, 2) assert t.shape == (3, 0, 2) # numpy has stride 0, 0, 0; torch has stride 2, 2, 1 - assert t.lazydata.st.real_strides() == (0, 0, 0) + assert t.uop.st.real_strides() == (0, 0, 0) t = Tensor.empty(0, 0, 0) assert t.shape == (0, 0, 0) # numpy has stride 0, 0, 0; torch has stride 1, 1, 1 - assert t.lazydata.st.real_strides() == (0, 0, 0) + assert t.uop.st.real_strides() == (0, 0, 0) def test_rand(self): t = Tensor.rand(3, 2, 0) @@ -690,24 +690,24 @@ class TestZeroShapeTensor(unittest.TestCase): a = Tensor.rand(16, 16).realize() b = a.clone() np.testing.assert_allclose(a.numpy(), b.numpy()) - self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer) + self.assertIsNot(a.uop.base.buffer, b.uop.base.buffer) a = Tensor.rand(16, 16).mul(5.0).add(5.0).realize() b = a.clone() np.testing.assert_allclose(a.numpy(), b.numpy()) - self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer) + self.assertIsNot(a.uop.base.buffer, b.uop.base.buffer) def test_clone_with_shrink(self): a = Tensor.rand(16, 16) b = a.shrink(((2, 10), None)).clone() b.realize() - self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer) + self.assertIsNot(a.uop.base.buffer, b.uop.base.buffer) def test_clone_with_shrink_realized(self): a = Tensor.rand(16, 16).realize() b = a.shrink(((2, 10), None)).clone() b.realize() - self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer) + self.assertIsNot(a.uop.base.buffer, b.uop.base.buffer) def test_clone_with_grad(self): a = Tensor.rand(16, 16, requires_grad=True) @@ -780,7 +780,7 @@ class TestTensorMetadata(unittest.TestCase): @unittest.skip("why would this be true?") def test_exclude_noop_metadata(self): a = Tensor.rand(4, 4)*1 - self.assertEqual(a.lazydata.metadata[0].name, "__mul__") + self.assertEqual(a.uop.metadata[0].name, "__mul__") k = a.schedule()[-1] self.assertEqual([m.name for m in k.metadata], ["rand"]) @@ -797,7 +797,7 @@ class TestTensorMetadata(unittest.TestCase): x = Tensor.rand(3, requires_grad=True) W = Tensor.rand(3, 3, requires_grad=True) out = x.matmul(W) - self.assertEqual(out.lazydata.metadata[0].name, "matmul") + self.assertEqual(out.uop.metadata[0].name, "matmul") si = out.schedule()[-1] self.assertEqual(len(si.metadata), 1) self.assertEqual(si.metadata[0].name, "matmul") @@ -805,7 +805,7 @@ class TestTensorMetadata(unittest.TestCase): def test_relu(self): x = Tensor.rand(3, requires_grad=True) out = x.relu() - self.assertEqual(out.lazydata.metadata[0].name, "relu") + self.assertEqual(out.uop.metadata[0].name, "relu") si = out.schedule()[-1] self.assertEqual(len(si.metadata), 1) self.assertEqual(si.metadata[0].name, "relu") @@ -814,9 +814,9 @@ class TestTensorMetadata(unittest.TestCase): x = Tensor.rand(3, requires_grad=True) y = Tensor.rand(3, requires_grad=True) out = x.relu() * y.sigmoid() - self.assertEqual(out.lazydata.metadata[0].name, "__mul__") - self.assertEqual(out.lazydata.src[0].metadata[0].name, "relu") - self.assertEqual(out.lazydata.src[1].metadata[0].name, "sigmoid") + self.assertEqual(out.uop.metadata[0].name, "__mul__") + self.assertEqual(out.uop.src[0].metadata[0].name, "relu") + self.assertEqual(out.uop.src[1].metadata[0].name, "sigmoid") si = out.schedule()[-1] self.assertEqual(len(si.metadata), 3) self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"}) @@ -825,12 +825,12 @@ class TestTensorMetadata(unittest.TestCase): x = Tensor.rand(3, requires_grad=True).realize() y = Tensor.rand(3, requires_grad=True).realize() out = (x.relu() * y.sigmoid()).sum() - self.assertEqual(out.lazydata.metadata[0].name, "sum") + self.assertEqual(out.uop.metadata[0].name, "sum") out.backward() - self.assertEqual(x.grad.lazydata.metadata[0].name, "relu") - self.assertTrue(x.grad.lazydata.metadata[0].backward) - self.assertEqual(y.grad.lazydata.metadata[0].name, "sigmoid") - self.assertTrue(y.grad.lazydata.metadata[0].backward) + self.assertEqual(x.grad.uop.metadata[0].name, "relu") + self.assertTrue(x.grad.uop.metadata[0].backward) + self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid") + self.assertTrue(y.grad.uop.metadata[0].backward) si = Tensor.schedule(out, x.grad, y.grad)[-1] self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}") self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "__mul__", "relu"}) diff --git a/test/test_tensor_uop.py b/test/test_tensor_uop.py index 42acc51c7a..19541ead38 100644 --- a/test/test_tensor_uop.py +++ b/test/test_tensor_uop.py @@ -9,7 +9,7 @@ class TestTensorUOp(unittest.TestCase): def test_fromcpu_shape_tracker(self): def helper(a: np.ndarray): print(a.shape, a.strides, a.flags.c_contiguous) - b = Tensor(a).lazydata + b = Tensor(a).uop #assert b.st.contiguous == a.flags.c_contiguous assert b.st.shape == a.shape np.testing.assert_equal(a, Tensor(b).numpy()) @@ -60,11 +60,11 @@ class TestTensorUOp(unittest.TestCase): np.testing.assert_allclose(c.numpy(), np.concatenate((a.numpy(), b.numpy()), axis=1)) def test_const_dtype(self): - lb: UOp = Tensor([1], dtype=dtypes.int).lazydata + lb: UOp = Tensor([1], dtype=dtypes.int).uop assert lb.const_like(1).base.arg == 1 assert type(lb.const_like(1).base.arg) is int - lb: UOp = Tensor([1], dtype=dtypes.float).lazydata + lb: UOp = Tensor([1], dtype=dtypes.float).uop assert lb.const_like(1).base.arg == 1.0 assert type(lb.const_like(1).base.arg) is float diff --git a/test/test_uops.py b/test/test_uops.py index f2572ae4e6..779cac713d 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -476,7 +476,7 @@ class TestUOpStr(unittest.TestCase): assert str(eval(str(device))) == str(device) def test_reduceop_arg(self): - sum_uop = Tensor.empty(32, 32).sum().lazydata + sum_uop = Tensor.empty(32, 32).sum().uop assert str(eval(str(sum_uop))) == str(sum_uop) @unittest.skip("uop no longer has order like this") @@ -549,9 +549,9 @@ class TestShapeSpec(unittest.TestCase): # ** CONST is CONST(VIEW(DEVICE)) -> RESHPAE -> EXPAND def test_expanded_const(self): - a = Tensor(1).lazydata + a = Tensor(1).uop self.assertEqual(a.st, ShapeTracker.from_shape(())) - a = Tensor.ones((4, 4)).lazydata + a = Tensor.ones((4, 4)).uop self.assertEqual(a.st, ShapeTracker.from_shape(()).reshape((1,1)).expand((4,4))) def test_padded_const(self): @@ -569,12 +569,12 @@ class TestShapeSpec(unittest.TestCase): # NOTE: CONST ShapeTracker comes from its source def test_scalar_const(self): - a = Tensor(0).lazydata + a = Tensor(0).uop self.assertEqual(a.st, ShapeTracker.from_shape(())) def test_scalar_var(self): vv = UOp.variable("a", 1, 4).bind(2) - t = Tensor(vv).lazydata + t = Tensor(vv).uop self.assertEqual(t.st, ShapeTracker.from_shape(())) # ** ASSIGN is ASSIGN(VIEW(BUFFER), new_val) @@ -583,7 +583,7 @@ class TestShapeSpec(unittest.TestCase): buffer = Tensor.arange(4).realize() a = buffer.assign(Tensor.zeros((4,), dtype=dtypes.int)) assign_pattern = UPat(Ops.ASSIGN, src=(UPat(Ops.BUFFER), UPat())) - assert assign_pattern.match(a.lazydata, {}) + assert assign_pattern.match(a.uop, {}) a.realize() self.assertEqual(buffer.tolist(), [0, 0, 0, 0]) @@ -597,7 +597,7 @@ class TestShapeSpec(unittest.TestCase): buffer = Tensor.ones((4,)).contiguous().realize() a = buffer.reshape((2, 2)).assign(Tensor.zeros((2, 2))) assign_pattern = UPat(Ops.ASSIGN, src=(UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER))), UPat())) - assert assign_pattern.match(a.lazydata, {}) + assert assign_pattern.match(a.uop, {}) a.realize() self.assertEqual(buffer.tolist(), [0, 0, 0, 0]) @@ -606,13 +606,13 @@ class TestShapeSpec(unittest.TestCase): a = Tensor.ones((4,)).contiguous().realize() assign = a.shrink(((1, 2),)).assign(Tensor.zeros((1,))) # the ASSIGN UOp has size=1 - self.assertEqual(assign.lazydata.size, 1) + self.assertEqual(assign.uop.size, 1) # the ASSIGN views the buffer with a shrunk st - self.assertEqual(assign.lazydata.src[0].st, ShapeTracker.from_shape((4,)).shrink(((1, 2),))) + self.assertEqual(assign.uop.src[0].st, ShapeTracker.from_shape((4,)).shrink(((1, 2),))) # the underlying BUFFER has a size=4 - self.assertEqual(assign.lazydata.buf_uop.size, 4) + self.assertEqual(assign.uop.buf_uop.size, 4) # NOTE: output shape is different from the BUFFER shape - self.assertNotEqual(assign.lazydata.shape, a.lazydata.shape) + self.assertNotEqual(assign.uop.shape, a.uop.shape) assign.realize() self.assertEqual(a.tolist(), [1, 0, 1, 1]) @@ -622,13 +622,13 @@ class TestShapeSpec(unittest.TestCase): def test_ops_st(self): # view / mop - a = Tensor.empty(4, 2, 1).permute((1, 2, 0)).lazydata + a = Tensor.empty(4, 2, 1).permute((1, 2, 0)).uop self.assertEqual(a.st, ShapeTracker.from_shape((4, 2, 1)).permute((1, 2, 0))) # alu / reduce alu = a*2 self.assertEqual(alu.st, ShapeTracker.from_shape((2, 1, 4))) r = Tensor.empty(4, 4).sum(axis=1) - self.assertEqual(r.lazydata.st, ShapeTracker.from_shape((4,))) + self.assertEqual(r.uop.st, ShapeTracker.from_shape((4,))) def test_st_wmma_none(self): A = UOp(Ops.DEFINE_VAR, dtypes.float.vec(16), arg=('a', UOp.const(dtypes.float, 0), UOp.const(dtypes.float, 1))) diff --git a/test/test_zero_copy.py b/test/test_zero_copy.py index 6f2b2cda0b..c8625780dd 100644 --- a/test/test_zero_copy.py +++ b/test/test_zero_copy.py @@ -6,7 +6,7 @@ def time_tensor_numpy(out:Tensor): times = [] for _ in range(5): st = time.perf_counter() - out.lazydata.base.realized.as_buffer(allow_zero_copy=True) + out.uop.base.realized.as_buffer(allow_zero_copy=True) et = time.perf_counter() - st times.append(et) return min(times) diff --git a/test/unit/test_gradient.py b/test/unit/test_gradient.py index af2cce8505..2739fe9006 100644 --- a/test/unit/test_gradient.py +++ b/test/unit/test_gradient.py @@ -113,16 +113,16 @@ class TestTensorGradient(unittest.TestCase): class TestRealizeMeansRealize(unittest.TestCase): def test_randn_realizes(self): x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize() - assert x.lazydata is not x.lazydata.base - assert x.lazydata.is_realized + assert x.uop is not x.uop.base + assert x.uop.is_realized #@unittest.expectedFailure # update: passing after delete_forced_realize def test_uniform_realizes(self): x = Tensor.uniform(16, 3, 3, 3, requires_grad=True).realize() - print(x.lazydata) - assert x.lazydata is not x.lazydata.base - assert x.lazydata.is_realized + print(x.uop) + assert x.uop is not x.uop.base + assert x.uop.is_realized # NOTE: even though it doesn't realize, this seems fine def test_uniform_gradient(self): diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index c19672bf0d..23fda92598 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -836,29 +836,29 @@ class TestConsecutive(unittest.TestCase): self.ones = Tensor.ones(2, 4) def test_unmodified(self): - assert self.t.lazydata.st.consecutive - assert self.t.reshape(4, 2).lazydata.st.consecutive - assert self.t.reshape(1, 8).lazydata.st.consecutive + assert self.t.uop.st.consecutive + assert self.t.reshape(4, 2).uop.st.consecutive + assert self.t.reshape(1, 8).uop.st.consecutive def test_sliced(self): - assert self.t[0].lazydata.st.consecutive - assert self.t[0, 1:2].lazydata.st.consecutive - assert self.t[1].lazydata.st.consecutive - assert not self.t[:, 0].lazydata.st.consecutive - assert not self.t[:, 1].lazydata.st.consecutive + assert self.t[0].uop.st.consecutive + assert self.t[0, 1:2].uop.st.consecutive + assert self.t[1].uop.st.consecutive + assert not self.t[:, 0].uop.st.consecutive + assert not self.t[:, 1].uop.st.consecutive def test_padded(self): - assert not self.t.pad(((1, 1), None)).lazydata.st.consecutive - assert not self.t.pad((None, (1, 1))).lazydata.st.consecutive + assert not self.t.pad(((1, 1), None)).uop.st.consecutive + assert not self.t.pad((None, (1, 1))).uop.st.consecutive def test_const(self): - assert self.const.lazydata.st.consecutive + assert self.const.uop.st.consecutive def test_ones(self): - assert not self.ones.lazydata.st.consecutive - assert not self.ones[0, :].lazydata.st.consecutive + assert not self.ones.uop.st.consecutive + assert not self.ones[0, :].uop.st.consecutive # consecutive if sliced into size 1 - assert self.ones[0, 0].lazydata.st.consecutive + assert self.ones[0, 0].uop.st.consecutive class TestRender(unittest.TestCase): def test_render(self): diff --git a/test/unit/test_tensor_uop_representation.py b/test/unit/test_tensor_uop_representation.py index 9156d2b6fb..a2dacee16a 100644 --- a/test/unit/test_tensor_uop_representation.py +++ b/test/unit/test_tensor_uop_representation.py @@ -8,21 +8,21 @@ realized_pattern = UPat(Ops.BUFFER) buffer_view_pattern = UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)) const_pattern = UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),),))) def is_pattern_uop(u:UOp, pat:UPat): assert pat.match(u, {}), f"{u}\nis not\n{pat}" -def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.lazydata, pat) +def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.uop, pat) class TestTensorMutates(unittest.TestCase): def test_mutate_add(self): a = Tensor([1,2,3]) b = Tensor([4,5,6]) ret = a+b - pa = a.lazydata - pb = b.lazydata - pr = ret.lazydata + pa = a.uop + pb = b.uop + pr = ret.uop ret.schedule() - self.assertIsNot(pa, a.lazydata) - self.assertIsNot(pb, b.lazydata) - self.assertIsNot(pr, ret.lazydata) - for t in [a,b,ret]: is_pattern_uop(t.lazydata.base, realized_pattern) + self.assertIsNot(pa, a.uop) + self.assertIsNot(pb, b.uop) + self.assertIsNot(pr, ret.uop) + for t in [a,b,ret]: is_pattern_uop(t.uop.base, realized_pattern) def test_reshape_is_same_parent(self): a = Tensor([1,2,3]) @@ -30,11 +30,11 @@ class TestTensorMutates(unittest.TestCase): c = a+b d = (a+b).reshape(3,1) d.realize() - is_pattern_uop(d.lazydata.base, realized_pattern) - is_pattern_uop(c.lazydata.base, realized_pattern) + is_pattern_uop(d.uop.base, realized_pattern) + is_pattern_uop(c.uop.base, realized_pattern) # NOTE: we keep movement ops on top of the buffer view - is_pattern_uop(c.lazydata, UPat(Ops.BUFFER)) - is_pattern_uop(d.lazydata, UPat(Ops.VIEW, src=(realized_pattern,))) + is_pattern_uop(c.uop, UPat(Ops.BUFFER)) + is_pattern_uop(d.uop, UPat(Ops.VIEW, src=(realized_pattern,))) def test_reshape_is_same_child(self): a = Tensor([1,2,3]) @@ -42,41 +42,41 @@ class TestTensorMutates(unittest.TestCase): c = a+b d = (a+b).reshape(3,1) c.realize() - is_pattern_uop(c.lazydata.base, realized_pattern) - is_pattern_uop(d.lazydata.base, realized_pattern) + is_pattern_uop(c.uop.base, realized_pattern) + is_pattern_uop(d.uop.base, realized_pattern) class TestTensorUopRepresentation(unittest.TestCase): def test_realized(self): a = Tensor([1.,2,3]).realize() - print(a.lazydata) - is_pattern_uop(a.lazydata.base, realized_pattern) + print(a.uop) + is_pattern_uop(a.uop.base, realized_pattern) def test_add_realized(self): a = Tensor([1.,2,3]).realize() b = Tensor([4.,5,6]).realize() c = a+b - print(c.lazydata) + print(c.uop) is_pattern(c, UPat(Ops.ADD, src=(realized_pattern, realized_pattern))) def test_const_pattern(self): a = Tensor(1) - print(a.lazydata) + print(a.uop) is_pattern(a, const_pattern) # const in tensor has a DEVICE and VIEW src is_pattern(a, UPat.cvar("x")) # even cvar works! def test_consts_do_not_realize(self): a = Tensor(1) - print(a.lazydata) - pre_realize = a.lazydata + print(a.uop) + pre_realize = a.uop a.realize() - assert a.lazydata is pre_realize + assert a.uop is pre_realize def test_viewed_consts_do_not_realize(self): a = Tensor.ones(10, 10) - print(a.lazydata) + print(a.uop) a.realize() is_pattern(a, const_pattern) - self.assertEqual(a.lazydata.shape, (10, 10)) + self.assertEqual(a.uop.shape, (10, 10)) # currently, CONSTs have a "fake" BUFFER. this should be fixed # current: @@ -93,8 +93,8 @@ class TestTensorUopRepresentation(unittest.TestCase): # UOp(Ops.DEVICE, dtypes.void, arg="METAL", src=()),)),)),)) def test_consts_dont_have_buffers(self): a = Tensor.ones(10, 10) - print(a.lazydata) - buffers_in_parents = [x.op for x in a.lazydata.toposort() if x.op is Ops.BUFFER] + print(a.uop) + buffers_in_parents = [x.op for x in a.uop.toposort() if x.op is Ops.BUFFER] self.assertEqual(len(buffers_in_parents), 0) # currently, COPY has an extra BUFFER on the output @@ -112,7 +112,7 @@ class TestTensorUopRepresentation(unittest.TestCase): def test_copyin(self): a = Tensor([1.,2,3]).realize() c = a.to("TEST") # NOTE: this isn't checked - print(c.lazydata) + print(c.uop) is_pattern(c, UPat(Ops.COPY, src=(realized_pattern, UPat(Ops.DEVICE)))) def test_empty_buf(self): @@ -121,7 +121,7 @@ class TestTensorUopRepresentation(unittest.TestCase): vi = UOp.variable("i", 1, 3).bind(1) a = Tensor.empty(3, vi) is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),))) - self.assertEqual(a.lazydata.base.buffer.size, 9) + self.assertEqual(a.uop.base.buffer.size, 9) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 23f24250f7..c63810f101 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -176,7 +176,7 @@ class CapturedJit(Generic[ReturnType]): self.__post_init__() # reset the graph state def replan_buffers_memory_layout(self): - blacklist = [t.lazydata.buffer for t in get_parameters(self.ret)] + blacklist = [t.uop.buffer for t in get_parameters(self.ret)] asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True) self.jit_cache = [ExecItem(item.prg, [asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache] for old, new in asgn.items(): @@ -210,9 +210,9 @@ class CapturedJit(Generic[ReturnType]): def _prepare_jit_inputs(args, kwargs): input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor] names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors] - if len(unrealized_tensors := [x for x in tensors if not x.lazydata.is_realized]): Tensor.realize(*unrealized_tensors) + if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors) # TODO: this multi unpack stuff is not well tested. - lbs: list[UOp] = flatten([t.lazydata.src if t.lazydata.op is Ops.MULTI else [t.lazydata] for t in tensors]) + lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors]) input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb] for lb in lbs if lb.base.realized is not None]) assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT" diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 9ab16ca738..6b56a45b6c 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -155,7 +155,7 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.') if isinstance(v.device, tuple): if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k]) - else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis)) + else: v.replace(state_dict[k].shard(v.device, v.uop.axis)) else: v.replace(state_dict[k].to(v.device)) if realize: v.realize() if consume: del state_dict[k] diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 6e0f038378..544ddf8398 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -20,7 +20,7 @@ from tinygrad.engine.grouper import get_kernelize_map all_tensors: set[weakref.ref[Tensor]] = set() def _find_all_tensors_for_uops(all_uops: set[UOp]) -> list[Tensor]: - return [t for tref in all_tensors if (t:=tref()) is not None and t.lazydata in all_uops] + return [t for tref in all_tensors if (t:=tref()) is not None and t.uop in all_uops] def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str|None=None) -> None: # get all children of keys in applied_map @@ -36,13 +36,13 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str|None=None) -> Non # NOTE: this uses all_tensors, but it's fast if len(fixed_tensors := _find_all_tensors_for_uops(all_uops)): # potentially rewrite all the discovered Tensors - sink = UOp.sink(*[t.lazydata for t in fixed_tensors]) + sink = UOp.sink(*[t.uop for t in fixed_tensors]) new_sink = sink.substitute(applied_map, name=name) - # set the relevant lazydata to the realized UOps + # set the relevant uop to the realized UOps for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src): if s is ns: continue - t.lazydata = ns + t.uop = ns # **** Tensor helper functions **** @@ -119,7 +119,7 @@ class Tensor(MathTrait): np.set_printoptions(precision=4) ``` """ - __slots__ = "lazydata", "requires_grad", "grad" + __slots__ = "uop", "requires_grad", "grad" training: ClassVar[bool] = False def __init__(self, data:ConstType|bytes|list|tuple|UOp|'np.ndarray'|pathlib.Path|None, # type: ignore [name-defined] # noqa: F821 @@ -150,7 +150,7 @@ class Tensor(MathTrait): if dtype is None: if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True - if dtype in [dtypes.bfloat16, *dtypes.fp8s]: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtype).lazydata + if dtype in [dtypes.bfloat16, *dtypes.fp8s]: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtype).uop else: data = _frompy(data, dtype) elif str(type(data)) == "": import numpy as np @@ -165,19 +165,19 @@ class Tensor(MathTrait): if not isinstance(data, UOp): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}") # data might be on a different device - if isinstance(device, str): self.lazydata:UOp = data if data.device == device else data.copy_to_device(device) + if isinstance(device, str): self.uop:UOp = data if data.device == device else data.copy_to_device(device) # if device is a tuple, we should have/construct a MultiLazyBuffer - elif isinstance(data.device, str): self.lazydata = Tensor(data).shard(device).lazydata + elif isinstance(data.device, str): self.uop = Tensor(data).shard(device).uop else: assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}" - self.lazydata = data + self.uop = data # add to all_tensors after construction succeeds all_tensors.add(weakref.ref(self)) def __del__(self): all_tensors.discard(weakref.ref(self)) def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor: - new_uop: UOp = fxn(*[t.lazydata for t in (self,)+x], **kwargs) + new_uop: UOp = fxn(*[t.uop for t in (self,)+x], **kwargs) if (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,) needs_input_grad = [t.requires_grad for t in (self,)+x] return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False) @@ -196,9 +196,9 @@ class Tensor(MathTrait): def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev def __repr__(self): - ld = self.lazydata + ld = self.uop ld_repr = f"" - return f"" + return f"" # Python has a non moving GC, so this should be okay def __hash__(self): return id(self) @@ -210,13 +210,13 @@ class Tensor(MathTrait): return self.shape[0] @property - def device(self) -> str|tuple[str, ...]: return self.lazydata.device + def device(self) -> str|tuple[str, ...]: return self.uop.device @property - def shape(self) -> tuple[sint, ...]: return self.lazydata.shape + def shape(self) -> tuple[sint, ...]: return self.uop.shape @property - def dtype(self) -> DType: return self.lazydata.dtype + def dtype(self) -> DType: return self.uop.dtype # ***** data handlers **** @@ -226,7 +226,7 @@ class Tensor(MathTrait): NOTE: Kernelize can be called multiple times on a Tensor """ - big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst]) + big_sink = UOp.sink(*[x.uop for x in (self,)+lst]) # verify Tensors match the spec if __debug__: type_verify(list(big_sink.toposort()), tensor_uop_spec) @@ -243,7 +243,7 @@ class Tensor(MathTrait): """ st = time.perf_counter() self.kernelize(*lst) - sink = UOp.sink(*[x.lazydata for x in (self,)+lst]) + sink = UOp.sink(*[x.uop for x in (self,)+lst]) # remove all ASSIGNs, after scheduling, the tensors are just buffers remove_assign_map = {u:u.buf_uop for u in sink.toposort() if u.op is Ops.ASSIGN} @@ -272,31 +272,31 @@ class Tensor(MathTrait): """ # used for replacing a Tensor with a new version of it (potentially with a different device and dtype) assert self.shape == x.shape or allow_shape_mismatch, f"replace shape mismatch {self.shape} != {x.shape}" - self.lazydata = x.lazydata + self.uop = x.uop return self def assign(self, x) -> Tensor: # TODO: this is a hack for writing to DISK. remove with working assign if isinstance(self.device, str) and self.device.startswith("DISK"): if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype) - cast(Buffer, self.contiguous().realize().lazydata.base.buffer).ensure_allocated().copyin(x._data()) + cast(Buffer, self.contiguous().realize().uop.base.buffer).ensure_allocated().copyin(x._data()) return self if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype) - if self.lazydata is x.lazydata: return self # a self assign is a NOOP + if self.uop is x.uop: return self # a self assign is a NOOP # NOTE: we allow cross device assign assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}" assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}" - self.lazydata = self.lazydata.assign(x.lazydata) + self.uop = self.uop.assign(x.uop) return self def detach(self) -> Tensor: """ Returns a new tensor with the same data as this tensor, but detached from the autograd graph. """ - return Tensor(self.lazydata.detach(), device=self.device, requires_grad=False) + return Tensor(self.uop.detach(), device=self.device, requires_grad=False) - def _buffer(self) -> Buffer: return cast(Buffer, self.cast(self.dtype.base).contiguous().to("CPU").realize().lazydata.base.buffer) + def _buffer(self) -> Buffer: return cast(Buffer, self.cast(self.dtype.base).contiguous().to("CPU").realize().uop.base.buffer) def _data(self) -> memoryview: return self._buffer().as_buffer() def data(self) -> memoryview: @@ -372,7 +372,7 @@ class Tensor(MathTrait): device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device) if device == self.device: return self if not isinstance(device, str): return self.shard(device) - ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad) + ret = Tensor(self.uop, device, requires_grad=self.requires_grad) if self.grad is not None: ret.grad = self.grad.to(device) return ret @@ -390,12 +390,12 @@ class Tensor(MathTrait): ```python exec="true" source="above" session="tensor" result="python" t = Tensor.empty(2, 4) - print(t.shard((t.device, t.device), axis=1).lazydata) + print(t.shard((t.device, t.device), axis=1).uop) ``` """ assert isinstance(self.device, str), "can't shard a MultiLazyBuffer" devices = tuple(Device.canonicalize(x) for x in devices) - mlb = self.lazydata.shard(devices, self._resolve_dim(axis)) if axis is not None else self.lazydata.copy_to_device(devices) + mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices) return Tensor(mlb, device=devices, requires_grad=self.requires_grad) def shard_(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor: @@ -444,7 +444,7 @@ class Tensor(MathTrait): """ r = Tensor.empty(*shape, **kwargs) assert isinstance(r.device, str) - cast(Buffer, r.lazydata.buffer).allocate(external_ptr=ptr) + cast(Buffer, r.uop.buffer).allocate(external_ptr=ptr) return r @staticmethod @@ -719,7 +719,7 @@ class Tensor(MathTrait): dtype = kwargs.pop("dtype", self.dtype) if isinstance(self.device, tuple): if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor") - return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device, self.lazydata.axis) + return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device, self.uop.axis) return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs) # ***** rng hlops ***** @@ -923,13 +923,13 @@ class Tensor(MathTrait): assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor" if not (self.is_floating_point() and all(t.is_floating_point() for t in targets)): raise RuntimeError("only float Tensors have gradient") if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False) - target_uops = [x.lazydata for x in targets] - grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops)) + target_uops = [x.uop for x in targets] + grads = compute_gradient(self.uop, gradient.uop, set(target_uops)) ret = [] for x in target_uops: if (y:=grads.get(x)) is None: if materialize_grads: y = x.const_like(0) - else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}") + else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.uop}") ret.append(y) # create returned Tensors return [Tensor(u, device=t.device) for t,u in zip(targets, ret)] @@ -944,9 +944,9 @@ class Tensor(MathTrait): print(t.grad.numpy()) ``` """ - all_uops = self.lazydata.toposort() + all_uops = self.uop.toposort() tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \ - t.lazydata in all_uops and t.requires_grad] + t.uop in all_uops and t.requires_grad] # clear contexts for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)): assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}" @@ -1253,14 +1253,14 @@ class Tensor(MathTrait): self.realize()._getitem(indices).assign(v) return # NOTE: check that setitem target is valid first - if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous") + if not unwrap(self.uop.st).contiguous: raise RuntimeError("setitem target needs to be contiguous") if isinstance(v, get_args(ConstType)): v = Tensor(v, device=self.device, dtype=self.dtype) if not isinstance(v, Tensor): raise TypeError(f"can't set a {type(v).__name__} to a Tensor") if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported") res = self.realize()._getitem(indices, v) # if shapes match and data is not shared it's a copy and we assign to self - if res.shape == self.shape and res.lazydata is not self.lazydata: + if res.shape == self.shape and res.uop is not self.uop: self.assign(res).realize() else: # no copy, basic setitem v = v.cast(res.dtype)._broadcast_to(_broadcast_shape(res.shape, v.shape)).contiguous()