mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
rename lazydata to uop (#10698)
This commit is contained in:
@@ -36,7 +36,7 @@ optim.schedule_step() # this will step the optimizer without running realize
|
||||
# 3. Create a schedule.
|
||||
|
||||
# The weight Tensors have been assigned to, but not yet realized. Everything is still lazy at this point
|
||||
# l1.lazydata and l2.lazydata define a computation graph
|
||||
# l1.uop and l2.uop define a computation graph
|
||||
|
||||
from tinygrad.engine.schedule import ScheduleItem
|
||||
schedule: List[ScheduleItem] = Tensor.schedule(l1, l2)
|
||||
|
||||
@@ -34,7 +34,7 @@ print(out) # <Tensor <UOp METAL (1,) int (<Ops.ASSIGN: 66>, None)> on METAL with
|
||||
The multiply Tensor stays the same because it is fused. The output Tensor's UOp becomes a new ASSIGN UOp:
|
||||
|
||||
```py
|
||||
print(out.lazydata)
|
||||
print(out.uop)
|
||||
```
|
||||
|
||||
The first source is the output BUFFER:
|
||||
@@ -72,7 +72,7 @@ Once a Tensor is kernelized, all children will LOAD its BUFFER, instead of fusin
|
||||
```py
|
||||
child = out+2
|
||||
child.kernelize()
|
||||
print(child.lazydata.src[1].arg.ast)
|
||||
print(child.uop.src[1].arg.ast)
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
30
docs/ramp.py
30
docs/ramp.py
@@ -39,8 +39,8 @@ assert t.shape == (4,)
|
||||
print(t)
|
||||
# <Tensor <UOp METAL (4,) int (<Ops.COPY: 7>, None)> on METAL with grad None>
|
||||
|
||||
# the ".lazydata" property on Tensor contains the specification of how to compute it
|
||||
print(t.lazydata)
|
||||
# the ".uop" property on Tensor contains the specification of how to compute it
|
||||
print(t.uop)
|
||||
"""
|
||||
UOp(Ops.COPY, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.int, arg=4, src=(
|
||||
@@ -57,21 +57,21 @@ UOp(Ops.COPY, dtypes.int, arg=None, src=(
|
||||
|
||||
t.realize()
|
||||
# if we want to "realize" a tensor, we can with the "realize" method
|
||||
# now when we look at the lazydata, it's changed
|
||||
print(t.lazydata)
|
||||
# now when we look at the uop, it's changed
|
||||
print(t.uop)
|
||||
"""
|
||||
UOp(Ops.BUFFER, dtypes.int, arg=4, src=(
|
||||
UOp(Ops.UNIQUE, dtypes.void, arg=1, src=()),
|
||||
UOp(Ops.DEVICE, dtypes.void, arg='METAL', src=()),))
|
||||
"""
|
||||
# the copy was actually run, and now the "lazydata" of the Tensor is just a BUFFER
|
||||
# the copy was actually run, and now the "uop" of the Tensor is just a BUFFER
|
||||
# if you run this script with DEBUG=2 in the environment, you can see the copy happen
|
||||
# *** METAL 1 copy 16, METAL <- PYTHON ...
|
||||
|
||||
# now let's do some compute
|
||||
# we look at the lazydata to see the specification of the compute
|
||||
# we look at the uop to see the specification of the compute
|
||||
t_times_2 = t * 2
|
||||
print(t_times_2.lazydata)
|
||||
print(t_times_2.uop)
|
||||
"""
|
||||
UOp(Ops.MUL, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.int, arg=4, src=(
|
||||
@@ -90,24 +90,24 @@ UOp(Ops.MUL, dtypes.int, arg=None, src=(
|
||||
assert t_times_2.tolist() == [2, 4, 6, 8]
|
||||
|
||||
# UOps are both immutable and globally unique
|
||||
# if i multiply the Tensor by 4 twice, these result Tensors will have the same lazydata specification
|
||||
# if i multiply the Tensor by 4 twice, these result Tensors will have the same uop specification
|
||||
t_times_4_try_1 = t * 4
|
||||
t_times_4_try_2 = t * 4
|
||||
assert t_times_4_try_1.lazydata is t_times_4_try_2.lazydata
|
||||
assert t_times_4_try_1.uop is t_times_4_try_2.uop
|
||||
# the specification isn't just the same, it's the exact same Python object
|
||||
assert t_times_4_try_1 is not t_times_4_try_2
|
||||
# the Tensor is a different Python object
|
||||
|
||||
# if we realize `t_times_4_try_1` ...
|
||||
t_times_4_try_1.realize()
|
||||
print(t_times_4_try_2.lazydata)
|
||||
print(t_times_4_try_2.uop)
|
||||
"""
|
||||
UOp(Ops.BUFFER, dtypes.int, arg=4, src=(
|
||||
UOp(Ops.UNIQUE, dtypes.void, arg=4, src=()),
|
||||
UOp(Ops.DEVICE, dtypes.void, arg='METAL', src=()),))
|
||||
"""
|
||||
# ... `t_times_4_try_2` also becomes the same BUFFER
|
||||
assert t_times_4_try_1.lazydata is t_times_4_try_2.lazydata
|
||||
assert t_times_4_try_1.uop is t_times_4_try_2.uop
|
||||
# so this print doesn't require any computation, just a copy back to the CPU so we can print it
|
||||
print("** only the copy start")
|
||||
print(t_times_4_try_2.tolist()) # [4, 8, 12, 16]
|
||||
@@ -120,7 +120,7 @@ t_float = Tensor([3.0])
|
||||
t_log = t_float.log()
|
||||
t_log_grad, = t_log.sum().gradient(t_float)
|
||||
# due to how log is implemented, this gradient contains a lot of UOps
|
||||
print(t_log_grad.lazydata)
|
||||
print(t_log_grad.uop)
|
||||
# ...not shown here...
|
||||
# but if you run with DEBUG=4 (CPU=1 used here for simpler code), you can see the generated code
|
||||
"""
|
||||
@@ -144,7 +144,7 @@ t = Tensor([1,2,3,4])
|
||||
# NOTE: the APIs here are subject to change
|
||||
|
||||
t_plus_3_plus_4 = t + 3 + 4
|
||||
print(t_plus_3_plus_4.lazydata)
|
||||
print(t_plus_3_plus_4.uop)
|
||||
"""
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
@@ -166,7 +166,7 @@ UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
# but by the time we are actually running the code, it's adding 7
|
||||
# `kernelize` will simplify and group the operations in the graph into kernels
|
||||
t_plus_3_plus_4.kernelize()
|
||||
print(t_plus_3_plus_4.lazydata)
|
||||
print(t_plus_3_plus_4.uop)
|
||||
"""
|
||||
UOp(Ops.ASSIGN, dtypes.int, arg=None, src=(
|
||||
x0:=UOp(Ops.BUFFER, dtypes.int, arg=4, src=(
|
||||
@@ -181,7 +181,7 @@ UOp(Ops.ASSIGN, dtypes.int, arg=None, src=(
|
||||
# ASSIGN has two srcs, src[0] is the BUFFER that's assigned to, and src[1] is the thing to assign
|
||||
# src[1] is the GPU Kernel that's going to be run
|
||||
# we can get the ast of the Kernel as follows
|
||||
kernel_ast = t_plus_3_plus_4.lazydata.src[1].arg.ast
|
||||
kernel_ast = t_plus_3_plus_4.uop.src[1].arg.ast
|
||||
|
||||
# almost everything in tinygrad functions as a rewrite of the UOps
|
||||
# the codegen rewrites the ast to a simplified form ready for "rendering"
|
||||
|
||||
@@ -240,7 +240,6 @@ class LLaMa:
|
||||
#elif k.endswith('.weight'): v.shard_(device, axis=-1)
|
||||
#elif 'norm.' in k: v.shard_(device, axis=-1)
|
||||
else: v.shard_(device, axis=None)
|
||||
#print(k, v.shape, v.lazydata.axis)
|
||||
|
||||
# replace weights in model
|
||||
load_state_dict(model, weights, strict=False, consume=True)
|
||||
@@ -446,7 +445,7 @@ After you are done speaking, output [EOS]. You are not Chad.
|
||||
print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model")
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(args.shard)) if args.shard > 1 else Device.DEFAULT
|
||||
llama = LLaMa.build(MODEL_PATH, TOKENIZER_PATH, model_gen=args.gen, model_size=args.size, quantize=args.quantize, device=device)
|
||||
param_bytes = sum(x.lazydata.size * x.dtype.itemsize for x in get_parameters(llama.model))
|
||||
param_bytes = sum(x.uop.size * x.dtype.itemsize for x in get_parameters(llama.model))
|
||||
|
||||
outputted = pre_prompt if chatbot else args.prompt
|
||||
start_pos, toks = 0, [llama.tokenizer.bos_id()] + llama.tokenizer.encode(outputted)
|
||||
|
||||
@@ -284,7 +284,7 @@ if __name__ == "__main__":
|
||||
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(args.shard)) if args.shard > 1 else Device.DEFAULT
|
||||
model = build_transformer(args.model, model_size=args.size, quantize=args.quantize, device=device)
|
||||
param_bytes = sum(x.lazydata.size * x.dtype.itemsize for x in get_parameters(model))
|
||||
param_bytes = sum(x.uop.size * x.dtype.itemsize for x in get_parameters(model))
|
||||
|
||||
if not args.no_api and not args.benchmark:
|
||||
from bottle import Bottle, request, response, HTTPResponse, abort, static_file
|
||||
|
||||
@@ -16,7 +16,7 @@ if __name__ == "__main__":
|
||||
#model.load_pretrained()
|
||||
for p in nn.state.get_parameters(model): p.replace(Tensor.empty(p.shape, dtype=p.dtype)) # fake load pretrained
|
||||
|
||||
#early_sched = create_schedule([x.lazydata for x in nn.state.get_parameters(model)])
|
||||
#early_sched = create_schedule([x.uop for x in nn.state.get_parameters(model)])
|
||||
#print(f"built model {len(early_sched)}")
|
||||
|
||||
#B, T = Variable("B", 1, 128).bind(4), 64 #Variable("T", 1, 1024).bind(64)
|
||||
@@ -56,7 +56,7 @@ if __name__ == "__main__":
|
||||
state_dict.update({'X': X, 'Y': Y, 'loss': loss})
|
||||
grad_state_dict = {}
|
||||
for k,v in state_dict.items():
|
||||
if v.lazydata.base.buffer not in used_buffers: print(f"UNUSED: {k}")
|
||||
if v.uop.base.buffer not in used_buffers: print(f"UNUSED: {k}")
|
||||
if v.grad is not None: grad_state_dict['grad_'+k] = v.grad
|
||||
state_dict.update(grad_state_dict)
|
||||
state_dict.update({'adam_b1_t': optimizer.b1_t, 'adam_b2_t': optimizer.b2_t, 'adam_lr': optimizer.lr})
|
||||
@@ -65,7 +65,7 @@ if __name__ == "__main__":
|
||||
nm = inverse_state_dict[p]
|
||||
state_dict["adam_m_"+nm] = m
|
||||
state_dict["adam_v_"+nm] = v
|
||||
named_buffers = {v.lazydata.base.buffer:k.replace(".", "_") for k,v in state_dict.items()}
|
||||
named_buffers = {v.uop.base.buffer:k.replace(".", "_") for k,v in state_dict.items()}
|
||||
|
||||
c_code = ["#include <stdlib.h>", "#include <tgmath.h>", "#include <stdbool.h>"]
|
||||
if TIMING: c_code += ["#include <stdio.h>", "#include <time.h>"]
|
||||
|
||||
@@ -71,7 +71,7 @@ def loader_process(q_in, q_out, X:Tensor, seed):
|
||||
#storage_tensor._copyin(img_tensor.numpy())
|
||||
|
||||
# faster
|
||||
X[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
|
||||
X[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
|
||||
|
||||
# ideal
|
||||
#X[idx].assign(img.tobytes()) # NOTE: this is slow!
|
||||
@@ -262,8 +262,8 @@ def load_unet3d_data(preprocessed_dataset_dir, seed, queue_in, queue_out, X:Tens
|
||||
x = random_brightness_augmentation(x)
|
||||
x = gaussian_noise(x)
|
||||
|
||||
X[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = x.tobytes()
|
||||
Y[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = y.tobytes()
|
||||
X[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = x.tobytes()
|
||||
Y[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = y.tobytes()
|
||||
|
||||
queue_out.put(idx)
|
||||
queue_out.put(None)
|
||||
@@ -377,12 +377,12 @@ def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue
|
||||
clipped_match_idxs = np.clip(match_idxs, 0, None)
|
||||
clipped_boxes, clipped_labels = tgt["boxes"][clipped_match_idxs], tgt["labels"][clipped_match_idxs]
|
||||
|
||||
boxes[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = clipped_boxes.tobytes()
|
||||
labels[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = clipped_labels.tobytes()
|
||||
matches[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = match_idxs.tobytes()
|
||||
anchors[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = anchor.tobytes()
|
||||
boxes[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = clipped_boxes.tobytes()
|
||||
labels[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = clipped_labels.tobytes()
|
||||
matches[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = match_idxs.tobytes()
|
||||
anchors[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = anchor.tobytes()
|
||||
|
||||
imgs[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
|
||||
imgs[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
|
||||
|
||||
queue_out.put(idx)
|
||||
queue_out.put(None)
|
||||
|
||||
@@ -19,8 +19,8 @@ if __name__ == "__main__":
|
||||
|
||||
inputs = run_onnx.get_empty_input_data("npy")
|
||||
out: Tensor = next(iter(run_onnx({k:v.to(None) for k,v in inputs.items()}).values())).to('cpu')
|
||||
root = out.lazydata
|
||||
targets = [x.lazydata for x in inputs.values()]
|
||||
root = out.uop
|
||||
targets = [x.uop for x in inputs.values()]
|
||||
print(targets)
|
||||
|
||||
# TODO: abstract this from gradient?
|
||||
@@ -42,7 +42,7 @@ if __name__ == "__main__":
|
||||
|
||||
print("**** real ****")
|
||||
GlobalCounters.reset()
|
||||
out.lazydata = root.substitute(kernelized).substitute(becomes_map)
|
||||
out.uop = root.substitute(kernelized).substitute(becomes_map)
|
||||
out.kernelize()
|
||||
|
||||
# realize
|
||||
|
||||
@@ -66,7 +66,7 @@ if __name__ == "__main__":
|
||||
model_path = Path(args.weights) if args.weights else download_weights(model_info["total_num_weights"])
|
||||
transformer = load_model(model_path, model_info["model_params"])
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_info["tokenizer"])
|
||||
param_bytes = sum(x.lazydata.size * x.dtype.itemsize for x in get_parameters(transformer))
|
||||
param_bytes = sum(x.uop.size * x.dtype.itemsize for x in get_parameters(transformer))
|
||||
|
||||
outputted = args.prompt
|
||||
start_pos, toks = 0, tokenizer(outputted)["input_ids"]
|
||||
|
||||
@@ -13,8 +13,8 @@ def prepare_browser_chunks(model):
|
||||
chunk_size = 16 * 1024 * 1024 # small chunks based on iphone browser constraints
|
||||
metadata = {}
|
||||
# We won't export cache_kv bytes (because we start inference on client at start_pos=0), but we will tell the client how big cache_kv needs to be
|
||||
t_infos = [(v.lazydata.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" not in k]
|
||||
empty_t_infos = [(v.lazydata.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" in k]
|
||||
t_infos = [(v.uop.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" not in k]
|
||||
empty_t_infos = [(v.uop.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" in k]
|
||||
|
||||
split_t_infos = []
|
||||
for size, name, dtype in t_infos:
|
||||
@@ -48,7 +48,7 @@ def prepare_browser_chunks(model):
|
||||
weight_metadata = metadata.get(name, default)
|
||||
weight_metadata["parts"][part_num] = {"file": i, "file_start_pos": cursor, "size": size}
|
||||
metadata[name] = weight_metadata
|
||||
data = bytes(state_dict[name].lazydata.base.realized.as_buffer())
|
||||
data = bytes(state_dict[name].uop.base.realized.as_buffer())
|
||||
data = data if not offsets else data[offsets[0]:offsets[1]]
|
||||
writer.write(data)
|
||||
cursor += size
|
||||
|
||||
@@ -114,7 +114,7 @@ if __name__ == "__main__":
|
||||
run, special_names = jit_model(step, *step.input)
|
||||
functions, statements, bufs, _ = compile_net(run, special_names)
|
||||
state = get_state_dict(model)
|
||||
weights = {id(x.lazydata.base.realized): name for name, x in state.items()}
|
||||
weights = {id(x.uop.base.realized): name for name, x in state.items()}
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{fixup_code(code, key)}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _, _, _) in statements])
|
||||
input_names = [name for _,name in special_names.items() if "input" in name]
|
||||
|
||||
@@ -48,13 +48,13 @@ def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
|
||||
|
||||
# hack to put the inputs back
|
||||
for (j,i),idx in run.input_replace.items():
|
||||
realized_input = args[idx].lazydata.base.realized
|
||||
realized_input = args[idx].uop.base.realized
|
||||
run.jit_cache[j].bufs[i] = realized_input
|
||||
special_names[id(realized_input)] = f'input{idx}'
|
||||
|
||||
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
|
||||
for i, output in enumerate(the_output):
|
||||
special_names[id(output.lazydata.base.realized)] = f'output{i}'
|
||||
special_names[id(output.uop.base.realized)] = f'output{i}'
|
||||
return run, special_names
|
||||
|
||||
def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]],
|
||||
@@ -242,7 +242,7 @@ def export_model(model, target:str, *inputs, model_name: Optional[str] = "model"
|
||||
with Context(JIT=2): run,special_names = jit_model(model, *inputs)
|
||||
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
state = get_state_dict(model)
|
||||
weight_names = {id(x.lazydata.base.realized): name for name, x in state.items()}
|
||||
weight_names = {id(x.uop.base.realized): name for name, x in state.items()}
|
||||
input_names = [name for _,name in special_names.items() if "input" in name]
|
||||
output_names = [name for _,name in special_names.items() if "output" in name]
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ if __name__ == "__main__":
|
||||
|
||||
a = Tensor([0.,1.,2.], device="KFD").realize()
|
||||
b = a + 7
|
||||
b.lazydata.buffer.allocate()
|
||||
b.uop.buffer.allocate()
|
||||
si = b.schedule()[-1]
|
||||
runner = dev.get_runner(*si.ast)
|
||||
prg: AMDProgram = runner.clprg
|
||||
@@ -69,8 +69,8 @@ if __name__ == "__main__":
|
||||
|
||||
#scratch = dev._gpu_alloc(0x10000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM)
|
||||
ka = to_mv(dev.kernargs_ptr, 0x10).cast("Q")
|
||||
ka[0] = b.lazydata.buffer._buf.va_addr
|
||||
ka[1] = a.lazydata.buffer._buf.va_addr
|
||||
ka[0] = b.uop.buffer._buf.va_addr
|
||||
ka[1] = a.uop.buffer._buf.va_addr
|
||||
|
||||
compute_read_pointer = to_mv(compute_queue.read_pointer_address, 8).cast("Q")
|
||||
compute_write_pointer = to_mv(compute_queue.write_pointer_address, 8).cast("Q")
|
||||
|
||||
@@ -23,7 +23,7 @@ def explicit_shard_W_axis_1(X, W):
|
||||
x = x.reshape(N, 1, N).expand(N, N, N)
|
||||
w = w.T.reshape(1, N, N).expand(N, N, N)
|
||||
m = x*w
|
||||
assert m.lazydata.st.views[0].mask is not None
|
||||
assert m.uop.st.views[0].mask is not None
|
||||
ret = m.sum(2)
|
||||
return ret
|
||||
#Os = [lm(Xs[0], Ws[0]), lm(Xs[1], Ws[1])]
|
||||
|
||||
@@ -89,7 +89,7 @@ def get_output(s:str, n_threads:int=1):
|
||||
"s_waitcnt 0",
|
||||
"global_store_b32 v0, v1, s[0:1]",
|
||||
"s_nop 0", "s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)", "s_endpgm"])
|
||||
test = Tensor.zeros((n_threads,), dtype=dtypes.uint32).contiguous().realize().lazydata.buffer
|
||||
test = Tensor.zeros((n_threads,), dtype=dtypes.uint32).contiguous().realize().uop.buffer
|
||||
prg = get_prg(code, 32, 32)
|
||||
prg(test._buf, global_size=(1, 1, 1), local_size=(n_threads, 1, 1), wait=True)
|
||||
return test.numpy()
|
||||
|
||||
@@ -65,12 +65,12 @@ for k,v in view_ops.items(): torch.library.impl(k.replace("aten.", "aten::"), "p
|
||||
|
||||
# in place operations with views
|
||||
def realize_with_views(self: Tensor, views: Tensor):
|
||||
if not self.lazydata.st.contiguous: self.replace(self.contiguous())
|
||||
if not self.uop.st.contiguous: self.replace(self.contiguous())
|
||||
self.replace(self.clone().realize())
|
||||
for v in views:
|
||||
if v.lazydata.base.op is Ops.BUFFER_VIEW: continue # skip subbuffer, we just use the real buffer view
|
||||
if v.uop.base.op is Ops.BUFFER_VIEW: continue # skip subbuffer, we just use the real buffer view
|
||||
ret = self
|
||||
st = ShapeTracker(self.lazydata.st.views + v.lazydata.st.views) # TODO: is this right?
|
||||
st = ShapeTracker(self.uop.st.views + v.uop.st.views) # TODO: is this right?
|
||||
for mo in cached_to_movement_ops(self.shape, st): ret = apply_mop(ret, mo)
|
||||
v.replace(ret)
|
||||
def maybe_realize_storage(self: Tensor) -> bool:
|
||||
@@ -178,7 +178,7 @@ def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
|
||||
# multiple as_strided do not compound
|
||||
base = canonical_base(tensor)
|
||||
# TODO: this is heavyweight
|
||||
st = ShapeTracker(base.lazydata.st.views + (View.create(tuple(size), tuple(stride), storage_offset),))
|
||||
st = ShapeTracker(base.uop.st.views + (View.create(tuple(size), tuple(stride), storage_offset),))
|
||||
ret = base
|
||||
if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st)
|
||||
if prod(size) == 1: return ret.flatten()[storage_offset].reshape(size)
|
||||
@@ -315,7 +315,7 @@ def _copy_from(src: torch.Tensor, dest, non_blocking=False):
|
||||
to_device = _from_torch_device(dest.device)
|
||||
src,dest = unwrap(src),unwrap(dest)
|
||||
# TODO we need to properly match dest shape and strides, not blindly assign
|
||||
if dest.lazydata.st.contiguous or dest.lazydata.is_realized: src = src.contiguous() # this only solves some cases
|
||||
if dest.uop.st.contiguous or dest.uop.is_realized: src = src.contiguous() # this only solves some cases
|
||||
dest.assign(src.cast(cast_dtype).to(to_device))
|
||||
if realize: Tensor.realize(dest)
|
||||
elif src.is_tiny and dest.is_cpu:
|
||||
@@ -493,7 +493,7 @@ def wrap_out(f):
|
||||
assert out.shape == assigned.shape, f"shape mismatch: {assigned.shape} -> {out.shape}"
|
||||
assert out.device == assigned.device, f"device mismatch: {assigned.device} -> {out.device}"
|
||||
assert out.dtype == assigned.dtype, f"dtype mismatch: {assigned.dtype} -> {out.dtype}"
|
||||
if out.lazydata.is_realized: assigned = assigned.contiguous() # TODO: how does this map to torch's semantics
|
||||
if out.uop.is_realized: assigned = assigned.contiguous() # TODO: how does this map to torch's semantics
|
||||
return out.assign(assigned)
|
||||
return _wrap_out
|
||||
|
||||
|
||||
@@ -116,7 +116,7 @@ at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype, c10::DeviceInd
|
||||
// TODO: we have to get the dtype and the shape from the tinygrad Tensor
|
||||
std::vector<int64_t> sizes = py_obj.attr("shape").cast<std::vector<int64_t>>();
|
||||
|
||||
py::list views = py_obj.attr("lazydata").attr("st").attr("views");
|
||||
py::list views = py_obj.attr("uop").attr("st").attr("views");
|
||||
std::vector<int64_t> strides = views[views.size() - 1].attr("strides").cast<std::vector<int64_t>>();
|
||||
int64_t storage_offset = 0;
|
||||
for (auto& v: views) {
|
||||
|
||||
@@ -113,7 +113,7 @@ class DispatchLog(TorchDispatchMode):
|
||||
_ = tiny_x.cpu().numpy()
|
||||
if torch.is_tensor(tiny_x) and tiny_x.device.type == "tiny":
|
||||
tt = tiny_torch.unwrap(tiny_x)
|
||||
try: out_addr = tt.lazydata.buffer._buf.value
|
||||
try: out_addr = tt.uop.buffer._buf.value
|
||||
except Exception: pass
|
||||
tiny_events = hook_cuda.collect_events(clear=True)
|
||||
print_events(tiny_events, colored("tiny", "magenta"), out_addr)
|
||||
|
||||
@@ -13,6 +13,6 @@ if __name__ == "__main__":
|
||||
(Tensor.empty(BS, 16, 512, 512), Tensor.empty(BS, 512, 16, 64).permute(0,2,1,3)), # qk@v
|
||||
]
|
||||
for t0, t1 in tensors:
|
||||
print(f"{t0.shape=}, {t0.lazydata.st.real_strides()=}, {t1.shape=}, {t1.lazydata.st.real_strides()=}")
|
||||
print(f"{t0.shape=}, {t0.uop.st.real_strides()=}, {t1.shape=}, {t1.uop.st.real_strides()=}")
|
||||
for _ in range(5):
|
||||
t0.dot(t1, dtype=acc_dtype).realize()
|
||||
|
||||
4
test/external/external_multi_gpu.py
vendored
4
test/external/external_multi_gpu.py
vendored
@@ -21,8 +21,8 @@ if __name__ == "__main__":
|
||||
with Timing("CPU creation: ", on_exit=lambda x: f", {(sz*4*2)/x:.2f} GB/sec"):
|
||||
c0 = (Tensor.ones(sz, device="CPU")/2).realize()
|
||||
c1 = (Tensor.ones(sz, device="CPU")/4).realize()
|
||||
print(c0.lazydata.base.realized)
|
||||
print(c1.lazydata.base.realized)
|
||||
print(c0.uop.base.realized)
|
||||
print(c1.uop.base.realized)
|
||||
|
||||
with Timing("CPU -> 0: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"):
|
||||
a0 = c0.to(d0).realize()
|
||||
|
||||
10
test/external/external_test_amd.py
vendored
10
test/external/external_test_amd.py
vendored
@@ -9,18 +9,18 @@ class TestAMD(unittest.TestCase):
|
||||
TestAMD.d0: AMDDevice = Device["AMD"]
|
||||
TestAMD.a = Tensor([0.,1.], device="AMD").realize()
|
||||
TestAMD.b = self.a + 1
|
||||
si = create_schedule([self.b.lazydata])[-1]
|
||||
si = create_schedule([self.b.uop])[-1]
|
||||
TestAMD.d0_runner = TestAMD.d0.get_runner(*si.ast)
|
||||
TestAMD.b.lazydata.buffer.allocate()
|
||||
TestAMD.b.uop.buffer.allocate()
|
||||
|
||||
def test_amd_ring_64bit_doorbell(self):
|
||||
TestAMD.d0.pm4_write_pointer[0] = TestAMD.d0.pm4_write_pointer[0] + (2 << 32) - TestAMD.d0.pm4_ring.size // 4
|
||||
for _ in range(2000):
|
||||
TestAMD.d0_runner.clprg(TestAMD.b.lazydata.buffer._buf, TestAMD.a.lazydata.buffer._buf,
|
||||
TestAMD.d0_runner.clprg(TestAMD.b.uop.buffer._buf, TestAMD.a.uop.buffer._buf,
|
||||
global_size=TestAMD.d0_runner.global_size, local_size=TestAMD.d0_runner.local_size)
|
||||
TestAMD.d0_runner.clprg(TestAMD.a.lazydata.buffer._buf, TestAMD.b.lazydata.buffer._buf,
|
||||
TestAMD.d0_runner.clprg(TestAMD.a.uop.buffer._buf, TestAMD.b.uop.buffer._buf,
|
||||
global_size=TestAMD.d0_runner.global_size, local_size=TestAMD.d0_runner.local_size)
|
||||
val = TestAMD.a.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
val = TestAMD.a.uop.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 4000.0, f"got val {val}"
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
52
test/external/external_test_hcq.py
vendored
52
test/external/external_test_hcq.py
vendored
@@ -22,10 +22,10 @@ class TestHCQ(unittest.TestCase):
|
||||
TestHCQ.b = self.a + 1
|
||||
si = self.b.schedule()[-1]
|
||||
TestHCQ.runner = get_runner(TestHCQ.d0.device, si.ast)
|
||||
TestHCQ.b.lazydata.buffer.allocate()
|
||||
TestHCQ.b.uop.buffer.allocate()
|
||||
# wow that's a lot of abstraction layers
|
||||
TestHCQ.addr = struct.pack("QQ", TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr)
|
||||
TestHCQ.addr2 = struct.pack("QQ", TestHCQ.a.lazydata.buffer._buf.va_addr, TestHCQ.b.lazydata.buffer._buf.va_addr)
|
||||
TestHCQ.addr = struct.pack("QQ", TestHCQ.b.uop.buffer._buf.va_addr, TestHCQ.a.uop.buffer._buf.va_addr)
|
||||
TestHCQ.addr2 = struct.pack("QQ", TestHCQ.a.uop.buffer._buf.va_addr, TestHCQ.b.uop.buffer._buf.va_addr)
|
||||
TestHCQ.kernargs_off = TestHCQ.runner._prg.kernargs_offset
|
||||
TestHCQ.kernargs_size = TestHCQ.runner._prg.kernargs_alloc_size
|
||||
ctypes.memmove(TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_off, TestHCQ.addr, len(TestHCQ.addr))
|
||||
@@ -45,8 +45,8 @@ class TestHCQ(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
TestHCQ.d0.synchronize()
|
||||
TestHCQ.a.lazydata.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
|
||||
TestHCQ.b.lazydata.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 0))))
|
||||
TestHCQ.a.uop.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
|
||||
TestHCQ.b.uop.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 0))))
|
||||
TestHCQ.d0.synchronize() # wait for copyins to complete
|
||||
|
||||
def test_run_1000_times_one_submit(self):
|
||||
@@ -65,7 +65,7 @@ class TestHCQ(unittest.TestCase):
|
||||
q.submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 2000.0, f"got val {val}"
|
||||
|
||||
def test_run_1000_times(self):
|
||||
@@ -81,7 +81,7 @@ class TestHCQ(unittest.TestCase):
|
||||
TestHCQ.compute_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 2000.0, f"got val {val}"
|
||||
|
||||
def test_run_to_3(self):
|
||||
@@ -95,7 +95,7 @@ class TestHCQ(unittest.TestCase):
|
||||
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 3.0, f"got val {val}"
|
||||
|
||||
def test_update_exec(self):
|
||||
@@ -106,9 +106,9 @@ class TestHCQ(unittest.TestCase):
|
||||
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 1.0, f"got val {val}"
|
||||
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
|
||||
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
|
||||
assert val == 0.0, f"got val {val}, should not be updated"
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
|
||||
@@ -126,7 +126,7 @@ class TestHCQ(unittest.TestCase):
|
||||
TestHCQ.compute_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 2000.0, f"got val {val}"
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
|
||||
@@ -141,9 +141,9 @@ class TestHCQ(unittest.TestCase):
|
||||
q.submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 1.0, f"got val {val}"
|
||||
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
|
||||
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
|
||||
assert val == 0.0, f"got val {val}, should not be updated"
|
||||
|
||||
@unittest.skipIf(CI, "Can't handle async update on CPU")
|
||||
@@ -174,7 +174,7 @@ class TestHCQ(unittest.TestCase):
|
||||
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 1.0, f"got val {val}"
|
||||
|
||||
def test_submit_empty_queues(self):
|
||||
@@ -206,13 +206,13 @@ class TestHCQ(unittest.TestCase):
|
||||
q.submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 1.0, f"got val {val}"
|
||||
|
||||
def test_copy_1000_times(self):
|
||||
q = TestHCQ.copy_queue()
|
||||
q.copy(TestHCQ.a.lazydata.buffer._buf.va_addr, TestHCQ.b.lazydata.buffer._buf.va_addr, 8)
|
||||
q.copy(TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr, 8)
|
||||
q.copy(TestHCQ.a.uop.buffer._buf.va_addr, TestHCQ.b.uop.buffer._buf.va_addr, 8)
|
||||
q.copy(TestHCQ.b.uop.buffer._buf.va_addr, TestHCQ.a.uop.buffer._buf.va_addr, 8)
|
||||
for _ in range(1000):
|
||||
q.submit(TestHCQ.d0)
|
||||
TestHCQ.copy_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
@@ -221,24 +221,24 @@ class TestHCQ(unittest.TestCase):
|
||||
# confirm the signal didn't exceed the put value
|
||||
with self.assertRaises(RuntimeError):
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value + 1, timeout=50)
|
||||
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
|
||||
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
|
||||
assert val == 0.0, f"got val {val}"
|
||||
|
||||
def test_copy(self):
|
||||
q = TestHCQ.copy_queue()
|
||||
q.copy(TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr, 8)
|
||||
q.copy(TestHCQ.b.uop.buffer._buf.va_addr, TestHCQ.a.uop.buffer._buf.va_addr, 8)
|
||||
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
q.submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
|
||||
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
|
||||
assert val == 1.0, f"got val {val}"
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
|
||||
def test_bind_copy(self):
|
||||
q = TestHCQ.copy_queue()
|
||||
q.copy(TestHCQ.a.lazydata.buffer._buf.va_addr, TestHCQ.b.lazydata.buffer._buf.va_addr, 8)
|
||||
q.copy(TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr, 8)
|
||||
q.copy(TestHCQ.a.uop.buffer._buf.va_addr, TestHCQ.b.uop.buffer._buf.va_addr, 8)
|
||||
q.copy(TestHCQ.b.uop.buffer._buf.va_addr, TestHCQ.a.uop.buffer._buf.va_addr, 8)
|
||||
q.bind(TestHCQ.d0)
|
||||
for _ in range(1000):
|
||||
q.submit(TestHCQ.d0)
|
||||
@@ -248,7 +248,7 @@ class TestHCQ(unittest.TestCase):
|
||||
# confirm the signal didn't exceed the put value
|
||||
with self.assertRaises(RuntimeError):
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value + 1, timeout=50)
|
||||
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
|
||||
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
|
||||
assert val == 0.0, f"got val {val}"
|
||||
|
||||
def test_copy_bandwidth(self):
|
||||
@@ -281,14 +281,14 @@ class TestHCQ(unittest.TestCase):
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) # b = [1, 2]
|
||||
q.signal(sig:=TestHCQ.d0._alloc_signal(value=0), value=1)
|
||||
qc.wait(sig, value=1)
|
||||
qc.copy(TestHCQ.a.lazydata.buffer._buf.va_addr, TestHCQ.b.lazydata.buffer._buf.va_addr, 8)
|
||||
qc.copy(TestHCQ.a.uop.buffer._buf.va_addr, TestHCQ.b.uop.buffer._buf.va_addr, 8)
|
||||
qc.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
qc.submit(TestHCQ.d0)
|
||||
time.sleep(0.02) # give it time for the wait to fail
|
||||
q.submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 1.0, f"got val {val}"
|
||||
|
||||
def test_cross_device_signal(self):
|
||||
@@ -319,7 +319,7 @@ class TestHCQ(unittest.TestCase):
|
||||
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 1.0, f"got val {val}"
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
2
test/external/external_test_hip_compile.py
vendored
2
test/external/external_test_hip_compile.py
vendored
@@ -10,7 +10,7 @@ class TestHIPCompileSpeed(unittest.TestCase):
|
||||
def test_hip_compile(self):
|
||||
a, b = Tensor([1,2,3,4,5]), Tensor([1,2,3,4,5])
|
||||
out = a + b
|
||||
lin = Kernel(create_schedule([out.lazydata])[-1].ast[0])
|
||||
lin = Kernel(create_schedule([out.uop])[-1].ast[0])
|
||||
lin.linearize()
|
||||
|
||||
reference = """
|
||||
|
||||
10
test/external/external_test_nv.py
vendored
10
test/external/external_test_nv.py
vendored
@@ -21,8 +21,8 @@ class TestNV(unittest.TestCase):
|
||||
TestNV.b = self.a + 1
|
||||
si = self.b.schedule()[-1]
|
||||
TestNV.d0_runner = get_runner(TestNV.d0.device, si.ast)
|
||||
TestNV.b.lazydata.buffer.allocate()
|
||||
TestNV.addr = struct.pack("QQ", TestNV.b.lazydata.buffer._buf.va_addr, TestNV.a.lazydata.buffer._buf.va_addr)
|
||||
TestNV.b.uop.buffer.allocate()
|
||||
TestNV.addr = struct.pack("QQ", TestNV.b.uop.buffer._buf.va_addr, TestNV.a.uop.buffer._buf.va_addr)
|
||||
|
||||
def test_oor_kernels(self):
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
@@ -44,8 +44,8 @@ class TestNV(unittest.TestCase):
|
||||
TestNV.along = Tensor([105615], device="NV").realize()
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.SIN, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
temp_runner = get_runner(TestNV.d0.device, (ast,))
|
||||
temp_runner([TestNV.b.lazydata.buffer, TestNV.along.lazydata.buffer], var_vals={})
|
||||
val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
temp_runner([TestNV.b.uop.buffer, TestNV.along.uop.buffer], var_vals={})
|
||||
val = TestNV.b.uop.buffer.as_buffer().cast("f")[0]
|
||||
assert abs(val - 0.80647) < 0.001, f"got val {val}"
|
||||
|
||||
def test_kernargs_no_oob_access(self):
|
||||
@@ -59,7 +59,7 @@ class TestNV(unittest.TestCase):
|
||||
q.signal(TestNV.d0.timeline_signal, TestNV.d0.timeline_value).submit(TestNV.d0)
|
||||
TestNV.d0._wait_signal(TestNV.d0.timeline_signal, TestNV.d0.timeline_value)
|
||||
TestNV.d0.timeline_value += 1
|
||||
val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
val = TestNV.b.uop.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 1.0, f"got val {val}"
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
2
test/external/fuzz_graph.py
vendored
2
test/external/fuzz_graph.py
vendored
@@ -28,7 +28,7 @@ def alloc_rawbuffer(device, fill=False):
|
||||
if fill:
|
||||
with Context(DEBUG=0):
|
||||
data = np.random.randint(-10000, 10000, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype))
|
||||
rawbuf.copyin(Tensor(data).realize().lazydata.base.realized.as_buffer())
|
||||
rawbuf.copyin(Tensor(data).realize().uop.base.realized.as_buffer())
|
||||
return rawbuf
|
||||
|
||||
def gen_kernel_ji(device, deps):
|
||||
|
||||
2
test/external/fuzz_linearizer.py
vendored
2
test/external/fuzz_linearizer.py
vendored
@@ -73,7 +73,7 @@ def get_fuzz_rawbufs(lin):
|
||||
data = np.random.uniform(-1, 1, size=rawbuf.size).astype(dtype=_to_np_dtype(rawbuf.dtype))
|
||||
else:
|
||||
data = np.random.uniform(-10, 10, size=rawbuf.size).astype(dtype=_to_np_dtype(rawbuf.dtype))
|
||||
rawbuf.copyin(Tensor(data, device=lin.opts.device).realize().lazydata.base.realized.as_buffer())
|
||||
rawbuf.copyin(Tensor(data, device=lin.opts.device).realize().uop.base.realized.as_buffer())
|
||||
return rawbufs
|
||||
|
||||
def get_fuzz_rawbuf_like(old_rawbuf, zero=False, copy=False, size=None, force_device=None):
|
||||
|
||||
@@ -21,18 +21,18 @@ def consec(shape, start=1):
|
||||
|
||||
# creates strided tensor with base set to reference tensor's base, equivalent to torch.set_()
|
||||
def set_(reference: Tensor, shape, strides, offset):
|
||||
raise NotImplementedError("need to implement without calling lazydata.view")
|
||||
if reference.lazydata.base.realized is None: reference.realize()
|
||||
assert reference.lazydata.base.realized, "base has to be realized before setting it to strided's base"
|
||||
strided = Tensor(reference.lazydata.view(ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),))))
|
||||
assert strided.lazydata.st.real_strides() == strides, "real_strides should equal strides for strided"
|
||||
raise NotImplementedError("need to implement without calling uop.view")
|
||||
if reference.uop.base.realized is None: reference.realize()
|
||||
assert reference.uop.base.realized, "base has to be realized before setting it to strided's base"
|
||||
strided = Tensor(reference.uop.view(ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),))))
|
||||
assert strided.uop.st.real_strides() == strides, "real_strides should equal strides for strided"
|
||||
return strided
|
||||
|
||||
def clone(original:Tensor): return original.clone()
|
||||
def copy_(src:Tensor, other:Tensor) -> Tensor: return src.clone()
|
||||
# this is fine for tested usecases since as geohotstan understands,
|
||||
# data_ptr is used to compare if operations needed between tensors is the same
|
||||
def data_ptr(tensor:Tensor): return tensor.lazydata
|
||||
def data_ptr(tensor:Tensor): return tensor.uop
|
||||
|
||||
# https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html
|
||||
def index_put_(tensor:Tensor, indices, values, accumulate) -> Tensor:
|
||||
@@ -971,9 +971,9 @@ class TestIndexing(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper((2, 0, 4), z.shape)
|
||||
# this isn't technically necessary, but matches NumPy stride calculations.
|
||||
# NOTE: this is empty and shouldn't have strides
|
||||
#numpy_testing_assert_equal_helper((60, 20, 5), z.lazydata.st.real_strides())
|
||||
#numpy_testing_assert_equal_helper((60, 20, 5), z.uop.st.real_strides())
|
||||
# NOTE tinygrad's int slicing implementation makes this not contiguous
|
||||
# self.assertTrue(z.lazydata.st.contiguous)
|
||||
# self.assertTrue(z.uop.st.contiguous)
|
||||
|
||||
@unittest.skip("bool indexing not supported")
|
||||
def test_index_getitem_copy_bools_slices(self):
|
||||
|
||||
@@ -20,7 +20,7 @@ class TestArange(unittest.TestCase):
|
||||
p = k.to_program()
|
||||
print(p.name)
|
||||
#print(p.src)
|
||||
ExecItem(CompiledRunner(p), [tt.lazydata.buffer]).run()
|
||||
ExecItem(CompiledRunner(p), [tt.uop.buffer]).run()
|
||||
np.testing.assert_equal(tt.numpy(), np.arange(N))
|
||||
return p.estimates.ops
|
||||
|
||||
|
||||
@@ -13,11 +13,11 @@ class TestAssign(unittest.TestCase):
|
||||
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||
a.realize()
|
||||
b.realize()
|
||||
ba1 = a.lazydata.base.realized
|
||||
bb1 = b.lazydata.base.realized
|
||||
ba1 = a.uop.base.realized
|
||||
bb1 = b.uop.base.realized
|
||||
a += b
|
||||
a.realize()
|
||||
ba2 = a.lazydata.base.realized
|
||||
ba2 = a.uop.base.realized
|
||||
assert ba1 == ba2 and ba1 != bb1
|
||||
np.testing.assert_allclose(a.numpy(), (np.arange(N*N)*2).reshape((N,N)))
|
||||
|
||||
@@ -259,13 +259,13 @@ class TestAssign(unittest.TestCase):
|
||||
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||
a.realize()
|
||||
b.realize()
|
||||
ba1 = a.lazydata.base.realized
|
||||
bb1 = b.lazydata.base.realized
|
||||
ba1 = a.uop.base.realized
|
||||
bb1 = b.uop.base.realized
|
||||
with self.assertRaises((RuntimeError, AssertionError)):
|
||||
a = a.permute(1,0)
|
||||
a += b
|
||||
a.realize()
|
||||
ba2 = a.lazydata.base.realized
|
||||
ba2 = a.uop.base.realized
|
||||
assert ba1 != ba2 and ba1 != bb1
|
||||
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
|
||||
|
||||
@@ -275,12 +275,12 @@ class TestAssign(unittest.TestCase):
|
||||
a.realize()
|
||||
b.realize()
|
||||
#GlobalCounters.cache = []
|
||||
ba1 = a.lazydata.base.realized # noqa: F841
|
||||
bb1 = b.lazydata.base.realized # noqa: F841
|
||||
ba1 = a.uop.base.realized # noqa: F841
|
||||
bb1 = b.uop.base.realized # noqa: F841
|
||||
with self.assertRaisesRegex(RuntimeError, "contiguous"):
|
||||
a.assign(a.permute(1,0) + b) # this should not work!
|
||||
a.realize()
|
||||
ba2 = a.lazydata.base.realized # noqa: F841
|
||||
ba2 = a.uop.base.realized # noqa: F841
|
||||
# NOTE: don't test that it's assigned
|
||||
#assert ba1 == ba2 and ba1 != bb1
|
||||
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
|
||||
@@ -383,10 +383,10 @@ class TestAssign(unittest.TestCase):
|
||||
def test_cast_assignment(self):
|
||||
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||
a.realize()
|
||||
oba1 = a.lazydata.base.output_buffer
|
||||
oba1 = a.uop.base.output_buffer
|
||||
a.assign(a.cast(dtypes.int32).realize())
|
||||
a.realize()
|
||||
oba2 = a.lazydata.base.output_buffer
|
||||
oba2 = a.uop.base.output_buffer
|
||||
assert oba1 is None and oba2 is None
|
||||
np.testing.assert_allclose(a.numpy(), np.arange(N*N,dtype=np.int32).reshape((N,N)))
|
||||
|
||||
|
||||
@@ -221,7 +221,7 @@ class TestReduceOpsConstFolding(unittest.TestCase):
|
||||
# contiguous folded const can still schedule
|
||||
a = Tensor.empty(1, 0).sum().contiguous()
|
||||
_check_ast_count(2, a+2)
|
||||
self.assertIs(a.lazydata.base.op, Ops.BUFFER)
|
||||
self.assertIs(a.uop.base.op, Ops.BUFFER)
|
||||
np.testing.assert_equal((Tensor.empty(1, 0).sum().contiguous()+2).numpy(), 2)
|
||||
# otherwise we just fuse it
|
||||
_check_ast_count(1, (Tensor.empty(1, 0).sum()+2).contiguous())
|
||||
|
||||
@@ -32,7 +32,7 @@ def get_available_cast_dtypes(dtype: DType) -> List[DType]:
|
||||
def _test_to_np(a:Tensor, np_dtype, target):
|
||||
if DEBUG >= 2: print(a)
|
||||
na = a.numpy()
|
||||
if DEBUG >= 2: print(na, na.dtype, a.lazydata.base.realized)
|
||||
if DEBUG >= 2: print(na, na.dtype, a.uop.base.realized)
|
||||
try:
|
||||
assert na.dtype == np_dtype
|
||||
np.testing.assert_allclose(na, target)
|
||||
|
||||
@@ -73,7 +73,7 @@ class TestGC(unittest.TestCase):
|
||||
x = Tensor.ones(4,4).contiguous().realize()+1
|
||||
self.assertEqual(bufs_allocated()-init, 1)
|
||||
# try commenting this part out, it's green!
|
||||
x.lazydata.toposort()
|
||||
x.uop.toposort()
|
||||
del x
|
||||
if bufs_allocated()-init != 0:
|
||||
print(inspect.getclosurevars(UOp.toposort().fget))
|
||||
@@ -84,11 +84,11 @@ class TestGC(unittest.TestCase):
|
||||
a = Tensor.empty(10)
|
||||
self.assertEqual(bufs_allocated()-init, 0)
|
||||
a.realize()
|
||||
real_buf = a.lazydata.buffer
|
||||
real_buf = a.uop.buffer
|
||||
# after the Tensor UOp is deleted there shouldn't be any references on the Buffer
|
||||
self.assertEqual(real_buf.lb_refcount, 1)
|
||||
self.assertEqual(bufs_allocated()-init, 1)
|
||||
del a.lazydata
|
||||
del a.uop
|
||||
self.assertEqual(real_buf.lb_refcount, 0)
|
||||
self.assertEqual(bufs_allocated()-init, 1) # keep the buffer alive
|
||||
del real_buf
|
||||
@@ -98,10 +98,10 @@ class TestGC(unittest.TestCase):
|
||||
init = bufs_allocated()
|
||||
a = Tensor.full((4,), 1.).contiguous()
|
||||
a.realize()
|
||||
real_buf = a.lazydata.buffer
|
||||
real_buf = a.uop.buffer
|
||||
self.assertEqual(real_buf.lb_refcount, 1)
|
||||
a.assign(Tensor.full((4,), 2.))
|
||||
self.assertIs(a.lazydata.src[0].buffer, real_buf)
|
||||
self.assertIs(a.uop.src[0].buffer, real_buf)
|
||||
# NOTE: this is still 1, we don't count the ASSIGN
|
||||
self.assertEqual(real_buf.lb_refcount, 1)
|
||||
a.realize()
|
||||
|
||||
@@ -36,7 +36,7 @@ def helper_alloc_rawbuffer(device, fill=False):
|
||||
if fill:
|
||||
with Context(DEBUG=0):
|
||||
data = np.random.randint(-10000, 10000, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype))
|
||||
rawbuf.copyin(Tensor(data).realize().lazydata.base.realized.as_buffer())
|
||||
rawbuf.copyin(Tensor(data).realize().uop.base.realized.as_buffer())
|
||||
return rawbuf
|
||||
|
||||
def helper_create_offset_rawbuffer(base, offset=0):
|
||||
|
||||
@@ -20,15 +20,15 @@ class TestHCQ(unittest.TestCase):
|
||||
si = self.b.schedule()[-1]
|
||||
|
||||
TestHCQ.runner = get_runner(TestHCQ.d0.device, si.ast)
|
||||
TestHCQ.b.lazydata.buffer.allocate()
|
||||
TestHCQ.b.uop.buffer.allocate()
|
||||
|
||||
TestHCQ.kernargs_ba_ptr = TestHCQ.runner._prg.fill_kernargs([TestHCQ.b.lazydata.buffer._buf, TestHCQ.a.lazydata.buffer._buf])
|
||||
TestHCQ.kernargs_ab_ptr = TestHCQ.runner._prg.fill_kernargs([TestHCQ.a.lazydata.buffer._buf, TestHCQ.b.lazydata.buffer._buf])
|
||||
TestHCQ.kernargs_ba_ptr = TestHCQ.runner._prg.fill_kernargs([TestHCQ.b.uop.buffer._buf, TestHCQ.a.uop.buffer._buf])
|
||||
TestHCQ.kernargs_ab_ptr = TestHCQ.runner._prg.fill_kernargs([TestHCQ.a.uop.buffer._buf, TestHCQ.b.uop.buffer._buf])
|
||||
|
||||
def setUp(self):
|
||||
TestHCQ.d0.synchronize()
|
||||
TestHCQ.a.lazydata.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
|
||||
TestHCQ.b.lazydata.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 0))))
|
||||
TestHCQ.a.uop.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
|
||||
TestHCQ.b.uop.buffer.copyin(memoryview(bytearray(struct.pack("ff", 0, 0))))
|
||||
TestHCQ.d0.synchronize() # wait for copyins to complete
|
||||
|
||||
# Test signals
|
||||
@@ -117,7 +117,7 @@ class TestHCQ(unittest.TestCase):
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 1.0, f"got val {val}"
|
||||
|
||||
def test_exec_2_kernels_100_times(self):
|
||||
@@ -133,7 +133,7 @@ class TestHCQ(unittest.TestCase):
|
||||
q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value})
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 200.0, f"got val {val}"
|
||||
|
||||
def test_exec_update(self):
|
||||
@@ -148,9 +148,9 @@ class TestHCQ(unittest.TestCase):
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 1.0, f"got val {val}"
|
||||
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
|
||||
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
|
||||
assert val == 0.0, f"got val {val}, should not be updated"
|
||||
|
||||
def test_exec_update_fuzz(self):
|
||||
@@ -192,13 +192,13 @@ class TestHCQ(unittest.TestCase):
|
||||
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
|
||||
|
||||
TestHCQ.d0.hw_copy_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
|
||||
.copy(TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr, 8) \
|
||||
.copy(TestHCQ.b.uop.buffer._buf.va_addr, TestHCQ.a.uop.buffer._buf.va_addr, 8) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
|
||||
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
|
||||
assert val == 1.0, f"got val {val}"
|
||||
|
||||
def test_copy_long(self):
|
||||
@@ -252,12 +252,12 @@ class TestHCQ(unittest.TestCase):
|
||||
.copy(virt_dest_addr, virt_src_addr, 8) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
|
||||
q.submit(TestHCQ.d0, {virt_src_addr: TestHCQ.a.lazydata.buffer._buf.va_addr, virt_dest_addr: TestHCQ.b.lazydata.buffer._buf.va_addr})
|
||||
q.submit(TestHCQ.d0, {virt_src_addr: TestHCQ.a.uop.buffer._buf.va_addr, virt_dest_addr: TestHCQ.b.uop.buffer._buf.va_addr})
|
||||
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
|
||||
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
|
||||
assert val == 1.0, f"got val {val}"
|
||||
|
||||
def test_update_copy_long(self):
|
||||
|
||||
@@ -13,7 +13,7 @@ IMAGE_SUPPORTED_DEVICES = ("QCOM", "GPU")
|
||||
class TestImageCopy(unittest.TestCase):
|
||||
def test_image_copyout_1x1(self, img_type=dtypes.imagef):
|
||||
it = Tensor.arange(4).cast(img_type((1,1,4))).realize()
|
||||
buf = it.lazydata.buffer
|
||||
buf = it.uop.buffer
|
||||
out = buf.as_buffer()
|
||||
np.testing.assert_equal(out.cast(it.dtype.fmt).tolist(), np.arange(4))
|
||||
|
||||
@@ -27,18 +27,18 @@ class TestImageCopy(unittest.TestCase):
|
||||
|
||||
def test_image_copyout_2x3(self):
|
||||
it = Tensor.arange(2*3*4).cast(dtypes.imagef((2,3,4))).realize()
|
||||
buf = it.lazydata.buffer
|
||||
buf = it.uop.buffer
|
||||
out = buf.as_buffer()
|
||||
np.testing.assert_equal(out.cast('f').tolist(), np.arange(2*3*4))
|
||||
|
||||
def test_image_roundtrip(self):
|
||||
sz = (4,2,4)
|
||||
it = Tensor.rand(prod(sz)).cast(dtypes.imagef(sz)).realize()
|
||||
buf = it.lazydata.buffer
|
||||
buf = it.uop.buffer
|
||||
out = buf.as_buffer()
|
||||
|
||||
it2 = Tensor.rand(prod(sz)).cast(dtypes.imagef(sz)).realize()
|
||||
buf2 = it2.lazydata.buffer
|
||||
buf2 = it2.uop.buffer
|
||||
buf2.copyin(out)
|
||||
|
||||
assert (it == it2).sum().item() == prod(sz)
|
||||
@@ -49,7 +49,7 @@ class TestImageDType(unittest.TestCase):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
tst = data.numpy()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
assert isinstance(it.lazydata.base.realized.dtype, ImageDType)
|
||||
assert isinstance(it.uop.base.realized.dtype, ImageDType)
|
||||
np.testing.assert_equal(tst, it.numpy())
|
||||
|
||||
@unittest.expectedFailure # this isn't supported anymore, CAST to ImageDType stays ImageDType
|
||||
@@ -58,14 +58,14 @@ class TestImageDType(unittest.TestCase):
|
||||
tst = data.numpy()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).realize()
|
||||
# the underlying UOp is identical
|
||||
self.assertIs(it.lazydata.base.realized, data.lazydata.base.realized)
|
||||
self.assertIs(it.uop.base.realized, data.uop.base.realized)
|
||||
np.testing.assert_equal(tst, it.numpy())
|
||||
|
||||
def test_image_and_back_wrong_shape(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
tst = data.numpy()
|
||||
it = data.cast(dtypes.imagef((9,12,4))).realize()
|
||||
assert not isinstance(it.lazydata.base.realized.dtype, ImageDType)
|
||||
assert not isinstance(it.uop.base.realized.dtype, ImageDType)
|
||||
np.testing.assert_equal(tst, it.numpy())
|
||||
|
||||
def test_shrink_load_float(self):
|
||||
@@ -77,7 +77,7 @@ class TestImageDType(unittest.TestCase):
|
||||
# NOTE: contiguous is needed otherwise this folds
|
||||
it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).contiguous().realize()
|
||||
out = (it*2).realize()
|
||||
assert isinstance(out.lazydata.base.realized.dtype, ImageDType)
|
||||
assert isinstance(out.uop.base.realized.dtype, ImageDType)
|
||||
|
||||
def test_sum(self):
|
||||
it = Tensor.rand(8).cast(dtypes.imagef((1,2,4))).realize()
|
||||
@@ -98,26 +98,26 @@ class TestImageDType(unittest.TestCase):
|
||||
def test_lru_alloc(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).realize()
|
||||
b1 = it.lazydata.base.realized._buf
|
||||
b1 = it.uop.base.realized._buf
|
||||
del it
|
||||
it = data.cast(dtypes.imagef((9,27,4))).realize()
|
||||
assert it.lazydata.base.realized._buf == b1
|
||||
assert it.uop.base.realized._buf == b1
|
||||
|
||||
def test_no_lru_alloc(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
b1 = it.lazydata.base.realized._buf
|
||||
b1 = it.uop.base.realized._buf
|
||||
del it
|
||||
it = data.cast(dtypes.imagef((10,27,4))).contiguous().realize()
|
||||
assert it.lazydata.base.realized._buf != b1
|
||||
assert it.uop.base.realized._buf != b1
|
||||
|
||||
def test_no_lru_alloc_dtype(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
b1 = it.lazydata.base.realized._buf
|
||||
b1 = it.uop.base.realized._buf
|
||||
del it
|
||||
it = data.cast(dtypes.imageh((9,27,4))).realize()
|
||||
assert it.lazydata.base.realized._buf != b1
|
||||
assert it.uop.base.realized._buf != b1
|
||||
|
||||
# issue caused by: don't realize image to image casts. this is part of a larger problem
|
||||
#@unittest.expectedFailure
|
||||
@@ -137,8 +137,8 @@ class TestImageDType(unittest.TestCase):
|
||||
print(lst)
|
||||
assert not np.any(np.isnan(lst))
|
||||
# NOTE: the w1 grad must realize to a seperate kernel
|
||||
assert w1.grad.lazydata.is_realized, f"never realized {w1.grad}"
|
||||
self.assertEqual(w1.grad.lazydata.base.buffer.dtype, dtypes.float32)
|
||||
assert w1.grad.uop.is_realized, f"never realized {w1.grad}"
|
||||
self.assertEqual(w1.grad.uop.base.buffer.dtype, dtypes.float32)
|
||||
self.assertEqual(len(sched), 10)
|
||||
|
||||
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
|
||||
|
||||
@@ -75,7 +75,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
lowered = [x[1] for x in lower_schedule(c.schedule())]
|
||||
for ei in lowered: ei.run()
|
||||
rawbufs = lowered[-1].bufs
|
||||
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.base.realized, b.lazydata.base.realized}
|
||||
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.uop.base.realized, b.uop.base.realized}
|
||||
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
|
||||
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
@@ -140,7 +140,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_multireduce(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(32, dtype=dtypes.float).realize()
|
||||
st_x = x.lazydata.st
|
||||
st_x = x.uop.st
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(st_x.reshape((1, 32)).expand((32, 32))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (1,)))
|
||||
@@ -172,7 +172,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_mid_dim_multireduce(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
|
||||
st_x = x.lazydata.st
|
||||
st_x = x.uop.st
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(st_x.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
@@ -232,12 +232,12 @@ class TestLinearizer(unittest.TestCase):
|
||||
x1 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
|
||||
x2 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
|
||||
g0, g1, g2, g3 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(4)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5))),))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x0.uop.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5))),))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x1.uop.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5))),))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 32, 32, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (2,)))
|
||||
third_x = UOp(Ops.LOAD, dtypes.float, (g3.view(x2.lazydata.st.reshape((27, 32, 1, 1, 5))),))
|
||||
third_x = UOp(Ops.LOAD, dtypes.float, (g3.view(x2.uop.st.reshape((27, 32, 1, 1, 5))),))
|
||||
mul = (third_x*second_reduce)
|
||||
third_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (mul,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 1, 5))), third_reduce))
|
||||
@@ -253,7 +253,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_double_reduce_multireduce(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(8, 32, 8, 16, dtype=dtypes.float).realize()
|
||||
st = x.lazydata.st
|
||||
st = x.uop.st
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, st.reshape((8, 1, 32, 8, 1, 16)).expand((8, 32, 32, 8, 16, 16)).to_uop()))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2, 5)))
|
||||
@@ -302,9 +302,9 @@ class TestLinearizer(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).softmax(1).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.uop.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.lazydata.st.reshape((27, 15, 1, 5))),))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.uop.st.reshape((27, 15, 1, 5))),))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce))
|
||||
@@ -329,11 +329,11 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.randn(4, 32, dtype=dtypes.float).realize()
|
||||
x_p = Tensor.randn(4, 32, dtype=dtypes.float).realize()
|
||||
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32))),))
|
||||
first_x_p = UOp(Ops.LOAD, dtypes.float, (g2.view(x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32))),))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.uop.st.reshape((4, 1, 32)).expand((4, 32, 32))),))
|
||||
first_x_p = UOp(Ops.LOAD, dtypes.float, (g2.view(x_p.uop.st.reshape((4, 1, 32)).expand((4, 32, 32))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
first_reduce_p = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x_p.alu(Ops.EXP2),), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.lazydata.st.reshape((4, 32, 1))),))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.uop.st.reshape((4, 32, 1))),))
|
||||
diff = (second_x+(first_reduce + first_reduce_p)*ast_const(dtypes.float, -1, (4, 32, 1)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((4, 1, 1))), second_reduce))
|
||||
@@ -361,9 +361,9 @@ class TestLinearizer(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
|
||||
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x.uop.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x.lazydata.st.reshape((27, 15, 1, 5))),))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x.uop.st.reshape((27, 15, 1, 5))),))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store0 = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce))
|
||||
@@ -383,9 +383,9 @@ class TestLinearizer(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
|
||||
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g2.view(x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g2.view(x.uop.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g2.view(x.lazydata.st.reshape((27, 15, 1, 5))),))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g2.view(x.uop.st.reshape((27, 15, 1, 5))),))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store0 = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce))
|
||||
@@ -402,9 +402,9 @@ class TestLinearizer(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5))),))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((27, 3, 1, 5))),))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((27, 3, 1, 5))),))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce))
|
||||
@@ -418,9 +418,9 @@ class TestLinearizer(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5))),))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((27, 3, 1, 5))),))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((27, 3, 1, 5))),))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce))
|
||||
@@ -437,9 +437,9 @@ class TestLinearizer(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(27, 12, 5, dtype=dtypes.float).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop()))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.uop.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop()))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 12, 1, 5)).to_uop()))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.uop.st.reshape((27, 12, 1, 5)).to_uop()))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 12, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
@@ -453,10 +453,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35))),))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
|
||||
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((15, 25, 35, 1))),))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((15, 25, 35, 1))),))
|
||||
squares = (second_x+neg_mean)*(second_x+neg_mean)
|
||||
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,)))
|
||||
variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1))
|
||||
@@ -471,10 +471,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35))),))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
neg_mean = first_reduce * ast_const(dtypes.float, -0.04, (15, 25, 1, 35))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((15, 25, 1, 35))),))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((15, 25, 1, 35))),))
|
||||
squares = (second_x+neg_mean)*(second_x+neg_mean)
|
||||
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (1,)))
|
||||
variance = squares_sum * ast_const(dtypes.float, 0.04, (15, 1, 1, 35))
|
||||
@@ -491,10 +491,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
||||
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.uop.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
|
||||
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop()))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.uop.st.reshape((15, 25, 35, 1)).to_uop()))
|
||||
squares = (second_x+neg_mean)*(second_x+neg_mean)
|
||||
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,)))
|
||||
variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1))
|
||||
@@ -514,12 +514,12 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.randn(3, 27, 32, dtype=dtypes.float).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
# push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32))),))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
|
||||
neg_mean = first_reduce * ast_const(dtypes.float, -0.03125, (3, 27, 32, 1))
|
||||
# store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 32, 1)).to_uop(), mean))
|
||||
# verify_lazyop(store)
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((3, 27, 32, 1))),))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((3, 27, 32, 1))),))
|
||||
squares = (second_x+neg_mean)*(second_x+neg_mean)
|
||||
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,)))
|
||||
variance = squares_sum * ast_const(dtypes.float, 0.03125, (3, 27, 1, 1))
|
||||
@@ -535,9 +535,9 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_softmax_multireduce(self):
|
||||
x = Tensor.rand(4, 32).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32))),))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((4, 1, 32,)).expand((4, 32, 32))),))
|
||||
max_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.MAX, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((4, 32, 1,))),))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((4, 32, 1,))),))
|
||||
centered_x = second_x+max_x*ast_const(dtypes.float, -1, (4, 32, 1))
|
||||
exp_x = centered_x.alu(Ops.EXP2)
|
||||
sum_exp_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (exp_x,), (Ops.ADD, (1,)))
|
||||
@@ -553,7 +553,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
dataset = Tensor.rand(16384, 256).realize()
|
||||
idxs = Tensor([0,3,5,6]).realize()
|
||||
with Context(FUSE_ARANGE=1):
|
||||
sink = dataset[idxs].contiguous().kernelize().lazydata.base.src[1].arg.ast
|
||||
sink = dataset[idxs].contiguous().kernelize().uop.base.src[1].arg.ast
|
||||
real_index = dataset.numpy()[idxs.numpy()].reshape(4, 1, 256, 1)
|
||||
helper_linearizer_ast(sink, [dataset, idxs], wanna_output=[real_index])
|
||||
|
||||
@@ -656,16 +656,16 @@ class TestLinearizer(unittest.TestCase):
|
||||
]
|
||||
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((1, N, N)).expand((N,N,N))),))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, 1, N))),))
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((1, N, N)).expand((N,N,N))),))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((N, 1, N))),))
|
||||
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.ADD, (1,)))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),),(Ops.ADD, (0,)))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((1,1,N))), r1))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=0, keepdims=True)).sum(axis=0).reshape(1,1,N)], opts=opts)
|
||||
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, 1, N)).expand((N,N,N))),))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, N, 1))),))
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((N, 1, N)).expand((N,N,N))),))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((N, N, 1))),))
|
||||
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.ADD, (2,)))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((N,1,1))), r1))
|
||||
@@ -683,16 +683,16 @@ class TestLinearizer(unittest.TestCase):
|
||||
]
|
||||
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((1, N, N)).expand((N,N,N))),))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, 1, N))),))
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((1, N, N)).expand((N,N,N))),))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((N, 1, N))),))
|
||||
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.MAX, (1,)))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),), (Ops.MAX, (0,)))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((1,1,N))), r1))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=0, keepdims=True)).max(axis=0).reshape(1,1,N)], opts=opts)
|
||||
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, 1, N)).expand((N,N,N))),))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, N, 1))),))
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((N, 1, N)).expand((N,N,N))),))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.uop.st.reshape((N, N, 1))),))
|
||||
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.MAX, (2,)))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (Ops.MAX, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((N,1,1))), r1))
|
||||
@@ -711,8 +711,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
opts = [[Opt(OptOps.PADTO, 0, 32)],[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],]
|
||||
|
||||
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=1,keepdims=True), a.numpy(), b.numpy())).sum(axis=1),0.0,1.0).reshape((N,1,1)) # noqa: E501
|
||||
ld0 = x.lazydata.st.reshape((N, 1, N)).expand((N,N,N))
|
||||
ld1 = x.lazydata.st.reshape((N, N, 1))
|
||||
ld0 = x.uop.st.reshape((N, 1, N)).expand((N,N,N))
|
||||
ld1 = x.uop.st.reshape((N, N, 1))
|
||||
ast = UOp(Ops.SINK, src=(
|
||||
UOp(Ops.STORE, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(), arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
@@ -742,8 +742,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
ast_const(dtypes.float, 1.0, (N, 1, 1)),)),)),))
|
||||
helper_linearizer_ast(ast, [x,a,b], opts=opts, wanna_output=[wanna_output])
|
||||
|
||||
ld0 = x.lazydata.st.reshape((1, N, N)).expand((N,N,N))
|
||||
ld1 = x.lazydata.st.reshape((N, 1, N))
|
||||
ld0 = x.uop.st.reshape((1, N, N)).expand((N,N,N))
|
||||
ld1 = x.uop.st.reshape((N, 1, N))
|
||||
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=0,keepdims=True), a.numpy(), b.numpy())).sum(axis=0),0.0,1.0).reshape(1,1,N) # noqa: E501
|
||||
ast = UOp(Ops.SINK, src=(
|
||||
UOp(Ops.STORE, src=(
|
||||
@@ -776,8 +776,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
# pad reduce axis
|
||||
helper_linearizer_ast(ast, [x,a,b], opts=[[Opt(OptOps.PADTO, 1, 32)],], wanna_output=[wanna_output])
|
||||
|
||||
ld0 = x.lazydata.st.reshape((1,1,N,N)).expand((N,N,N,N))
|
||||
ld1 = x.lazydata.st.reshape((N,N,1,1))
|
||||
ld0 = x.uop.st.reshape((1,1,N,N)).expand((N,N,N,N))
|
||||
ld1 = x.uop.st.reshape((N,N,1,1))
|
||||
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(keepdims=True), a.numpy(), b.numpy())).sum(keepdims=True),0.0,1.0).reshape((1,1,1,1))# noqa: E501
|
||||
ast = UOp(Ops.SINK, src=(
|
||||
UOp(Ops.STORE, src=(
|
||||
@@ -1794,7 +1794,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
|
||||
def helper_linearizer_ast(ast:UOp, inputs:list[Tensor], *args, **kwargs):
|
||||
assert isinstance(ast, UOp), "ast must be UOp"
|
||||
inbufs = [x.lazydata.base.buffer for x in inputs]
|
||||
inbufs = [x.uop.base.buffer for x in inputs]
|
||||
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.st_arg.size, out.src[1].dtype).allocate() \
|
||||
for out in ast.src]
|
||||
return _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs)
|
||||
|
||||
@@ -7,7 +7,7 @@ class TestMaskedShapeTracker(unittest.TestCase):
|
||||
b = Tensor([1,1]).pad(((0,3),))
|
||||
c = a*b
|
||||
assert c.shape == a.shape
|
||||
#assert c.lazydata.st.views[0].mask is not None
|
||||
#assert c.uop.st.views[0].mask is not None
|
||||
ret = c.data()
|
||||
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
|
||||
|
||||
@@ -16,7 +16,7 @@ class TestMaskedShapeTracker(unittest.TestCase):
|
||||
b = Tensor([1,1]).pad(((0,3),))
|
||||
c = a*b
|
||||
assert c.shape == a.shape
|
||||
#assert c.lazydata.st.views[0].mask is not None
|
||||
#assert c.uop.st.views[0].mask is not None
|
||||
ret = c.data()
|
||||
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
|
||||
|
||||
@@ -24,7 +24,7 @@ class TestMaskedShapeTracker(unittest.TestCase):
|
||||
a = Tensor([1,1]).pad(((0,2),))
|
||||
b = Tensor([1,1]).pad(((0,2),))
|
||||
c = a+b
|
||||
#assert c.lazydata.st.views[0].mask is not None
|
||||
#assert c.uop.st.views[0].mask is not None
|
||||
ret = c.data()
|
||||
assert ret.tolist() == [2.0, 2.0, 0.0, 0.0]
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ N = 128
|
||||
def _test_allreduce(t:Tensor):
|
||||
aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize()
|
||||
ts = t.shard(devices_4, 0).realize()
|
||||
b = Tensor(UOp.allreduce(ts.lazydata, Ops.ADD, ts.device))
|
||||
b = Tensor(UOp.allreduce(ts.uop, Ops.ADD, ts.device))
|
||||
b.realize()
|
||||
return aa, b
|
||||
|
||||
@@ -50,7 +50,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
def test_shard(self):
|
||||
X = Tensor.ones(256).contiguous().realize()
|
||||
X.shard_(devices_2, 0)
|
||||
for lb in X.lazydata.src:
|
||||
for lb in X.uop.src:
|
||||
assert lb.shape == (128,)
|
||||
(X + X).realize()
|
||||
|
||||
@@ -61,12 +61,12 @@ class TestMultiTensor(unittest.TestCase):
|
||||
|
||||
def test_tensor_from_multi(self):
|
||||
X = Tensor([1, 2], dtype=dtypes.int).shard_(devices_2, 0)
|
||||
Y = Tensor(X.lazydata)
|
||||
Y = Tensor(X.uop)
|
||||
self.assertEqual(Y.device, Device.DEFAULT)
|
||||
np.testing.assert_equal(X.numpy(), Y.numpy())
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
_ = Tensor(X.lazydata, dtype=dtypes.float)
|
||||
_ = Tensor(X.uop, dtype=dtypes.float)
|
||||
|
||||
def test_sharded_arange(self):
|
||||
sharded_arange = Tensor.arange(1000).shard(devices_2, 0)
|
||||
@@ -247,9 +247,9 @@ class TestMultiTensor(unittest.TestCase):
|
||||
shape = tuple([(n if i == 0 else 1) * random.randint(1, 10) for i in range(random.randint(1, 4))])
|
||||
t = Tensor.rand(shape).shard_(tuple([d0, d1, d2, d3][:n]), 0)
|
||||
with Context(RING=0):
|
||||
a = Tensor(UOp.allreduce(t.lazydata, Ops.ADD, t.device))
|
||||
a = Tensor(UOp.allreduce(t.uop, Ops.ADD, t.device))
|
||||
with Context(RING=2):
|
||||
b = Tensor(UOp.allreduce(t.lazydata, Ops.ADD, t.device))
|
||||
b = Tensor(UOp.allreduce(t.uop, Ops.ADD, t.device))
|
||||
diff = a - b
|
||||
mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy()
|
||||
max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy()
|
||||
@@ -590,21 +590,21 @@ class TestMultiTensor(unittest.TestCase):
|
||||
t4 = t2.reshape((26, 105,))
|
||||
|
||||
for t in [t0, t1, t2, t3, t4]:
|
||||
assert t.lazydata.axis == 1
|
||||
assert t.uop.axis == 1
|
||||
np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten())
|
||||
|
||||
# test shape-one axis
|
||||
t5 = t4.reshape((26, 1, 105))
|
||||
assert t5.lazydata.axis == 2
|
||||
assert t5.uop.axis == 2
|
||||
np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten())
|
||||
|
||||
# test split and rejoin to the right and reshape to the left
|
||||
t5 = t0.reshape((2, 13, 3, 5, 7))
|
||||
t6 = t0.reshape((13, 2, 3, 7, 5))
|
||||
t7 = t0.reshape((1, 13, 2, 3, 1, 7, 5))
|
||||
assert t5.lazydata.axis == 2
|
||||
assert t6.lazydata.axis == 2
|
||||
assert t7.lazydata.axis == 3
|
||||
assert t5.uop.axis == 2
|
||||
assert t6.uop.axis == 2
|
||||
assert t7.uop.axis == 3
|
||||
np.testing.assert_allclose(t5.numpy().flatten(), t0.numpy().flatten())
|
||||
np.testing.assert_allclose(t6.numpy().flatten(), t0.numpy().flatten())
|
||||
np.testing.assert_allclose(t7.numpy().flatten(), t0.numpy().flatten())
|
||||
@@ -616,7 +616,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_reshape_on_axis_uneven(self):
|
||||
def reshape_helper(t0, t, t_axis):
|
||||
assert t.lazydata.axis == t_axis
|
||||
assert t.uop.axis == t_axis
|
||||
np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy())
|
||||
|
||||
t0 = Tensor.rand((4, 42, 15)).shard(devices_3, axis=1, splits=[14, 7, 21])
|
||||
@@ -687,24 +687,24 @@ class TestMultiTensor(unittest.TestCase):
|
||||
self.assertEqual(t.shape, t2.shape)
|
||||
self.assertEqual(t.device, t2.device)
|
||||
self.assertEqual(t.dtype, t2.dtype)
|
||||
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
|
||||
self.assertEqual(t.uop.axis, t2.uop.axis)
|
||||
|
||||
def test_rand_like_from_alu(self):
|
||||
a = Tensor.ones(4, 4).shard(devices_4, axis=0)
|
||||
aa = a + a
|
||||
self.assertEqual(aa.device, devices_4)
|
||||
self.assertEqual(aa.lazydata.axis, 0)
|
||||
self.assertEqual(aa.uop.axis, 0)
|
||||
raa = aa.rand_like()
|
||||
self.assertEqual(raa.device, devices_4)
|
||||
self.assertEqual(raa.lazydata.axis, 0)
|
||||
self.assertEqual(raa.uop.axis, 0)
|
||||
|
||||
b = Tensor.empty(4, 4).shard(devices_4, axis=None)
|
||||
ab = a + b
|
||||
self.assertEqual(ab.device, devices_4)
|
||||
self.assertEqual(ab.lazydata.axis, 0)
|
||||
self.assertEqual(ab.uop.axis, 0)
|
||||
rab = ab.rand_like()
|
||||
self.assertEqual(rab.device, devices_4)
|
||||
self.assertEqual(rab.lazydata.axis, 0)
|
||||
self.assertEqual(rab.uop.axis, 0)
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_rand_like_uneven_shard(self):
|
||||
@@ -713,8 +713,8 @@ class TestMultiTensor(unittest.TestCase):
|
||||
self.assertEqual(t.shape, t2.shape)
|
||||
self.assertEqual(t.device, t2.device)
|
||||
self.assertEqual(t.dtype, t2.dtype)
|
||||
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
|
||||
assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.lazydata.src, t2.lazydata.src))
|
||||
self.assertEqual(t.uop.axis, t2.uop.axis)
|
||||
assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.uop.src, t2.uop.src))
|
||||
|
||||
def test_rand_like_none_shard(self):
|
||||
t = Tensor.empty((16, 16)).shard(devices_2)
|
||||
@@ -722,7 +722,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
self.assertEqual(t.shape, t2.shape)
|
||||
self.assertEqual(t.device, t2.device)
|
||||
self.assertEqual(t.dtype, t2.dtype)
|
||||
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
|
||||
self.assertEqual(t.uop.axis, t2.uop.axis)
|
||||
|
||||
def test_rand_like_arg_dtype(self):
|
||||
t = Tensor.empty((16, 16), dtype=dtypes.int32).shard(devices_2, axis=1)
|
||||
@@ -771,7 +771,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
devices = (d0, d1, d2, d3)
|
||||
t = Tensor.zeros(16, 16).contiguous()
|
||||
t.shard_(devices, axis=0).realize()
|
||||
assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.lazydata.src])
|
||||
assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.uop.src])
|
||||
|
||||
@unittest.skip("this is unreliable on OSX")
|
||||
def test_clone(self):
|
||||
@@ -928,7 +928,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
to_add.append((Tensor.ones(2, 8) * i).shard(devices))
|
||||
|
||||
added:list[Tensor] = []
|
||||
for bound, a in zip(x.lazydata.bounds, to_add):
|
||||
for bound, a in zip(x.uop.bounds, to_add):
|
||||
added.append(x[bound[0]:bound[1]] + a)
|
||||
|
||||
output = added[0].cat(*added[1:])
|
||||
@@ -1043,7 +1043,7 @@ class TestBatchNorm(unittest.TestCase):
|
||||
bns.append(bn)
|
||||
|
||||
bn_ts = []
|
||||
for bound, bn in zip(x.lazydata.bounds, bns):
|
||||
for bound, bn in zip(x.uop.bounds, bns):
|
||||
bni = bn(x[bound[0]:bound[1]])
|
||||
bn_ts.append(bni)
|
||||
|
||||
|
||||
@@ -596,9 +596,9 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
# sharded model shards the state_dict
|
||||
self.assertEqual(layer.weight.device, devices)
|
||||
self.assertEqual(layer.weight.lazydata.axis, 3)
|
||||
self.assertEqual(layer.weight.uop.axis, 3)
|
||||
self.assertEqual(layer.bias.device, devices)
|
||||
self.assertEqual(layer.bias.lazydata.axis, None)
|
||||
self.assertEqual(layer.bias.uop.axis, None)
|
||||
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
|
||||
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
|
||||
|
||||
@@ -634,9 +634,9 @@ class TestNN(unittest.TestCase):
|
||||
load_state_dict(layer, state_dict)
|
||||
|
||||
self.assertEqual(layer.weight.device, devices)
|
||||
self.assertEqual(layer.weight.lazydata.axis, 3)
|
||||
self.assertEqual(layer.weight.uop.axis, 3)
|
||||
self.assertEqual(layer.bias.device, devices)
|
||||
self.assertEqual(layer.bias.lazydata.axis, None)
|
||||
self.assertEqual(layer.bias.uop.axis, None)
|
||||
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
|
||||
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
|
||||
|
||||
@@ -658,9 +658,9 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
# NOTE: model and state_dict shard differently, use the state_dict sharding # TODO: revisit this?
|
||||
self.assertEqual(layer.weight.device, devices)
|
||||
self.assertEqual(layer.weight.lazydata.axis, None)
|
||||
self.assertEqual(layer.weight.uop.axis, None)
|
||||
self.assertEqual(layer.bias.device, devices5)
|
||||
self.assertEqual(layer.bias.lazydata.axis, 0)
|
||||
self.assertEqual(layer.bias.uop.axis, 0)
|
||||
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
|
||||
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ class TestPickle(unittest.TestCase):
|
||||
def test_pickle_realized_tensor_alt2(self):
|
||||
print("** init")
|
||||
t = Tensor.rand(10, 10).to("CPU").realize()
|
||||
tensor_uop = t.lazydata
|
||||
tensor_uop = t.uop
|
||||
assert tensor_uop.is_realized, f"expected {tensor_uop} to be realized"
|
||||
t_values = t.numpy()
|
||||
# pickle
|
||||
@@ -63,13 +63,13 @@ class TestPickle(unittest.TestCase):
|
||||
del tensor_uop
|
||||
print("** post pickle")
|
||||
t2:Tensor = pickle.loads(st)
|
||||
assert t2.lazydata.is_realized, f"expected {t2.lazydata} to be realized"
|
||||
assert t2.uop.is_realized, f"expected {t2.uop} to be realized"
|
||||
np.testing.assert_equal(t_values, t2.numpy())
|
||||
|
||||
# NOTE: currently Buffer exists on the uop, not tensor
|
||||
def test_pickle_buffer_uop(self):
|
||||
t = Tensor.arange(4).realize()
|
||||
a = t.lazydata
|
||||
a = t.uop
|
||||
assert a.op is Ops.BUFFER
|
||||
self.assertIsNotNone(buffer:=a.realized)
|
||||
s = pickle.dumps(a)
|
||||
@@ -98,12 +98,12 @@ class TestPickle(unittest.TestCase):
|
||||
def test_pickle_buffer_view(self):
|
||||
t = Tensor.arange(10, device="CPU").contiguous().realize()
|
||||
vt = t[3:5].contiguous().realize()
|
||||
assert hasattr(vt.lazydata.buffer, 'base')
|
||||
assert hasattr(vt.uop.buffer, 'base')
|
||||
ref_value = vt.tolist()
|
||||
st = pickle.dumps(vt)
|
||||
del t, vt
|
||||
vt2 = pickle.loads(st)
|
||||
assert hasattr(vt2.lazydata.buffer, 'base')
|
||||
assert hasattr(vt2.uop.buffer, 'base')
|
||||
assert ref_value == vt2.tolist()
|
||||
|
||||
def test_pickle_numpy(self):
|
||||
|
||||
@@ -36,12 +36,12 @@ class TestProfiler(unittest.TestCase):
|
||||
si = self.b.schedule()[-1]
|
||||
|
||||
TestProfiler.runner = get_runner(TestProfiler.d0.device, si.ast)
|
||||
TestProfiler.b.lazydata.buffer.allocate()
|
||||
TestProfiler.b.uop.buffer.allocate()
|
||||
|
||||
def test_profile_kernel_run(self):
|
||||
runner_name = TestProfiler.runner._prg.name
|
||||
with helper_collect_profile(TestProfiler.d0) as profile:
|
||||
TestProfiler.runner([TestProfiler.b.lazydata.buffer, TestProfiler.a.lazydata.buffer], var_vals={})
|
||||
TestProfiler.runner([TestProfiler.b.uop.buffer, TestProfiler.a.uop.buffer], var_vals={})
|
||||
|
||||
profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
|
||||
kernel_runs = [x for x in profile if isinstance(x, ProfileRangeEvent)]
|
||||
@@ -66,7 +66,7 @@ class TestProfiler(unittest.TestCase):
|
||||
|
||||
with helper_collect_profile(TestProfiler.d0) as profile:
|
||||
buf1.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
|
||||
TestProfiler.runner([buf1, TestProfiler.a.lazydata.buffer], var_vals={})
|
||||
TestProfiler.runner([buf1, TestProfiler.a.uop.buffer], var_vals={})
|
||||
buf1.copyout(memoryview(bytearray(buf1.nbytes)))
|
||||
|
||||
profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
|
||||
|
||||
@@ -23,7 +23,7 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
|
||||
uops = dedup(flatten(_recursive_add(st) for st in stores))
|
||||
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \
|
||||
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
|
||||
inbufs = [cast(UOp,x.lazydata).base.buffer for x in inputs]
|
||||
inbufs = [cast(UOp,x.uop).base.buffer for x in inputs]
|
||||
src = Device[Device.DEFAULT].renderer.render(uops)
|
||||
ei = CompiledRunner(ProgramSpec("test", src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size))
|
||||
ei.exec(outbufs+inbufs)
|
||||
|
||||
@@ -25,7 +25,7 @@ class TestRewriteTrackedChildren(unittest.TestCase):
|
||||
a = Tensor(2)
|
||||
b = Tensor(3)
|
||||
c = a + b
|
||||
sink = c.lazydata.sink()
|
||||
sink = c.uop.sink()
|
||||
sink = graph_rewrite(sink, rewrite, track_children=True)
|
||||
|
||||
def test_simple_child(self):
|
||||
@@ -35,8 +35,8 @@ class TestRewriteTrackedChildren(unittest.TestCase):
|
||||
a = Tensor(2)
|
||||
b = Tensor(3)
|
||||
c = a + b
|
||||
sink = c.lazydata
|
||||
view_w_child = a.lazydata.src[0]
|
||||
sink = c.uop
|
||||
view_w_child = a.uop.src[0]
|
||||
print([x().arg for x in view_w_child.children])
|
||||
print([x.arg for x in sink.get_children_map()[view_w_child]])
|
||||
self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((2,3)))
|
||||
@@ -57,7 +57,7 @@ class TestRewriteTrackedChildren(unittest.TestCase):
|
||||
extra = PatternMatcher([(UPat(Ops.REDUCE_AXIS, name="r"), print_children)])
|
||||
a = Tensor.empty(3, 3)
|
||||
r = (a+0).sum()
|
||||
graph_rewrite(r.lazydata, merge_views+sym+extra, track_children=True)
|
||||
graph_rewrite(r.uop, merge_views+sym+extra, track_children=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -188,7 +188,7 @@ class TestSchedule(unittest.TestCase):
|
||||
# pre-realize shared seed
|
||||
Tensor._device_rng_counters[x.device].realize()
|
||||
# run custom kernelized kernel
|
||||
sched_sink = graph_rewrite(x.lazydata, create_kernels, ctx={u:None for u in x.lazydata.toposort() if u.op is Ops.COPY}, bottom_up=True)
|
||||
sched_sink = graph_rewrite(x.uop, create_kernels, ctx={u:None for u in x.uop.toposort() if u.op is Ops.COPY}, bottom_up=True)
|
||||
y = Tensor(graph_rewrite(sched_sink, create_ast, bottom_up=True))
|
||||
run_schedule(check_schedule(y, 1))
|
||||
# compare against reference
|
||||
@@ -198,15 +198,15 @@ class TestSchedule(unittest.TestCase):
|
||||
def test_empty_is_not_realized(self):
|
||||
a = Tensor.empty(10)
|
||||
child = a+2
|
||||
assert not a.lazydata.is_realized
|
||||
assert not a.uop.is_realized
|
||||
child.realize()
|
||||
assert a.lazydata.is_realized
|
||||
assert a.uop.is_realized
|
||||
|
||||
# NOTE: because empty does not have an ExecItem if realize is called on a childless empty, it never gets allocated.
|
||||
def test_childless_empty_never_allocates(self):
|
||||
a = Tensor.empty(10)
|
||||
a.realize()
|
||||
assert not a.lazydata.is_realized
|
||||
assert not a.uop.is_realized
|
||||
|
||||
def test_simplify_padded_const(self):
|
||||
a = Tensor.empty(1022).cummax(axis=0)
|
||||
@@ -404,7 +404,7 @@ class TestSchedule(unittest.TestCase):
|
||||
sched = check_schedule([a, b], 1)
|
||||
run_schedule(sched)
|
||||
# a and b share the same underlying device memory
|
||||
self.assertIs(a.lazydata.realized, b.lazydata.realized)
|
||||
self.assertIs(a.uop.realized, b.uop.realized)
|
||||
|
||||
def test_clone_doesnt_dedup(self):
|
||||
src = Tensor.ones(4).contiguous().realize()
|
||||
@@ -413,7 +413,7 @@ class TestSchedule(unittest.TestCase):
|
||||
sched = check_schedule([a, b], 2, filter_sink=False)
|
||||
run_schedule(sched)
|
||||
# a and b are assigned to the same device Buffer
|
||||
self.assertIsNot(a.lazydata.realized, b.lazydata.realized)
|
||||
self.assertIsNot(a.uop.realized, b.uop.realized)
|
||||
|
||||
# EMPTY is assigned to a unique device Buffer
|
||||
|
||||
@@ -422,7 +422,7 @@ class TestSchedule(unittest.TestCase):
|
||||
b = Tensor.empty((4,))
|
||||
# NOTE: empty does not have any schedule
|
||||
check_schedule([a, b], 0, filter_sink=False)
|
||||
self.assertIsNot(a.lazydata.buffer, b.lazydata.buffer)
|
||||
self.assertIsNot(a.uop.buffer, b.uop.buffer)
|
||||
|
||||
def test_dedup_outputs(self):
|
||||
a = Tensor.full((4, 4), 1.).contiguous().realize()
|
||||
@@ -712,13 +712,13 @@ class TestSchedule(unittest.TestCase):
|
||||
prev_a = (a+1).contiguous()
|
||||
a.assign(Tensor([2]))
|
||||
a.kernelize(prev_a)
|
||||
assert prev_a.lazydata in a.lazydata.src, "contiguous usage must run before assign"
|
||||
assert prev_a.uop in a.uop.src, "contiguous usage must run before assign"
|
||||
self.assertEqual((prev_a+a*3).item(), 1+2*3)
|
||||
|
||||
def test_multioutput_ast(self):
|
||||
a = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().lazydata
|
||||
b = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().lazydata
|
||||
c = Tensor.arange(4).realize().lazydata
|
||||
a = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop
|
||||
b = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop
|
||||
c = Tensor.arange(4).realize().uop
|
||||
kernel = UOp(Ops.KERNEL, src=(a, b, c.base), arg=Kernel(UOp.sink(c.r(Ops.ADD, (0,))+1, c.r(Ops.ADD, (0,))*2)))
|
||||
assert all(s.op is Ops.BUFFER for s in kernel.src), f"views are not allowed here {kernel}"
|
||||
kernel = graph_rewrite(kernel, create_ast)
|
||||
@@ -1623,7 +1623,7 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
@unittest.skip("disabling subbuffer manually isn't supported anymore")
|
||||
def test_bitcast_disable_subbufer(self):
|
||||
x = cast(UOp, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata)
|
||||
x = cast(UOp, Tensor.empty(1, dtype=dtypes.float32).realize().uop)
|
||||
a = x.alu(Ops.EXP2).cast(dtypes.int32, True, allow_buffer_view=False)
|
||||
b = x.cast(dtypes.int32, True, allow_buffer_view=False)
|
||||
b = a.alu(Ops.ADD, b)
|
||||
@@ -1660,11 +1660,11 @@ class TestSchedule(unittest.TestCase):
|
||||
self.assertEqual(GlobalCounters.mem_used-base, 1024)
|
||||
|
||||
def test_const_schedule(self):
|
||||
constv = Tensor.empty(2, 2).lazydata.const_like(10)
|
||||
constv = Tensor.empty(2, 2).uop.const_like(10)
|
||||
check_schedule(constv, 0)
|
||||
|
||||
def test_const_schedule_contig(self):
|
||||
constv = Tensor.empty(2, 2).lazydata.const_like(10).contiguous()
|
||||
constv = Tensor.empty(2, 2).uop.const_like(10).contiguous()
|
||||
check_schedule(constv, 1)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU")
|
||||
@@ -1676,8 +1676,8 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(out, 3))
|
||||
np.testing.assert_allclose(out.numpy(), x.numpy()@y.numpy(), atol=1e-4, rtol=1e-4)
|
||||
self.assertIsInstance(out.dtype, ImageDType)
|
||||
self.assertIsNotNone(out.lazydata.base.realized)
|
||||
self.assertIsInstance(out.lazydata.base.realized.dtype, ImageDType)
|
||||
self.assertIsNotNone(out.uop.base.realized)
|
||||
self.assertIsInstance(out.uop.base.realized.dtype, ImageDType)
|
||||
|
||||
def _test_fusion(self, shapes, f, cnt):
|
||||
with Context(DEBUG=0, TRACK_MATCH_STATS=0): args = [Tensor.randn(s).realize() for s in shapes]
|
||||
@@ -1707,9 +1707,9 @@ class TestSchedule(unittest.TestCase):
|
||||
a = Tensor.arange(4).reshape(1, 4)
|
||||
casted_view = a.pad(((0, 1), (0, 0))).cast(dtypes.float)
|
||||
casted_view.realize()
|
||||
self.assertEqual(casted_view.lazydata.base.realized.size, 4)
|
||||
self.assertEqual(casted_view.uop.base.realized.size, 4)
|
||||
realized_view = casted_view.contiguous().realize()
|
||||
self.assertEqual(realized_view.lazydata.base.realized.size, 8)
|
||||
self.assertEqual(realized_view.uop.base.realized.size, 8)
|
||||
self.assertListEqual(realized_view.tolist(), [[0.0, 1.0, 2.0, 3.0], [0.0, 0.0, 0.0, 0.0]])
|
||||
|
||||
# NOTE: we only reorder CAST if it's an EXPAND
|
||||
@@ -1717,16 +1717,16 @@ class TestSchedule(unittest.TestCase):
|
||||
a = Tensor.arange(4).reshape(1, 4)
|
||||
casted_view = a.shrink(((0, 1), (0, 2))).cast(dtypes.float)
|
||||
casted_view.realize()
|
||||
self.assertEqual(casted_view.lazydata.base.realized.size, 2)
|
||||
self.assertEqual(casted_view.uop.base.realized.size, 2)
|
||||
realized_view = casted_view.contiguous().realize()
|
||||
self.assertEqual(realized_view.lazydata.base.realized.size, 2)
|
||||
self.assertEqual(realized_view.uop.base.realized.size, 2)
|
||||
self.assertListEqual(realized_view.tolist(), [[0, 1]])
|
||||
|
||||
def test_cast_const_view(self):
|
||||
a = Tensor.ones((4, 4), dtype=dtypes.float32)
|
||||
casted_view = a.cast(dtypes.int32)
|
||||
run_schedule(check_schedule(casted_view, 0))
|
||||
self.assertIsNone(casted_view.lazydata.base.realized)
|
||||
self.assertIsNone(casted_view.uop.base.realized)
|
||||
realized_const_view = casted_view.contiguous()
|
||||
run_schedule(check_schedule(realized_const_view, 1))
|
||||
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
|
||||
@@ -1998,7 +1998,7 @@ class TestIndexing(unittest.TestCase):
|
||||
def test_recursive_swizzle(self):
|
||||
a = Tensor([1,2,3,4]).realize()
|
||||
for _ in range(24): a = a + a
|
||||
new_uop = swizzle_rewrite(a.lazydata.reshape((4, 1)))
|
||||
new_uop = swizzle_rewrite(a.uop.reshape((4, 1)))
|
||||
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
|
||||
self.assertEqual(swizzle_cnt(new_uop), 0)
|
||||
|
||||
@@ -2012,13 +2012,13 @@ class TestIndexing(unittest.TestCase):
|
||||
a = Tensor.empty(32, 32).sum(axis=1)+Tensor.empty(1,32)
|
||||
ast = a.schedule()[0].ast
|
||||
self.assertEqual(ast.shape, (32, 1))
|
||||
self.assertEqual(a.lazydata.shape, (1, 32))
|
||||
self.assertEqual(a.uop.shape, (1, 32))
|
||||
|
||||
def test_no_reshape_reduceop(self):
|
||||
a = Tensor.empty(32, 32).sum(axis=(1,)).contiguous()
|
||||
ast = a.schedule()[0].ast
|
||||
self.assertEqual(ast.shape, (32, 1))
|
||||
self.assertEqual(a.lazydata.shape, (32,))
|
||||
self.assertEqual(a.uop.shape, (32,))
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_left), view_right)
|
||||
@@ -2161,7 +2161,7 @@ class TestView(unittest.TestCase):
|
||||
assert b.shape == (10, 10)
|
||||
sched = check_schedule(b.contiguous(), 1)
|
||||
self.assertEqual(store_val(sched[-1]).op, Ops.LOAD)
|
||||
self.assertEqual(store_val(sched[-1]).st_arg, b.lazydata.st)
|
||||
self.assertEqual(store_val(sched[-1]).st_arg, b.uop.st)
|
||||
run_schedule(sched)
|
||||
np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:])
|
||||
|
||||
@@ -2175,9 +2175,9 @@ class TestView(unittest.TestCase):
|
||||
late_mul = a*bv
|
||||
check_schedule(late_mul, 0)
|
||||
# the arange doesn't realize
|
||||
self.assertIsNone(b.lazydata.base.realized)
|
||||
self.assertIsNone(b.uop.base.realized)
|
||||
# mul doesn't realize
|
||||
self.assertIsNone(late_mul.lazydata.base.realized)
|
||||
self.assertIsNone(late_mul.uop.base.realized)
|
||||
self.assertEqual(late_mul.tolist(), [0, 0])
|
||||
|
||||
# SINK has two branches:
|
||||
@@ -2192,13 +2192,13 @@ class TestView(unittest.TestCase):
|
||||
other_child = b+2
|
||||
s = check_schedule([late_mul, other_child], 2)
|
||||
# the arange becomes a BUFFER
|
||||
self.assertIs(b.lazydata.base.op, Ops.BUFFER)
|
||||
self.assertIs(b.uop.base.op, Ops.BUFFER)
|
||||
# mul still collapses
|
||||
self.assertIs(late_mul.lazydata.base.op, Ops.CONST)
|
||||
self.assertIs(late_mul.uop.base.op, Ops.CONST)
|
||||
run_schedule(s)
|
||||
self.assertEqual(other_child.tolist(), [2, 3, 4])
|
||||
|
||||
def tensor_rewrite(t) -> UOp: return graph_rewrite(t.lazydata.base, merge_views+symbolic_simple)
|
||||
def tensor_rewrite(t) -> UOp: return graph_rewrite(t.uop.base, merge_views+symbolic_simple)
|
||||
class TestSimplifier(unittest.TestCase):
|
||||
def test_sink_childless_const(self):
|
||||
x = Tensor(0)
|
||||
@@ -2222,8 +2222,8 @@ class TestSimplifier(unittest.TestCase):
|
||||
a = Tensor.empty(4, 4, dtype=dtypes.int)
|
||||
sink = tensor_rewrite(a*0)
|
||||
assert UPat(Ops.CONST, arg=0).match(sink, {})
|
||||
self.assertIs(tensor_rewrite(a*1).base, a.lazydata.base)
|
||||
self.assertIs(tensor_rewrite(a+0).base, a.lazydata.base)
|
||||
self.assertIs(tensor_rewrite(a*1).base, a.uop.base)
|
||||
self.assertIs(tensor_rewrite(a+0).base, a.uop.base)
|
||||
|
||||
def test_cast_folding(self):
|
||||
a = Tensor(1.0).cast(dtypes.int)
|
||||
@@ -2257,14 +2257,14 @@ class TestConst(unittest.TestCase):
|
||||
|
||||
def test_tensor_const(self):
|
||||
a = Tensor(1)
|
||||
print(a.lazydata)
|
||||
self.assertTrue(tensor_const_pm.rewrite(a.lazydata))
|
||||
print(a.uop)
|
||||
self.assertTrue(tensor_const_pm.rewrite(a.uop))
|
||||
|
||||
def test_tensor_variable(self):
|
||||
vv = UOp.variable("a", 0, 10).bind(1)
|
||||
a = Tensor(vv)
|
||||
print(a.lazydata)
|
||||
self.assertTrue(tensor_const_pm.rewrite(a.lazydata))
|
||||
print(a.uop)
|
||||
self.assertTrue(tensor_const_pm.rewrite(a.uop))
|
||||
|
||||
def test_const_schedule(self):
|
||||
a = Tensor.ones((4, 4))
|
||||
@@ -2311,7 +2311,7 @@ class TestConst(unittest.TestCase):
|
||||
sched = add.schedule()
|
||||
self.assertEqual(len(sched), 0)
|
||||
# b+0 and b share the same underlying device memory
|
||||
self.assertIs(add.lazydata.buffer, b.lazydata.buffer)
|
||||
self.assertIs(add.uop.buffer, b.uop.buffer)
|
||||
self.assertListEqual(add.tolist(), [2, 2, 2, 2])
|
||||
|
||||
def test_src_masked_const_folding(self):
|
||||
@@ -2324,7 +2324,7 @@ class TestConst(unittest.TestCase):
|
||||
self.assertEqual(len(sched), 1)
|
||||
run_schedule(sched)
|
||||
# add gets assigned to a new buffer
|
||||
self.assertIsNot(add.lazydata.base.realized, b.lazydata.base.realized)
|
||||
self.assertIsNot(add.uop.base.realized, b.uop.base.realized)
|
||||
self.assertListEqual(add.tolist(), [4, 2, 2, 2, 2, 4])
|
||||
|
||||
# ** part 3: Tensor variable bindings
|
||||
@@ -2359,15 +2359,15 @@ class TestCopyFolding(unittest.TestCase):
|
||||
self.assertListEqual(b.tolist(), [0, 0, 0])
|
||||
|
||||
def test_alu_after_copy(self):
|
||||
a = Tensor.ones((4,)).to("CPU").lazydata
|
||||
b = Tensor.empty(4, device="CPU").lazydata
|
||||
a = Tensor.ones((4,)).to("CPU").uop
|
||||
b = Tensor.empty(4, device="CPU").uop
|
||||
add = a+b
|
||||
add = schedule_graph_rewrite(add)
|
||||
assert all_same([x.device for x in add.src]), f"ALU has different devices! {[x.device for x in add.src]}"
|
||||
|
||||
@unittest.skip("this is just clone now")
|
||||
def test_copy_to_same_device(self):
|
||||
a = Tensor.empty(4).lazydata
|
||||
a = Tensor.empty(4).uop
|
||||
b = a.copy_to_device(a.device)
|
||||
check_schedule(b, 0, filter_sink=False)
|
||||
b = schedule_graph_rewrite(b)
|
||||
@@ -2377,7 +2377,7 @@ class TestCopyFolding(unittest.TestCase):
|
||||
|
||||
@unittest.skip("this is just clone now")
|
||||
def test_copy_to_same_device_alt(self):
|
||||
a = Tensor.empty(4, 4).lazydata
|
||||
a = Tensor.empty(4, 4).uop
|
||||
b = a.copy_to_device(a.device)
|
||||
check_schedule(b, 0, filter_sink=False)
|
||||
b = schedule_graph_rewrite(b)
|
||||
@@ -2394,8 +2394,8 @@ class TestCopyFolding(unittest.TestCase):
|
||||
b = view.clone()
|
||||
# NOTE: this was sort of a bug making this 2
|
||||
run_schedule(check_schedule(b, 2, filter_sink=False))
|
||||
self.assertEqual(b.lazydata.base.buffer.size, 2)
|
||||
self.assertEqual(b.lazydata.size, 2)
|
||||
self.assertEqual(b.uop.base.buffer.size, 2)
|
||||
self.assertEqual(b.uop.size, 2)
|
||||
self.assertListEqual(b.tolist(), [0, 1])
|
||||
|
||||
def test_expanded_copy(self):
|
||||
@@ -2403,8 +2403,8 @@ class TestCopyFolding(unittest.TestCase):
|
||||
view = a.reshape(2, 1).expand(2, 2)
|
||||
b = view.clone()
|
||||
run_schedule(check_schedule(b, 2, filter_sink=False))
|
||||
self.assertEqual(b.lazydata.base.buffer.size, 4)
|
||||
self.assertEqual(b.lazydata.size, 4)
|
||||
self.assertEqual(b.uop.base.buffer.size, 4)
|
||||
self.assertEqual(b.uop.size, 4)
|
||||
self.assertListEqual(b.tolist(), [[0, 0], [1, 1]])
|
||||
|
||||
def test_permuted_copy(self):
|
||||
@@ -2414,7 +2414,7 @@ class TestCopyFolding(unittest.TestCase):
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
def test_permute_on_disk(self):
|
||||
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().lazydata.base.buffer.as_buffer())
|
||||
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer())
|
||||
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
|
||||
b = a.reshape(2, 2).permute(1, 0).to("CPU")
|
||||
b.realize()
|
||||
@@ -2430,7 +2430,7 @@ class TestCopyFolding(unittest.TestCase):
|
||||
# TODO: this is wrong because of the permute
|
||||
@unittest.expectedFailure
|
||||
def test_permute_after_shrink_on_disk(self):
|
||||
with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().lazydata.base.buffer.as_buffer())
|
||||
with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().uop.base.buffer.as_buffer())
|
||||
a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}")
|
||||
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CPU")
|
||||
b.realize()
|
||||
@@ -2442,13 +2442,13 @@ class TestTensorUOpSpec(unittest.TestCase):
|
||||
unsafe_push_views = PatternMatcher([
|
||||
(UPat.cvar("root").view(name="view"), lambda root,view: root.replace(src=tuple(x.view(view.st) for x in root.src))),
|
||||
])
|
||||
a.lazydata = graph_rewrite(a.lazydata.sink(), merge_views+merge_views+unsafe_push_views)
|
||||
a.uop = graph_rewrite(a.uop.sink(), merge_views+merge_views+unsafe_push_views)
|
||||
with self.assertRaisesRegex(RuntimeError, "UOp verification failed"):
|
||||
a.schedule()
|
||||
|
||||
def test_expanded_const_ok(self):
|
||||
a = Tensor.ones((4, 4))
|
||||
t = graph_rewrite(a.lazydata.sink(), merge_views+merge_views)
|
||||
t = graph_rewrite(a.uop.sink(), merge_views+merge_views)
|
||||
create_schedule_with_vars(t)
|
||||
|
||||
# NOTE: changing symbolic CONST VIEWs is not allowed
|
||||
@@ -2456,69 +2456,69 @@ class TestTensorUOpSpec(unittest.TestCase):
|
||||
def test_symbolic_shape_ok(self):
|
||||
a = Tensor.ones(4)
|
||||
vi = UOp.variable("i", 1, 10).bind(4)
|
||||
a.lazydata = graph_rewrite(a.reshape(vi).sum().lazydata, merge_views+merge_views)
|
||||
a.uop = graph_rewrite(a.reshape(vi).sum().uop, merge_views+merge_views)
|
||||
a.schedule()
|
||||
|
||||
class TestBufferUOp(unittest.TestCase):
|
||||
# BUFFER has a ShapeTracker of shape=(n,) and stride=(1,)
|
||||
def test_buffer_has_buffer(self):
|
||||
buf = Tensor.empty(10)
|
||||
self.assertIsNotNone(buf.lazydata.buffer)
|
||||
self.assertEqual(buf.lazydata.st, ShapeTracker.from_shape((10,)))
|
||||
self.assertIsNotNone(buf.uop.buffer)
|
||||
self.assertEqual(buf.uop.st, ShapeTracker.from_shape((10,)))
|
||||
# the device Buffer remains unallocated until it's we run the schedule
|
||||
self.assertFalse(buf.lazydata.buffer.is_allocated())
|
||||
self.assertFalse(buf.uop.buffer.is_allocated())
|
||||
add = buf+1
|
||||
sched = add.schedule()
|
||||
self.assertFalse(buf.lazydata.buffer.is_allocated())
|
||||
self.assertFalse(buf.uop.buffer.is_allocated())
|
||||
run_schedule(sched)
|
||||
self.assertTrue(buf.lazydata.buffer.is_allocated())
|
||||
self.assertTrue(buf.uop.buffer.is_allocated())
|
||||
|
||||
def test_buffer_has_unique_buffer(self):
|
||||
buf = Tensor.empty(10)
|
||||
buf1 = buf.lazydata.buffer
|
||||
buf2 = buf.lazydata.buffer
|
||||
buf1 = buf.uop.buffer
|
||||
buf2 = buf.uop.buffer
|
||||
self.assertIs(buf1, buf2)
|
||||
|
||||
# we also allow VIEW(BUFFER) to access the underlying device Buffer, as long as it's contiguous
|
||||
def test_buffer_view_allowed(self):
|
||||
add = Tensor.empty(1, 1)+Tensor.empty(1, 1)
|
||||
add.realize()
|
||||
self.assertIsNotNone(add.lazydata.buffer)
|
||||
self.assertEqual(add.lazydata.shape, (1, 1))
|
||||
self.assertIsNotNone(add.uop.buffer)
|
||||
self.assertEqual(add.uop.shape, (1, 1))
|
||||
|
||||
def test_buffer_view_not_allowed(self):
|
||||
permuted_view = Tensor.empty(1, 2, 3).permute(0, 2, 1)
|
||||
merged = graph_rewrite(permuted_view.lazydata, merge_views)
|
||||
merged = graph_rewrite(permuted_view.uop, merge_views)
|
||||
with self.assertRaisesRegex(AssertionError, "VIEW only works here if it's contiguous"):
|
||||
merged.buffer # cannot access Buffer of a non contiguous VIEW
|
||||
|
||||
def test_buffer_only_after_realize(self):
|
||||
a = Tensor([1])+Tensor([2])
|
||||
# accessing realized will return None
|
||||
self.assertIsNone(a.lazydata.realized)
|
||||
self.assertIsNone(a.uop.realized)
|
||||
# accessing Buffer will assert
|
||||
with self.assertRaisesRegex(AssertionError, "must be BUFFER"):
|
||||
a.lazydata.buffer # there is no BUFFER on an unrealized ADD
|
||||
a.uop.buffer # there is no BUFFER on an unrealized ADD
|
||||
# Buffer only exists once we realize it
|
||||
a.realize()
|
||||
self.assertIsNotNone(a.lazydata.buffer)
|
||||
self.assertIsNotNone(a.uop.buffer)
|
||||
|
||||
def test_const_does_not_realize(self):
|
||||
a = Tensor(1)+Tensor(2)
|
||||
run_schedule(check_schedule(a, 0))
|
||||
self.assertIsNone(a.lazydata.base.realized)
|
||||
self.assertIsNone(a.uop.base.realized)
|
||||
|
||||
def test_var_does_not_realize(self):
|
||||
a = Tensor(UOp.variable("a", 0, 10).bind(1))
|
||||
run_schedule(check_schedule(a, 0))
|
||||
self.assertIsNone(a.lazydata.base.realized)
|
||||
self.assertIsNone(a.uop.base.realized)
|
||||
|
||||
def test_view_does_not_realize(self):
|
||||
a = Tensor.randn(1, 4).expand(4, 4)
|
||||
a.realize()
|
||||
self.assertEqual(a.lazydata.base.realized.size, 4)
|
||||
self.assertEqual(a.uop.base.realized.size, 4)
|
||||
a2 = a.contiguous().realize()
|
||||
self.assertEqual(a2.lazydata.base.realized.size, 16)
|
||||
self.assertEqual(a2.uop.base.realized.size, 16)
|
||||
|
||||
class TestContiguous(unittest.TestCase):
|
||||
def test_contiguous_buffer(self):
|
||||
@@ -2550,13 +2550,13 @@ class TestContiguous(unittest.TestCase):
|
||||
a = Tensor.empty(4)
|
||||
b = a.expand((4, 4))
|
||||
check_schedule(b, 0)
|
||||
self.assertEqual(b.lazydata.base.buffer.size, 4)
|
||||
self.assertEqual(b.uop.base.buffer.size, 4)
|
||||
|
||||
def test_contiguous_view_realizes(self):
|
||||
a = Tensor.empty(4)
|
||||
b = a.expand((4, 4)).contiguous()
|
||||
check_schedule(b, 1)
|
||||
self.assertEqual(b.lazydata.base.buffer.size, 16)
|
||||
self.assertEqual(b.uop.base.buffer.size, 16)
|
||||
|
||||
class TestUOpBecome(unittest.TestCase):
|
||||
# the simplest case, if we create a new BUFFER for this tensor UOp
|
||||
@@ -2566,21 +2566,21 @@ class TestUOpBecome(unittest.TestCase):
|
||||
add = a+b
|
||||
check_schedule(add, 1)
|
||||
# NOTE: realized base is always a flat buffer
|
||||
assert UPat(Ops.BUFFER).match(add.lazydata.base, {})
|
||||
assert UPat(Ops.BUFFER).match(add.uop.base, {})
|
||||
# the Tensor UOp can optionally stack a VIEW on top of the BUFFER, in this case to preserve the (4, 4) shape of the tensor
|
||||
assert add.lazydata is not add.lazydata.base
|
||||
self.assertEqual(add.lazydata.size, 16)
|
||||
self.assertEqual(add.lazydata.shape, (4, 4))
|
||||
assert add.uop is not add.uop.base
|
||||
self.assertEqual(add.uop.size, 16)
|
||||
self.assertEqual(add.uop.shape, (4, 4))
|
||||
|
||||
def test_new_buffer_view(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = Tensor.empty(4, 4)
|
||||
add = (a+b).reshape(8, 2)
|
||||
check_schedule(add, 1)
|
||||
assert UPat(Ops.BUFFER).match(add.lazydata.base, {})
|
||||
assert UPat(Ops.BUFFER).match(add.uop.base, {})
|
||||
# the shape is preserverd in the becomes_map.
|
||||
self.assertEqual(add.lazydata.shape, (8, 2))
|
||||
assert add.lazydata is not add.lazydata.base
|
||||
self.assertEqual(add.uop.shape, (8, 2))
|
||||
assert add.uop is not add.uop.base
|
||||
|
||||
def test_new_flat_buffer(self):
|
||||
a = Tensor.empty(4,)
|
||||
@@ -2588,7 +2588,7 @@ class TestUOpBecome(unittest.TestCase):
|
||||
add = a+b
|
||||
check_schedule(add, 1)
|
||||
# BUFFER already has a shape (4,), this tensor just becomes a contiguous BUFFER
|
||||
assert UPat(Ops.BUFFER).match(add.lazydata, {})
|
||||
assert UPat(Ops.BUFFER).match(add.uop, {})
|
||||
|
||||
# sometimes we prefer to perform an op before movement ops, in this case we should stack the mops on top of the new buffer
|
||||
|
||||
@@ -2597,8 +2597,8 @@ class TestUOpBecome(unittest.TestCase):
|
||||
a = Tensor.empty(4, 1)
|
||||
b = a.expand(4, 4).reciprocal()
|
||||
check_schedule(b, 1)
|
||||
self.assertEqual(b.lazydata.base.buffer.size, 16)
|
||||
self.assertEqual(b.lazydata.st, ShapeTracker.from_shape((4, 4)))
|
||||
self.assertEqual(b.uop.base.buffer.size, 16)
|
||||
self.assertEqual(b.uop.st, ShapeTracker.from_shape((4, 4)))
|
||||
|
||||
def test_reorder_expand_alt(self):
|
||||
x = Tensor.empty(4, 1)
|
||||
@@ -2610,95 +2610,95 @@ class TestUOpBecome(unittest.TestCase):
|
||||
def test_become_existing_buffer(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = a*1
|
||||
assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul
|
||||
assert UPat(Ops.MUL).match(b.uop, {}) # before scheduling it's a mul
|
||||
check_schedule(b, 0)
|
||||
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.lazydata, {}) # scheduling merges all MovementOps into a single VIEW
|
||||
self.assertIs(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.uop, {}) # scheduling merges all MovementOps into a single VIEW
|
||||
self.assertIs(a.uop.base.buffer, b.uop.base.buffer)
|
||||
|
||||
def test_become_buf_with_mops(self):
|
||||
a = Tensor.empty(2, 4, 2)
|
||||
noop = a.shrink(((1, 2), (0, 4), (0, 2))).reshape(4, 2)*1+0
|
||||
# before realizing, this tensor is base
|
||||
assert noop.lazydata is noop.lazydata.base
|
||||
assert noop.uop is noop.uop.base
|
||||
noop.realize()
|
||||
# it becomes a realized view after realize
|
||||
assert noop.lazydata is not noop.lazydata.base
|
||||
assert noop.lazydata.base.op is Ops.BUFFER
|
||||
assert noop.uop is not noop.uop.base
|
||||
assert noop.uop.base.op is Ops.BUFFER
|
||||
late_add = noop+2
|
||||
late_add.realize()
|
||||
|
||||
def test_become_const_in_base(self):
|
||||
a = Tensor.empty(4)
|
||||
b = a*0
|
||||
assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul
|
||||
assert UPat(Ops.MUL).match(b.uop, {}) # before scheduling it's a mul
|
||||
check_schedule(b, 0)
|
||||
assert UPat(Ops.CONST, arg=0).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER)
|
||||
assert UPat(Ops.CONST, arg=0).match(b.uop.base, {}) # scheduling replaces the tensor uop with a VIEW(BUFFER)
|
||||
|
||||
def test_become_const_in_view(self):
|
||||
# if we shrink the base down to a size 0, only the VIEW becomes CONST, base is unchanged.
|
||||
add = Tensor.empty(2, 2)+Tensor.empty(2, 2)
|
||||
b = add.shrink(((0, 1), (0, 0)))
|
||||
check_schedule(b, 0)
|
||||
assert UPat(Ops.CONST, arg=0).match(b.lazydata, {})
|
||||
assert UPat(Ops.CONST, arg=0).match(b.uop, {})
|
||||
self.assertEqual(b.shape, (1, 0))
|
||||
# the base is untouched.
|
||||
assert UPat(Ops.ADD).match(add.lazydata, {})
|
||||
assert UPat(Ops.ADD).match(add.uop, {})
|
||||
|
||||
def test_become_const_from_const(self):
|
||||
const_add = Tensor(1)+Tensor(2)
|
||||
assert UPat(Ops.ADD).match(const_add.lazydata, {})
|
||||
assert UPat(Ops.ADD).match(const_add.uop, {})
|
||||
check_schedule(const_add, 0)
|
||||
assert UPat(Ops.CONST, arg=3).match(const_add.lazydata.base, {})
|
||||
assert UPat(Ops.CONST, arg=3).match(const_add.uop.base, {})
|
||||
|
||||
# tensors can become another realized tensor source
|
||||
def test_become_existing_buf_simple(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = a+0
|
||||
check_schedule(b, 0)
|
||||
assert b.lazydata.base.op is Ops.BUFFER
|
||||
self.assertIs(a.lazydata, b.lazydata)
|
||||
assert b.uop.base.op is Ops.BUFFER
|
||||
self.assertIs(a.uop, b.uop)
|
||||
|
||||
# they can also chain other movement ops on top of the tensor source
|
||||
def test_become_existing_buf_view(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = a.permute((1, 0))+0
|
||||
check_schedule(b, 0)
|
||||
self.assertEqual(b.lazydata.st, a.lazydata.permute((1, 0)).st)
|
||||
self.assertEqual(b.uop.st, a.uop.permute((1, 0)).st)
|
||||
|
||||
def test_become_existing_buf_view_alt(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = a.permute((1, 0)).reshape((8, 2))+0
|
||||
check_schedule(b, 0)
|
||||
self.assertEqual(b.lazydata.st, a.lazydata.permute((1, 0)).reshape((8, 2)).st)
|
||||
self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st)
|
||||
|
||||
# they can also have other base parents that simplified, in that case we just backtrack to the chained mops
|
||||
def test_become_existing_buf_complex(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = (a.permute((1, 0))+0).reshape((8, 2))+0
|
||||
check_schedule(b, 0)
|
||||
self.assertEqual(b.lazydata.st, a.lazydata.permute((1, 0)).reshape((8, 2)).st)
|
||||
assert b.lazydata.base.op is Ops.BUFFER
|
||||
self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st)
|
||||
assert b.uop.base.op is Ops.BUFFER
|
||||
|
||||
def test_become_multiple_choices(self):
|
||||
a = Tensor.empty(16)
|
||||
b = (a.reshape(1, 1, 4, 1, 4)+0).reshape(1, 1, 4, 4).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0
|
||||
c = (a.reshape(1, 1, 4, 4)+0).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0
|
||||
check_schedule([b, c], 0)
|
||||
assert all_same([x.lazydata.base.realized for x in [a,b,c]])
|
||||
assert all_same([x.uop.base.realized for x in [a,b,c]])
|
||||
# these movement ops result in the same ShapeTracker
|
||||
assert b.lazydata.st == c.lazydata.st
|
||||
assert b.lazydata is c.lazydata
|
||||
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(c.lazydata, {})
|
||||
assert b.uop.st == c.uop.st
|
||||
assert b.uop is c.uop
|
||||
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(c.uop, {})
|
||||
|
||||
def test_setitem_becomes_subbuffer(self):
|
||||
a = Tensor.full((4,), 2.).contiguous().realize()
|
||||
b = a.shrink(((0, 2),)).assign(Tensor.full((2,), 1.0))
|
||||
b.realize()
|
||||
assert a.lazydata.is_realized
|
||||
assert a.lazydata.buffer._base is None
|
||||
assert a.uop.is_realized
|
||||
assert a.uop.buffer._base is None
|
||||
# b is a subbuffer of a
|
||||
assert b.lazydata.op is Ops.BUFFER_VIEW
|
||||
assert b.lazydata.src[0] is a.lazydata
|
||||
assert b.uop.op is Ops.BUFFER_VIEW
|
||||
assert b.uop.src[0] is a.uop
|
||||
|
||||
def test_setitem_offset(self):
|
||||
a = Tensor.full((16,), 0.).contiguous().realize()
|
||||
|
||||
@@ -44,7 +44,7 @@ class TestTimeLinearizer(unittest.TestCase):
|
||||
|
||||
# Ensure that the kernel count is not incremented by time_linearizer when clearing l2
|
||||
def test_kernel_count(self):
|
||||
ast = Tensor.zeros(16).contiguous().kernelize().lazydata.src[1].arg.ast
|
||||
ast = Tensor.zeros(16).contiguous().kernelize().uop.src[1].arg.ast
|
||||
lin = Kernel(ast)
|
||||
bufs = bufs_from_lin(lin)
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ class TestSetitem(unittest.TestCase):
|
||||
|
||||
def test_setitem_into_noncontiguous(self):
|
||||
t = Tensor.ones(4)
|
||||
self.assertFalse(t.lazydata.st.contiguous)
|
||||
self.assertFalse(t.uop.st.contiguous)
|
||||
with self.assertRaises(RuntimeError): t[1] = 5
|
||||
|
||||
@unittest.skip("TODO: flaky")
|
||||
|
||||
@@ -49,11 +49,11 @@ class TestSymbolic(unittest.TestCase):
|
||||
j = Variable("j", 1, 5).bind(3)
|
||||
k = Variable("k", 1, 5).bind(3)
|
||||
t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
|
||||
st = t.lazydata.st
|
||||
st = t.uop.st
|
||||
self.assert_tuple_equal(st.shape, (i+j+k, 4))
|
||||
assert st.real_strides() == (4, 1)
|
||||
t = Tensor.rand(3, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
|
||||
st = t.lazydata.st
|
||||
st = t.uop.st
|
||||
self.assert_tuple_equal(st.shape, (2*i+3, 3))
|
||||
assert st.real_strides() == (3, 1)
|
||||
|
||||
@@ -62,7 +62,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
j = Variable("j", 1, 5).bind(4)
|
||||
k = Variable("k", 1, 5).bind(4)
|
||||
t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
|
||||
st = t.lazydata.st
|
||||
st = t.uop.st
|
||||
self.assert_tuple_equal(st.shape, (3, i+j+k))
|
||||
self.assert_tuple_equal(st.real_strides(), (i+j+k, 1))
|
||||
|
||||
@@ -113,7 +113,7 @@ class TestShapeTrackerUnbind(unittest.TestCase):
|
||||
v = Variable("v", 1, 100)
|
||||
bv = Variable("v", 1, 100).bind(3)
|
||||
t = Tensor.rand(3, 4).reshape(bv, 4)
|
||||
unbound_st, var_val = t.lazydata.st.unbind()
|
||||
unbound_st, var_val = t.uop.st.unbind()
|
||||
assert unbound_st == ShapeTracker((View.create(shape=(v, 4)),))
|
||||
assert var_val == {v: 3}
|
||||
|
||||
@@ -121,7 +121,7 @@ class TestShapeTrackerUnbind(unittest.TestCase):
|
||||
v = Variable("v", 1, 100)
|
||||
bv = Variable("v", 1, 100).bind(2)
|
||||
t = Tensor.rand(3, 4).shrink(((bv, bv+1), (0, 4)))
|
||||
unbound_st, var_val = t.lazydata.st.unbind()
|
||||
unbound_st, var_val = t.uop.st.unbind()
|
||||
assert unbound_st == ShapeTracker((View.create(shape=(1, 4), offset=4*v),))
|
||||
assert var_val == {v: 2}
|
||||
|
||||
@@ -180,8 +180,8 @@ class TestSymbolicReshapeFromNonContiguous(unittest.TestCase):
|
||||
vi = Variable("i", 1, 5).bind(4)
|
||||
t = Tensor.ones(3, 4).reshape(3, vi)
|
||||
assert t.shape == (3, vi)
|
||||
assert not t.lazydata.st.contiguous
|
||||
assert len(t.lazydata.st.views) == 1
|
||||
assert not t.uop.st.contiguous
|
||||
assert len(t.uop.st.views) == 1
|
||||
|
||||
def test_reshape_not_allowed(self):
|
||||
vi = Variable("i", 1, 5).bind(4)
|
||||
@@ -195,12 +195,12 @@ class TestSymbolicReshapeFromNonContiguous(unittest.TestCase):
|
||||
def test_reshape_from_padded(self):
|
||||
vi = Variable("i", 1, 5).bind(4)
|
||||
t = Tensor.ones(3, 4).contiguous().expand(2, 3, 4).pad(((1, 1), None, None)).shrink((None, None, (1, 3)))
|
||||
st = t.lazydata.st
|
||||
st = t.uop.st
|
||||
assert len(st.views) == 1
|
||||
view = st.views[0]
|
||||
assert view.shape == (4, 3, 2)
|
||||
t = t.reshape(vi, 3, 2)
|
||||
st2 = t.lazydata.st
|
||||
st2 = t.uop.st
|
||||
assert len(st2.views) == 1
|
||||
view2 = st2.views[0]
|
||||
# check only shape changed. strides, offset, mask, contiguous remained the same
|
||||
@@ -237,7 +237,7 @@ class TestSymbolicPad(unittest.TestCase):
|
||||
v = Variable("v", 1, 100).bind(5)
|
||||
t = Tensor.ones(5).reshape(v).pad(((4, 0),)).reshape(9)
|
||||
assert t.shape == (9,)
|
||||
st = t.lazydata.st
|
||||
st = t.uop.st
|
||||
print(st)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -565,17 +565,17 @@ class TestZeroShapeTensor(unittest.TestCase):
|
||||
t = Tensor.empty(3, 2, 0)
|
||||
assert t.shape == (3, 2, 0)
|
||||
# numpy has stride 0, 0, 0; torch has stride 2, 1, 1
|
||||
assert t.lazydata.st.real_strides() == (0, 0, 0)
|
||||
assert t.uop.st.real_strides() == (0, 0, 0)
|
||||
|
||||
t = Tensor.empty(3, 0, 2)
|
||||
assert t.shape == (3, 0, 2)
|
||||
# numpy has stride 0, 0, 0; torch has stride 2, 2, 1
|
||||
assert t.lazydata.st.real_strides() == (0, 0, 0)
|
||||
assert t.uop.st.real_strides() == (0, 0, 0)
|
||||
|
||||
t = Tensor.empty(0, 0, 0)
|
||||
assert t.shape == (0, 0, 0)
|
||||
# numpy has stride 0, 0, 0; torch has stride 1, 1, 1
|
||||
assert t.lazydata.st.real_strides() == (0, 0, 0)
|
||||
assert t.uop.st.real_strides() == (0, 0, 0)
|
||||
|
||||
def test_rand(self):
|
||||
t = Tensor.rand(3, 2, 0)
|
||||
@@ -690,24 +690,24 @@ class TestZeroShapeTensor(unittest.TestCase):
|
||||
a = Tensor.rand(16, 16).realize()
|
||||
b = a.clone()
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy())
|
||||
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
self.assertIsNot(a.uop.base.buffer, b.uop.base.buffer)
|
||||
|
||||
a = Tensor.rand(16, 16).mul(5.0).add(5.0).realize()
|
||||
b = a.clone()
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy())
|
||||
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
self.assertIsNot(a.uop.base.buffer, b.uop.base.buffer)
|
||||
|
||||
def test_clone_with_shrink(self):
|
||||
a = Tensor.rand(16, 16)
|
||||
b = a.shrink(((2, 10), None)).clone()
|
||||
b.realize()
|
||||
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
self.assertIsNot(a.uop.base.buffer, b.uop.base.buffer)
|
||||
|
||||
def test_clone_with_shrink_realized(self):
|
||||
a = Tensor.rand(16, 16).realize()
|
||||
b = a.shrink(((2, 10), None)).clone()
|
||||
b.realize()
|
||||
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
self.assertIsNot(a.uop.base.buffer, b.uop.base.buffer)
|
||||
|
||||
def test_clone_with_grad(self):
|
||||
a = Tensor.rand(16, 16, requires_grad=True)
|
||||
@@ -780,7 +780,7 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
@unittest.skip("why would this be true?")
|
||||
def test_exclude_noop_metadata(self):
|
||||
a = Tensor.rand(4, 4)*1
|
||||
self.assertEqual(a.lazydata.metadata[0].name, "__mul__")
|
||||
self.assertEqual(a.uop.metadata[0].name, "__mul__")
|
||||
k = a.schedule()[-1]
|
||||
self.assertEqual([m.name for m in k.metadata], ["rand"])
|
||||
|
||||
@@ -797,7 +797,7 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
W = Tensor.rand(3, 3, requires_grad=True)
|
||||
out = x.matmul(W)
|
||||
self.assertEqual(out.lazydata.metadata[0].name, "matmul")
|
||||
self.assertEqual(out.uop.metadata[0].name, "matmul")
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "matmul")
|
||||
@@ -805,7 +805,7 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
def test_relu(self):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
out = x.relu()
|
||||
self.assertEqual(out.lazydata.metadata[0].name, "relu")
|
||||
self.assertEqual(out.uop.metadata[0].name, "relu")
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "relu")
|
||||
@@ -814,9 +814,9 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
y = Tensor.rand(3, requires_grad=True)
|
||||
out = x.relu() * y.sigmoid()
|
||||
self.assertEqual(out.lazydata.metadata[0].name, "__mul__")
|
||||
self.assertEqual(out.lazydata.src[0].metadata[0].name, "relu")
|
||||
self.assertEqual(out.lazydata.src[1].metadata[0].name, "sigmoid")
|
||||
self.assertEqual(out.uop.metadata[0].name, "__mul__")
|
||||
self.assertEqual(out.uop.src[0].metadata[0].name, "relu")
|
||||
self.assertEqual(out.uop.src[1].metadata[0].name, "sigmoid")
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 3)
|
||||
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
|
||||
@@ -825,12 +825,12 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
x = Tensor.rand(3, requires_grad=True).realize()
|
||||
y = Tensor.rand(3, requires_grad=True).realize()
|
||||
out = (x.relu() * y.sigmoid()).sum()
|
||||
self.assertEqual(out.lazydata.metadata[0].name, "sum")
|
||||
self.assertEqual(out.uop.metadata[0].name, "sum")
|
||||
out.backward()
|
||||
self.assertEqual(x.grad.lazydata.metadata[0].name, "relu")
|
||||
self.assertTrue(x.grad.lazydata.metadata[0].backward)
|
||||
self.assertEqual(y.grad.lazydata.metadata[0].name, "sigmoid")
|
||||
self.assertTrue(y.grad.lazydata.metadata[0].backward)
|
||||
self.assertEqual(x.grad.uop.metadata[0].name, "relu")
|
||||
self.assertTrue(x.grad.uop.metadata[0].backward)
|
||||
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
|
||||
self.assertTrue(y.grad.uop.metadata[0].backward)
|
||||
si = Tensor.schedule(out, x.grad, y.grad)[-1]
|
||||
self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}")
|
||||
self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "__mul__", "relu"})
|
||||
|
||||
@@ -9,7 +9,7 @@ class TestTensorUOp(unittest.TestCase):
|
||||
def test_fromcpu_shape_tracker(self):
|
||||
def helper(a: np.ndarray):
|
||||
print(a.shape, a.strides, a.flags.c_contiguous)
|
||||
b = Tensor(a).lazydata
|
||||
b = Tensor(a).uop
|
||||
#assert b.st.contiguous == a.flags.c_contiguous
|
||||
assert b.st.shape == a.shape
|
||||
np.testing.assert_equal(a, Tensor(b).numpy())
|
||||
@@ -60,11 +60,11 @@ class TestTensorUOp(unittest.TestCase):
|
||||
np.testing.assert_allclose(c.numpy(), np.concatenate((a.numpy(), b.numpy()), axis=1))
|
||||
|
||||
def test_const_dtype(self):
|
||||
lb: UOp = Tensor([1], dtype=dtypes.int).lazydata
|
||||
lb: UOp = Tensor([1], dtype=dtypes.int).uop
|
||||
assert lb.const_like(1).base.arg == 1
|
||||
assert type(lb.const_like(1).base.arg) is int
|
||||
|
||||
lb: UOp = Tensor([1], dtype=dtypes.float).lazydata
|
||||
lb: UOp = Tensor([1], dtype=dtypes.float).uop
|
||||
assert lb.const_like(1).base.arg == 1.0
|
||||
assert type(lb.const_like(1).base.arg) is float
|
||||
|
||||
|
||||
@@ -476,7 +476,7 @@ class TestUOpStr(unittest.TestCase):
|
||||
assert str(eval(str(device))) == str(device)
|
||||
|
||||
def test_reduceop_arg(self):
|
||||
sum_uop = Tensor.empty(32, 32).sum().lazydata
|
||||
sum_uop = Tensor.empty(32, 32).sum().uop
|
||||
assert str(eval(str(sum_uop))) == str(sum_uop)
|
||||
|
||||
@unittest.skip("uop no longer has order like this")
|
||||
@@ -549,9 +549,9 @@ class TestShapeSpec(unittest.TestCase):
|
||||
# ** CONST is CONST(VIEW(DEVICE)) -> RESHPAE -> EXPAND
|
||||
|
||||
def test_expanded_const(self):
|
||||
a = Tensor(1).lazydata
|
||||
a = Tensor(1).uop
|
||||
self.assertEqual(a.st, ShapeTracker.from_shape(()))
|
||||
a = Tensor.ones((4, 4)).lazydata
|
||||
a = Tensor.ones((4, 4)).uop
|
||||
self.assertEqual(a.st, ShapeTracker.from_shape(()).reshape((1,1)).expand((4,4)))
|
||||
|
||||
def test_padded_const(self):
|
||||
@@ -569,12 +569,12 @@ class TestShapeSpec(unittest.TestCase):
|
||||
|
||||
# NOTE: CONST ShapeTracker comes from its source
|
||||
def test_scalar_const(self):
|
||||
a = Tensor(0).lazydata
|
||||
a = Tensor(0).uop
|
||||
self.assertEqual(a.st, ShapeTracker.from_shape(()))
|
||||
|
||||
def test_scalar_var(self):
|
||||
vv = UOp.variable("a", 1, 4).bind(2)
|
||||
t = Tensor(vv).lazydata
|
||||
t = Tensor(vv).uop
|
||||
self.assertEqual(t.st, ShapeTracker.from_shape(()))
|
||||
|
||||
# ** ASSIGN is ASSIGN(VIEW(BUFFER), new_val)
|
||||
@@ -583,7 +583,7 @@ class TestShapeSpec(unittest.TestCase):
|
||||
buffer = Tensor.arange(4).realize()
|
||||
a = buffer.assign(Tensor.zeros((4,), dtype=dtypes.int))
|
||||
assign_pattern = UPat(Ops.ASSIGN, src=(UPat(Ops.BUFFER), UPat()))
|
||||
assert assign_pattern.match(a.lazydata, {})
|
||||
assert assign_pattern.match(a.uop, {})
|
||||
a.realize()
|
||||
self.assertEqual(buffer.tolist(), [0, 0, 0, 0])
|
||||
|
||||
@@ -597,7 +597,7 @@ class TestShapeSpec(unittest.TestCase):
|
||||
buffer = Tensor.ones((4,)).contiguous().realize()
|
||||
a = buffer.reshape((2, 2)).assign(Tensor.zeros((2, 2)))
|
||||
assign_pattern = UPat(Ops.ASSIGN, src=(UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER))), UPat()))
|
||||
assert assign_pattern.match(a.lazydata, {})
|
||||
assert assign_pattern.match(a.uop, {})
|
||||
a.realize()
|
||||
self.assertEqual(buffer.tolist(), [0, 0, 0, 0])
|
||||
|
||||
@@ -606,13 +606,13 @@ class TestShapeSpec(unittest.TestCase):
|
||||
a = Tensor.ones((4,)).contiguous().realize()
|
||||
assign = a.shrink(((1, 2),)).assign(Tensor.zeros((1,)))
|
||||
# the ASSIGN UOp has size=1
|
||||
self.assertEqual(assign.lazydata.size, 1)
|
||||
self.assertEqual(assign.uop.size, 1)
|
||||
# the ASSIGN views the buffer with a shrunk st
|
||||
self.assertEqual(assign.lazydata.src[0].st, ShapeTracker.from_shape((4,)).shrink(((1, 2),)))
|
||||
self.assertEqual(assign.uop.src[0].st, ShapeTracker.from_shape((4,)).shrink(((1, 2),)))
|
||||
# the underlying BUFFER has a size=4
|
||||
self.assertEqual(assign.lazydata.buf_uop.size, 4)
|
||||
self.assertEqual(assign.uop.buf_uop.size, 4)
|
||||
# NOTE: output shape is different from the BUFFER shape
|
||||
self.assertNotEqual(assign.lazydata.shape, a.lazydata.shape)
|
||||
self.assertNotEqual(assign.uop.shape, a.uop.shape)
|
||||
assign.realize()
|
||||
self.assertEqual(a.tolist(), [1, 0, 1, 1])
|
||||
|
||||
@@ -622,13 +622,13 @@ class TestShapeSpec(unittest.TestCase):
|
||||
|
||||
def test_ops_st(self):
|
||||
# view / mop
|
||||
a = Tensor.empty(4, 2, 1).permute((1, 2, 0)).lazydata
|
||||
a = Tensor.empty(4, 2, 1).permute((1, 2, 0)).uop
|
||||
self.assertEqual(a.st, ShapeTracker.from_shape((4, 2, 1)).permute((1, 2, 0)))
|
||||
# alu / reduce
|
||||
alu = a*2
|
||||
self.assertEqual(alu.st, ShapeTracker.from_shape((2, 1, 4)))
|
||||
r = Tensor.empty(4, 4).sum(axis=1)
|
||||
self.assertEqual(r.lazydata.st, ShapeTracker.from_shape((4,)))
|
||||
self.assertEqual(r.uop.st, ShapeTracker.from_shape((4,)))
|
||||
|
||||
def test_st_wmma_none(self):
|
||||
A = UOp(Ops.DEFINE_VAR, dtypes.float.vec(16), arg=('a', UOp.const(dtypes.float, 0), UOp.const(dtypes.float, 1)))
|
||||
|
||||
@@ -6,7 +6,7 @@ def time_tensor_numpy(out:Tensor):
|
||||
times = []
|
||||
for _ in range(5):
|
||||
st = time.perf_counter()
|
||||
out.lazydata.base.realized.as_buffer(allow_zero_copy=True)
|
||||
out.uop.base.realized.as_buffer(allow_zero_copy=True)
|
||||
et = time.perf_counter() - st
|
||||
times.append(et)
|
||||
return min(times)
|
||||
|
||||
@@ -113,16 +113,16 @@ class TestTensorGradient(unittest.TestCase):
|
||||
class TestRealizeMeansRealize(unittest.TestCase):
|
||||
def test_randn_realizes(self):
|
||||
x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize()
|
||||
assert x.lazydata is not x.lazydata.base
|
||||
assert x.lazydata.is_realized
|
||||
assert x.uop is not x.uop.base
|
||||
assert x.uop.is_realized
|
||||
|
||||
#@unittest.expectedFailure
|
||||
# update: passing after delete_forced_realize
|
||||
def test_uniform_realizes(self):
|
||||
x = Tensor.uniform(16, 3, 3, 3, requires_grad=True).realize()
|
||||
print(x.lazydata)
|
||||
assert x.lazydata is not x.lazydata.base
|
||||
assert x.lazydata.is_realized
|
||||
print(x.uop)
|
||||
assert x.uop is not x.uop.base
|
||||
assert x.uop.is_realized
|
||||
|
||||
# NOTE: even though it doesn't realize, this seems fine
|
||||
def test_uniform_gradient(self):
|
||||
|
||||
@@ -836,29 +836,29 @@ class TestConsecutive(unittest.TestCase):
|
||||
self.ones = Tensor.ones(2, 4)
|
||||
|
||||
def test_unmodified(self):
|
||||
assert self.t.lazydata.st.consecutive
|
||||
assert self.t.reshape(4, 2).lazydata.st.consecutive
|
||||
assert self.t.reshape(1, 8).lazydata.st.consecutive
|
||||
assert self.t.uop.st.consecutive
|
||||
assert self.t.reshape(4, 2).uop.st.consecutive
|
||||
assert self.t.reshape(1, 8).uop.st.consecutive
|
||||
|
||||
def test_sliced(self):
|
||||
assert self.t[0].lazydata.st.consecutive
|
||||
assert self.t[0, 1:2].lazydata.st.consecutive
|
||||
assert self.t[1].lazydata.st.consecutive
|
||||
assert not self.t[:, 0].lazydata.st.consecutive
|
||||
assert not self.t[:, 1].lazydata.st.consecutive
|
||||
assert self.t[0].uop.st.consecutive
|
||||
assert self.t[0, 1:2].uop.st.consecutive
|
||||
assert self.t[1].uop.st.consecutive
|
||||
assert not self.t[:, 0].uop.st.consecutive
|
||||
assert not self.t[:, 1].uop.st.consecutive
|
||||
|
||||
def test_padded(self):
|
||||
assert not self.t.pad(((1, 1), None)).lazydata.st.consecutive
|
||||
assert not self.t.pad((None, (1, 1))).lazydata.st.consecutive
|
||||
assert not self.t.pad(((1, 1), None)).uop.st.consecutive
|
||||
assert not self.t.pad((None, (1, 1))).uop.st.consecutive
|
||||
|
||||
def test_const(self):
|
||||
assert self.const.lazydata.st.consecutive
|
||||
assert self.const.uop.st.consecutive
|
||||
|
||||
def test_ones(self):
|
||||
assert not self.ones.lazydata.st.consecutive
|
||||
assert not self.ones[0, :].lazydata.st.consecutive
|
||||
assert not self.ones.uop.st.consecutive
|
||||
assert not self.ones[0, :].uop.st.consecutive
|
||||
# consecutive if sliced into size 1
|
||||
assert self.ones[0, 0].lazydata.st.consecutive
|
||||
assert self.ones[0, 0].uop.st.consecutive
|
||||
|
||||
class TestRender(unittest.TestCase):
|
||||
def test_render(self):
|
||||
|
||||
@@ -8,21 +8,21 @@ realized_pattern = UPat(Ops.BUFFER)
|
||||
buffer_view_pattern = UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),))
|
||||
const_pattern = UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),),)))
|
||||
def is_pattern_uop(u:UOp, pat:UPat): assert pat.match(u, {}), f"{u}\nis not\n{pat}"
|
||||
def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.lazydata, pat)
|
||||
def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.uop, pat)
|
||||
|
||||
class TestTensorMutates(unittest.TestCase):
|
||||
def test_mutate_add(self):
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
ret = a+b
|
||||
pa = a.lazydata
|
||||
pb = b.lazydata
|
||||
pr = ret.lazydata
|
||||
pa = a.uop
|
||||
pb = b.uop
|
||||
pr = ret.uop
|
||||
ret.schedule()
|
||||
self.assertIsNot(pa, a.lazydata)
|
||||
self.assertIsNot(pb, b.lazydata)
|
||||
self.assertIsNot(pr, ret.lazydata)
|
||||
for t in [a,b,ret]: is_pattern_uop(t.lazydata.base, realized_pattern)
|
||||
self.assertIsNot(pa, a.uop)
|
||||
self.assertIsNot(pb, b.uop)
|
||||
self.assertIsNot(pr, ret.uop)
|
||||
for t in [a,b,ret]: is_pattern_uop(t.uop.base, realized_pattern)
|
||||
|
||||
def test_reshape_is_same_parent(self):
|
||||
a = Tensor([1,2,3])
|
||||
@@ -30,11 +30,11 @@ class TestTensorMutates(unittest.TestCase):
|
||||
c = a+b
|
||||
d = (a+b).reshape(3,1)
|
||||
d.realize()
|
||||
is_pattern_uop(d.lazydata.base, realized_pattern)
|
||||
is_pattern_uop(c.lazydata.base, realized_pattern)
|
||||
is_pattern_uop(d.uop.base, realized_pattern)
|
||||
is_pattern_uop(c.uop.base, realized_pattern)
|
||||
# NOTE: we keep movement ops on top of the buffer view
|
||||
is_pattern_uop(c.lazydata, UPat(Ops.BUFFER))
|
||||
is_pattern_uop(d.lazydata, UPat(Ops.VIEW, src=(realized_pattern,)))
|
||||
is_pattern_uop(c.uop, UPat(Ops.BUFFER))
|
||||
is_pattern_uop(d.uop, UPat(Ops.VIEW, src=(realized_pattern,)))
|
||||
|
||||
def test_reshape_is_same_child(self):
|
||||
a = Tensor([1,2,3])
|
||||
@@ -42,41 +42,41 @@ class TestTensorMutates(unittest.TestCase):
|
||||
c = a+b
|
||||
d = (a+b).reshape(3,1)
|
||||
c.realize()
|
||||
is_pattern_uop(c.lazydata.base, realized_pattern)
|
||||
is_pattern_uop(d.lazydata.base, realized_pattern)
|
||||
is_pattern_uop(c.uop.base, realized_pattern)
|
||||
is_pattern_uop(d.uop.base, realized_pattern)
|
||||
|
||||
class TestTensorUopRepresentation(unittest.TestCase):
|
||||
def test_realized(self):
|
||||
a = Tensor([1.,2,3]).realize()
|
||||
print(a.lazydata)
|
||||
is_pattern_uop(a.lazydata.base, realized_pattern)
|
||||
print(a.uop)
|
||||
is_pattern_uop(a.uop.base, realized_pattern)
|
||||
|
||||
def test_add_realized(self):
|
||||
a = Tensor([1.,2,3]).realize()
|
||||
b = Tensor([4.,5,6]).realize()
|
||||
c = a+b
|
||||
print(c.lazydata)
|
||||
print(c.uop)
|
||||
is_pattern(c, UPat(Ops.ADD, src=(realized_pattern, realized_pattern)))
|
||||
|
||||
def test_const_pattern(self):
|
||||
a = Tensor(1)
|
||||
print(a.lazydata)
|
||||
print(a.uop)
|
||||
is_pattern(a, const_pattern) # const in tensor has a DEVICE and VIEW src
|
||||
is_pattern(a, UPat.cvar("x")) # even cvar works!
|
||||
|
||||
def test_consts_do_not_realize(self):
|
||||
a = Tensor(1)
|
||||
print(a.lazydata)
|
||||
pre_realize = a.lazydata
|
||||
print(a.uop)
|
||||
pre_realize = a.uop
|
||||
a.realize()
|
||||
assert a.lazydata is pre_realize
|
||||
assert a.uop is pre_realize
|
||||
|
||||
def test_viewed_consts_do_not_realize(self):
|
||||
a = Tensor.ones(10, 10)
|
||||
print(a.lazydata)
|
||||
print(a.uop)
|
||||
a.realize()
|
||||
is_pattern(a, const_pattern)
|
||||
self.assertEqual(a.lazydata.shape, (10, 10))
|
||||
self.assertEqual(a.uop.shape, (10, 10))
|
||||
|
||||
# currently, CONSTs have a "fake" BUFFER. this should be fixed
|
||||
# current:
|
||||
@@ -93,8 +93,8 @@ class TestTensorUopRepresentation(unittest.TestCase):
|
||||
# UOp(Ops.DEVICE, dtypes.void, arg="METAL", src=()),)),)),))
|
||||
def test_consts_dont_have_buffers(self):
|
||||
a = Tensor.ones(10, 10)
|
||||
print(a.lazydata)
|
||||
buffers_in_parents = [x.op for x in a.lazydata.toposort() if x.op is Ops.BUFFER]
|
||||
print(a.uop)
|
||||
buffers_in_parents = [x.op for x in a.uop.toposort() if x.op is Ops.BUFFER]
|
||||
self.assertEqual(len(buffers_in_parents), 0)
|
||||
|
||||
# currently, COPY has an extra BUFFER on the output
|
||||
@@ -112,7 +112,7 @@ class TestTensorUopRepresentation(unittest.TestCase):
|
||||
def test_copyin(self):
|
||||
a = Tensor([1.,2,3]).realize()
|
||||
c = a.to("TEST") # NOTE: this isn't checked
|
||||
print(c.lazydata)
|
||||
print(c.uop)
|
||||
is_pattern(c, UPat(Ops.COPY, src=(realized_pattern, UPat(Ops.DEVICE))))
|
||||
|
||||
def test_empty_buf(self):
|
||||
@@ -121,7 +121,7 @@ class TestTensorUopRepresentation(unittest.TestCase):
|
||||
vi = UOp.variable("i", 1, 3).bind(1)
|
||||
a = Tensor.empty(3, vi)
|
||||
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)))
|
||||
self.assertEqual(a.lazydata.base.buffer.size, 9)
|
||||
self.assertEqual(a.uop.base.buffer.size, 9)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -176,7 +176,7 @@ class CapturedJit(Generic[ReturnType]):
|
||||
self.__post_init__() # reset the graph state
|
||||
|
||||
def replan_buffers_memory_layout(self):
|
||||
blacklist = [t.lazydata.buffer for t in get_parameters(self.ret)]
|
||||
blacklist = [t.uop.buffer for t in get_parameters(self.ret)]
|
||||
asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True)
|
||||
self.jit_cache = [ExecItem(item.prg, [asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
|
||||
for old, new in asgn.items():
|
||||
@@ -210,9 +210,9 @@ class CapturedJit(Generic[ReturnType]):
|
||||
def _prepare_jit_inputs(args, kwargs):
|
||||
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
|
||||
names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
|
||||
if len(unrealized_tensors := [x for x in tensors if not x.lazydata.is_realized]): Tensor.realize(*unrealized_tensors)
|
||||
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
|
||||
# TODO: this multi unpack stuff is not well tested.
|
||||
lbs: list[UOp] = flatten([t.lazydata.src if t.lazydata.op is Ops.MULTI else [t.lazydata] for t in tensors])
|
||||
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
|
||||
input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb]
|
||||
for lb in lbs if lb.base.realized is not None])
|
||||
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
|
||||
|
||||
@@ -155,7 +155,7 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr
|
||||
raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.')
|
||||
if isinstance(v.device, tuple):
|
||||
if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k])
|
||||
else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis))
|
||||
else: v.replace(state_dict[k].shard(v.device, v.uop.axis))
|
||||
else: v.replace(state_dict[k].to(v.device))
|
||||
if realize: v.realize()
|
||||
if consume: del state_dict[k]
|
||||
|
||||
@@ -20,7 +20,7 @@ from tinygrad.engine.grouper import get_kernelize_map
|
||||
|
||||
all_tensors: set[weakref.ref[Tensor]] = set()
|
||||
def _find_all_tensors_for_uops(all_uops: set[UOp]) -> list[Tensor]:
|
||||
return [t for tref in all_tensors if (t:=tref()) is not None and t.lazydata in all_uops]
|
||||
return [t for tref in all_tensors if (t:=tref()) is not None and t.uop in all_uops]
|
||||
|
||||
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str|None=None) -> None:
|
||||
# get all children of keys in applied_map
|
||||
@@ -36,13 +36,13 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str|None=None) -> Non
|
||||
# NOTE: this uses all_tensors, but it's fast
|
||||
if len(fixed_tensors := _find_all_tensors_for_uops(all_uops)):
|
||||
# potentially rewrite all the discovered Tensors
|
||||
sink = UOp.sink(*[t.lazydata for t in fixed_tensors])
|
||||
sink = UOp.sink(*[t.uop for t in fixed_tensors])
|
||||
new_sink = sink.substitute(applied_map, name=name)
|
||||
|
||||
# set the relevant lazydata to the realized UOps
|
||||
# set the relevant uop to the realized UOps
|
||||
for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
|
||||
if s is ns: continue
|
||||
t.lazydata = ns
|
||||
t.uop = ns
|
||||
|
||||
# **** Tensor helper functions ****
|
||||
|
||||
@@ -119,7 +119,7 @@ class Tensor(MathTrait):
|
||||
np.set_printoptions(precision=4)
|
||||
```
|
||||
"""
|
||||
__slots__ = "lazydata", "requires_grad", "grad"
|
||||
__slots__ = "uop", "requires_grad", "grad"
|
||||
training: ClassVar[bool] = False
|
||||
|
||||
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'np.ndarray'|pathlib.Path|None, # type: ignore [name-defined] # noqa: F821
|
||||
@@ -150,7 +150,7 @@ class Tensor(MathTrait):
|
||||
if dtype is None:
|
||||
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
|
||||
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True
|
||||
if dtype in [dtypes.bfloat16, *dtypes.fp8s]: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtype).lazydata
|
||||
if dtype in [dtypes.bfloat16, *dtypes.fp8s]: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtype).uop
|
||||
else: data = _frompy(data, dtype)
|
||||
elif str(type(data)) == "<class 'numpy.ndarray'>":
|
||||
import numpy as np
|
||||
@@ -165,19 +165,19 @@ class Tensor(MathTrait):
|
||||
if not isinstance(data, UOp): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
||||
|
||||
# data might be on a different device
|
||||
if isinstance(device, str): self.lazydata:UOp = data if data.device == device else data.copy_to_device(device)
|
||||
if isinstance(device, str): self.uop:UOp = data if data.device == device else data.copy_to_device(device)
|
||||
# if device is a tuple, we should have/construct a MultiLazyBuffer
|
||||
elif isinstance(data.device, str): self.lazydata = Tensor(data).shard(device).lazydata
|
||||
elif isinstance(data.device, str): self.uop = Tensor(data).shard(device).uop
|
||||
else:
|
||||
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
|
||||
self.lazydata = data
|
||||
self.uop = data
|
||||
|
||||
# add to all_tensors after construction succeeds
|
||||
all_tensors.add(weakref.ref(self))
|
||||
def __del__(self): all_tensors.discard(weakref.ref(self))
|
||||
|
||||
def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
|
||||
new_uop: UOp = fxn(*[t.lazydata for t in (self,)+x], **kwargs)
|
||||
new_uop: UOp = fxn(*[t.uop for t in (self,)+x], **kwargs)
|
||||
if (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,)
|
||||
needs_input_grad = [t.requires_grad for t in (self,)+x]
|
||||
return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False)
|
||||
@@ -196,9 +196,9 @@ class Tensor(MathTrait):
|
||||
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
|
||||
|
||||
def __repr__(self):
|
||||
ld = self.lazydata
|
||||
ld = self.uop
|
||||
ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]} {ld.st if ld.base is not ld else (ld.op, ld.realized)}>"
|
||||
return f"<Tensor {ld_repr} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
|
||||
return f"<Tensor {ld_repr} on {self.device} with grad {(self.grad.uop if self.grad is not None else None)!r}>"
|
||||
|
||||
# Python has a non moving GC, so this should be okay
|
||||
def __hash__(self): return id(self)
|
||||
@@ -210,13 +210,13 @@ class Tensor(MathTrait):
|
||||
return self.shape[0]
|
||||
|
||||
@property
|
||||
def device(self) -> str|tuple[str, ...]: return self.lazydata.device
|
||||
def device(self) -> str|tuple[str, ...]: return self.uop.device
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[sint, ...]: return self.lazydata.shape
|
||||
def shape(self) -> tuple[sint, ...]: return self.uop.shape
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType: return self.lazydata.dtype
|
||||
def dtype(self) -> DType: return self.uop.dtype
|
||||
|
||||
# ***** data handlers ****
|
||||
|
||||
@@ -226,7 +226,7 @@ class Tensor(MathTrait):
|
||||
|
||||
NOTE: Kernelize can be called multiple times on a Tensor
|
||||
"""
|
||||
big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst])
|
||||
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
||||
|
||||
# verify Tensors match the spec
|
||||
if __debug__: type_verify(list(big_sink.toposort()), tensor_uop_spec)
|
||||
@@ -243,7 +243,7 @@ class Tensor(MathTrait):
|
||||
"""
|
||||
st = time.perf_counter()
|
||||
self.kernelize(*lst)
|
||||
sink = UOp.sink(*[x.lazydata for x in (self,)+lst])
|
||||
sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
||||
|
||||
# remove all ASSIGNs, after scheduling, the tensors are just buffers
|
||||
remove_assign_map = {u:u.buf_uop for u in sink.toposort() if u.op is Ops.ASSIGN}
|
||||
@@ -272,31 +272,31 @@ class Tensor(MathTrait):
|
||||
"""
|
||||
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
|
||||
assert self.shape == x.shape or allow_shape_mismatch, f"replace shape mismatch {self.shape} != {x.shape}"
|
||||
self.lazydata = x.lazydata
|
||||
self.uop = x.uop
|
||||
return self
|
||||
|
||||
def assign(self, x) -> Tensor:
|
||||
# TODO: this is a hack for writing to DISK. remove with working assign
|
||||
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
||||
if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
|
||||
cast(Buffer, self.contiguous().realize().lazydata.base.buffer).ensure_allocated().copyin(x._data())
|
||||
cast(Buffer, self.contiguous().realize().uop.base.buffer).ensure_allocated().copyin(x._data())
|
||||
return self
|
||||
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
|
||||
if self.lazydata is x.lazydata: return self # a self assign is a NOOP
|
||||
if self.uop is x.uop: return self # a self assign is a NOOP
|
||||
# NOTE: we allow cross device assign
|
||||
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
||||
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
||||
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
||||
self.lazydata = self.lazydata.assign(x.lazydata)
|
||||
self.uop = self.uop.assign(x.uop)
|
||||
return self
|
||||
|
||||
def detach(self) -> Tensor:
|
||||
"""
|
||||
Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
|
||||
"""
|
||||
return Tensor(self.lazydata.detach(), device=self.device, requires_grad=False)
|
||||
return Tensor(self.uop.detach(), device=self.device, requires_grad=False)
|
||||
|
||||
def _buffer(self) -> Buffer: return cast(Buffer, self.cast(self.dtype.base).contiguous().to("CPU").realize().lazydata.base.buffer)
|
||||
def _buffer(self) -> Buffer: return cast(Buffer, self.cast(self.dtype.base).contiguous().to("CPU").realize().uop.base.buffer)
|
||||
def _data(self) -> memoryview: return self._buffer().as_buffer()
|
||||
|
||||
def data(self) -> memoryview:
|
||||
@@ -372,7 +372,7 @@ class Tensor(MathTrait):
|
||||
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
||||
if device == self.device: return self
|
||||
if not isinstance(device, str): return self.shard(device)
|
||||
ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
|
||||
ret = Tensor(self.uop, device, requires_grad=self.requires_grad)
|
||||
if self.grad is not None: ret.grad = self.grad.to(device)
|
||||
return ret
|
||||
|
||||
@@ -390,12 +390,12 @@ class Tensor(MathTrait):
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.empty(2, 4)
|
||||
print(t.shard((t.device, t.device), axis=1).lazydata)
|
||||
print(t.shard((t.device, t.device), axis=1).uop)
|
||||
```
|
||||
"""
|
||||
assert isinstance(self.device, str), "can't shard a MultiLazyBuffer"
|
||||
devices = tuple(Device.canonicalize(x) for x in devices)
|
||||
mlb = self.lazydata.shard(devices, self._resolve_dim(axis)) if axis is not None else self.lazydata.copy_to_device(devices)
|
||||
mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices)
|
||||
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
|
||||
|
||||
def shard_(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor:
|
||||
@@ -444,7 +444,7 @@ class Tensor(MathTrait):
|
||||
"""
|
||||
r = Tensor.empty(*shape, **kwargs)
|
||||
assert isinstance(r.device, str)
|
||||
cast(Buffer, r.lazydata.buffer).allocate(external_ptr=ptr)
|
||||
cast(Buffer, r.uop.buffer).allocate(external_ptr=ptr)
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
@@ -719,7 +719,7 @@ class Tensor(MathTrait):
|
||||
dtype = kwargs.pop("dtype", self.dtype)
|
||||
if isinstance(self.device, tuple):
|
||||
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
|
||||
return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device, self.lazydata.axis)
|
||||
return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device, self.uop.axis)
|
||||
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)
|
||||
|
||||
# ***** rng hlops *****
|
||||
@@ -923,13 +923,13 @@ class Tensor(MathTrait):
|
||||
assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
|
||||
if not (self.is_floating_point() and all(t.is_floating_point() for t in targets)): raise RuntimeError("only float Tensors have gradient")
|
||||
if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
||||
target_uops = [x.lazydata for x in targets]
|
||||
grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops))
|
||||
target_uops = [x.uop for x in targets]
|
||||
grads = compute_gradient(self.uop, gradient.uop, set(target_uops))
|
||||
ret = []
|
||||
for x in target_uops:
|
||||
if (y:=grads.get(x)) is None:
|
||||
if materialize_grads: y = x.const_like(0)
|
||||
else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}")
|
||||
else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.uop}")
|
||||
ret.append(y)
|
||||
# create returned Tensors
|
||||
return [Tensor(u, device=t.device) for t,u in zip(targets, ret)]
|
||||
@@ -944,9 +944,9 @@ class Tensor(MathTrait):
|
||||
print(t.grad.numpy())
|
||||
```
|
||||
"""
|
||||
all_uops = self.lazydata.toposort()
|
||||
all_uops = self.uop.toposort()
|
||||
tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \
|
||||
t.lazydata in all_uops and t.requires_grad]
|
||||
t.uop in all_uops and t.requires_grad]
|
||||
# clear contexts
|
||||
for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)):
|
||||
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
||||
@@ -1253,14 +1253,14 @@ class Tensor(MathTrait):
|
||||
self.realize()._getitem(indices).assign(v)
|
||||
return
|
||||
# NOTE: check that setitem target is valid first
|
||||
if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")
|
||||
if not unwrap(self.uop.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")
|
||||
if isinstance(v, get_args(ConstType)): v = Tensor(v, device=self.device, dtype=self.dtype)
|
||||
if not isinstance(v, Tensor): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
|
||||
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
|
||||
|
||||
res = self.realize()._getitem(indices, v)
|
||||
# if shapes match and data is not shared it's a copy and we assign to self
|
||||
if res.shape == self.shape and res.lazydata is not self.lazydata:
|
||||
if res.shape == self.shape and res.uop is not self.uop:
|
||||
self.assign(res).realize()
|
||||
else: # no copy, basic setitem
|
||||
v = v.cast(res.dtype)._broadcast_to(_broadcast_shape(res.shape, v.shape)).contiguous()
|
||||
|
||||
Reference in New Issue
Block a user