rename lazydata to uop (#10698)

This commit is contained in:
George Hotz
2025-06-08 08:42:22 -07:00
committed by GitHub
parent 8e3f337075
commit 32e9949052
57 changed files with 485 additions and 486 deletions

View File

@@ -36,7 +36,7 @@ optim.schedule_step() # this will step the optimizer without running realize
# 3. Create a schedule.
# The weight Tensors have been assigned to, but not yet realized. Everything is still lazy at this point
# l1.lazydata and l2.lazydata define a computation graph
# l1.uop and l2.uop define a computation graph
from tinygrad.engine.schedule import ScheduleItem
schedule: List[ScheduleItem] = Tensor.schedule(l1, l2)

View File

@@ -34,7 +34,7 @@ print(out) # <Tensor <UOp METAL (1,) int (<Ops.ASSIGN: 66>, None)> on METAL with
The multiply Tensor stays the same because it is fused. The output Tensor's UOp becomes a new ASSIGN UOp:
```py
print(out.lazydata)
print(out.uop)
```
The first source is the output BUFFER:
@@ -72,7 +72,7 @@ Once a Tensor is kernelized, all children will LOAD its BUFFER, instead of fusin
```py
child = out+2
child.kernelize()
print(child.lazydata.src[1].arg.ast)
print(child.uop.src[1].arg.ast)
```
```

View File

@@ -39,8 +39,8 @@ assert t.shape == (4,)
print(t)
# <Tensor <UOp METAL (4,) int (<Ops.COPY: 7>, None)> on METAL with grad None>
# the ".lazydata" property on Tensor contains the specification of how to compute it
print(t.lazydata)
# the ".uop" property on Tensor contains the specification of how to compute it
print(t.uop)
"""
UOp(Ops.COPY, dtypes.int, arg=None, src=(
UOp(Ops.BUFFER, dtypes.int, arg=4, src=(
@@ -57,21 +57,21 @@ UOp(Ops.COPY, dtypes.int, arg=None, src=(
t.realize()
# if we want to "realize" a tensor, we can with the "realize" method
# now when we look at the lazydata, it's changed
print(t.lazydata)
# now when we look at the uop, it's changed
print(t.uop)
"""
UOp(Ops.BUFFER, dtypes.int, arg=4, src=(
UOp(Ops.UNIQUE, dtypes.void, arg=1, src=()),
UOp(Ops.DEVICE, dtypes.void, arg='METAL', src=()),))
"""
# the copy was actually run, and now the "lazydata" of the Tensor is just a BUFFER
# the copy was actually run, and now the "uop" of the Tensor is just a BUFFER
# if you run this script with DEBUG=2 in the environment, you can see the copy happen
# *** METAL 1 copy 16, METAL <- PYTHON ...
# now let's do some compute
# we look at the lazydata to see the specification of the compute
# we look at the uop to see the specification of the compute
t_times_2 = t * 2
print(t_times_2.lazydata)
print(t_times_2.uop)
"""
UOp(Ops.MUL, dtypes.int, arg=None, src=(
UOp(Ops.BUFFER, dtypes.int, arg=4, src=(
@@ -90,24 +90,24 @@ UOp(Ops.MUL, dtypes.int, arg=None, src=(
assert t_times_2.tolist() == [2, 4, 6, 8]
# UOps are both immutable and globally unique
# if i multiply the Tensor by 4 twice, these result Tensors will have the same lazydata specification
# if i multiply the Tensor by 4 twice, these result Tensors will have the same uop specification
t_times_4_try_1 = t * 4
t_times_4_try_2 = t * 4
assert t_times_4_try_1.lazydata is t_times_4_try_2.lazydata
assert t_times_4_try_1.uop is t_times_4_try_2.uop
# the specification isn't just the same, it's the exact same Python object
assert t_times_4_try_1 is not t_times_4_try_2
# the Tensor is a different Python object
# if we realize `t_times_4_try_1` ...
t_times_4_try_1.realize()
print(t_times_4_try_2.lazydata)
print(t_times_4_try_2.uop)
"""
UOp(Ops.BUFFER, dtypes.int, arg=4, src=(
UOp(Ops.UNIQUE, dtypes.void, arg=4, src=()),
UOp(Ops.DEVICE, dtypes.void, arg='METAL', src=()),))
"""
# ... `t_times_4_try_2` also becomes the same BUFFER
assert t_times_4_try_1.lazydata is t_times_4_try_2.lazydata
assert t_times_4_try_1.uop is t_times_4_try_2.uop
# so this print doesn't require any computation, just a copy back to the CPU so we can print it
print("** only the copy start")
print(t_times_4_try_2.tolist()) # [4, 8, 12, 16]
@@ -120,7 +120,7 @@ t_float = Tensor([3.0])
t_log = t_float.log()
t_log_grad, = t_log.sum().gradient(t_float)
# due to how log is implemented, this gradient contains a lot of UOps
print(t_log_grad.lazydata)
print(t_log_grad.uop)
# ...not shown here...
# but if you run with DEBUG=4 (CPU=1 used here for simpler code), you can see the generated code
"""
@@ -144,7 +144,7 @@ t = Tensor([1,2,3,4])
# NOTE: the APIs here are subject to change
t_plus_3_plus_4 = t + 3 + 4
print(t_plus_3_plus_4.lazydata)
print(t_plus_3_plus_4.uop)
"""
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
@@ -166,7 +166,7 @@ UOp(Ops.ADD, dtypes.int, arg=None, src=(
# but by the time we are actually running the code, it's adding 7
# `kernelize` will simplify and group the operations in the graph into kernels
t_plus_3_plus_4.kernelize()
print(t_plus_3_plus_4.lazydata)
print(t_plus_3_plus_4.uop)
"""
UOp(Ops.ASSIGN, dtypes.int, arg=None, src=(
x0:=UOp(Ops.BUFFER, dtypes.int, arg=4, src=(
@@ -181,7 +181,7 @@ UOp(Ops.ASSIGN, dtypes.int, arg=None, src=(
# ASSIGN has two srcs, src[0] is the BUFFER that's assigned to, and src[1] is the thing to assign
# src[1] is the GPU Kernel that's going to be run
# we can get the ast of the Kernel as follows
kernel_ast = t_plus_3_plus_4.lazydata.src[1].arg.ast
kernel_ast = t_plus_3_plus_4.uop.src[1].arg.ast
# almost everything in tinygrad functions as a rewrite of the UOps
# the codegen rewrites the ast to a simplified form ready for "rendering"

View File

@@ -240,7 +240,6 @@ class LLaMa:
#elif k.endswith('.weight'): v.shard_(device, axis=-1)
#elif 'norm.' in k: v.shard_(device, axis=-1)
else: v.shard_(device, axis=None)
#print(k, v.shape, v.lazydata.axis)
# replace weights in model
load_state_dict(model, weights, strict=False, consume=True)
@@ -446,7 +445,7 @@ After you are done speaking, output [EOS]. You are not Chad.
print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model")
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(args.shard)) if args.shard > 1 else Device.DEFAULT
llama = LLaMa.build(MODEL_PATH, TOKENIZER_PATH, model_gen=args.gen, model_size=args.size, quantize=args.quantize, device=device)
param_bytes = sum(x.lazydata.size * x.dtype.itemsize for x in get_parameters(llama.model))
param_bytes = sum(x.uop.size * x.dtype.itemsize for x in get_parameters(llama.model))
outputted = pre_prompt if chatbot else args.prompt
start_pos, toks = 0, [llama.tokenizer.bos_id()] + llama.tokenizer.encode(outputted)

View File

@@ -284,7 +284,7 @@ if __name__ == "__main__":
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(args.shard)) if args.shard > 1 else Device.DEFAULT
model = build_transformer(args.model, model_size=args.size, quantize=args.quantize, device=device)
param_bytes = sum(x.lazydata.size * x.dtype.itemsize for x in get_parameters(model))
param_bytes = sum(x.uop.size * x.dtype.itemsize for x in get_parameters(model))
if not args.no_api and not args.benchmark:
from bottle import Bottle, request, response, HTTPResponse, abort, static_file

View File

@@ -16,7 +16,7 @@ if __name__ == "__main__":
#model.load_pretrained()
for p in nn.state.get_parameters(model): p.replace(Tensor.empty(p.shape, dtype=p.dtype)) # fake load pretrained
#early_sched = create_schedule([x.lazydata for x in nn.state.get_parameters(model)])
#early_sched = create_schedule([x.uop for x in nn.state.get_parameters(model)])
#print(f"built model {len(early_sched)}")
#B, T = Variable("B", 1, 128).bind(4), 64 #Variable("T", 1, 1024).bind(64)
@@ -56,7 +56,7 @@ if __name__ == "__main__":
state_dict.update({'X': X, 'Y': Y, 'loss': loss})
grad_state_dict = {}
for k,v in state_dict.items():
if v.lazydata.base.buffer not in used_buffers: print(f"UNUSED: {k}")
if v.uop.base.buffer not in used_buffers: print(f"UNUSED: {k}")
if v.grad is not None: grad_state_dict['grad_'+k] = v.grad
state_dict.update(grad_state_dict)
state_dict.update({'adam_b1_t': optimizer.b1_t, 'adam_b2_t': optimizer.b2_t, 'adam_lr': optimizer.lr})
@@ -65,7 +65,7 @@ if __name__ == "__main__":
nm = inverse_state_dict[p]
state_dict["adam_m_"+nm] = m
state_dict["adam_v_"+nm] = v
named_buffers = {v.lazydata.base.buffer:k.replace(".", "_") for k,v in state_dict.items()}
named_buffers = {v.uop.base.buffer:k.replace(".", "_") for k,v in state_dict.items()}
c_code = ["#include <stdlib.h>", "#include <tgmath.h>", "#include <stdbool.h>"]
if TIMING: c_code += ["#include <stdio.h>", "#include <time.h>"]

View File

@@ -71,7 +71,7 @@ def loader_process(q_in, q_out, X:Tensor, seed):
#storage_tensor._copyin(img_tensor.numpy())
# faster
X[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
X[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
# ideal
#X[idx].assign(img.tobytes()) # NOTE: this is slow!
@@ -262,8 +262,8 @@ def load_unet3d_data(preprocessed_dataset_dir, seed, queue_in, queue_out, X:Tens
x = random_brightness_augmentation(x)
x = gaussian_noise(x)
X[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = x.tobytes()
Y[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = y.tobytes()
X[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = x.tobytes()
Y[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = y.tobytes()
queue_out.put(idx)
queue_out.put(None)
@@ -377,12 +377,12 @@ def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue
clipped_match_idxs = np.clip(match_idxs, 0, None)
clipped_boxes, clipped_labels = tgt["boxes"][clipped_match_idxs], tgt["labels"][clipped_match_idxs]
boxes[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = clipped_boxes.tobytes()
labels[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = clipped_labels.tobytes()
matches[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = match_idxs.tobytes()
anchors[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = anchor.tobytes()
boxes[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = clipped_boxes.tobytes()
labels[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = clipped_labels.tobytes()
matches[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = match_idxs.tobytes()
anchors[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = anchor.tobytes()
imgs[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
imgs[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
queue_out.put(idx)
queue_out.put(None)

View File

@@ -19,8 +19,8 @@ if __name__ == "__main__":
inputs = run_onnx.get_empty_input_data("npy")
out: Tensor = next(iter(run_onnx({k:v.to(None) for k,v in inputs.items()}).values())).to('cpu')
root = out.lazydata
targets = [x.lazydata for x in inputs.values()]
root = out.uop
targets = [x.uop for x in inputs.values()]
print(targets)
# TODO: abstract this from gradient?
@@ -42,7 +42,7 @@ if __name__ == "__main__":
print("**** real ****")
GlobalCounters.reset()
out.lazydata = root.substitute(kernelized).substitute(becomes_map)
out.uop = root.substitute(kernelized).substitute(becomes_map)
out.kernelize()
# realize

View File

@@ -66,7 +66,7 @@ if __name__ == "__main__":
model_path = Path(args.weights) if args.weights else download_weights(model_info["total_num_weights"])
transformer = load_model(model_path, model_info["model_params"])
tokenizer = AutoTokenizer.from_pretrained(model_info["tokenizer"])
param_bytes = sum(x.lazydata.size * x.dtype.itemsize for x in get_parameters(transformer))
param_bytes = sum(x.uop.size * x.dtype.itemsize for x in get_parameters(transformer))
outputted = args.prompt
start_pos, toks = 0, tokenizer(outputted)["input_ids"]

View File

@@ -13,8 +13,8 @@ def prepare_browser_chunks(model):
chunk_size = 16 * 1024 * 1024 # small chunks based on iphone browser constraints
metadata = {}
# We won't export cache_kv bytes (because we start inference on client at start_pos=0), but we will tell the client how big cache_kv needs to be
t_infos = [(v.lazydata.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" not in k]
empty_t_infos = [(v.lazydata.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" in k]
t_infos = [(v.uop.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" not in k]
empty_t_infos = [(v.uop.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" in k]
split_t_infos = []
for size, name, dtype in t_infos:
@@ -48,7 +48,7 @@ def prepare_browser_chunks(model):
weight_metadata = metadata.get(name, default)
weight_metadata["parts"][part_num] = {"file": i, "file_start_pos": cursor, "size": size}
metadata[name] = weight_metadata
data = bytes(state_dict[name].lazydata.base.realized.as_buffer())
data = bytes(state_dict[name].uop.base.realized.as_buffer())
data = data if not offsets else data[offsets[0]:offsets[1]]
writer.write(data)
cursor += size

View File

@@ -114,7 +114,7 @@ if __name__ == "__main__":
run, special_names = jit_model(step, *step.input)
functions, statements, bufs, _ = compile_net(run, special_names)
state = get_state_dict(model)
weights = {id(x.lazydata.base.realized): name for name, x in state.items()}
weights = {id(x.uop.base.realized): name for name, x in state.items()}
kernel_code = '\n\n'.join([f"const {key} = `{fixup_code(code, key)}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _, _, _) in statements])
input_names = [name for _,name in special_names.items() if "input" in name]

View File

@@ -48,13 +48,13 @@ def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
# hack to put the inputs back
for (j,i),idx in run.input_replace.items():
realized_input = args[idx].lazydata.base.realized
realized_input = args[idx].uop.base.realized
run.jit_cache[j].bufs[i] = realized_input
special_names[id(realized_input)] = f'input{idx}'
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
for i, output in enumerate(the_output):
special_names[id(output.lazydata.base.realized)] = f'output{i}'
special_names[id(output.uop.base.realized)] = f'output{i}'
return run, special_names
def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]],
@@ -242,7 +242,7 @@ def export_model(model, target:str, *inputs, model_name: Optional[str] = "model"
with Context(JIT=2): run,special_names = jit_model(model, *inputs)
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
state = get_state_dict(model)
weight_names = {id(x.lazydata.base.realized): name for name, x in state.items()}
weight_names = {id(x.uop.base.realized): name for name, x in state.items()}
input_names = [name for _,name in special_names.items() if "input" in name]
output_names = [name for _,name in special_names.items() if "output" in name]

View File

@@ -47,7 +47,7 @@ if __name__ == "__main__":
a = Tensor([0.,1.,2.], device="KFD").realize()
b = a + 7
b.lazydata.buffer.allocate()
b.uop.buffer.allocate()
si = b.schedule()[-1]
runner = dev.get_runner(*si.ast)
prg: AMDProgram = runner.clprg
@@ -69,8 +69,8 @@ if __name__ == "__main__":
#scratch = dev._gpu_alloc(0x10000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM)
ka = to_mv(dev.kernargs_ptr, 0x10).cast("Q")
ka[0] = b.lazydata.buffer._buf.va_addr
ka[1] = a.lazydata.buffer._buf.va_addr
ka[0] = b.uop.buffer._buf.va_addr
ka[1] = a.uop.buffer._buf.va_addr
compute_read_pointer = to_mv(compute_queue.read_pointer_address, 8).cast("Q")
compute_write_pointer = to_mv(compute_queue.write_pointer_address, 8).cast("Q")

View File

@@ -23,7 +23,7 @@ def explicit_shard_W_axis_1(X, W):
x = x.reshape(N, 1, N).expand(N, N, N)
w = w.T.reshape(1, N, N).expand(N, N, N)
m = x*w
assert m.lazydata.st.views[0].mask is not None
assert m.uop.st.views[0].mask is not None
ret = m.sum(2)
return ret
#Os = [lm(Xs[0], Ws[0]), lm(Xs[1], Ws[1])]

View File

@@ -89,7 +89,7 @@ def get_output(s:str, n_threads:int=1):
"s_waitcnt 0",
"global_store_b32 v0, v1, s[0:1]",
"s_nop 0", "s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)", "s_endpgm"])
test = Tensor.zeros((n_threads,), dtype=dtypes.uint32).contiguous().realize().lazydata.buffer
test = Tensor.zeros((n_threads,), dtype=dtypes.uint32).contiguous().realize().uop.buffer
prg = get_prg(code, 32, 32)
prg(test._buf, global_size=(1, 1, 1), local_size=(n_threads, 1, 1), wait=True)
return test.numpy()

View File

@@ -65,12 +65,12 @@ for k,v in view_ops.items(): torch.library.impl(k.replace("aten.", "aten::"), "p
# in place operations with views
def realize_with_views(self: Tensor, views: Tensor):
if not self.lazydata.st.contiguous: self.replace(self.contiguous())
if not self.uop.st.contiguous: self.replace(self.contiguous())
self.replace(self.clone().realize())
for v in views:
if v.lazydata.base.op is Ops.BUFFER_VIEW: continue # skip subbuffer, we just use the real buffer view
if v.uop.base.op is Ops.BUFFER_VIEW: continue # skip subbuffer, we just use the real buffer view
ret = self
st = ShapeTracker(self.lazydata.st.views + v.lazydata.st.views) # TODO: is this right?
st = ShapeTracker(self.uop.st.views + v.uop.st.views) # TODO: is this right?
for mo in cached_to_movement_ops(self.shape, st): ret = apply_mop(ret, mo)
v.replace(ret)
def maybe_realize_storage(self: Tensor) -> bool:
@@ -178,7 +178,7 @@ def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
# multiple as_strided do not compound
base = canonical_base(tensor)
# TODO: this is heavyweight
st = ShapeTracker(base.lazydata.st.views + (View.create(tuple(size), tuple(stride), storage_offset),))
st = ShapeTracker(base.uop.st.views + (View.create(tuple(size), tuple(stride), storage_offset),))
ret = base
if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st)
if prod(size) == 1: return ret.flatten()[storage_offset].reshape(size)
@@ -315,7 +315,7 @@ def _copy_from(src: torch.Tensor, dest, non_blocking=False):
to_device = _from_torch_device(dest.device)
src,dest = unwrap(src),unwrap(dest)
# TODO we need to properly match dest shape and strides, not blindly assign
if dest.lazydata.st.contiguous or dest.lazydata.is_realized: src = src.contiguous() # this only solves some cases
if dest.uop.st.contiguous or dest.uop.is_realized: src = src.contiguous() # this only solves some cases
dest.assign(src.cast(cast_dtype).to(to_device))
if realize: Tensor.realize(dest)
elif src.is_tiny and dest.is_cpu:
@@ -493,7 +493,7 @@ def wrap_out(f):
assert out.shape == assigned.shape, f"shape mismatch: {assigned.shape} -> {out.shape}"
assert out.device == assigned.device, f"device mismatch: {assigned.device} -> {out.device}"
assert out.dtype == assigned.dtype, f"dtype mismatch: {assigned.dtype} -> {out.dtype}"
if out.lazydata.is_realized: assigned = assigned.contiguous() # TODO: how does this map to torch's semantics
if out.uop.is_realized: assigned = assigned.contiguous() # TODO: how does this map to torch's semantics
return out.assign(assigned)
return _wrap_out

View File

@@ -116,7 +116,7 @@ at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype, c10::DeviceInd
// TODO: we have to get the dtype and the shape from the tinygrad Tensor
std::vector<int64_t> sizes = py_obj.attr("shape").cast<std::vector<int64_t>>();
py::list views = py_obj.attr("lazydata").attr("st").attr("views");
py::list views = py_obj.attr("uop").attr("st").attr("views");
std::vector<int64_t> strides = views[views.size() - 1].attr("strides").cast<std::vector<int64_t>>();
int64_t storage_offset = 0;
for (auto& v: views) {

View File

@@ -113,7 +113,7 @@ class DispatchLog(TorchDispatchMode):
_ = tiny_x.cpu().numpy()
if torch.is_tensor(tiny_x) and tiny_x.device.type == "tiny":
tt = tiny_torch.unwrap(tiny_x)
try: out_addr = tt.lazydata.buffer._buf.value
try: out_addr = tt.uop.buffer._buf.value
except Exception: pass
tiny_events = hook_cuda.collect_events(clear=True)
print_events(tiny_events, colored("tiny", "magenta"), out_addr)

View File

@@ -13,6 +13,6 @@ if __name__ == "__main__":
(Tensor.empty(BS, 16, 512, 512), Tensor.empty(BS, 512, 16, 64).permute(0,2,1,3)), # qk@v
]
for t0, t1 in tensors:
print(f"{t0.shape=}, {t0.lazydata.st.real_strides()=}, {t1.shape=}, {t1.lazydata.st.real_strides()=}")
print(f"{t0.shape=}, {t0.uop.st.real_strides()=}, {t1.shape=}, {t1.uop.st.real_strides()=}")
for _ in range(5):
t0.dot(t1, dtype=acc_dtype).realize()

View File

@@ -21,8 +21,8 @@ if __name__ == "__main__":
with Timing("CPU creation: ", on_exit=lambda x: f", {(sz*4*2)/x:.2f} GB/sec"):
c0 = (Tensor.ones(sz, device="CPU")/2).realize()
c1 = (Tensor.ones(sz, device="CPU")/4).realize()
print(c0.lazydata.base.realized)
print(c1.lazydata.base.realized)
print(c0.uop.base.realized)
print(c1.uop.base.realized)
with Timing("CPU -> 0: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"):
a0 = c0.to(d0).realize()

View File

@@ -9,18 +9,18 @@ class TestAMD(unittest.TestCase):
TestAMD.d0: AMDDevice = Device["AMD"]
TestAMD.a = Tensor([0.,1.], device="AMD").realize()
TestAMD.b = self.a + 1
si = create_schedule([self.b.lazydata])[-1]
si = create_schedule([self.b.uop])[-1]
TestAMD.d0_runner = TestAMD.d0.get_runner(*si.ast)
TestAMD.b.lazydata.buffer.allocate()
TestAMD.b.uop.buffer.allocate()
def test_amd_ring_64bit_doorbell(self):
TestAMD.d0.pm4_write_pointer[0] = TestAMD.d0.pm4_write_pointer[0] + (2 << 32) - TestAMD.d0.pm4_ring.size // 4
for _ in range(2000):
TestAMD.d0_runner.clprg(TestAMD.b.lazydata.buffer._buf, TestAMD.a.lazydata.buffer._buf,
TestAMD.d0_runner.clprg(TestAMD.b.uop.buffer._buf, TestAMD.a.uop.buffer._buf,
global_size=TestAMD.d0_runner.global_size, local_size=TestAMD.d0_runner.local_size)
TestAMD.d0_runner.clprg(TestAMD.a.lazydata.buffer._buf, TestAMD.b.lazydata.buffer._buf,
TestAMD.d0_runner.clprg(TestAMD.a.uop.buffer._buf, TestAMD.b.uop.buffer._buf,
global_size=TestAMD.d0_runner.global_size, local_size=TestAMD.d0_runner.local_size)
val = TestAMD.a.lazydata.buffer.as_buffer().cast("f")[0]
val = TestAMD.a.uop.buffer.as_buffer().cast("f")[0]
assert val == 4000.0, f"got val {val}"
if __name__ == "__main__":

View File

@@ -22,10 +22,10 @@ class TestHCQ(unittest.TestCase):
TestHCQ.b = self.a + 1
si = self.b.schedule()[-1]
TestHCQ.runner = get_runner(TestHCQ.d0.device, si.ast)
TestHCQ.b.lazydata.buffer.allocate()
TestHCQ.b.uop.buffer.allocate()
# wow that's a lot of abstraction layers
TestHCQ.addr = struct.pack("QQ", TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr)
TestHCQ.addr2 = struct.pack("QQ", TestHCQ.a.lazydata.buffer._buf.va_addr, TestHCQ.b.lazydata.buffer._buf.va_addr)
TestHCQ.addr = struct.pack("QQ", TestHCQ.b.uop.buffer._buf.va_addr, TestHCQ.a.uop.buffer._buf.va_addr)
TestHCQ.addr2 = struct.pack("QQ", TestHCQ.a.uop.buffer._buf.va_addr, TestHCQ.b.uop.buffer._buf.va_addr)
TestHCQ.kernargs_off = TestHCQ.runner._prg.kernargs_offset
TestHCQ.kernargs_size = TestHCQ.runner._prg.kernargs_alloc_size
ctypes.memmove(TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_off, TestHCQ.addr, len(TestHCQ.addr))
@@ -45,8 +45,8 @@ class TestHCQ(unittest.TestCase):
def setUp(self):
TestHCQ.d0.synchronize()
TestHCQ.a.lazydata.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
TestHCQ.b.lazydata.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 0))))
TestHCQ.a.uop.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
TestHCQ.b.uop.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 0))))
TestHCQ.d0.synchronize() # wait for copyins to complete
def test_run_1000_times_one_submit(self):
@@ -65,7 +65,7 @@ class TestHCQ(unittest.TestCase):
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
assert val == 2000.0, f"got val {val}"
def test_run_1000_times(self):
@@ -81,7 +81,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.compute_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
assert val == 2000.0, f"got val {val}"
def test_run_to_3(self):
@@ -95,7 +95,7 @@ class TestHCQ(unittest.TestCase):
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
assert val == 3.0, f"got val {val}"
def test_update_exec(self):
@@ -106,9 +106,9 @@ class TestHCQ(unittest.TestCase):
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
assert val == 1.0, f"got val {val}"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
assert val == 0.0, f"got val {val}, should not be updated"
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
@@ -126,7 +126,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.compute_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
assert val == 2000.0, f"got val {val}"
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
@@ -141,9 +141,9 @@ class TestHCQ(unittest.TestCase):
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
assert val == 1.0, f"got val {val}"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
assert val == 0.0, f"got val {val}, should not be updated"
@unittest.skipIf(CI, "Can't handle async update on CPU")
@@ -174,7 +174,7 @@ class TestHCQ(unittest.TestCase):
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
assert val == 1.0, f"got val {val}"
def test_submit_empty_queues(self):
@@ -206,13 +206,13 @@ class TestHCQ(unittest.TestCase):
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
assert val == 1.0, f"got val {val}"
def test_copy_1000_times(self):
q = TestHCQ.copy_queue()
q.copy(TestHCQ.a.lazydata.buffer._buf.va_addr, TestHCQ.b.lazydata.buffer._buf.va_addr, 8)
q.copy(TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr, 8)
q.copy(TestHCQ.a.uop.buffer._buf.va_addr, TestHCQ.b.uop.buffer._buf.va_addr, 8)
q.copy(TestHCQ.b.uop.buffer._buf.va_addr, TestHCQ.a.uop.buffer._buf.va_addr, 8)
for _ in range(1000):
q.submit(TestHCQ.d0)
TestHCQ.copy_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
@@ -221,24 +221,24 @@ class TestHCQ(unittest.TestCase):
# confirm the signal didn't exceed the put value
with self.assertRaises(RuntimeError):
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value + 1, timeout=50)
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
assert val == 0.0, f"got val {val}"
def test_copy(self):
q = TestHCQ.copy_queue()
q.copy(TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr, 8)
q.copy(TestHCQ.b.uop.buffer._buf.va_addr, TestHCQ.a.uop.buffer._buf.va_addr, 8)
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
assert val == 1.0, f"got val {val}"
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
def test_bind_copy(self):
q = TestHCQ.copy_queue()
q.copy(TestHCQ.a.lazydata.buffer._buf.va_addr, TestHCQ.b.lazydata.buffer._buf.va_addr, 8)
q.copy(TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr, 8)
q.copy(TestHCQ.a.uop.buffer._buf.va_addr, TestHCQ.b.uop.buffer._buf.va_addr, 8)
q.copy(TestHCQ.b.uop.buffer._buf.va_addr, TestHCQ.a.uop.buffer._buf.va_addr, 8)
q.bind(TestHCQ.d0)
for _ in range(1000):
q.submit(TestHCQ.d0)
@@ -248,7 +248,7 @@ class TestHCQ(unittest.TestCase):
# confirm the signal didn't exceed the put value
with self.assertRaises(RuntimeError):
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value + 1, timeout=50)
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
assert val == 0.0, f"got val {val}"
def test_copy_bandwidth(self):
@@ -281,14 +281,14 @@ class TestHCQ(unittest.TestCase):
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) # b = [1, 2]
q.signal(sig:=TestHCQ.d0._alloc_signal(value=0), value=1)
qc.wait(sig, value=1)
qc.copy(TestHCQ.a.lazydata.buffer._buf.va_addr, TestHCQ.b.lazydata.buffer._buf.va_addr, 8)
qc.copy(TestHCQ.a.uop.buffer._buf.va_addr, TestHCQ.b.uop.buffer._buf.va_addr, 8)
qc.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
qc.submit(TestHCQ.d0)
time.sleep(0.02) # give it time for the wait to fail
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
assert val == 1.0, f"got val {val}"
def test_cross_device_signal(self):
@@ -319,7 +319,7 @@ class TestHCQ(unittest.TestCase):
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
assert val == 1.0, f"got val {val}"
if __name__ == "__main__":

View File

@@ -10,7 +10,7 @@ class TestHIPCompileSpeed(unittest.TestCase):
def test_hip_compile(self):
a, b = Tensor([1,2,3,4,5]), Tensor([1,2,3,4,5])
out = a + b
lin = Kernel(create_schedule([out.lazydata])[-1].ast[0])
lin = Kernel(create_schedule([out.uop])[-1].ast[0])
lin.linearize()
reference = """

View File

@@ -21,8 +21,8 @@ class TestNV(unittest.TestCase):
TestNV.b = self.a + 1
si = self.b.schedule()[-1]
TestNV.d0_runner = get_runner(TestNV.d0.device, si.ast)
TestNV.b.lazydata.buffer.allocate()
TestNV.addr = struct.pack("QQ", TestNV.b.lazydata.buffer._buf.va_addr, TestNV.a.lazydata.buffer._buf.va_addr)
TestNV.b.uop.buffer.allocate()
TestNV.addr = struct.pack("QQ", TestNV.b.uop.buffer._buf.va_addr, TestNV.a.uop.buffer._buf.va_addr)
def test_oor_kernels(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
@@ -44,8 +44,8 @@ class TestNV(unittest.TestCase):
TestNV.along = Tensor([105615], device="NV").realize()
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.SIN, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501
temp_runner = get_runner(TestNV.d0.device, (ast,))
temp_runner([TestNV.b.lazydata.buffer, TestNV.along.lazydata.buffer], var_vals={})
val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]
temp_runner([TestNV.b.uop.buffer, TestNV.along.uop.buffer], var_vals={})
val = TestNV.b.uop.buffer.as_buffer().cast("f")[0]
assert abs(val - 0.80647) < 0.001, f"got val {val}"
def test_kernargs_no_oob_access(self):
@@ -59,7 +59,7 @@ class TestNV(unittest.TestCase):
q.signal(TestNV.d0.timeline_signal, TestNV.d0.timeline_value).submit(TestNV.d0)
TestNV.d0._wait_signal(TestNV.d0.timeline_signal, TestNV.d0.timeline_value)
TestNV.d0.timeline_value += 1
val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]
val = TestNV.b.uop.buffer.as_buffer().cast("f")[0]
assert val == 1.0, f"got val {val}"
if __name__ == "__main__":

View File

@@ -28,7 +28,7 @@ def alloc_rawbuffer(device, fill=False):
if fill:
with Context(DEBUG=0):
data = np.random.randint(-10000, 10000, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype))
rawbuf.copyin(Tensor(data).realize().lazydata.base.realized.as_buffer())
rawbuf.copyin(Tensor(data).realize().uop.base.realized.as_buffer())
return rawbuf
def gen_kernel_ji(device, deps):

View File

@@ -73,7 +73,7 @@ def get_fuzz_rawbufs(lin):
data = np.random.uniform(-1, 1, size=rawbuf.size).astype(dtype=_to_np_dtype(rawbuf.dtype))
else:
data = np.random.uniform(-10, 10, size=rawbuf.size).astype(dtype=_to_np_dtype(rawbuf.dtype))
rawbuf.copyin(Tensor(data, device=lin.opts.device).realize().lazydata.base.realized.as_buffer())
rawbuf.copyin(Tensor(data, device=lin.opts.device).realize().uop.base.realized.as_buffer())
return rawbufs
def get_fuzz_rawbuf_like(old_rawbuf, zero=False, copy=False, size=None, force_device=None):

View File

@@ -21,18 +21,18 @@ def consec(shape, start=1):
# creates strided tensor with base set to reference tensor's base, equivalent to torch.set_()
def set_(reference: Tensor, shape, strides, offset):
raise NotImplementedError("need to implement without calling lazydata.view")
if reference.lazydata.base.realized is None: reference.realize()
assert reference.lazydata.base.realized, "base has to be realized before setting it to strided's base"
strided = Tensor(reference.lazydata.view(ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),))))
assert strided.lazydata.st.real_strides() == strides, "real_strides should equal strides for strided"
raise NotImplementedError("need to implement without calling uop.view")
if reference.uop.base.realized is None: reference.realize()
assert reference.uop.base.realized, "base has to be realized before setting it to strided's base"
strided = Tensor(reference.uop.view(ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),))))
assert strided.uop.st.real_strides() == strides, "real_strides should equal strides for strided"
return strided
def clone(original:Tensor): return original.clone()
def copy_(src:Tensor, other:Tensor) -> Tensor: return src.clone()
# this is fine for tested usecases since as geohotstan understands,
# data_ptr is used to compare if operations needed between tensors is the same
def data_ptr(tensor:Tensor): return tensor.lazydata
def data_ptr(tensor:Tensor): return tensor.uop
# https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html
def index_put_(tensor:Tensor, indices, values, accumulate) -> Tensor:
@@ -971,9 +971,9 @@ class TestIndexing(unittest.TestCase):
numpy_testing_assert_equal_helper((2, 0, 4), z.shape)
# this isn't technically necessary, but matches NumPy stride calculations.
# NOTE: this is empty and shouldn't have strides
#numpy_testing_assert_equal_helper((60, 20, 5), z.lazydata.st.real_strides())
#numpy_testing_assert_equal_helper((60, 20, 5), z.uop.st.real_strides())
# NOTE tinygrad's int slicing implementation makes this not contiguous
# self.assertTrue(z.lazydata.st.contiguous)
# self.assertTrue(z.uop.st.contiguous)
@unittest.skip("bool indexing not supported")
def test_index_getitem_copy_bools_slices(self):

View File

@@ -20,7 +20,7 @@ class TestArange(unittest.TestCase):
p = k.to_program()
print(p.name)
#print(p.src)
ExecItem(CompiledRunner(p), [tt.lazydata.buffer]).run()
ExecItem(CompiledRunner(p), [tt.uop.buffer]).run()
np.testing.assert_equal(tt.numpy(), np.arange(N))
return p.estimates.ops

View File

@@ -13,11 +13,11 @@ class TestAssign(unittest.TestCase):
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
a.realize()
b.realize()
ba1 = a.lazydata.base.realized
bb1 = b.lazydata.base.realized
ba1 = a.uop.base.realized
bb1 = b.uop.base.realized
a += b
a.realize()
ba2 = a.lazydata.base.realized
ba2 = a.uop.base.realized
assert ba1 == ba2 and ba1 != bb1
np.testing.assert_allclose(a.numpy(), (np.arange(N*N)*2).reshape((N,N)))
@@ -259,13 +259,13 @@ class TestAssign(unittest.TestCase):
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
a.realize()
b.realize()
ba1 = a.lazydata.base.realized
bb1 = b.lazydata.base.realized
ba1 = a.uop.base.realized
bb1 = b.uop.base.realized
with self.assertRaises((RuntimeError, AssertionError)):
a = a.permute(1,0)
a += b
a.realize()
ba2 = a.lazydata.base.realized
ba2 = a.uop.base.realized
assert ba1 != ba2 and ba1 != bb1
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
@@ -275,12 +275,12 @@ class TestAssign(unittest.TestCase):
a.realize()
b.realize()
#GlobalCounters.cache = []
ba1 = a.lazydata.base.realized # noqa: F841
bb1 = b.lazydata.base.realized # noqa: F841
ba1 = a.uop.base.realized # noqa: F841
bb1 = b.uop.base.realized # noqa: F841
with self.assertRaisesRegex(RuntimeError, "contiguous"):
a.assign(a.permute(1,0) + b) # this should not work!
a.realize()
ba2 = a.lazydata.base.realized # noqa: F841
ba2 = a.uop.base.realized # noqa: F841
# NOTE: don't test that it's assigned
#assert ba1 == ba2 and ba1 != bb1
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
@@ -383,10 +383,10 @@ class TestAssign(unittest.TestCase):
def test_cast_assignment(self):
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
a.realize()
oba1 = a.lazydata.base.output_buffer
oba1 = a.uop.base.output_buffer
a.assign(a.cast(dtypes.int32).realize())
a.realize()
oba2 = a.lazydata.base.output_buffer
oba2 = a.uop.base.output_buffer
assert oba1 is None and oba2 is None
np.testing.assert_allclose(a.numpy(), np.arange(N*N,dtype=np.int32).reshape((N,N)))

View File

@@ -221,7 +221,7 @@ class TestReduceOpsConstFolding(unittest.TestCase):
# contiguous folded const can still schedule
a = Tensor.empty(1, 0).sum().contiguous()
_check_ast_count(2, a+2)
self.assertIs(a.lazydata.base.op, Ops.BUFFER)
self.assertIs(a.uop.base.op, Ops.BUFFER)
np.testing.assert_equal((Tensor.empty(1, 0).sum().contiguous()+2).numpy(), 2)
# otherwise we just fuse it
_check_ast_count(1, (Tensor.empty(1, 0).sum()+2).contiguous())

View File

@@ -32,7 +32,7 @@ def get_available_cast_dtypes(dtype: DType) -> List[DType]:
def _test_to_np(a:Tensor, np_dtype, target):
if DEBUG >= 2: print(a)
na = a.numpy()
if DEBUG >= 2: print(na, na.dtype, a.lazydata.base.realized)
if DEBUG >= 2: print(na, na.dtype, a.uop.base.realized)
try:
assert na.dtype == np_dtype
np.testing.assert_allclose(na, target)

View File

@@ -73,7 +73,7 @@ class TestGC(unittest.TestCase):
x = Tensor.ones(4,4).contiguous().realize()+1
self.assertEqual(bufs_allocated()-init, 1)
# try commenting this part out, it's green!
x.lazydata.toposort()
x.uop.toposort()
del x
if bufs_allocated()-init != 0:
print(inspect.getclosurevars(UOp.toposort().fget))
@@ -84,11 +84,11 @@ class TestGC(unittest.TestCase):
a = Tensor.empty(10)
self.assertEqual(bufs_allocated()-init, 0)
a.realize()
real_buf = a.lazydata.buffer
real_buf = a.uop.buffer
# after the Tensor UOp is deleted there shouldn't be any references on the Buffer
self.assertEqual(real_buf.lb_refcount, 1)
self.assertEqual(bufs_allocated()-init, 1)
del a.lazydata
del a.uop
self.assertEqual(real_buf.lb_refcount, 0)
self.assertEqual(bufs_allocated()-init, 1) # keep the buffer alive
del real_buf
@@ -98,10 +98,10 @@ class TestGC(unittest.TestCase):
init = bufs_allocated()
a = Tensor.full((4,), 1.).contiguous()
a.realize()
real_buf = a.lazydata.buffer
real_buf = a.uop.buffer
self.assertEqual(real_buf.lb_refcount, 1)
a.assign(Tensor.full((4,), 2.))
self.assertIs(a.lazydata.src[0].buffer, real_buf)
self.assertIs(a.uop.src[0].buffer, real_buf)
# NOTE: this is still 1, we don't count the ASSIGN
self.assertEqual(real_buf.lb_refcount, 1)
a.realize()

View File

@@ -36,7 +36,7 @@ def helper_alloc_rawbuffer(device, fill=False):
if fill:
with Context(DEBUG=0):
data = np.random.randint(-10000, 10000, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype))
rawbuf.copyin(Tensor(data).realize().lazydata.base.realized.as_buffer())
rawbuf.copyin(Tensor(data).realize().uop.base.realized.as_buffer())
return rawbuf
def helper_create_offset_rawbuffer(base, offset=0):

View File

@@ -20,15 +20,15 @@ class TestHCQ(unittest.TestCase):
si = self.b.schedule()[-1]
TestHCQ.runner = get_runner(TestHCQ.d0.device, si.ast)
TestHCQ.b.lazydata.buffer.allocate()
TestHCQ.b.uop.buffer.allocate()
TestHCQ.kernargs_ba_ptr = TestHCQ.runner._prg.fill_kernargs([TestHCQ.b.lazydata.buffer._buf, TestHCQ.a.lazydata.buffer._buf])
TestHCQ.kernargs_ab_ptr = TestHCQ.runner._prg.fill_kernargs([TestHCQ.a.lazydata.buffer._buf, TestHCQ.b.lazydata.buffer._buf])
TestHCQ.kernargs_ba_ptr = TestHCQ.runner._prg.fill_kernargs([TestHCQ.b.uop.buffer._buf, TestHCQ.a.uop.buffer._buf])
TestHCQ.kernargs_ab_ptr = TestHCQ.runner._prg.fill_kernargs([TestHCQ.a.uop.buffer._buf, TestHCQ.b.uop.buffer._buf])
def setUp(self):
TestHCQ.d0.synchronize()
TestHCQ.a.lazydata.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
TestHCQ.b.lazydata.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 0))))
TestHCQ.a.uop.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
TestHCQ.b.uop.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 0))))
TestHCQ.d0.synchronize() # wait for copyins to complete
# Test signals
@@ -117,7 +117,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
assert val == 1.0, f"got val {val}"
def test_exec_2_kernels_100_times(self):
@@ -133,7 +133,7 @@ class TestHCQ(unittest.TestCase):
q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value})
TestHCQ.d0.timeline_value += 1
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
assert val == 200.0, f"got val {val}"
def test_exec_update(self):
@@ -148,9 +148,9 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
assert val == 1.0, f"got val {val}"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
assert val == 0.0, f"got val {val}, should not be updated"
def test_exec_update_fuzz(self):
@@ -192,13 +192,13 @@ class TestHCQ(unittest.TestCase):
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
TestHCQ.d0.hw_copy_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
.copy(TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr, 8) \
.copy(TestHCQ.b.uop.buffer._buf.va_addr, TestHCQ.a.uop.buffer._buf.va_addr, 8) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
assert val == 1.0, f"got val {val}"
def test_copy_long(self):
@@ -252,12 +252,12 @@ class TestHCQ(unittest.TestCase):
.copy(virt_dest_addr, virt_src_addr, 8) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.submit(TestHCQ.d0, {virt_src_addr: TestHCQ.a.lazydata.buffer._buf.va_addr, virt_dest_addr: TestHCQ.b.lazydata.buffer._buf.va_addr})
q.submit(TestHCQ.d0, {virt_src_addr: TestHCQ.a.uop.buffer._buf.va_addr, virt_dest_addr: TestHCQ.b.uop.buffer._buf.va_addr})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
assert val == 1.0, f"got val {val}"
def test_update_copy_long(self):

View File

@@ -13,7 +13,7 @@ IMAGE_SUPPORTED_DEVICES = ("QCOM", "GPU")
class TestImageCopy(unittest.TestCase):
def test_image_copyout_1x1(self, img_type=dtypes.imagef):
it = Tensor.arange(4).cast(img_type((1,1,4))).realize()
buf = it.lazydata.buffer
buf = it.uop.buffer
out = buf.as_buffer()
np.testing.assert_equal(out.cast(it.dtype.fmt).tolist(), np.arange(4))
@@ -27,18 +27,18 @@ class TestImageCopy(unittest.TestCase):
def test_image_copyout_2x3(self):
it = Tensor.arange(2*3*4).cast(dtypes.imagef((2,3,4))).realize()
buf = it.lazydata.buffer
buf = it.uop.buffer
out = buf.as_buffer()
np.testing.assert_equal(out.cast('f').tolist(), np.arange(2*3*4))
def test_image_roundtrip(self):
sz = (4,2,4)
it = Tensor.rand(prod(sz)).cast(dtypes.imagef(sz)).realize()
buf = it.lazydata.buffer
buf = it.uop.buffer
out = buf.as_buffer()
it2 = Tensor.rand(prod(sz)).cast(dtypes.imagef(sz)).realize()
buf2 = it2.lazydata.buffer
buf2 = it2.uop.buffer
buf2.copyin(out)
assert (it == it2).sum().item() == prod(sz)
@@ -49,7 +49,7 @@ class TestImageDType(unittest.TestCase):
data = Tensor.randn(9*27*4).realize()
tst = data.numpy()
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
assert isinstance(it.lazydata.base.realized.dtype, ImageDType)
assert isinstance(it.uop.base.realized.dtype, ImageDType)
np.testing.assert_equal(tst, it.numpy())
@unittest.expectedFailure # this isn't supported anymore, CAST to ImageDType stays ImageDType
@@ -58,14 +58,14 @@ class TestImageDType(unittest.TestCase):
tst = data.numpy()
it = data.cast(dtypes.imagef((9,27,4))).realize()
# the underlying UOp is identical
self.assertIs(it.lazydata.base.realized, data.lazydata.base.realized)
self.assertIs(it.uop.base.realized, data.uop.base.realized)
np.testing.assert_equal(tst, it.numpy())
def test_image_and_back_wrong_shape(self):
data = Tensor.randn(9*27*4).realize()
tst = data.numpy()
it = data.cast(dtypes.imagef((9,12,4))).realize()
assert not isinstance(it.lazydata.base.realized.dtype, ImageDType)
assert not isinstance(it.uop.base.realized.dtype, ImageDType)
np.testing.assert_equal(tst, it.numpy())
def test_shrink_load_float(self):
@@ -77,7 +77,7 @@ class TestImageDType(unittest.TestCase):
# NOTE: contiguous is needed otherwise this folds
it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).contiguous().realize()
out = (it*2).realize()
assert isinstance(out.lazydata.base.realized.dtype, ImageDType)
assert isinstance(out.uop.base.realized.dtype, ImageDType)
def test_sum(self):
it = Tensor.rand(8).cast(dtypes.imagef((1,2,4))).realize()
@@ -98,26 +98,26 @@ class TestImageDType(unittest.TestCase):
def test_lru_alloc(self):
data = Tensor.randn(9*27*4).realize()
it = data.cast(dtypes.imagef((9,27,4))).realize()
b1 = it.lazydata.base.realized._buf
b1 = it.uop.base.realized._buf
del it
it = data.cast(dtypes.imagef((9,27,4))).realize()
assert it.lazydata.base.realized._buf == b1
assert it.uop.base.realized._buf == b1
def test_no_lru_alloc(self):
data = Tensor.randn(9*27*4).realize()
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
b1 = it.lazydata.base.realized._buf
b1 = it.uop.base.realized._buf
del it
it = data.cast(dtypes.imagef((10,27,4))).contiguous().realize()
assert it.lazydata.base.realized._buf != b1
assert it.uop.base.realized._buf != b1
def test_no_lru_alloc_dtype(self):
data = Tensor.randn(9*27*4).realize()
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
b1 = it.lazydata.base.realized._buf
b1 = it.uop.base.realized._buf
del it
it = data.cast(dtypes.imageh((9,27,4))).realize()
assert it.lazydata.base.realized._buf != b1
assert it.uop.base.realized._buf != b1
# issue caused by: don't realize image to image casts. this is part of a larger problem
#@unittest.expectedFailure
@@ -137,8 +137,8 @@ class TestImageDType(unittest.TestCase):
print(lst)
assert not np.any(np.isnan(lst))
# NOTE: the w1 grad must realize to a seperate kernel
assert w1.grad.lazydata.is_realized, f"never realized {w1.grad}"
self.assertEqual(w1.grad.lazydata.base.buffer.dtype, dtypes.float32)
assert w1.grad.uop.is_realized, f"never realized {w1.grad}"
self.assertEqual(w1.grad.uop.base.buffer.dtype, dtypes.float32)
self.assertEqual(len(sched), 10)
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")

View File

@@ -75,7 +75,7 @@ class TestLinearizer(unittest.TestCase):
lowered = [x[1] for x in lower_schedule(c.schedule())]
for ei in lowered: ei.run()
rawbufs = lowered[-1].bufs
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.base.realized, b.lazydata.base.realized}
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.uop.base.realized, b.uop.base.realized}
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)
@@ -140,7 +140,7 @@ class TestLinearizer(unittest.TestCase):
def test_multireduce(self):
Tensor.manual_seed(0)
x = Tensor.randn(32, dtype=dtypes.float).realize()
st_x = x.lazydata.st
st_x = x.uop.st
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(st_x.reshape((1, 32)).expand((32, 32))),))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (1,)))
@@ -172,7 +172,7 @@ class TestLinearizer(unittest.TestCase):
def test_mid_dim_multireduce(self):
Tensor.manual_seed(0)
x = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
st_x = x.lazydata.st
st_x = x.uop.st
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(st_x.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5))),))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
@@ -232,12 +232,12 @@ class TestLinearizer(unittest.TestCase):
x1 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
x2 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
g0, g1, g2, g3 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(4)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5))),))
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x0.uop.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5))),))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5))),))
second_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x1.uop.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5))),))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 32, 32, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (2,)))
third_x = UOp(Ops.LOAD, dtypes.float, (g3.view(x2.lazydata.st.reshape((27, 32, 1, 1, 5))),))
third_x = UOp(Ops.LOAD, dtypes.float, (g3.view(x2.uop.st.reshape((27, 32, 1, 1, 5))),))
mul = (third_x*second_reduce)
third_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (mul,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 1, 5))), third_reduce))
@@ -253,7 +253,7 @@ class TestLinearizer(unittest.TestCase):
def test_double_reduce_multireduce(self):
Tensor.manual_seed(0)
x = Tensor.randn(8, 32, 8, 16, dtype=dtypes.float).realize()
st = x.lazydata.st
st = x.uop.st
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, st.reshape((8, 1, 32, 8, 1, 16)).expand((8, 32, 32, 8, 16, 16)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2, 5)))
@@ -302,9 +302,9 @@ class TestLinearizer(unittest.TestCase):
Tensor.manual_seed(0)
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).softmax(1).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),))
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.uop.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.lazydata.st.reshape((27, 15, 1, 5))),))
second_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.uop.st.reshape((27, 15, 1, 5))),))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce))
@@ -329,11 +329,11 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(4, 32, dtype=dtypes.float).realize()
x_p = Tensor.randn(4, 32, dtype=dtypes.float).realize()
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32))),))
first_x_p = UOp(Ops.LOAD, dtypes.float, (g2.view(x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32))),))
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.uop.st.reshape((4, 1, 32)).expand((4, 32, 32))),))
first_x_p = UOp(Ops.LOAD, dtypes.float, (g2.view(x_p.uop.st.reshape((4, 1, 32)).expand((4, 32, 32))),))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
first_reduce_p = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x_p.alu(Ops.EXP2),), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.lazydata.st.reshape((4, 32, 1))),))
second_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.uop.st.reshape((4, 32, 1))),))
diff = (second_x+(first_reduce + first_reduce_p)*ast_const(dtypes.float, -1, (4, 32, 1)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((4, 1, 1))), second_reduce))
@@ -361,9 +361,9 @@ class TestLinearizer(unittest.TestCase):
Tensor.manual_seed(0)
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
first_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),))
first_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x.uop.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x.lazydata.st.reshape((27, 15, 1, 5))),))
second_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x.uop.st.reshape((27, 15, 1, 5))),))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store0 = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce))
@@ -383,9 +383,9 @@ class TestLinearizer(unittest.TestCase):
Tensor.manual_seed(0)
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
first_x = UOp(Ops.LOAD, dtypes.float, src=(g2.view(x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),))
first_x = UOp(Ops.LOAD, dtypes.float, src=(g2.view(x.uop.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, src=(g2.view(x.lazydata.st.reshape((27, 15, 1, 5))),))
second_x = UOp(Ops.LOAD, dtypes.float, src=(g2.view(x.uop.st.reshape((27, 15, 1, 5))),))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store0 = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce))
@@ -402,9 +402,9 @@ class TestLinearizer(unittest.TestCase):
Tensor.manual_seed(0)
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5))),))
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5))),))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((27, 3, 1, 5))),))
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((27, 3, 1, 5))),))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce))
@@ -418,9 +418,9 @@ class TestLinearizer(unittest.TestCase):
Tensor.manual_seed(0)
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5))),))
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5))),))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((27, 3, 1, 5))),))
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((27, 3, 1, 5))),))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce))
@@ -437,9 +437,9 @@ class TestLinearizer(unittest.TestCase):
Tensor.manual_seed(0)
x = Tensor.randn(27, 12, 5, dtype=dtypes.float).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop()))
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.uop.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 12, 1, 5)).to_uop()))
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.uop.st.reshape((27, 12, 1, 5)).to_uop()))
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 12, 1, 5)))
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
@@ -453,10 +453,10 @@ class TestLinearizer(unittest.TestCase):
Tensor.manual_seed(0)
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35))),))
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35))),))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((15, 25, 35, 1))),))
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((15, 25, 35, 1))),))
squares = (second_x+neg_mean)*(second_x+neg_mean)
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,)))
variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1))
@@ -471,10 +471,10 @@ class TestLinearizer(unittest.TestCase):
Tensor.manual_seed(0)
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35))),))
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35))),))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
neg_mean = first_reduce * ast_const(dtypes.float, -0.04, (15, 25, 1, 35))
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((15, 25, 1, 35))),))
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((15, 25, 1, 35))),))
squares = (second_x+neg_mean)*(second_x+neg_mean)
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (1,)))
variance = squares_sum * ast_const(dtypes.float, 0.04, (15, 1, 1, 35))
@@ -491,10 +491,10 @@ class TestLinearizer(unittest.TestCase):
Tensor.manual_seed(0)
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.uop.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop()))
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.uop.st.reshape((15, 25, 35, 1)).to_uop()))
squares = (second_x+neg_mean)*(second_x+neg_mean)
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,)))
variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1))
@@ -514,12 +514,12 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn(3, 27, 32, dtype=dtypes.float).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
# push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32))),))
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32))),))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
neg_mean = first_reduce * ast_const(dtypes.float, -0.03125, (3, 27, 32, 1))
# store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 32, 1)).to_uop(), mean))
# verify_lazyop(store)
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((3, 27, 32, 1))),))
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((3, 27, 32, 1))),))
squares = (second_x+neg_mean)*(second_x+neg_mean)
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,)))
variance = squares_sum * ast_const(dtypes.float, 0.03125, (3, 27, 1, 1))
@@ -535,9 +535,9 @@ class TestLinearizer(unittest.TestCase):
def test_softmax_multireduce(self):
x = Tensor.rand(4, 32).realize()
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32))),))
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((4, 1, 32,)).expand((4, 32, 32))),))
max_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.MAX, (2,)))
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((4, 32, 1,))),))
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((4, 32, 1,))),))
centered_x = second_x+max_x*ast_const(dtypes.float, -1, (4, 32, 1))
exp_x = centered_x.alu(Ops.EXP2)
sum_exp_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (exp_x,), (Ops.ADD, (1,)))
@@ -553,7 +553,7 @@ class TestLinearizer(unittest.TestCase):
dataset = Tensor.rand(16384, 256).realize()
idxs = Tensor([0,3,5,6]).realize()
with Context(FUSE_ARANGE=1):
sink = dataset[idxs].contiguous().kernelize().lazydata.base.src[1].arg.ast
sink = dataset[idxs].contiguous().kernelize().uop.base.src[1].arg.ast
real_index = dataset.numpy()[idxs.numpy()].reshape(4, 1, 256, 1)
helper_linearizer_ast(sink, [dataset, idxs], wanna_output=[real_index])
@@ -656,16 +656,16 @@ class TestLinearizer(unittest.TestCase):
]
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((1, N, N)).expand((N,N,N))),))
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, 1, N))),))
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((1, N, N)).expand((N,N,N))),))
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((N, 1, N))),))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.ADD, (1,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),),(Ops.ADD, (0,)))
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((1,1,N))), r1))
sink = UOp(Ops.SINK, src=(store,))
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=0, keepdims=True)).sum(axis=0).reshape(1,1,N)], opts=opts)
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, 1, N)).expand((N,N,N))),))
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, N, 1))),))
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((N, 1, N)).expand((N,N,N))),))
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((N, N, 1))),))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.ADD, (2,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (Ops.ADD, (1,)))
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((N,1,1))), r1))
@@ -683,16 +683,16 @@ class TestLinearizer(unittest.TestCase):
]
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((1, N, N)).expand((N,N,N))),))
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, 1, N))),))
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((1, N, N)).expand((N,N,N))),))
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((N, 1, N))),))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.MAX, (1,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),), (Ops.MAX, (0,)))
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((1,1,N))), r1))
sink = UOp(Ops.SINK, src=(store,))
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=0, keepdims=True)).max(axis=0).reshape(1,1,N)], opts=opts)
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, 1, N)).expand((N,N,N))),))
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, N, 1))),))
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((N, 1, N)).expand((N,N,N))),))
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((N, N, 1))),))
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.MAX, (2,)))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (Ops.MAX, (1,)))
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((N,1,1))), r1))
@@ -711,8 +711,8 @@ class TestLinearizer(unittest.TestCase):
opts = [[Opt(OptOps.PADTO, 0, 32)],[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],]
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=1,keepdims=True), a.numpy(), b.numpy())).sum(axis=1),0.0,1.0).reshape((N,1,1)) # noqa: E501
ld0 = x.lazydata.st.reshape((N, 1, N)).expand((N,N,N))
ld1 = x.lazydata.st.reshape((N, N, 1))
ld0 = x.uop.st.reshape((N, 1, N)).expand((N,N,N))
ld1 = x.uop.st.reshape((N, N, 1))
ast = UOp(Ops.SINK, src=(
UOp(Ops.STORE, src=(
UOp(Ops.VIEW, dtypes.float.ptr(), arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
@@ -742,8 +742,8 @@ class TestLinearizer(unittest.TestCase):
ast_const(dtypes.float, 1.0, (N, 1, 1)),)),)),))
helper_linearizer_ast(ast, [x,a,b], opts=opts, wanna_output=[wanna_output])
ld0 = x.lazydata.st.reshape((1, N, N)).expand((N,N,N))
ld1 = x.lazydata.st.reshape((N, 1, N))
ld0 = x.uop.st.reshape((1, N, N)).expand((N,N,N))
ld1 = x.uop.st.reshape((N, 1, N))
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=0,keepdims=True), a.numpy(), b.numpy())).sum(axis=0),0.0,1.0).reshape(1,1,N) # noqa: E501
ast = UOp(Ops.SINK, src=(
UOp(Ops.STORE, src=(
@@ -776,8 +776,8 @@ class TestLinearizer(unittest.TestCase):
# pad reduce axis
helper_linearizer_ast(ast, [x,a,b], opts=[[Opt(OptOps.PADTO, 1, 32)],], wanna_output=[wanna_output])
ld0 = x.lazydata.st.reshape((1,1,N,N)).expand((N,N,N,N))
ld1 = x.lazydata.st.reshape((N,N,1,1))
ld0 = x.uop.st.reshape((1,1,N,N)).expand((N,N,N,N))
ld1 = x.uop.st.reshape((N,N,1,1))
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(keepdims=True), a.numpy(), b.numpy())).sum(keepdims=True),0.0,1.0).reshape((1,1,1,1))# noqa: E501
ast = UOp(Ops.SINK, src=(
UOp(Ops.STORE, src=(
@@ -1794,7 +1794,7 @@ class TestHandCodedOpts(unittest.TestCase):
def helper_linearizer_ast(ast:UOp, inputs:list[Tensor], *args, **kwargs):
assert isinstance(ast, UOp), "ast must be UOp"
inbufs = [x.lazydata.base.buffer for x in inputs]
inbufs = [x.uop.base.buffer for x in inputs]
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.st_arg.size, out.src[1].dtype).allocate() \
for out in ast.src]
return _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs)

View File

@@ -7,7 +7,7 @@ class TestMaskedShapeTracker(unittest.TestCase):
b = Tensor([1,1]).pad(((0,3),))
c = a*b
assert c.shape == a.shape
#assert c.lazydata.st.views[0].mask is not None
#assert c.uop.st.views[0].mask is not None
ret = c.data()
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
@@ -16,7 +16,7 @@ class TestMaskedShapeTracker(unittest.TestCase):
b = Tensor([1,1]).pad(((0,3),))
c = a*b
assert c.shape == a.shape
#assert c.lazydata.st.views[0].mask is not None
#assert c.uop.st.views[0].mask is not None
ret = c.data()
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
@@ -24,7 +24,7 @@ class TestMaskedShapeTracker(unittest.TestCase):
a = Tensor([1,1]).pad(((0,2),))
b = Tensor([1,1]).pad(((0,2),))
c = a+b
#assert c.lazydata.st.views[0].mask is not None
#assert c.uop.st.views[0].mask is not None
ret = c.data()
assert ret.tolist() == [2.0, 2.0, 0.0, 0.0]

View File

@@ -29,7 +29,7 @@ N = 128
def _test_allreduce(t:Tensor):
aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize()
ts = t.shard(devices_4, 0).realize()
b = Tensor(UOp.allreduce(ts.lazydata, Ops.ADD, ts.device))
b = Tensor(UOp.allreduce(ts.uop, Ops.ADD, ts.device))
b.realize()
return aa, b
@@ -50,7 +50,7 @@ class TestMultiTensor(unittest.TestCase):
def test_shard(self):
X = Tensor.ones(256).contiguous().realize()
X.shard_(devices_2, 0)
for lb in X.lazydata.src:
for lb in X.uop.src:
assert lb.shape == (128,)
(X + X).realize()
@@ -61,12 +61,12 @@ class TestMultiTensor(unittest.TestCase):
def test_tensor_from_multi(self):
X = Tensor([1, 2], dtype=dtypes.int).shard_(devices_2, 0)
Y = Tensor(X.lazydata)
Y = Tensor(X.uop)
self.assertEqual(Y.device, Device.DEFAULT)
np.testing.assert_equal(X.numpy(), Y.numpy())
with self.assertRaises(AssertionError):
_ = Tensor(X.lazydata, dtype=dtypes.float)
_ = Tensor(X.uop, dtype=dtypes.float)
def test_sharded_arange(self):
sharded_arange = Tensor.arange(1000).shard(devices_2, 0)
@@ -247,9 +247,9 @@ class TestMultiTensor(unittest.TestCase):
shape = tuple([(n if i == 0 else 1) * random.randint(1, 10) for i in range(random.randint(1, 4))])
t = Tensor.rand(shape).shard_(tuple([d0, d1, d2, d3][:n]), 0)
with Context(RING=0):
a = Tensor(UOp.allreduce(t.lazydata, Ops.ADD, t.device))
a = Tensor(UOp.allreduce(t.uop, Ops.ADD, t.device))
with Context(RING=2):
b = Tensor(UOp.allreduce(t.lazydata, Ops.ADD, t.device))
b = Tensor(UOp.allreduce(t.uop, Ops.ADD, t.device))
diff = a - b
mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy()
max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy()
@@ -590,21 +590,21 @@ class TestMultiTensor(unittest.TestCase):
t4 = t2.reshape((26, 105,))
for t in [t0, t1, t2, t3, t4]:
assert t.lazydata.axis == 1
assert t.uop.axis == 1
np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten())
# test shape-one axis
t5 = t4.reshape((26, 1, 105))
assert t5.lazydata.axis == 2
assert t5.uop.axis == 2
np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten())
# test split and rejoin to the right and reshape to the left
t5 = t0.reshape((2, 13, 3, 5, 7))
t6 = t0.reshape((13, 2, 3, 7, 5))
t7 = t0.reshape((1, 13, 2, 3, 1, 7, 5))
assert t5.lazydata.axis == 2
assert t6.lazydata.axis == 2
assert t7.lazydata.axis == 3
assert t5.uop.axis == 2
assert t6.uop.axis == 2
assert t7.uop.axis == 3
np.testing.assert_allclose(t5.numpy().flatten(), t0.numpy().flatten())
np.testing.assert_allclose(t6.numpy().flatten(), t0.numpy().flatten())
np.testing.assert_allclose(t7.numpy().flatten(), t0.numpy().flatten())
@@ -616,7 +616,7 @@ class TestMultiTensor(unittest.TestCase):
@unittest.skip("no longer supports uneven shard")
def test_reshape_on_axis_uneven(self):
def reshape_helper(t0, t, t_axis):
assert t.lazydata.axis == t_axis
assert t.uop.axis == t_axis
np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy())
t0 = Tensor.rand((4, 42, 15)).shard(devices_3, axis=1, splits=[14, 7, 21])
@@ -687,24 +687,24 @@ class TestMultiTensor(unittest.TestCase):
self.assertEqual(t.shape, t2.shape)
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
self.assertEqual(t.uop.axis, t2.uop.axis)
def test_rand_like_from_alu(self):
a = Tensor.ones(4, 4).shard(devices_4, axis=0)
aa = a + a
self.assertEqual(aa.device, devices_4)
self.assertEqual(aa.lazydata.axis, 0)
self.assertEqual(aa.uop.axis, 0)
raa = aa.rand_like()
self.assertEqual(raa.device, devices_4)
self.assertEqual(raa.lazydata.axis, 0)
self.assertEqual(raa.uop.axis, 0)
b = Tensor.empty(4, 4).shard(devices_4, axis=None)
ab = a + b
self.assertEqual(ab.device, devices_4)
self.assertEqual(ab.lazydata.axis, 0)
self.assertEqual(ab.uop.axis, 0)
rab = ab.rand_like()
self.assertEqual(rab.device, devices_4)
self.assertEqual(rab.lazydata.axis, 0)
self.assertEqual(rab.uop.axis, 0)
@unittest.skip("no longer supports uneven shard")
def test_rand_like_uneven_shard(self):
@@ -713,8 +713,8 @@ class TestMultiTensor(unittest.TestCase):
self.assertEqual(t.shape, t2.shape)
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.lazydata.src, t2.lazydata.src))
self.assertEqual(t.uop.axis, t2.uop.axis)
assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.uop.src, t2.uop.src))
def test_rand_like_none_shard(self):
t = Tensor.empty((16, 16)).shard(devices_2)
@@ -722,7 +722,7 @@ class TestMultiTensor(unittest.TestCase):
self.assertEqual(t.shape, t2.shape)
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
self.assertEqual(t.uop.axis, t2.uop.axis)
def test_rand_like_arg_dtype(self):
t = Tensor.empty((16, 16), dtype=dtypes.int32).shard(devices_2, axis=1)
@@ -771,7 +771,7 @@ class TestMultiTensor(unittest.TestCase):
devices = (d0, d1, d2, d3)
t = Tensor.zeros(16, 16).contiguous()
t.shard_(devices, axis=0).realize()
assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.lazydata.src])
assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.uop.src])
@unittest.skip("this is unreliable on OSX")
def test_clone(self):
@@ -928,7 +928,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
to_add.append((Tensor.ones(2, 8) * i).shard(devices))
added:list[Tensor] = []
for bound, a in zip(x.lazydata.bounds, to_add):
for bound, a in zip(x.uop.bounds, to_add):
added.append(x[bound[0]:bound[1]] + a)
output = added[0].cat(*added[1:])
@@ -1043,7 +1043,7 @@ class TestBatchNorm(unittest.TestCase):
bns.append(bn)
bn_ts = []
for bound, bn in zip(x.lazydata.bounds, bns):
for bound, bn in zip(x.uop.bounds, bns):
bni = bn(x[bound[0]:bound[1]])
bn_ts.append(bni)

View File

@@ -596,9 +596,9 @@ class TestNN(unittest.TestCase):
# sharded model shards the state_dict
self.assertEqual(layer.weight.device, devices)
self.assertEqual(layer.weight.lazydata.axis, 3)
self.assertEqual(layer.weight.uop.axis, 3)
self.assertEqual(layer.bias.device, devices)
self.assertEqual(layer.bias.lazydata.axis, None)
self.assertEqual(layer.bias.uop.axis, None)
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
@@ -634,9 +634,9 @@ class TestNN(unittest.TestCase):
load_state_dict(layer, state_dict)
self.assertEqual(layer.weight.device, devices)
self.assertEqual(layer.weight.lazydata.axis, 3)
self.assertEqual(layer.weight.uop.axis, 3)
self.assertEqual(layer.bias.device, devices)
self.assertEqual(layer.bias.lazydata.axis, None)
self.assertEqual(layer.bias.uop.axis, None)
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
@@ -658,9 +658,9 @@ class TestNN(unittest.TestCase):
# NOTE: model and state_dict shard differently, use the state_dict sharding # TODO: revisit this?
self.assertEqual(layer.weight.device, devices)
self.assertEqual(layer.weight.lazydata.axis, None)
self.assertEqual(layer.weight.uop.axis, None)
self.assertEqual(layer.bias.device, devices5)
self.assertEqual(layer.bias.lazydata.axis, 0)
self.assertEqual(layer.bias.uop.axis, 0)
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())

View File

@@ -53,7 +53,7 @@ class TestPickle(unittest.TestCase):
def test_pickle_realized_tensor_alt2(self):
print("** init")
t = Tensor.rand(10, 10).to("CPU").realize()
tensor_uop = t.lazydata
tensor_uop = t.uop
assert tensor_uop.is_realized, f"expected {tensor_uop} to be realized"
t_values = t.numpy()
# pickle
@@ -63,13 +63,13 @@ class TestPickle(unittest.TestCase):
del tensor_uop
print("** post pickle")
t2:Tensor = pickle.loads(st)
assert t2.lazydata.is_realized, f"expected {t2.lazydata} to be realized"
assert t2.uop.is_realized, f"expected {t2.uop} to be realized"
np.testing.assert_equal(t_values, t2.numpy())
# NOTE: currently Buffer exists on the uop, not tensor
def test_pickle_buffer_uop(self):
t = Tensor.arange(4).realize()
a = t.lazydata
a = t.uop
assert a.op is Ops.BUFFER
self.assertIsNotNone(buffer:=a.realized)
s = pickle.dumps(a)
@@ -98,12 +98,12 @@ class TestPickle(unittest.TestCase):
def test_pickle_buffer_view(self):
t = Tensor.arange(10, device="CPU").contiguous().realize()
vt = t[3:5].contiguous().realize()
assert hasattr(vt.lazydata.buffer, 'base')
assert hasattr(vt.uop.buffer, 'base')
ref_value = vt.tolist()
st = pickle.dumps(vt)
del t, vt
vt2 = pickle.loads(st)
assert hasattr(vt2.lazydata.buffer, 'base')
assert hasattr(vt2.uop.buffer, 'base')
assert ref_value == vt2.tolist()
def test_pickle_numpy(self):

View File

@@ -36,12 +36,12 @@ class TestProfiler(unittest.TestCase):
si = self.b.schedule()[-1]
TestProfiler.runner = get_runner(TestProfiler.d0.device, si.ast)
TestProfiler.b.lazydata.buffer.allocate()
TestProfiler.b.uop.buffer.allocate()
def test_profile_kernel_run(self):
runner_name = TestProfiler.runner._prg.name
with helper_collect_profile(TestProfiler.d0) as profile:
TestProfiler.runner([TestProfiler.b.lazydata.buffer, TestProfiler.a.lazydata.buffer], var_vals={})
TestProfiler.runner([TestProfiler.b.uop.buffer, TestProfiler.a.uop.buffer], var_vals={})
profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
kernel_runs = [x for x in profile if isinstance(x, ProfileRangeEvent)]
@@ -66,7 +66,7 @@ class TestProfiler(unittest.TestCase):
with helper_collect_profile(TestProfiler.d0) as profile:
buf1.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
TestProfiler.runner([buf1, TestProfiler.a.lazydata.buffer], var_vals={})
TestProfiler.runner([buf1, TestProfiler.a.uop.buffer], var_vals={})
buf1.copyout(memoryview(bytearray(buf1.nbytes)))
profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)

View File

@@ -23,7 +23,7 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
uops = dedup(flatten(_recursive_add(st) for st in stores))
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
inbufs = [cast(UOp,x.lazydata).base.buffer for x in inputs]
inbufs = [cast(UOp,x.uop).base.buffer for x in inputs]
src = Device[Device.DEFAULT].renderer.render(uops)
ei = CompiledRunner(ProgramSpec("test", src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size))
ei.exec(outbufs+inbufs)

View File

@@ -25,7 +25,7 @@ class TestRewriteTrackedChildren(unittest.TestCase):
a = Tensor(2)
b = Tensor(3)
c = a + b
sink = c.lazydata.sink()
sink = c.uop.sink()
sink = graph_rewrite(sink, rewrite, track_children=True)
def test_simple_child(self):
@@ -35,8 +35,8 @@ class TestRewriteTrackedChildren(unittest.TestCase):
a = Tensor(2)
b = Tensor(3)
c = a + b
sink = c.lazydata
view_w_child = a.lazydata.src[0]
sink = c.uop
view_w_child = a.uop.src[0]
print([x().arg for x in view_w_child.children])
print([x.arg for x in sink.get_children_map()[view_w_child]])
self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((2,3)))
@@ -57,7 +57,7 @@ class TestRewriteTrackedChildren(unittest.TestCase):
extra = PatternMatcher([(UPat(Ops.REDUCE_AXIS, name="r"), print_children)])
a = Tensor.empty(3, 3)
r = (a+0).sum()
graph_rewrite(r.lazydata, merge_views+sym+extra, track_children=True)
graph_rewrite(r.uop, merge_views+sym+extra, track_children=True)
if __name__ == '__main__':
unittest.main()

View File

@@ -188,7 +188,7 @@ class TestSchedule(unittest.TestCase):
# pre-realize shared seed
Tensor._device_rng_counters[x.device].realize()
# run custom kernelized kernel
sched_sink = graph_rewrite(x.lazydata, create_kernels, ctx={u:None for u in x.lazydata.toposort() if u.op is Ops.COPY}, bottom_up=True)
sched_sink = graph_rewrite(x.uop, create_kernels, ctx={u:None for u in x.uop.toposort() if u.op is Ops.COPY}, bottom_up=True)
y = Tensor(graph_rewrite(sched_sink, create_ast, bottom_up=True))
run_schedule(check_schedule(y, 1))
# compare against reference
@@ -198,15 +198,15 @@ class TestSchedule(unittest.TestCase):
def test_empty_is_not_realized(self):
a = Tensor.empty(10)
child = a+2
assert not a.lazydata.is_realized
assert not a.uop.is_realized
child.realize()
assert a.lazydata.is_realized
assert a.uop.is_realized
# NOTE: because empty does not have an ExecItem if realize is called on a childless empty, it never gets allocated.
def test_childless_empty_never_allocates(self):
a = Tensor.empty(10)
a.realize()
assert not a.lazydata.is_realized
assert not a.uop.is_realized
def test_simplify_padded_const(self):
a = Tensor.empty(1022).cummax(axis=0)
@@ -404,7 +404,7 @@ class TestSchedule(unittest.TestCase):
sched = check_schedule([a, b], 1)
run_schedule(sched)
# a and b share the same underlying device memory
self.assertIs(a.lazydata.realized, b.lazydata.realized)
self.assertIs(a.uop.realized, b.uop.realized)
def test_clone_doesnt_dedup(self):
src = Tensor.ones(4).contiguous().realize()
@@ -413,7 +413,7 @@ class TestSchedule(unittest.TestCase):
sched = check_schedule([a, b], 2, filter_sink=False)
run_schedule(sched)
# a and b are assigned to the same device Buffer
self.assertIsNot(a.lazydata.realized, b.lazydata.realized)
self.assertIsNot(a.uop.realized, b.uop.realized)
# EMPTY is assigned to a unique device Buffer
@@ -422,7 +422,7 @@ class TestSchedule(unittest.TestCase):
b = Tensor.empty((4,))
# NOTE: empty does not have any schedule
check_schedule([a, b], 0, filter_sink=False)
self.assertIsNot(a.lazydata.buffer, b.lazydata.buffer)
self.assertIsNot(a.uop.buffer, b.uop.buffer)
def test_dedup_outputs(self):
a = Tensor.full((4, 4), 1.).contiguous().realize()
@@ -712,13 +712,13 @@ class TestSchedule(unittest.TestCase):
prev_a = (a+1).contiguous()
a.assign(Tensor([2]))
a.kernelize(prev_a)
assert prev_a.lazydata in a.lazydata.src, "contiguous usage must run before assign"
assert prev_a.uop in a.uop.src, "contiguous usage must run before assign"
self.assertEqual((prev_a+a*3).item(), 1+2*3)
def test_multioutput_ast(self):
a = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().lazydata
b = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().lazydata
c = Tensor.arange(4).realize().lazydata
a = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop
b = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop
c = Tensor.arange(4).realize().uop
kernel = UOp(Ops.KERNEL, src=(a, b, c.base), arg=Kernel(UOp.sink(c.r(Ops.ADD, (0,))+1, c.r(Ops.ADD, (0,))*2)))
assert all(s.op is Ops.BUFFER for s in kernel.src), f"views are not allowed here {kernel}"
kernel = graph_rewrite(kernel, create_ast)
@@ -1623,7 +1623,7 @@ class TestSchedule(unittest.TestCase):
@unittest.skip("disabling subbuffer manually isn't supported anymore")
def test_bitcast_disable_subbufer(self):
x = cast(UOp, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata)
x = cast(UOp, Tensor.empty(1, dtype=dtypes.float32).realize().uop)
a = x.alu(Ops.EXP2).cast(dtypes.int32, True, allow_buffer_view=False)
b = x.cast(dtypes.int32, True, allow_buffer_view=False)
b = a.alu(Ops.ADD, b)
@@ -1660,11 +1660,11 @@ class TestSchedule(unittest.TestCase):
self.assertEqual(GlobalCounters.mem_used-base, 1024)
def test_const_schedule(self):
constv = Tensor.empty(2, 2).lazydata.const_like(10)
constv = Tensor.empty(2, 2).uop.const_like(10)
check_schedule(constv, 0)
def test_const_schedule_contig(self):
constv = Tensor.empty(2, 2).lazydata.const_like(10).contiguous()
constv = Tensor.empty(2, 2).uop.const_like(10).contiguous()
check_schedule(constv, 1)
@unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU")
@@ -1676,8 +1676,8 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 3))
np.testing.assert_allclose(out.numpy(), x.numpy()@y.numpy(), atol=1e-4, rtol=1e-4)
self.assertIsInstance(out.dtype, ImageDType)
self.assertIsNotNone(out.lazydata.base.realized)
self.assertIsInstance(out.lazydata.base.realized.dtype, ImageDType)
self.assertIsNotNone(out.uop.base.realized)
self.assertIsInstance(out.uop.base.realized.dtype, ImageDType)
def _test_fusion(self, shapes, f, cnt):
with Context(DEBUG=0, TRACK_MATCH_STATS=0): args = [Tensor.randn(s).realize() for s in shapes]
@@ -1707,9 +1707,9 @@ class TestSchedule(unittest.TestCase):
a = Tensor.arange(4).reshape(1, 4)
casted_view = a.pad(((0, 1), (0, 0))).cast(dtypes.float)
casted_view.realize()
self.assertEqual(casted_view.lazydata.base.realized.size, 4)
self.assertEqual(casted_view.uop.base.realized.size, 4)
realized_view = casted_view.contiguous().realize()
self.assertEqual(realized_view.lazydata.base.realized.size, 8)
self.assertEqual(realized_view.uop.base.realized.size, 8)
self.assertListEqual(realized_view.tolist(), [[0.0, 1.0, 2.0, 3.0], [0.0, 0.0, 0.0, 0.0]])
# NOTE: we only reorder CAST if it's an EXPAND
@@ -1717,16 +1717,16 @@ class TestSchedule(unittest.TestCase):
a = Tensor.arange(4).reshape(1, 4)
casted_view = a.shrink(((0, 1), (0, 2))).cast(dtypes.float)
casted_view.realize()
self.assertEqual(casted_view.lazydata.base.realized.size, 2)
self.assertEqual(casted_view.uop.base.realized.size, 2)
realized_view = casted_view.contiguous().realize()
self.assertEqual(realized_view.lazydata.base.realized.size, 2)
self.assertEqual(realized_view.uop.base.realized.size, 2)
self.assertListEqual(realized_view.tolist(), [[0, 1]])
def test_cast_const_view(self):
a = Tensor.ones((4, 4), dtype=dtypes.float32)
casted_view = a.cast(dtypes.int32)
run_schedule(check_schedule(casted_view, 0))
self.assertIsNone(casted_view.lazydata.base.realized)
self.assertIsNone(casted_view.uop.base.realized)
realized_const_view = casted_view.contiguous()
run_schedule(check_schedule(realized_const_view, 1))
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
@@ -1998,7 +1998,7 @@ class TestIndexing(unittest.TestCase):
def test_recursive_swizzle(self):
a = Tensor([1,2,3,4]).realize()
for _ in range(24): a = a + a
new_uop = swizzle_rewrite(a.lazydata.reshape((4, 1)))
new_uop = swizzle_rewrite(a.uop.reshape((4, 1)))
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
self.assertEqual(swizzle_cnt(new_uop), 0)
@@ -2012,13 +2012,13 @@ class TestIndexing(unittest.TestCase):
a = Tensor.empty(32, 32).sum(axis=1)+Tensor.empty(1,32)
ast = a.schedule()[0].ast
self.assertEqual(ast.shape, (32, 1))
self.assertEqual(a.lazydata.shape, (1, 32))
self.assertEqual(a.uop.shape, (1, 32))
def test_no_reshape_reduceop(self):
a = Tensor.empty(32, 32).sum(axis=(1,)).contiguous()
ast = a.schedule()[0].ast
self.assertEqual(ast.shape, (32, 1))
self.assertEqual(a.lazydata.shape, (32,))
self.assertEqual(a.uop.shape, (32,))
@track_rewrites(named=True)
def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_left), view_right)
@@ -2161,7 +2161,7 @@ class TestView(unittest.TestCase):
assert b.shape == (10, 10)
sched = check_schedule(b.contiguous(), 1)
self.assertEqual(store_val(sched[-1]).op, Ops.LOAD)
self.assertEqual(store_val(sched[-1]).st_arg, b.lazydata.st)
self.assertEqual(store_val(sched[-1]).st_arg, b.uop.st)
run_schedule(sched)
np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:])
@@ -2175,9 +2175,9 @@ class TestView(unittest.TestCase):
late_mul = a*bv
check_schedule(late_mul, 0)
# the arange doesn't realize
self.assertIsNone(b.lazydata.base.realized)
self.assertIsNone(b.uop.base.realized)
# mul doesn't realize
self.assertIsNone(late_mul.lazydata.base.realized)
self.assertIsNone(late_mul.uop.base.realized)
self.assertEqual(late_mul.tolist(), [0, 0])
# SINK has two branches:
@@ -2192,13 +2192,13 @@ class TestView(unittest.TestCase):
other_child = b+2
s = check_schedule([late_mul, other_child], 2)
# the arange becomes a BUFFER
self.assertIs(b.lazydata.base.op, Ops.BUFFER)
self.assertIs(b.uop.base.op, Ops.BUFFER)
# mul still collapses
self.assertIs(late_mul.lazydata.base.op, Ops.CONST)
self.assertIs(late_mul.uop.base.op, Ops.CONST)
run_schedule(s)
self.assertEqual(other_child.tolist(), [2, 3, 4])
def tensor_rewrite(t) -> UOp: return graph_rewrite(t.lazydata.base, merge_views+symbolic_simple)
def tensor_rewrite(t) -> UOp: return graph_rewrite(t.uop.base, merge_views+symbolic_simple)
class TestSimplifier(unittest.TestCase):
def test_sink_childless_const(self):
x = Tensor(0)
@@ -2222,8 +2222,8 @@ class TestSimplifier(unittest.TestCase):
a = Tensor.empty(4, 4, dtype=dtypes.int)
sink = tensor_rewrite(a*0)
assert UPat(Ops.CONST, arg=0).match(sink, {})
self.assertIs(tensor_rewrite(a*1).base, a.lazydata.base)
self.assertIs(tensor_rewrite(a+0).base, a.lazydata.base)
self.assertIs(tensor_rewrite(a*1).base, a.uop.base)
self.assertIs(tensor_rewrite(a+0).base, a.uop.base)
def test_cast_folding(self):
a = Tensor(1.0).cast(dtypes.int)
@@ -2257,14 +2257,14 @@ class TestConst(unittest.TestCase):
def test_tensor_const(self):
a = Tensor(1)
print(a.lazydata)
self.assertTrue(tensor_const_pm.rewrite(a.lazydata))
print(a.uop)
self.assertTrue(tensor_const_pm.rewrite(a.uop))
def test_tensor_variable(self):
vv = UOp.variable("a", 0, 10).bind(1)
a = Tensor(vv)
print(a.lazydata)
self.assertTrue(tensor_const_pm.rewrite(a.lazydata))
print(a.uop)
self.assertTrue(tensor_const_pm.rewrite(a.uop))
def test_const_schedule(self):
a = Tensor.ones((4, 4))
@@ -2311,7 +2311,7 @@ class TestConst(unittest.TestCase):
sched = add.schedule()
self.assertEqual(len(sched), 0)
# b+0 and b share the same underlying device memory
self.assertIs(add.lazydata.buffer, b.lazydata.buffer)
self.assertIs(add.uop.buffer, b.uop.buffer)
self.assertListEqual(add.tolist(), [2, 2, 2, 2])
def test_src_masked_const_folding(self):
@@ -2324,7 +2324,7 @@ class TestConst(unittest.TestCase):
self.assertEqual(len(sched), 1)
run_schedule(sched)
# add gets assigned to a new buffer
self.assertIsNot(add.lazydata.base.realized, b.lazydata.base.realized)
self.assertIsNot(add.uop.base.realized, b.uop.base.realized)
self.assertListEqual(add.tolist(), [4, 2, 2, 2, 2, 4])
# ** part 3: Tensor variable bindings
@@ -2359,15 +2359,15 @@ class TestCopyFolding(unittest.TestCase):
self.assertListEqual(b.tolist(), [0, 0, 0])
def test_alu_after_copy(self):
a = Tensor.ones((4,)).to("CPU").lazydata
b = Tensor.empty(4, device="CPU").lazydata
a = Tensor.ones((4,)).to("CPU").uop
b = Tensor.empty(4, device="CPU").uop
add = a+b
add = schedule_graph_rewrite(add)
assert all_same([x.device for x in add.src]), f"ALU has different devices! {[x.device for x in add.src]}"
@unittest.skip("this is just clone now")
def test_copy_to_same_device(self):
a = Tensor.empty(4).lazydata
a = Tensor.empty(4).uop
b = a.copy_to_device(a.device)
check_schedule(b, 0, filter_sink=False)
b = schedule_graph_rewrite(b)
@@ -2377,7 +2377,7 @@ class TestCopyFolding(unittest.TestCase):
@unittest.skip("this is just clone now")
def test_copy_to_same_device_alt(self):
a = Tensor.empty(4, 4).lazydata
a = Tensor.empty(4, 4).uop
b = a.copy_to_device(a.device)
check_schedule(b, 0, filter_sink=False)
b = schedule_graph_rewrite(b)
@@ -2394,8 +2394,8 @@ class TestCopyFolding(unittest.TestCase):
b = view.clone()
# NOTE: this was sort of a bug making this 2
run_schedule(check_schedule(b, 2, filter_sink=False))
self.assertEqual(b.lazydata.base.buffer.size, 2)
self.assertEqual(b.lazydata.size, 2)
self.assertEqual(b.uop.base.buffer.size, 2)
self.assertEqual(b.uop.size, 2)
self.assertListEqual(b.tolist(), [0, 1])
def test_expanded_copy(self):
@@ -2403,8 +2403,8 @@ class TestCopyFolding(unittest.TestCase):
view = a.reshape(2, 1).expand(2, 2)
b = view.clone()
run_schedule(check_schedule(b, 2, filter_sink=False))
self.assertEqual(b.lazydata.base.buffer.size, 4)
self.assertEqual(b.lazydata.size, 4)
self.assertEqual(b.uop.base.buffer.size, 4)
self.assertEqual(b.uop.size, 4)
self.assertListEqual(b.tolist(), [[0, 0], [1, 1]])
def test_permuted_copy(self):
@@ -2414,7 +2414,7 @@ class TestCopyFolding(unittest.TestCase):
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
def test_permute_on_disk(self):
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().lazydata.base.buffer.as_buffer())
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer())
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
b = a.reshape(2, 2).permute(1, 0).to("CPU")
b.realize()
@@ -2430,7 +2430,7 @@ class TestCopyFolding(unittest.TestCase):
# TODO: this is wrong because of the permute
@unittest.expectedFailure
def test_permute_after_shrink_on_disk(self):
with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().lazydata.base.buffer.as_buffer())
with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().uop.base.buffer.as_buffer())
a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}")
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CPU")
b.realize()
@@ -2442,13 +2442,13 @@ class TestTensorUOpSpec(unittest.TestCase):
unsafe_push_views = PatternMatcher([
(UPat.cvar("root").view(name="view"), lambda root,view: root.replace(src=tuple(x.view(view.st) for x in root.src))),
])
a.lazydata = graph_rewrite(a.lazydata.sink(), merge_views+merge_views+unsafe_push_views)
a.uop = graph_rewrite(a.uop.sink(), merge_views+merge_views+unsafe_push_views)
with self.assertRaisesRegex(RuntimeError, "UOp verification failed"):
a.schedule()
def test_expanded_const_ok(self):
a = Tensor.ones((4, 4))
t = graph_rewrite(a.lazydata.sink(), merge_views+merge_views)
t = graph_rewrite(a.uop.sink(), merge_views+merge_views)
create_schedule_with_vars(t)
# NOTE: changing symbolic CONST VIEWs is not allowed
@@ -2456,69 +2456,69 @@ class TestTensorUOpSpec(unittest.TestCase):
def test_symbolic_shape_ok(self):
a = Tensor.ones(4)
vi = UOp.variable("i", 1, 10).bind(4)
a.lazydata = graph_rewrite(a.reshape(vi).sum().lazydata, merge_views+merge_views)
a.uop = graph_rewrite(a.reshape(vi).sum().uop, merge_views+merge_views)
a.schedule()
class TestBufferUOp(unittest.TestCase):
# BUFFER has a ShapeTracker of shape=(n,) and stride=(1,)
def test_buffer_has_buffer(self):
buf = Tensor.empty(10)
self.assertIsNotNone(buf.lazydata.buffer)
self.assertEqual(buf.lazydata.st, ShapeTracker.from_shape((10,)))
self.assertIsNotNone(buf.uop.buffer)
self.assertEqual(buf.uop.st, ShapeTracker.from_shape((10,)))
# the device Buffer remains unallocated until it's we run the schedule
self.assertFalse(buf.lazydata.buffer.is_allocated())
self.assertFalse(buf.uop.buffer.is_allocated())
add = buf+1
sched = add.schedule()
self.assertFalse(buf.lazydata.buffer.is_allocated())
self.assertFalse(buf.uop.buffer.is_allocated())
run_schedule(sched)
self.assertTrue(buf.lazydata.buffer.is_allocated())
self.assertTrue(buf.uop.buffer.is_allocated())
def test_buffer_has_unique_buffer(self):
buf = Tensor.empty(10)
buf1 = buf.lazydata.buffer
buf2 = buf.lazydata.buffer
buf1 = buf.uop.buffer
buf2 = buf.uop.buffer
self.assertIs(buf1, buf2)
# we also allow VIEW(BUFFER) to access the underlying device Buffer, as long as it's contiguous
def test_buffer_view_allowed(self):
add = Tensor.empty(1, 1)+Tensor.empty(1, 1)
add.realize()
self.assertIsNotNone(add.lazydata.buffer)
self.assertEqual(add.lazydata.shape, (1, 1))
self.assertIsNotNone(add.uop.buffer)
self.assertEqual(add.uop.shape, (1, 1))
def test_buffer_view_not_allowed(self):
permuted_view = Tensor.empty(1, 2, 3).permute(0, 2, 1)
merged = graph_rewrite(permuted_view.lazydata, merge_views)
merged = graph_rewrite(permuted_view.uop, merge_views)
with self.assertRaisesRegex(AssertionError, "VIEW only works here if it's contiguous"):
merged.buffer # cannot access Buffer of a non contiguous VIEW
def test_buffer_only_after_realize(self):
a = Tensor([1])+Tensor([2])
# accessing realized will return None
self.assertIsNone(a.lazydata.realized)
self.assertIsNone(a.uop.realized)
# accessing Buffer will assert
with self.assertRaisesRegex(AssertionError, "must be BUFFER"):
a.lazydata.buffer # there is no BUFFER on an unrealized ADD
a.uop.buffer # there is no BUFFER on an unrealized ADD
# Buffer only exists once we realize it
a.realize()
self.assertIsNotNone(a.lazydata.buffer)
self.assertIsNotNone(a.uop.buffer)
def test_const_does_not_realize(self):
a = Tensor(1)+Tensor(2)
run_schedule(check_schedule(a, 0))
self.assertIsNone(a.lazydata.base.realized)
self.assertIsNone(a.uop.base.realized)
def test_var_does_not_realize(self):
a = Tensor(UOp.variable("a", 0, 10).bind(1))
run_schedule(check_schedule(a, 0))
self.assertIsNone(a.lazydata.base.realized)
self.assertIsNone(a.uop.base.realized)
def test_view_does_not_realize(self):
a = Tensor.randn(1, 4).expand(4, 4)
a.realize()
self.assertEqual(a.lazydata.base.realized.size, 4)
self.assertEqual(a.uop.base.realized.size, 4)
a2 = a.contiguous().realize()
self.assertEqual(a2.lazydata.base.realized.size, 16)
self.assertEqual(a2.uop.base.realized.size, 16)
class TestContiguous(unittest.TestCase):
def test_contiguous_buffer(self):
@@ -2550,13 +2550,13 @@ class TestContiguous(unittest.TestCase):
a = Tensor.empty(4)
b = a.expand((4, 4))
check_schedule(b, 0)
self.assertEqual(b.lazydata.base.buffer.size, 4)
self.assertEqual(b.uop.base.buffer.size, 4)
def test_contiguous_view_realizes(self):
a = Tensor.empty(4)
b = a.expand((4, 4)).contiguous()
check_schedule(b, 1)
self.assertEqual(b.lazydata.base.buffer.size, 16)
self.assertEqual(b.uop.base.buffer.size, 16)
class TestUOpBecome(unittest.TestCase):
# the simplest case, if we create a new BUFFER for this tensor UOp
@@ -2566,21 +2566,21 @@ class TestUOpBecome(unittest.TestCase):
add = a+b
check_schedule(add, 1)
# NOTE: realized base is always a flat buffer
assert UPat(Ops.BUFFER).match(add.lazydata.base, {})
assert UPat(Ops.BUFFER).match(add.uop.base, {})
# the Tensor UOp can optionally stack a VIEW on top of the BUFFER, in this case to preserve the (4, 4) shape of the tensor
assert add.lazydata is not add.lazydata.base
self.assertEqual(add.lazydata.size, 16)
self.assertEqual(add.lazydata.shape, (4, 4))
assert add.uop is not add.uop.base
self.assertEqual(add.uop.size, 16)
self.assertEqual(add.uop.shape, (4, 4))
def test_new_buffer_view(self):
a = Tensor.empty(4, 4)
b = Tensor.empty(4, 4)
add = (a+b).reshape(8, 2)
check_schedule(add, 1)
assert UPat(Ops.BUFFER).match(add.lazydata.base, {})
assert UPat(Ops.BUFFER).match(add.uop.base, {})
# the shape is preserverd in the becomes_map.
self.assertEqual(add.lazydata.shape, (8, 2))
assert add.lazydata is not add.lazydata.base
self.assertEqual(add.uop.shape, (8, 2))
assert add.uop is not add.uop.base
def test_new_flat_buffer(self):
a = Tensor.empty(4,)
@@ -2588,7 +2588,7 @@ class TestUOpBecome(unittest.TestCase):
add = a+b
check_schedule(add, 1)
# BUFFER already has a shape (4,), this tensor just becomes a contiguous BUFFER
assert UPat(Ops.BUFFER).match(add.lazydata, {})
assert UPat(Ops.BUFFER).match(add.uop, {})
# sometimes we prefer to perform an op before movement ops, in this case we should stack the mops on top of the new buffer
@@ -2597,8 +2597,8 @@ class TestUOpBecome(unittest.TestCase):
a = Tensor.empty(4, 1)
b = a.expand(4, 4).reciprocal()
check_schedule(b, 1)
self.assertEqual(b.lazydata.base.buffer.size, 16)
self.assertEqual(b.lazydata.st, ShapeTracker.from_shape((4, 4)))
self.assertEqual(b.uop.base.buffer.size, 16)
self.assertEqual(b.uop.st, ShapeTracker.from_shape((4, 4)))
def test_reorder_expand_alt(self):
x = Tensor.empty(4, 1)
@@ -2610,95 +2610,95 @@ class TestUOpBecome(unittest.TestCase):
def test_become_existing_buffer(self):
a = Tensor.empty(4, 4)
b = a*1
assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul
assert UPat(Ops.MUL).match(b.uop, {}) # before scheduling it's a mul
check_schedule(b, 0)
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.lazydata, {}) # scheduling merges all MovementOps into a single VIEW
self.assertIs(a.lazydata.base.buffer, b.lazydata.base.buffer)
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.uop, {}) # scheduling merges all MovementOps into a single VIEW
self.assertIs(a.uop.base.buffer, b.uop.base.buffer)
def test_become_buf_with_mops(self):
a = Tensor.empty(2, 4, 2)
noop = a.shrink(((1, 2), (0, 4), (0, 2))).reshape(4, 2)*1+0
# before realizing, this tensor is base
assert noop.lazydata is noop.lazydata.base
assert noop.uop is noop.uop.base
noop.realize()
# it becomes a realized view after realize
assert noop.lazydata is not noop.lazydata.base
assert noop.lazydata.base.op is Ops.BUFFER
assert noop.uop is not noop.uop.base
assert noop.uop.base.op is Ops.BUFFER
late_add = noop+2
late_add.realize()
def test_become_const_in_base(self):
a = Tensor.empty(4)
b = a*0
assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul
assert UPat(Ops.MUL).match(b.uop, {}) # before scheduling it's a mul
check_schedule(b, 0)
assert UPat(Ops.CONST, arg=0).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER)
assert UPat(Ops.CONST, arg=0).match(b.uop.base, {}) # scheduling replaces the tensor uop with a VIEW(BUFFER)
def test_become_const_in_view(self):
# if we shrink the base down to a size 0, only the VIEW becomes CONST, base is unchanged.
add = Tensor.empty(2, 2)+Tensor.empty(2, 2)
b = add.shrink(((0, 1), (0, 0)))
check_schedule(b, 0)
assert UPat(Ops.CONST, arg=0).match(b.lazydata, {})
assert UPat(Ops.CONST, arg=0).match(b.uop, {})
self.assertEqual(b.shape, (1, 0))
# the base is untouched.
assert UPat(Ops.ADD).match(add.lazydata, {})
assert UPat(Ops.ADD).match(add.uop, {})
def test_become_const_from_const(self):
const_add = Tensor(1)+Tensor(2)
assert UPat(Ops.ADD).match(const_add.lazydata, {})
assert UPat(Ops.ADD).match(const_add.uop, {})
check_schedule(const_add, 0)
assert UPat(Ops.CONST, arg=3).match(const_add.lazydata.base, {})
assert UPat(Ops.CONST, arg=3).match(const_add.uop.base, {})
# tensors can become another realized tensor source
def test_become_existing_buf_simple(self):
a = Tensor.empty(4, 4)
b = a+0
check_schedule(b, 0)
assert b.lazydata.base.op is Ops.BUFFER
self.assertIs(a.lazydata, b.lazydata)
assert b.uop.base.op is Ops.BUFFER
self.assertIs(a.uop, b.uop)
# they can also chain other movement ops on top of the tensor source
def test_become_existing_buf_view(self):
a = Tensor.empty(4, 4)
b = a.permute((1, 0))+0
check_schedule(b, 0)
self.assertEqual(b.lazydata.st, a.lazydata.permute((1, 0)).st)
self.assertEqual(b.uop.st, a.uop.permute((1, 0)).st)
def test_become_existing_buf_view_alt(self):
a = Tensor.empty(4, 4)
b = a.permute((1, 0)).reshape((8, 2))+0
check_schedule(b, 0)
self.assertEqual(b.lazydata.st, a.lazydata.permute((1, 0)).reshape((8, 2)).st)
self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st)
# they can also have other base parents that simplified, in that case we just backtrack to the chained mops
def test_become_existing_buf_complex(self):
a = Tensor.empty(4, 4)
b = (a.permute((1, 0))+0).reshape((8, 2))+0
check_schedule(b, 0)
self.assertEqual(b.lazydata.st, a.lazydata.permute((1, 0)).reshape((8, 2)).st)
assert b.lazydata.base.op is Ops.BUFFER
self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st)
assert b.uop.base.op is Ops.BUFFER
def test_become_multiple_choices(self):
a = Tensor.empty(16)
b = (a.reshape(1, 1, 4, 1, 4)+0).reshape(1, 1, 4, 4).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0
c = (a.reshape(1, 1, 4, 4)+0).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0
check_schedule([b, c], 0)
assert all_same([x.lazydata.base.realized for x in [a,b,c]])
assert all_same([x.uop.base.realized for x in [a,b,c]])
# these movement ops result in the same ShapeTracker
assert b.lazydata.st == c.lazydata.st
assert b.lazydata is c.lazydata
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(c.lazydata, {})
assert b.uop.st == c.uop.st
assert b.uop is c.uop
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(c.uop, {})
def test_setitem_becomes_subbuffer(self):
a = Tensor.full((4,), 2.).contiguous().realize()
b = a.shrink(((0, 2),)).assign(Tensor.full((2,), 1.0))
b.realize()
assert a.lazydata.is_realized
assert a.lazydata.buffer._base is None
assert a.uop.is_realized
assert a.uop.buffer._base is None
# b is a subbuffer of a
assert b.lazydata.op is Ops.BUFFER_VIEW
assert b.lazydata.src[0] is a.lazydata
assert b.uop.op is Ops.BUFFER_VIEW
assert b.uop.src[0] is a.uop
def test_setitem_offset(self):
a = Tensor.full((16,), 0.).contiguous().realize()

View File

@@ -44,7 +44,7 @@ class TestTimeLinearizer(unittest.TestCase):
# Ensure that the kernel count is not incremented by time_linearizer when clearing l2
def test_kernel_count(self):
ast = Tensor.zeros(16).contiguous().kernelize().lazydata.src[1].arg.ast
ast = Tensor.zeros(16).contiguous().kernelize().uop.src[1].arg.ast
lin = Kernel(ast)
bufs = bufs_from_lin(lin)

View File

@@ -50,7 +50,7 @@ class TestSetitem(unittest.TestCase):
def test_setitem_into_noncontiguous(self):
t = Tensor.ones(4)
self.assertFalse(t.lazydata.st.contiguous)
self.assertFalse(t.uop.st.contiguous)
with self.assertRaises(RuntimeError): t[1] = 5
@unittest.skip("TODO: flaky")

View File

@@ -49,11 +49,11 @@ class TestSymbolic(unittest.TestCase):
j = Variable("j", 1, 5).bind(3)
k = Variable("k", 1, 5).bind(3)
t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
st = t.lazydata.st
st = t.uop.st
self.assert_tuple_equal(st.shape, (i+j+k, 4))
assert st.real_strides() == (4, 1)
t = Tensor.rand(3, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
st = t.lazydata.st
st = t.uop.st
self.assert_tuple_equal(st.shape, (2*i+3, 3))
assert st.real_strides() == (3, 1)
@@ -62,7 +62,7 @@ class TestSymbolic(unittest.TestCase):
j = Variable("j", 1, 5).bind(4)
k = Variable("k", 1, 5).bind(4)
t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
st = t.lazydata.st
st = t.uop.st
self.assert_tuple_equal(st.shape, (3, i+j+k))
self.assert_tuple_equal(st.real_strides(), (i+j+k, 1))
@@ -113,7 +113,7 @@ class TestShapeTrackerUnbind(unittest.TestCase):
v = Variable("v", 1, 100)
bv = Variable("v", 1, 100).bind(3)
t = Tensor.rand(3, 4).reshape(bv, 4)
unbound_st, var_val = t.lazydata.st.unbind()
unbound_st, var_val = t.uop.st.unbind()
assert unbound_st == ShapeTracker((View.create(shape=(v, 4)),))
assert var_val == {v: 3}
@@ -121,7 +121,7 @@ class TestShapeTrackerUnbind(unittest.TestCase):
v = Variable("v", 1, 100)
bv = Variable("v", 1, 100).bind(2)
t = Tensor.rand(3, 4).shrink(((bv, bv+1), (0, 4)))
unbound_st, var_val = t.lazydata.st.unbind()
unbound_st, var_val = t.uop.st.unbind()
assert unbound_st == ShapeTracker((View.create(shape=(1, 4), offset=4*v),))
assert var_val == {v: 2}
@@ -180,8 +180,8 @@ class TestSymbolicReshapeFromNonContiguous(unittest.TestCase):
vi = Variable("i", 1, 5).bind(4)
t = Tensor.ones(3, 4).reshape(3, vi)
assert t.shape == (3, vi)
assert not t.lazydata.st.contiguous
assert len(t.lazydata.st.views) == 1
assert not t.uop.st.contiguous
assert len(t.uop.st.views) == 1
def test_reshape_not_allowed(self):
vi = Variable("i", 1, 5).bind(4)
@@ -195,12 +195,12 @@ class TestSymbolicReshapeFromNonContiguous(unittest.TestCase):
def test_reshape_from_padded(self):
vi = Variable("i", 1, 5).bind(4)
t = Tensor.ones(3, 4).contiguous().expand(2, 3, 4).pad(((1, 1), None, None)).shrink((None, None, (1, 3)))
st = t.lazydata.st
st = t.uop.st
assert len(st.views) == 1
view = st.views[0]
assert view.shape == (4, 3, 2)
t = t.reshape(vi, 3, 2)
st2 = t.lazydata.st
st2 = t.uop.st
assert len(st2.views) == 1
view2 = st2.views[0]
# check only shape changed. strides, offset, mask, contiguous remained the same
@@ -237,7 +237,7 @@ class TestSymbolicPad(unittest.TestCase):
v = Variable("v", 1, 100).bind(5)
t = Tensor.ones(5).reshape(v).pad(((4, 0),)).reshape(9)
assert t.shape == (9,)
st = t.lazydata.st
st = t.uop.st
print(st)
if __name__ == '__main__':

View File

@@ -565,17 +565,17 @@ class TestZeroShapeTensor(unittest.TestCase):
t = Tensor.empty(3, 2, 0)
assert t.shape == (3, 2, 0)
# numpy has stride 0, 0, 0; torch has stride 2, 1, 1
assert t.lazydata.st.real_strides() == (0, 0, 0)
assert t.uop.st.real_strides() == (0, 0, 0)
t = Tensor.empty(3, 0, 2)
assert t.shape == (3, 0, 2)
# numpy has stride 0, 0, 0; torch has stride 2, 2, 1
assert t.lazydata.st.real_strides() == (0, 0, 0)
assert t.uop.st.real_strides() == (0, 0, 0)
t = Tensor.empty(0, 0, 0)
assert t.shape == (0, 0, 0)
# numpy has stride 0, 0, 0; torch has stride 1, 1, 1
assert t.lazydata.st.real_strides() == (0, 0, 0)
assert t.uop.st.real_strides() == (0, 0, 0)
def test_rand(self):
t = Tensor.rand(3, 2, 0)
@@ -690,24 +690,24 @@ class TestZeroShapeTensor(unittest.TestCase):
a = Tensor.rand(16, 16).realize()
b = a.clone()
np.testing.assert_allclose(a.numpy(), b.numpy())
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
self.assertIsNot(a.uop.base.buffer, b.uop.base.buffer)
a = Tensor.rand(16, 16).mul(5.0).add(5.0).realize()
b = a.clone()
np.testing.assert_allclose(a.numpy(), b.numpy())
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
self.assertIsNot(a.uop.base.buffer, b.uop.base.buffer)
def test_clone_with_shrink(self):
a = Tensor.rand(16, 16)
b = a.shrink(((2, 10), None)).clone()
b.realize()
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
self.assertIsNot(a.uop.base.buffer, b.uop.base.buffer)
def test_clone_with_shrink_realized(self):
a = Tensor.rand(16, 16).realize()
b = a.shrink(((2, 10), None)).clone()
b.realize()
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
self.assertIsNot(a.uop.base.buffer, b.uop.base.buffer)
def test_clone_with_grad(self):
a = Tensor.rand(16, 16, requires_grad=True)
@@ -780,7 +780,7 @@ class TestTensorMetadata(unittest.TestCase):
@unittest.skip("why would this be true?")
def test_exclude_noop_metadata(self):
a = Tensor.rand(4, 4)*1
self.assertEqual(a.lazydata.metadata[0].name, "__mul__")
self.assertEqual(a.uop.metadata[0].name, "__mul__")
k = a.schedule()[-1]
self.assertEqual([m.name for m in k.metadata], ["rand"])
@@ -797,7 +797,7 @@ class TestTensorMetadata(unittest.TestCase):
x = Tensor.rand(3, requires_grad=True)
W = Tensor.rand(3, 3, requires_grad=True)
out = x.matmul(W)
self.assertEqual(out.lazydata.metadata[0].name, "matmul")
self.assertEqual(out.uop.metadata[0].name, "matmul")
si = out.schedule()[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "matmul")
@@ -805,7 +805,7 @@ class TestTensorMetadata(unittest.TestCase):
def test_relu(self):
x = Tensor.rand(3, requires_grad=True)
out = x.relu()
self.assertEqual(out.lazydata.metadata[0].name, "relu")
self.assertEqual(out.uop.metadata[0].name, "relu")
si = out.schedule()[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "relu")
@@ -814,9 +814,9 @@ class TestTensorMetadata(unittest.TestCase):
x = Tensor.rand(3, requires_grad=True)
y = Tensor.rand(3, requires_grad=True)
out = x.relu() * y.sigmoid()
self.assertEqual(out.lazydata.metadata[0].name, "__mul__")
self.assertEqual(out.lazydata.src[0].metadata[0].name, "relu")
self.assertEqual(out.lazydata.src[1].metadata[0].name, "sigmoid")
self.assertEqual(out.uop.metadata[0].name, "__mul__")
self.assertEqual(out.uop.src[0].metadata[0].name, "relu")
self.assertEqual(out.uop.src[1].metadata[0].name, "sigmoid")
si = out.schedule()[-1]
self.assertEqual(len(si.metadata), 3)
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
@@ -825,12 +825,12 @@ class TestTensorMetadata(unittest.TestCase):
x = Tensor.rand(3, requires_grad=True).realize()
y = Tensor.rand(3, requires_grad=True).realize()
out = (x.relu() * y.sigmoid()).sum()
self.assertEqual(out.lazydata.metadata[0].name, "sum")
self.assertEqual(out.uop.metadata[0].name, "sum")
out.backward()
self.assertEqual(x.grad.lazydata.metadata[0].name, "relu")
self.assertTrue(x.grad.lazydata.metadata[0].backward)
self.assertEqual(y.grad.lazydata.metadata[0].name, "sigmoid")
self.assertTrue(y.grad.lazydata.metadata[0].backward)
self.assertEqual(x.grad.uop.metadata[0].name, "relu")
self.assertTrue(x.grad.uop.metadata[0].backward)
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
self.assertTrue(y.grad.uop.metadata[0].backward)
si = Tensor.schedule(out, x.grad, y.grad)[-1]
self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}")
self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "__mul__", "relu"})

View File

@@ -9,7 +9,7 @@ class TestTensorUOp(unittest.TestCase):
def test_fromcpu_shape_tracker(self):
def helper(a: np.ndarray):
print(a.shape, a.strides, a.flags.c_contiguous)
b = Tensor(a).lazydata
b = Tensor(a).uop
#assert b.st.contiguous == a.flags.c_contiguous
assert b.st.shape == a.shape
np.testing.assert_equal(a, Tensor(b).numpy())
@@ -60,11 +60,11 @@ class TestTensorUOp(unittest.TestCase):
np.testing.assert_allclose(c.numpy(), np.concatenate((a.numpy(), b.numpy()), axis=1))
def test_const_dtype(self):
lb: UOp = Tensor([1], dtype=dtypes.int).lazydata
lb: UOp = Tensor([1], dtype=dtypes.int).uop
assert lb.const_like(1).base.arg == 1
assert type(lb.const_like(1).base.arg) is int
lb: UOp = Tensor([1], dtype=dtypes.float).lazydata
lb: UOp = Tensor([1], dtype=dtypes.float).uop
assert lb.const_like(1).base.arg == 1.0
assert type(lb.const_like(1).base.arg) is float

View File

@@ -476,7 +476,7 @@ class TestUOpStr(unittest.TestCase):
assert str(eval(str(device))) == str(device)
def test_reduceop_arg(self):
sum_uop = Tensor.empty(32, 32).sum().lazydata
sum_uop = Tensor.empty(32, 32).sum().uop
assert str(eval(str(sum_uop))) == str(sum_uop)
@unittest.skip("uop no longer has order like this")
@@ -549,9 +549,9 @@ class TestShapeSpec(unittest.TestCase):
# ** CONST is CONST(VIEW(DEVICE)) -> RESHPAE -> EXPAND
def test_expanded_const(self):
a = Tensor(1).lazydata
a = Tensor(1).uop
self.assertEqual(a.st, ShapeTracker.from_shape(()))
a = Tensor.ones((4, 4)).lazydata
a = Tensor.ones((4, 4)).uop
self.assertEqual(a.st, ShapeTracker.from_shape(()).reshape((1,1)).expand((4,4)))
def test_padded_const(self):
@@ -569,12 +569,12 @@ class TestShapeSpec(unittest.TestCase):
# NOTE: CONST ShapeTracker comes from its source
def test_scalar_const(self):
a = Tensor(0).lazydata
a = Tensor(0).uop
self.assertEqual(a.st, ShapeTracker.from_shape(()))
def test_scalar_var(self):
vv = UOp.variable("a", 1, 4).bind(2)
t = Tensor(vv).lazydata
t = Tensor(vv).uop
self.assertEqual(t.st, ShapeTracker.from_shape(()))
# ** ASSIGN is ASSIGN(VIEW(BUFFER), new_val)
@@ -583,7 +583,7 @@ class TestShapeSpec(unittest.TestCase):
buffer = Tensor.arange(4).realize()
a = buffer.assign(Tensor.zeros((4,), dtype=dtypes.int))
assign_pattern = UPat(Ops.ASSIGN, src=(UPat(Ops.BUFFER), UPat()))
assert assign_pattern.match(a.lazydata, {})
assert assign_pattern.match(a.uop, {})
a.realize()
self.assertEqual(buffer.tolist(), [0, 0, 0, 0])
@@ -597,7 +597,7 @@ class TestShapeSpec(unittest.TestCase):
buffer = Tensor.ones((4,)).contiguous().realize()
a = buffer.reshape((2, 2)).assign(Tensor.zeros((2, 2)))
assign_pattern = UPat(Ops.ASSIGN, src=(UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER))), UPat()))
assert assign_pattern.match(a.lazydata, {})
assert assign_pattern.match(a.uop, {})
a.realize()
self.assertEqual(buffer.tolist(), [0, 0, 0, 0])
@@ -606,13 +606,13 @@ class TestShapeSpec(unittest.TestCase):
a = Tensor.ones((4,)).contiguous().realize()
assign = a.shrink(((1, 2),)).assign(Tensor.zeros((1,)))
# the ASSIGN UOp has size=1
self.assertEqual(assign.lazydata.size, 1)
self.assertEqual(assign.uop.size, 1)
# the ASSIGN views the buffer with a shrunk st
self.assertEqual(assign.lazydata.src[0].st, ShapeTracker.from_shape((4,)).shrink(((1, 2),)))
self.assertEqual(assign.uop.src[0].st, ShapeTracker.from_shape((4,)).shrink(((1, 2),)))
# the underlying BUFFER has a size=4
self.assertEqual(assign.lazydata.buf_uop.size, 4)
self.assertEqual(assign.uop.buf_uop.size, 4)
# NOTE: output shape is different from the BUFFER shape
self.assertNotEqual(assign.lazydata.shape, a.lazydata.shape)
self.assertNotEqual(assign.uop.shape, a.uop.shape)
assign.realize()
self.assertEqual(a.tolist(), [1, 0, 1, 1])
@@ -622,13 +622,13 @@ class TestShapeSpec(unittest.TestCase):
def test_ops_st(self):
# view / mop
a = Tensor.empty(4, 2, 1).permute((1, 2, 0)).lazydata
a = Tensor.empty(4, 2, 1).permute((1, 2, 0)).uop
self.assertEqual(a.st, ShapeTracker.from_shape((4, 2, 1)).permute((1, 2, 0)))
# alu / reduce
alu = a*2
self.assertEqual(alu.st, ShapeTracker.from_shape((2, 1, 4)))
r = Tensor.empty(4, 4).sum(axis=1)
self.assertEqual(r.lazydata.st, ShapeTracker.from_shape((4,)))
self.assertEqual(r.uop.st, ShapeTracker.from_shape((4,)))
def test_st_wmma_none(self):
A = UOp(Ops.DEFINE_VAR, dtypes.float.vec(16), arg=('a', UOp.const(dtypes.float, 0), UOp.const(dtypes.float, 1)))

View File

@@ -6,7 +6,7 @@ def time_tensor_numpy(out:Tensor):
times = []
for _ in range(5):
st = time.perf_counter()
out.lazydata.base.realized.as_buffer(allow_zero_copy=True)
out.uop.base.realized.as_buffer(allow_zero_copy=True)
et = time.perf_counter() - st
times.append(et)
return min(times)

View File

@@ -113,16 +113,16 @@ class TestTensorGradient(unittest.TestCase):
class TestRealizeMeansRealize(unittest.TestCase):
def test_randn_realizes(self):
x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize()
assert x.lazydata is not x.lazydata.base
assert x.lazydata.is_realized
assert x.uop is not x.uop.base
assert x.uop.is_realized
#@unittest.expectedFailure
# update: passing after delete_forced_realize
def test_uniform_realizes(self):
x = Tensor.uniform(16, 3, 3, 3, requires_grad=True).realize()
print(x.lazydata)
assert x.lazydata is not x.lazydata.base
assert x.lazydata.is_realized
print(x.uop)
assert x.uop is not x.uop.base
assert x.uop.is_realized
# NOTE: even though it doesn't realize, this seems fine
def test_uniform_gradient(self):

View File

@@ -836,29 +836,29 @@ class TestConsecutive(unittest.TestCase):
self.ones = Tensor.ones(2, 4)
def test_unmodified(self):
assert self.t.lazydata.st.consecutive
assert self.t.reshape(4, 2).lazydata.st.consecutive
assert self.t.reshape(1, 8).lazydata.st.consecutive
assert self.t.uop.st.consecutive
assert self.t.reshape(4, 2).uop.st.consecutive
assert self.t.reshape(1, 8).uop.st.consecutive
def test_sliced(self):
assert self.t[0].lazydata.st.consecutive
assert self.t[0, 1:2].lazydata.st.consecutive
assert self.t[1].lazydata.st.consecutive
assert not self.t[:, 0].lazydata.st.consecutive
assert not self.t[:, 1].lazydata.st.consecutive
assert self.t[0].uop.st.consecutive
assert self.t[0, 1:2].uop.st.consecutive
assert self.t[1].uop.st.consecutive
assert not self.t[:, 0].uop.st.consecutive
assert not self.t[:, 1].uop.st.consecutive
def test_padded(self):
assert not self.t.pad(((1, 1), None)).lazydata.st.consecutive
assert not self.t.pad((None, (1, 1))).lazydata.st.consecutive
assert not self.t.pad(((1, 1), None)).uop.st.consecutive
assert not self.t.pad((None, (1, 1))).uop.st.consecutive
def test_const(self):
assert self.const.lazydata.st.consecutive
assert self.const.uop.st.consecutive
def test_ones(self):
assert not self.ones.lazydata.st.consecutive
assert not self.ones[0, :].lazydata.st.consecutive
assert not self.ones.uop.st.consecutive
assert not self.ones[0, :].uop.st.consecutive
# consecutive if sliced into size 1
assert self.ones[0, 0].lazydata.st.consecutive
assert self.ones[0, 0].uop.st.consecutive
class TestRender(unittest.TestCase):
def test_render(self):

View File

@@ -8,21 +8,21 @@ realized_pattern = UPat(Ops.BUFFER)
buffer_view_pattern = UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),))
const_pattern = UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),),)))
def is_pattern_uop(u:UOp, pat:UPat): assert pat.match(u, {}), f"{u}\nis not\n{pat}"
def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.lazydata, pat)
def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.uop, pat)
class TestTensorMutates(unittest.TestCase):
def test_mutate_add(self):
a = Tensor([1,2,3])
b = Tensor([4,5,6])
ret = a+b
pa = a.lazydata
pb = b.lazydata
pr = ret.lazydata
pa = a.uop
pb = b.uop
pr = ret.uop
ret.schedule()
self.assertIsNot(pa, a.lazydata)
self.assertIsNot(pb, b.lazydata)
self.assertIsNot(pr, ret.lazydata)
for t in [a,b,ret]: is_pattern_uop(t.lazydata.base, realized_pattern)
self.assertIsNot(pa, a.uop)
self.assertIsNot(pb, b.uop)
self.assertIsNot(pr, ret.uop)
for t in [a,b,ret]: is_pattern_uop(t.uop.base, realized_pattern)
def test_reshape_is_same_parent(self):
a = Tensor([1,2,3])
@@ -30,11 +30,11 @@ class TestTensorMutates(unittest.TestCase):
c = a+b
d = (a+b).reshape(3,1)
d.realize()
is_pattern_uop(d.lazydata.base, realized_pattern)
is_pattern_uop(c.lazydata.base, realized_pattern)
is_pattern_uop(d.uop.base, realized_pattern)
is_pattern_uop(c.uop.base, realized_pattern)
# NOTE: we keep movement ops on top of the buffer view
is_pattern_uop(c.lazydata, UPat(Ops.BUFFER))
is_pattern_uop(d.lazydata, UPat(Ops.VIEW, src=(realized_pattern,)))
is_pattern_uop(c.uop, UPat(Ops.BUFFER))
is_pattern_uop(d.uop, UPat(Ops.VIEW, src=(realized_pattern,)))
def test_reshape_is_same_child(self):
a = Tensor([1,2,3])
@@ -42,41 +42,41 @@ class TestTensorMutates(unittest.TestCase):
c = a+b
d = (a+b).reshape(3,1)
c.realize()
is_pattern_uop(c.lazydata.base, realized_pattern)
is_pattern_uop(d.lazydata.base, realized_pattern)
is_pattern_uop(c.uop.base, realized_pattern)
is_pattern_uop(d.uop.base, realized_pattern)
class TestTensorUopRepresentation(unittest.TestCase):
def test_realized(self):
a = Tensor([1.,2,3]).realize()
print(a.lazydata)
is_pattern_uop(a.lazydata.base, realized_pattern)
print(a.uop)
is_pattern_uop(a.uop.base, realized_pattern)
def test_add_realized(self):
a = Tensor([1.,2,3]).realize()
b = Tensor([4.,5,6]).realize()
c = a+b
print(c.lazydata)
print(c.uop)
is_pattern(c, UPat(Ops.ADD, src=(realized_pattern, realized_pattern)))
def test_const_pattern(self):
a = Tensor(1)
print(a.lazydata)
print(a.uop)
is_pattern(a, const_pattern) # const in tensor has a DEVICE and VIEW src
is_pattern(a, UPat.cvar("x")) # even cvar works!
def test_consts_do_not_realize(self):
a = Tensor(1)
print(a.lazydata)
pre_realize = a.lazydata
print(a.uop)
pre_realize = a.uop
a.realize()
assert a.lazydata is pre_realize
assert a.uop is pre_realize
def test_viewed_consts_do_not_realize(self):
a = Tensor.ones(10, 10)
print(a.lazydata)
print(a.uop)
a.realize()
is_pattern(a, const_pattern)
self.assertEqual(a.lazydata.shape, (10, 10))
self.assertEqual(a.uop.shape, (10, 10))
# currently, CONSTs have a "fake" BUFFER. this should be fixed
# current:
@@ -93,8 +93,8 @@ class TestTensorUopRepresentation(unittest.TestCase):
# UOp(Ops.DEVICE, dtypes.void, arg="METAL", src=()),)),)),))
def test_consts_dont_have_buffers(self):
a = Tensor.ones(10, 10)
print(a.lazydata)
buffers_in_parents = [x.op for x in a.lazydata.toposort() if x.op is Ops.BUFFER]
print(a.uop)
buffers_in_parents = [x.op for x in a.uop.toposort() if x.op is Ops.BUFFER]
self.assertEqual(len(buffers_in_parents), 0)
# currently, COPY has an extra BUFFER on the output
@@ -112,7 +112,7 @@ class TestTensorUopRepresentation(unittest.TestCase):
def test_copyin(self):
a = Tensor([1.,2,3]).realize()
c = a.to("TEST") # NOTE: this isn't checked
print(c.lazydata)
print(c.uop)
is_pattern(c, UPat(Ops.COPY, src=(realized_pattern, UPat(Ops.DEVICE))))
def test_empty_buf(self):
@@ -121,7 +121,7 @@ class TestTensorUopRepresentation(unittest.TestCase):
vi = UOp.variable("i", 1, 3).bind(1)
a = Tensor.empty(3, vi)
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)))
self.assertEqual(a.lazydata.base.buffer.size, 9)
self.assertEqual(a.uop.base.buffer.size, 9)
if __name__ == '__main__':
unittest.main()

View File

@@ -176,7 +176,7 @@ class CapturedJit(Generic[ReturnType]):
self.__post_init__() # reset the graph state
def replan_buffers_memory_layout(self):
blacklist = [t.lazydata.buffer for t in get_parameters(self.ret)]
blacklist = [t.uop.buffer for t in get_parameters(self.ret)]
asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True)
self.jit_cache = [ExecItem(item.prg, [asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
for old, new in asgn.items():
@@ -210,9 +210,9 @@ class CapturedJit(Generic[ReturnType]):
def _prepare_jit_inputs(args, kwargs):
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
if len(unrealized_tensors := [x for x in tensors if not x.lazydata.is_realized]): Tensor.realize(*unrealized_tensors)
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
# TODO: this multi unpack stuff is not well tested.
lbs: list[UOp] = flatten([t.lazydata.src if t.lazydata.op is Ops.MULTI else [t.lazydata] for t in tensors])
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb]
for lb in lbs if lb.base.realized is not None])
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"

View File

@@ -155,7 +155,7 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr
raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.')
if isinstance(v.device, tuple):
if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k])
else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis))
else: v.replace(state_dict[k].shard(v.device, v.uop.axis))
else: v.replace(state_dict[k].to(v.device))
if realize: v.realize()
if consume: del state_dict[k]

View File

@@ -20,7 +20,7 @@ from tinygrad.engine.grouper import get_kernelize_map
all_tensors: set[weakref.ref[Tensor]] = set()
def _find_all_tensors_for_uops(all_uops: set[UOp]) -> list[Tensor]:
return [t for tref in all_tensors if (t:=tref()) is not None and t.lazydata in all_uops]
return [t for tref in all_tensors if (t:=tref()) is not None and t.uop in all_uops]
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str|None=None) -> None:
# get all children of keys in applied_map
@@ -36,13 +36,13 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str|None=None) -> Non
# NOTE: this uses all_tensors, but it's fast
if len(fixed_tensors := _find_all_tensors_for_uops(all_uops)):
# potentially rewrite all the discovered Tensors
sink = UOp.sink(*[t.lazydata for t in fixed_tensors])
sink = UOp.sink(*[t.uop for t in fixed_tensors])
new_sink = sink.substitute(applied_map, name=name)
# set the relevant lazydata to the realized UOps
# set the relevant uop to the realized UOps
for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
if s is ns: continue
t.lazydata = ns
t.uop = ns
# **** Tensor helper functions ****
@@ -119,7 +119,7 @@ class Tensor(MathTrait):
np.set_printoptions(precision=4)
```
"""
__slots__ = "lazydata", "requires_grad", "grad"
__slots__ = "uop", "requires_grad", "grad"
training: ClassVar[bool] = False
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'np.ndarray'|pathlib.Path|None, # type: ignore [name-defined] # noqa: F821
@@ -150,7 +150,7 @@ class Tensor(MathTrait):
if dtype is None:
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True
if dtype in [dtypes.bfloat16, *dtypes.fp8s]: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtype).lazydata
if dtype in [dtypes.bfloat16, *dtypes.fp8s]: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtype).uop
else: data = _frompy(data, dtype)
elif str(type(data)) == "<class 'numpy.ndarray'>":
import numpy as np
@@ -165,19 +165,19 @@ class Tensor(MathTrait):
if not isinstance(data, UOp): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
# data might be on a different device
if isinstance(device, str): self.lazydata:UOp = data if data.device == device else data.copy_to_device(device)
if isinstance(device, str): self.uop:UOp = data if data.device == device else data.copy_to_device(device)
# if device is a tuple, we should have/construct a MultiLazyBuffer
elif isinstance(data.device, str): self.lazydata = Tensor(data).shard(device).lazydata
elif isinstance(data.device, str): self.uop = Tensor(data).shard(device).uop
else:
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
self.lazydata = data
self.uop = data
# add to all_tensors after construction succeeds
all_tensors.add(weakref.ref(self))
def __del__(self): all_tensors.discard(weakref.ref(self))
def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
new_uop: UOp = fxn(*[t.lazydata for t in (self,)+x], **kwargs)
new_uop: UOp = fxn(*[t.uop for t in (self,)+x], **kwargs)
if (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,)
needs_input_grad = [t.requires_grad for t in (self,)+x]
return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False)
@@ -196,9 +196,9 @@ class Tensor(MathTrait):
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
def __repr__(self):
ld = self.lazydata
ld = self.uop
ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]} {ld.st if ld.base is not ld else (ld.op, ld.realized)}>"
return f"<Tensor {ld_repr} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
return f"<Tensor {ld_repr} on {self.device} with grad {(self.grad.uop if self.grad is not None else None)!r}>"
# Python has a non moving GC, so this should be okay
def __hash__(self): return id(self)
@@ -210,13 +210,13 @@ class Tensor(MathTrait):
return self.shape[0]
@property
def device(self) -> str|tuple[str, ...]: return self.lazydata.device
def device(self) -> str|tuple[str, ...]: return self.uop.device
@property
def shape(self) -> tuple[sint, ...]: return self.lazydata.shape
def shape(self) -> tuple[sint, ...]: return self.uop.shape
@property
def dtype(self) -> DType: return self.lazydata.dtype
def dtype(self) -> DType: return self.uop.dtype
# ***** data handlers ****
@@ -226,7 +226,7 @@ class Tensor(MathTrait):
NOTE: Kernelize can be called multiple times on a Tensor
"""
big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst])
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
# verify Tensors match the spec
if __debug__: type_verify(list(big_sink.toposort()), tensor_uop_spec)
@@ -243,7 +243,7 @@ class Tensor(MathTrait):
"""
st = time.perf_counter()
self.kernelize(*lst)
sink = UOp.sink(*[x.lazydata for x in (self,)+lst])
sink = UOp.sink(*[x.uop for x in (self,)+lst])
# remove all ASSIGNs, after scheduling, the tensors are just buffers
remove_assign_map = {u:u.buf_uop for u in sink.toposort() if u.op is Ops.ASSIGN}
@@ -272,31 +272,31 @@ class Tensor(MathTrait):
"""
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
assert self.shape == x.shape or allow_shape_mismatch, f"replace shape mismatch {self.shape} != {x.shape}"
self.lazydata = x.lazydata
self.uop = x.uop
return self
def assign(self, x) -> Tensor:
# TODO: this is a hack for writing to DISK. remove with working assign
if isinstance(self.device, str) and self.device.startswith("DISK"):
if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
cast(Buffer, self.contiguous().realize().lazydata.base.buffer).ensure_allocated().copyin(x._data())
cast(Buffer, self.contiguous().realize().uop.base.buffer).ensure_allocated().copyin(x._data())
return self
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
if self.lazydata is x.lazydata: return self # a self assign is a NOOP
if self.uop is x.uop: return self # a self assign is a NOOP
# NOTE: we allow cross device assign
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
self.lazydata = self.lazydata.assign(x.lazydata)
self.uop = self.uop.assign(x.uop)
return self
def detach(self) -> Tensor:
"""
Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
"""
return Tensor(self.lazydata.detach(), device=self.device, requires_grad=False)
return Tensor(self.uop.detach(), device=self.device, requires_grad=False)
def _buffer(self) -> Buffer: return cast(Buffer, self.cast(self.dtype.base).contiguous().to("CPU").realize().lazydata.base.buffer)
def _buffer(self) -> Buffer: return cast(Buffer, self.cast(self.dtype.base).contiguous().to("CPU").realize().uop.base.buffer)
def _data(self) -> memoryview: return self._buffer().as_buffer()
def data(self) -> memoryview:
@@ -372,7 +372,7 @@ class Tensor(MathTrait):
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
if device == self.device: return self
if not isinstance(device, str): return self.shard(device)
ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
ret = Tensor(self.uop, device, requires_grad=self.requires_grad)
if self.grad is not None: ret.grad = self.grad.to(device)
return ret
@@ -390,12 +390,12 @@ class Tensor(MathTrait):
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.empty(2, 4)
print(t.shard((t.device, t.device), axis=1).lazydata)
print(t.shard((t.device, t.device), axis=1).uop)
```
"""
assert isinstance(self.device, str), "can't shard a MultiLazyBuffer"
devices = tuple(Device.canonicalize(x) for x in devices)
mlb = self.lazydata.shard(devices, self._resolve_dim(axis)) if axis is not None else self.lazydata.copy_to_device(devices)
mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices)
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
def shard_(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor:
@@ -444,7 +444,7 @@ class Tensor(MathTrait):
"""
r = Tensor.empty(*shape, **kwargs)
assert isinstance(r.device, str)
cast(Buffer, r.lazydata.buffer).allocate(external_ptr=ptr)
cast(Buffer, r.uop.buffer).allocate(external_ptr=ptr)
return r
@staticmethod
@@ -719,7 +719,7 @@ class Tensor(MathTrait):
dtype = kwargs.pop("dtype", self.dtype)
if isinstance(self.device, tuple):
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device, self.lazydata.axis)
return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device, self.uop.axis)
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)
# ***** rng hlops *****
@@ -923,13 +923,13 @@ class Tensor(MathTrait):
assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
if not (self.is_floating_point() and all(t.is_floating_point() for t in targets)): raise RuntimeError("only float Tensors have gradient")
if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
target_uops = [x.lazydata for x in targets]
grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops))
target_uops = [x.uop for x in targets]
grads = compute_gradient(self.uop, gradient.uop, set(target_uops))
ret = []
for x in target_uops:
if (y:=grads.get(x)) is None:
if materialize_grads: y = x.const_like(0)
else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}")
else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.uop}")
ret.append(y)
# create returned Tensors
return [Tensor(u, device=t.device) for t,u in zip(targets, ret)]
@@ -944,9 +944,9 @@ class Tensor(MathTrait):
print(t.grad.numpy())
```
"""
all_uops = self.lazydata.toposort()
all_uops = self.uop.toposort()
tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \
t.lazydata in all_uops and t.requires_grad]
t.uop in all_uops and t.requires_grad]
# clear contexts
for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)):
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
@@ -1253,14 +1253,14 @@ class Tensor(MathTrait):
self.realize()._getitem(indices).assign(v)
return
# NOTE: check that setitem target is valid first
if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")
if not unwrap(self.uop.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")
if isinstance(v, get_args(ConstType)): v = Tensor(v, device=self.device, dtype=self.dtype)
if not isinstance(v, Tensor): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
res = self.realize()._getitem(indices, v)
# if shapes match and data is not shared it's a copy and we assign to self
if res.shape == self.shape and res.lazydata is not self.lazydata:
if res.shape == self.shape and res.uop is not self.uop:
self.assign(res).realize()
else: # no copy, basic setitem
v = v.cast(res.dtype)._broadcast_to(_broadcast_shape(res.shape, v.shape)).contiguous()