Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2025-03-13 09:27:15 +00:00
27 changed files with 553 additions and 171 deletions

View File

@@ -447,7 +447,7 @@ jobs:
testmoreamdbenchmark:
name: tinybox red Training Benchmark
runs-on: [self-hosted, Linux, tinybox]
timeout-minutes: 20
timeout-minutes: 30
defaults:
run:
shell: bash -o pipefail {0}

View File

@@ -180,6 +180,8 @@ jobs:
run: PYTHONPATH=. LLVM=1 LLVMOPT=0 TINY_BACKEND=1 python3 -m pytest -n auto test/test_ops.py --durations=20 || true
- name: Test in-place operations on views
run: PYTHONPATH=. TORCH_DEBUG=1 python3 extra/torch_backend/test_inplace.py
- name: Test multi-gpu
run: PYTHONPATH=. LLVM=1 GPUS=4 TORCH_DEBUG=1 python3 extra/torch_backend/test_multigpu.py
torchbackendmore:
name: Torch Backend Tests More
@@ -328,8 +330,8 @@ jobs:
run: awk '/```python/{flag=1;next}/```/{flag=0}flag' README.md > README.py && PYTHONPATH=. python README.py
- name: Run unit tests
run: PYTHONPATH="." python -m pytest -n=auto test/unit/
- name: Repo line count < 11500 lines
run: MAX_LINE_COUNT=11500 python sz.py
- name: Repo line count < 12000 lines
run: MAX_LINE_COUNT=12000 python sz.py
fuzzing:
name: Fuzzing

4
.gitignore vendored
View File

@@ -34,6 +34,8 @@ extra/datasets/open-images-v6-mlperf
extra/datasets/kits/
extra/datasets/COCO/
extra/datasets/audio*
extra/huggingface_onnx/models/*
extra/huggingface_onnx/*.yaml
extra/weights
venv
examples/**/net.*[js,json]
@@ -56,5 +58,5 @@ weights
comgr_*
*.pkl
site/
master_schedule.py
profile_stats
*.log

View File

@@ -925,7 +925,7 @@ def train_bert():
# ** hyperparameters **
BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.00011 * math.sqrt(BS/66))
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.0002 * math.sqrt(BS/96))
train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3630000 // BS)
warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", 1)
@@ -936,7 +936,7 @@ def train_bert():
save_ckpt_dir = config["SAVE_CKPT_DIR"] = getenv("SAVE_CKPT_DIR", "./ckpts")
init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR)
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**10 if dtypes.default_float == dtypes.float16 else 1.0)
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**11 if dtypes.default_float == dtypes.float16 else 1.0)
decay = config["DECAY"] = getenv("DECAY", 0.01)
epsilon = config["EPSILON"] = getenv("EPSILON", 1e-6)
poly_power = config["POLY_POWER"] = getenv("POLY_POWER", 1.0)

View File

@@ -2,9 +2,9 @@
export PYTHONPATH="."
export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78
export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
export BEAM=4 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"

View File

@@ -2,9 +2,9 @@
export PYTHONPATH="."
export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78
export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
export BEAM=4 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"

View File

@@ -3,9 +3,9 @@
export PYTHONPATH="."
export MODEL="bert"
export SUBMISSION_PLATFORM="tinybox_green"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78
export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
export BEAM=4 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"

View File

@@ -2,9 +2,9 @@
export PYTHONPATH="."
export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78
export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"

View File

@@ -2,9 +2,9 @@
export PYTHONPATH="."
export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78
export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"

View File

@@ -3,9 +3,9 @@
export PYTHONPATH="."
export MODEL="bert"
export SUBMISSION_PLATFORM="tinybox_red"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78
export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"

View File

@@ -0,0 +1,85 @@
import yaml, time, requests, argparse
from pathlib import Path
from huggingface_hub import list_models, HfApi
from tinygrad.helpers import tqdm
HUGGINGFACE_URL = "https://huggingface.co"
SKIPPED_FILES = [
"fp16", "int8", "uint8", "quantized", # numerical accuracy issues
"avx2", "arm64", "avx512", "avx512_vnni", # numerical accuracy issues
"q4", "q4f16", "bnb4", # unimplemented quantization
"model_O4", # requires non cpu ort runner and MemcpyFromHost op
"merged", # TODO implement attribute with graph type and Loop op
]
SKIPPED_REPO_PATHS = [
# Invalid model-index
"AdamCodd/vit-base-nsfw-detector",
# TODO: implement attribute with graph type and Loop op
"minishlab/potion-base-8M", "minishlab/M2V_base_output", "minishlab/potion-retrieval-32M",
# TODO: implement SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, GroupQueryAttention
"HuggingFaceTB/SmolLM2-360M-Instruct",
# TODO: implement SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, RotaryEmbedding, MultiHeadAttention
"HuggingFaceTB/SmolLM2-1.7B-Instruct",
# TODO: implmement RandomNormalLike
"stabilityai/stable-diffusion-xl-base-1.0", "stabilityai/sdxl-turbo", 'SimianLuo/LCM_Dreamshaper_v7',
# TODO: implement NonZero
"mangoapps/fb_zeroshot_mnli_onnx",
# TODO huge Concat in here with 1024 (1, 3, 32, 32) Tensors, and maybe a MOD bug with const folding
"briaai/RMBG-2.0",
]
def get_top_repos(n: int, sort: str) -> list[str]: # list["FacebookAI/xlm-roberta-large", ...]
print(f"** Getting top {n} models sorted by {sort} **")
repos = []
i = 0
for model in list_models(filter="onnx", sort=sort):
if model.id in SKIPPED_REPO_PATHS: continue
print(f"{i+1}/{n}: {model.id} ({getattr(model, sort)})")
repos.append(model.id)
i += 1
if i == n: break
return repos
def get_metadata(repos:list[str]) -> dict:
api = HfApi()
repos_metadata = {"repositories": {}}
total_size = 0
# TODO: speed head requests up with async?
for repo in tqdm(repos, desc="Getting metadata"):
files_metadata = []
model_info = api.model_info(repo)
for file in model_info.siblings:
filename = file.rfilename
if not (filename.endswith('.onnx') or filename.endswith('.onnx_data')): continue
if any(skip_str in filename for skip_str in SKIPPED_FILES): continue
head = requests.head(f"{HUGGINGFACE_URL}/{repo}/resolve/main/{filename}", allow_redirects=True)
file_size = file.size or int(head.headers.get('Content-Length', 0))
files_metadata.append({"file": filename, "size": f"{file_size/1e6:.2f}MB"})
total_size += file_size
repos_metadata["repositories"][repo] = {
"url": f"{HUGGINGFACE_URL}/{repo}",
"download_path": None,
"files": files_metadata,
}
repos_metadata['total_size'] = f"{total_size/1e9:.2f}GB"
repos_metadata['created_at'] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
return repos_metadata
if __name__ == "__main__":
sort = "downloads" # recent 30 days downloads
huggingface_onnx_dir = Path(__file__).parent
parser = argparse.ArgumentParser(description="Produces a YAML file with metadata of top huggingface onnx models")
parser.add_argument("--limit", type=int, required=True, help="Number of top repositories to process (e.g., 100)")
parser.add_argument("--output", type=str, default="huggingface_repos.yaml", help="Output YAML file name to save the report")
args = parser.parse_args()
top_repos = get_top_repos(args.limit, sort)
metadata = get_metadata(top_repos)
yaml_path = huggingface_onnx_dir / args.output
with open(yaml_path, 'w') as f:
yaml.dump(metadata, f, sort_keys=False)
print(f"YAML saved to: {str(yaml_path)}")

View File

@@ -0,0 +1,29 @@
import yaml, argparse
from pathlib import Path
from huggingface_hub import snapshot_download
def download_models(yaml_file: str, download_dir: str) -> None:
with open(yaml_file, 'r') as f: metadata = yaml.safe_load(f)
n = len(metadata["repositories"])
for i, (model_id, model_data) in enumerate(metadata["repositories"].items()):
print(f"Downloading {i+1}/{n}: {model_id}...")
allow_patterns = [file_info["file"] for file_info in model_data["files"]]
root_path = Path(snapshot_download(repo_id=model_id, allow_patterns=allow_patterns, cache_dir=download_dir))
# download configs too (the sizes are small)
snapshot_download(repo_id=model_id, allow_patterns=["*config.json"], cache_dir=download_dir)
print(f"Downloaded model files to: {root_path}")
model_data["download_path"] = str(root_path)
# Save the updated metadata back to the YAML file
with open(yaml_file, 'w') as f: yaml.dump(metadata, f, sort_keys=False)
print("Download completed according to YAML file.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download models from Huggingface Hub based on a YAML configuration file.")
parser.add_argument("input", type=str, help="Path to the input YAML configuration file containing model information.")
args = parser.parse_args()
models_folder = Path(__file__).parent / "models"
models_folder.mkdir(parents=True, exist_ok=True)
download_models(args.input, str(models_folder))

View File

@@ -0,0 +1,136 @@
import onnx, yaml, tempfile, time, collections, pprint, argparse, json
from pathlib import Path
from extra.onnx import OnnxRunner, get_onnx_ops
from extra.onnx_helpers import validate, get_example_inputs
def get_config(root_path: Path):
ret = {}
for path in root_path.rglob("*config.json"):
config = json.load(path.open())
if isinstance(config, dict):
ret.update(config)
return ret
def run_huggingface_validate(onnx_model_path, config, rtol, atol):
onnx_model = onnx.load(onnx_model_path)
onnx_runner = OnnxRunner(onnx_model)
inputs = get_example_inputs(onnx_runner.graph_inputs, config)
validate(onnx_model_path, inputs, rtol=rtol, atol=atol)
def get_tolerances(file_name): # -> rtol, atol
# TODO very high rtol atol
if "fp16" in file_name: return 9e-2, 9e-2
if any(q in file_name for q in ["int8", "uint8", "quantized"]): return 4, 4
return 4e-3, 3e-2
def validate_repos(models:dict[str, tuple[Path, Path]]):
print(f"** Validating {len(model_paths)} models **")
for model_id, (root_path, relative_path) in models.items():
print(f"validating model {model_id}")
model_path = root_path / relative_path
onnx_file_name = model_path.stem
config = get_config(root_path)
rtol, atol = get_tolerances(onnx_file_name)
st = time.time()
run_huggingface_validate(model_path, config, rtol, atol)
et = time.time() - st
print(f"passed, took {et:.2f}s")
def retrieve_op_stats(models:dict[str, tuple[Path, Path]]) -> dict:
ret = {}
op_counter = collections.Counter()
unsupported_ops = collections.defaultdict(set)
supported_ops = get_onnx_ops()
print(f"** Retrieving stats from {len(model_paths)} models **")
for model_id, (root_path, relative_path) in models.items():
print(f"examining {model_id}")
model_path = root_path / relative_path
onnx_runner = OnnxRunner(onnx.load(model_path))
for node in onnx_runner.graph_nodes:
op_counter[node.op] += 1
if node.op not in supported_ops:
unsupported_ops[node.op].add(model_id)
del onnx_runner
ret["unsupported_ops"] = {k:list(v) for k, v in unsupported_ops.items()}
ret["op_counter"] = op_counter.most_common()
return ret
def debug_run(model_path, truncate, config, rtol, atol):
if truncate != -1:
model = onnx.load(model_path)
nodes_up_to_limit = list(model.graph.node)[:truncate + 1]
new_output_values = [onnx.helper.make_empty_tensor_value_info(output_name) for output_name in nodes_up_to_limit[-1].output]
model.graph.ClearField("node")
model.graph.node.extend(nodes_up_to_limit)
model.graph.ClearField("output")
model.graph.output.extend(new_output_values)
with tempfile.NamedTemporaryFile(suffix=model_path.suffix) as tmp:
onnx.save(model, tmp.name)
run_huggingface_validate(tmp.name, config, rtol, atol)
else:
run_huggingface_validate(model_path, config, rtol, atol)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Huggingface ONNX Model Validator and Ops Checker")
parser.add_argument("input", type=str, help="Path to the input YAML configuration file containing model information.")
parser.add_argument("--check_ops", action="store_true", default=False,
help="Check support for ONNX operations in models from the YAML file")
parser.add_argument("--validate", action="store_true", default=False,
help="Validate correctness of models from the YAML file")
parser.add_argument("--debug", type=str, default="",
help="""Validates without explicitly needing a YAML or models pre-installed.
provide repo id (e.g. "minishlab/potion-base-8M") to validate all onnx models inside the repo
provide onnx model path (e.g. "minishlab/potion-base-8M/onnx/model.onnx") to validate only that one model
""")
parser.add_argument("--truncate", type=int, default=-1, help="Truncate the ONNX model so intermediate results can be validated")
args = parser.parse_args()
if not (args.check_ops or args.validate or args.debug):
parser.error("Please provide either --validate, --check_ops, or --debug.")
if args.truncate != -1 and not args.debug:
parser.error("--truncate and --debug should be used together for debugging")
if args.check_ops or args.validate:
with open(args.input, 'r') as f:
data = yaml.safe_load(f)
assert all(repo["download_path"] is not None for repo in data["repositories"].values()), "please run `download_models.py` for this yaml"
model_paths = {
model_id + "/" + model["file"]: (Path(repo["download_path"]), Path(model["file"]))
for model_id, repo in data["repositories"].items()
for model in repo["files"]
if model["file"].endswith(".onnx")
}
if args.check_ops:
pprint.pprint(retrieve_op_stats(model_paths))
if args.validate:
validate_repos(model_paths)
if args.debug:
from huggingface_hub import snapshot_download
download_dir = Path(__file__).parent / "models"
path:list[str] = args.debug.split("/")
if len(path) == 2:
# repo id
# validates all onnx models inside repo
repo_id = "/".join(path)
root_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=["*.onnx", ".onnx_data"], cache_dir=download_dir))
snapshot_download(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=download_dir)
config = get_config(root_path)
for onnx_model in root_path.rglob("*.onnx"):
rtol, atol = get_tolerances(onnx_model.name)
print(f"validating {onnx_model.relative_to(root_path)} with truncate={args.truncate}, {rtol=}, {atol=}")
debug_run(onnx_model, -1, config, rtol, atol)
else:
# model id
# only validate the specified onnx model
onnx_model = path[-1]
assert path[-1].endswith(".onnx")
repo_id, relative_path = "/".join(path[:2]), "/".join(path[2:])
root_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=[relative_path], cache_dir=download_dir))
snapshot_download(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=download_dir)
config = get_config(root_path)
rtol, atol = get_tolerances(onnx_model)
print(f"validating {relative_path} with truncate={args.truncate}, {rtol=}, {atol=}")
debug_run(root_path / relative_path, args.truncate, config, rtol, atol)

View File

@@ -5,12 +5,43 @@ import onnx
import numpy as np
import onnxruntime as ort
def get_example_inputs(graph_inputs:dict[str, OnnxValue]):
def get_example_inputs(graph_inputs:dict[str, OnnxValue], config={}):
def _get_shape(onnx_shape: tuple[str|int]):
shape = []
for onnx_dim in onnx_shape:
match onnx_dim:
case int(): shape.append(onnx_dim)
case "width" | "height":
size = config.get("size", {})
shape.append(size) if isinstance(size, int) else shape.append(size.get(onnx_dim, 224))
case "sequence" | "sequence_length" | "decoder_sequence_length": shape.append(64)
case "encoder_sequence_length": shape.append(config.get("nb_max_frames", 64))
case "past_decoder_sequence_length" | "encoder_sequence_length_out": shape.append(64)
case "encoder_sequence_length / 2": shape.append(32)
case "batch_size": shape.append(1)
case "num_channels": shape.append(config.get("in_channels", 3))
case "num_channels_latent": shape.append(config.get("latent_channels", 4))
case "height_latent" | "width_latent": shape.append(config.get("sample_size", 1024) // 8)
case "feature_size": shape.append(config.get("num_mel_bins", 128))
case _: shape.append(1)
return shape
def _get_value(name, shape, dtype):
match name:
case "input_ids":
vocab_size = config.get("text_config", {}).get("vocab_size") or config.get("vocab_size", 32)
val = np.random.randint(0, vocab_size-1, shape)
case "attention_mask": val = np.random.randint(0, 2, size=shape)
case "token_type_ids": val = np.random.randint(0, config.get("type_vocab_size", 2), shape)
case "image_tensor": val = np.random.randint(0, 256, shape)
case "task_id": return Tensor(0, dtype=dtype)
case _: val = np.random.uniform(size=shape) * 8
return Tensor(val.astype(_to_np_dtype(dtype))).realize()
ret: dict[str, Tensor] = {}
for name, spec in graph_inputs.items():
assert not spec.is_optional and not spec.is_sequence, "only allow tensor input for now"
shape = tuple(dim if isinstance(dim, int) else 1 for dim in spec.shape)
value = Tensor(np.random.uniform(size=shape).astype(_to_np_dtype(spec.dtype)) * 8).realize()
shape = _get_shape(spec.shape)
value = _get_value(name, shape, spec.dtype)
ret.update({name:value})
return ret

View File

@@ -2,7 +2,7 @@
# A001 Variable `input` is shadowing a Python builtin
# A002 Function argument `input` is shadowing a Python builtin
# A006 Lambda argument `input` is shadowing a Python builtin
from tinygrad import Tensor, dtypes
from tinygrad import Tensor, dtypes, Device
from tinygrad.helpers import getenv, prod
import torch.lib
TORCH_DEBUG = getenv("TORCH_DEBUG")
@@ -12,9 +12,12 @@ from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype
# https://pytorch.org/docs/stable/torch.compiler_ir.html
def _from_torch_device(device: torch.device): return f"{Device.DEFAULT}:{device.index or 0}"
def _to_torch_device(device: str): return torch.device("tiny", int(device.partition(":")[2] or 0))
import torch.utils.cpp_extension
mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[str(pathlib.Path(__file__).parent / "wrapped_tensor.cpp")])
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, _to_torch_dtype(x.dtype))
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, _to_torch_dtype(x.dtype), _to_torch_device(x.device).index)
def unwrap(x:torch.Tensor) -> Tensor:
assert isinstance(x, torch.Tensor), f"x isn't {type(x)}"
return mod.unwrap(x)
@@ -24,6 +27,7 @@ class TinyBackend:
def current_device(self): return 0
def _is_in_bad_fork(self): return False
def manual_seed_all(self, seed: int): Tensor.manual_seed(seed)
def device_count(self): return getenv("GPUS", 1) # TODO: device count in tiny?
torch.utils.rename_privateuse1_backend("tiny")
torch._register_device_module("tiny", TinyBackend())
torch.utils.generate_methods_for_privateuse1_backend()
@@ -31,7 +35,7 @@ torch.utils.generate_methods_for_privateuse1_backend()
# in place operations with views
def is_view(self: torch.Tensor) -> bool: return getattr(self, "_base", None) is not None
def realize_with_views(self: torch.Tensor, views: list[torch.Tensor]):
assert self.device.type == "tiny"
assert self.is_tiny
self = unwrap(self)
if not self.lazydata.st.contiguous: raise ValueError("base of view must be contiguous") # TODO: support?
self.replace(self.clone().realize())
@@ -63,18 +67,18 @@ def inplace_fn(outvars: str|list[str]):
@torch.library.impl("aten::masked_select", "privateuseone")
def masked_select(self, mask):
# err, bad
return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()]))
return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()], device=_from_torch_device(self.device)))
@torch.library.impl("aten::_index_put_impl_", "privateuseone")
@inplace_fn("self")
def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False):
# TODO: move to tinygrad
ret = aten._index_put_impl_(self.cpu(), [x.cpu() if isinstance(x, torch.Tensor) else None for x in indices], values.cpu(), accumulate, unsafe).tiny()
ret = aten._index_put_impl_(self.cpu(), [x.cpu() if isinstance(x, torch.Tensor) else None for x in indices], values.cpu(), accumulate, unsafe).to(self.device)
return wrap(unwrap(self).assign(unwrap(ret)))
@torch.library.impl("aten::index.Tensor", "privateuseone")
def index_tensor(x, y):
return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).tiny()
return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).to(x.device)
@torch.library.impl("aten::randperm.generator_out", "privateuseone")
def randperm_generator(n, generator=None, out=None): out.copy_(torch.randperm(n, generator=generator, device="cpu").tiny())
@@ -121,13 +125,13 @@ def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
@torch.library.impl("aten::empty_strided", "privateuseone")
def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False):
if TORCH_DEBUG: print(f"empty_strided {size=} {stride=} {dtype=} {layout=} {device=} {pin_memory=}")
ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype))
ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype), device=_from_torch_device(device))
return wrap(ret)
@torch.library.impl("aten::empty.memory_format", "privateuseone")
def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
if TORCH_DEBUG: print(f"empty.memory_format {size=} {dtype=} {layout=} {device=} {pin_memory=} {memory_format=}")
ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())).contiguous()
ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype or torch.get_default_dtype()), device=_from_torch_device(device)).contiguous()
return wrap(ret)
@torch.library.impl("aten::max_pool2d_with_indices", "privateuseone")
@@ -170,7 +174,7 @@ def convolution_overrideable(input, weight, bias, stride, padding, dilation, tra
def convolution_backward_overrideable(grad_out, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask):
if TORCH_DEBUG >= 1:
print(f"convolution_backward {input.shape=} {weight.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
grad_out, input, weight, bias = unwrap(grad_out), unwrap(input), unwrap(weight), Tensor.zeros(weight.shape[0])
grad_out, input, weight, bias = unwrap(grad_out), unwrap(input), unwrap(weight), Tensor.zeros(weight.shape[0], device=_from_torch_device(weight.device))
out = Tensor.conv2d(input, weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding)
grads = out.gradient(*[t for t,m in zip([input, weight, bias], output_mask) if m], gradient=grad_out)
return tuple([wrap(grads.pop(0)) if m else None for m in output_mask])
@@ -183,17 +187,19 @@ for i,pre in enumerate(["", "bi", "tri"]):
@torch.library.impl("aten::_copy_from", "privateuseone")
def _copy_from(src: torch.Tensor, dest, non_blocking=False):
realize = str(dest.device) == "tiny" and maybe_realize_storage(dest)
realize = dest.is_tiny and maybe_realize_storage(dest)
cast_dtype = _from_torch_dtype(dest.dtype)
if str(src.device) == "tiny" and str(dest.device) == "tiny":
unwrap(dest).assign(unwrap(src).cast(cast_dtype))
if src.is_tiny and dest.is_tiny:
to_device = _from_torch_device(dest.device)
unwrap(dest).assign(unwrap(src).cast(cast_dtype).to(to_device))
if realize: Tensor.realize(unwrap(dest))
elif str(src.device) == "tiny" and str(dest.device) == "cpu":
elif src.is_tiny and dest.is_cpu:
# TODO: is there a better way?
dest.resize_(src.numel()).resize_(src.shape)
dest.copy_(torch.from_numpy(unwrap(src).cast(cast_dtype).numpy()))
elif str(src.device) == "cpu" and str(dest.device) == "tiny":
unwrap(dest).assign(Tensor(src.numpy()).cast(cast_dtype))
elif src.is_cpu and dest.is_tiny:
to_device = _from_torch_device(dest.device)
unwrap(dest).assign(Tensor(src.numpy()).cast(cast_dtype).to(to_device))
if realize: Tensor.realize(unwrap(dest))
else:
raise NotImplementedError(f"can't copy from {src.device} -> {dest.device}")
@@ -341,6 +347,7 @@ def wrap_out(f):
assigned = f(*args, **kwargs)
if getenv("ALLOW_DTYPE_MISMATCH", 1): assigned = assigned.cast(out.dtype)
assert out.shape == assigned.shape, f"shape mismatch: {assigned.shape} -> {out.shape}"
assert out.device == assigned.device, f"device mismatch: {assigned.device} -> {out.device}"
assert out.dtype == assigned.dtype, f"dtype mismatch: {assigned.dtype} -> {out.dtype}"
if out.lazydata.is_realized: assigned = assigned.contiguous() # TODO: how does this map to torch's semantics
return out.assign(assigned)
@@ -462,7 +469,7 @@ def get_real_tinygrad_buffers():
res = set()
for mod in _torch_modules_with_buffers:
for _,b in mod.named_buffers(recurse=False):
if b is not None and str(b.device) == "tiny":
if b is not None and b.is_tiny:
res.add(unwrap(b))
return res
torch.nn.modules.module.register_module_buffer_registration_hook(register_torch_buffer)
@@ -476,7 +483,7 @@ def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs):
for state_dict in optimizer.state.values():
for _, value in state_dict.items():
if torch.is_tensor(value): tinygrad_tensors.append(value)
real_tinygrad_tensors = [unwrap(x) for x in tinygrad_tensors if str(x.device) == "tiny"]
real_tinygrad_tensors = [unwrap(x) for x in tinygrad_tensors if x.is_tiny]
real_tinygrad_tensors += get_real_tinygrad_buffers()
if len(real_tinygrad_tensors): Tensor.realize(*real_tinygrad_tensors)

View File

@@ -0,0 +1,29 @@
import unittest
from tinygrad.helpers import getenv
import torch
import tinygrad.frontend.torch
torch.set_default_device("tiny")
import numpy as np
@unittest.skipIf(getenv("GPUS",1)<=1, "only single GPU")
class TestTorchBackendMultiGPU(unittest.TestCase):
def test_transfer(self):
a = torch.Tensor([[1,2],[3,4]]).to("tiny:0")
b = torch.Tensor([[3,2],[1,0]]).to("tiny:1")
self.assertNotEqual(a.device, b.device)
np.testing.assert_array_equal(a.cpu(), a.to("tiny:1").cpu())
np.testing.assert_array_equal(b.cpu(), b.to("tiny:1").cpu())
def test_basic_ops(self):
a = torch.Tensor([[1,2],[3,4]]).to("tiny:0")
b = torch.Tensor([[3,2],[1,0]]).to("tiny:1")
c1 = a + b.to("tiny:0")
c2 = b + a.to("tiny:1")
np.testing.assert_array_equal(c1.cpu(), torch.full((2,2),4).cpu())
np.testing.assert_array_equal(c1.cpu(), c2.cpu())
# TODO: torch.distributed functions
if __name__ == "__main__":
unittest.main()

View File

@@ -109,7 +109,7 @@ int register_hook() {
}
int temp_register_hook = register_hook();
at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype) {
at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype, c10::DeviceIndex device_index) {
// TODO: we have to get the dtype and the shape from the tinygrad Tensor
std::vector<int64_t> sizes = py_obj.attr("shape").cast<std::vector<int64_t>>();
@@ -127,7 +127,7 @@ at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype) {
return at::detail::make_tensor<at::TinyOpaqueTensorImpl<std::shared_ptr<c10::SafePyObject>>>(
at::DispatchKeySet(at::DispatchKey::PrivateUse1),
c10::scalarTypeToTypeMeta(dtype),
at::Device(at::kPrivateUse1),
at::Device(at::kPrivateUse1, device_index),
std::make_shared<c10::SafePyObject>(py_obj.release().ptr(), getPyInterpreter()),
sizes, strides);
}

View File

@@ -3,7 +3,7 @@ import numpy as np
from typing import List, Callable
import torch
import warnings
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, DEVECTORIZE
from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _to_np_dtype
from tinygrad.device import is_dtype_supported
@@ -1490,6 +1490,7 @@ class TestOps(unittest.TestCase):
helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(), atol=1e-7, grad_atol=1e-7)
helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=-1), lambda x: x.logcumsumexp(-1), atol=1e-7, grad_atol=1e-7)
@unittest.skipIf(not DEVECTORIZE, "broken without DEVECTORIZE. TODO: fix this")
def test_logcumsumexp_numerical(self):
helper_test_op(None, lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(), atol=1e-7, grad_atol=1e-7, vals=[[0.0, 100.0]])

View File

@@ -2527,5 +2527,12 @@ class TestUOpBecome(unittest.TestCase):
assert b.lazydata.is_realized
assert b.lazydata.base.buffer._base is None
def test_setitem_offset(self):
a = Tensor.full((16,), 0.).contiguous().realize()
b = Tensor.full((16,), 1.).contiguous().realize()
a_view = a[4:].reshape(3, 4).shrink(((0,2),(0,2))).reshape((4,))
b.shrink(((0,4),)).assign(a_view).realize()
self.assertListEqual(b.tolist(), [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -414,7 +414,7 @@ class TestTinygrad(unittest.TestCase):
def test_tensor_dtype_errors(self):
with self.assertRaises(AttributeError): Tensor([3], dtype="typo")
with self.assertRaises(TypeError): Tensor([3], dtype=(dtypes.int,))
with self.assertRaises(AttributeError): Tensor([3], dtype=(dtypes.int,))
def test_tensor_bytes(self):
data = b"abc123"

View File

@@ -1,86 +1,83 @@
from typing import Optional, Any, Callable
import functools, operator
from typing import Optional, Any, Callable, cast
import functools, operator, itertools
from collections import defaultdict
from tinygrad.dtype import dtypes, ImageDType, PtrDType
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve
from tinygrad.ops import graph_rewrite, GroupOp
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym
from tinygrad.helpers import getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
from tinygrad.renderer import Renderer
# ***** float4/image store handling *****
def fold_expanded(ex, buf):
new_srcs = dedup(list(ex.src))
old_new_srcs = new_srcs[:]
is_load, is_image = new_srcs[0].op is Ops.LOAD, isinstance(buf.dtype, ImageDType)
# TODO: get the device from the buffer somehow
# NOTE: this can't be Device.DEFAULT because it opens devices
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None
lengths = [4] if is_image else ([8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))
# ***** load/store grouping *****
def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
# first, extract all the relevant offsets
offsets_rootsrc: defaultdict[Any, dict] = defaultdict(dict)
for i,s in enumerate(new_srcs):
idx = s.src[0].src[1]
if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
for i in range(vec.dtype.count):
idx = vec.gep(i).simplify()
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
# add gates for gated
if len(s.src[0].src) == 3: root_src = (s.src[0].src[2], root_src)
assert arg not in offsets_rootsrc[root_src], f"{offsets_rootsrc[root_src][arg]} != {i} with {len(s.src)} sources"
offsets_rootsrc[root_src][arg] = i
if mask is not None: root_src = (mask.gep(i).simplify(), root_src)
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
# then rewrite everything we can
used: set[tuple[UOp, UOp]] = set()
# the buf.dtype is always a pointer
ptrdtype = cast(PtrDType, buf.dtype)
# then rewrite everything we can into groups
ret = []
idxs: list[int|None] = [None]*vec.dtype.count
global_offset = 0
for rootsrc, offsets in offsets_rootsrc.items():
for o in offsets:
for fold_length in lengths:
if all((rootsrc,o+i) not in used and o+i in offsets for i in range(fold_length)):
load_1 = new_srcs[offsets[o]]
new_src = list(load_1.src)
oidx = new_src[0].src[1]
if oidx.divides(fold_length) is None: continue
if is_image:
# for images, we rewrite the index. it must evenly divide 4 from the above check
new_src[0] = buf.index(
UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
rootsrc[0] if isinstance(rootsrc, tuple) else None)
else:
# for non image, we upcast the index pointer
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(size=new_src[0].dtype.size, local=new_src[0].dtype.local))
# generate the folded new_srcs
if is_load:
new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
for i in range(fold_length): new_srcs[offsets[o+i]] = new_load.gep(i)
else: # vectorize the store
new_src[1] = UOp(Ops.VECTORIZE, new_src[1].dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[1] for i in range(fold_length)))
for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(Ops.STORE, dtypes.void, tuple(new_src)) if i == 0 else None
used.update((rootsrc,o+i) for i in range(fold_length))
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
for grp in grouped_offsets:
# get the index offset for this element. using [0] is okay, because they are the same
oidx = vec.gep(offsets[grp[0]][0])
lidx = UOp(Ops.INDEX, buf.dtype, (buf, oidx, rootsrc[0]) if mask is not None else (buf, oidx))
if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, local=ptrdtype.local))
# set the idxs of the output
for i,g in enumerate(grp):
for oo in offsets[g]: idxs[oo] = global_offset+i
# add this lidx to the CAT
ret.append(lidx)
global_offset += len(grp)
assert None not in idxs, f"some idxs are missing {idxs}"
# this base thing is for image, we want the CAT to be a normal pointer
return UOp(Ops.CAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret)).gep(tuple(cast(list[int], idxs)))
# dedup expand for LOAD
if is_load and len(old_new_srcs) != len(ex.src): new_srcs = [new_srcs[old_new_srcs.index(s)] for s in ex.src]
# remove Nones for STORE
return UOp(ex.op, ex.dtype, tuple(x for x in new_srcs if x is not None), ex.arg) if len(used) else None
def cat_after_store(cat:UOp, data:UOp):
# TODO: this is written in many places
offset = 0
ret = []
for s in cat.src:
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count)))))
offset += s.dtype.count
return UOp.sink(ret[0], *ret[1:])
def fix_unfoldable_image_load(load:UOp, buf:UOp):
if not isinstance(buf.dtype, ImageDType) or (oidx:=load.src[0].src[1]).dtype.count == 2: return None
id4 = oidx % 4
new_src = list(load.src)
# TODO: copied logic from above
new_src[0] = load.src[0].src[0].index(
UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
load.src[0].src[2] if len(load.src[0].src) == 3 else None)
vec_load = UOp(Ops.LOAD, load.dtype.vec(4), tuple(new_src))
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), load.const_like(float('nan')))
def gep_on_store(gep:UOp, st:UOp):
# NOTE: we need to invert the gep here, but it may be an expanding gep
# fake argsort. TODO: handle duplicates
a = {}
for i,x in enumerate(gep.arg): a[x] = i
new_arg = tuple(x[1] for x in sorted(a.items()))
return UOp(Ops.STORE, src=(gep.src[0], st.gep(new_arg)))
buf_idx_pat = UPat(Ops.INDEX, src=(UPat.var("buf"),), allow_any_len=True)
float4_folding = PatternMatcher([
(UPat(Ops.VECTORIZE, src=UPat(Ops.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
(UPat((Ops.BARRIER, Ops.SINK), src=UPat(Ops.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
load_store_folding = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL), name="buf")), UPat.var("vec"))), expand_index),
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL), name="buf")), UPat.var("vec"),
UPat.var("mask"))), expand_index),
# GEP after LOAD
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
# GEP on data of STORE
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st"))), gep_on_store),
# put CAT after LOAD
(UPat(Ops.LOAD, src=(UPat(Ops.CAT, name="cat"),), name="ld", allow_any_len=True),
lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
# put CAT after STORE
(UPat(Ops.STORE, src=(UPat(Ops.CAT, name="cat"), UPat(name="data"))), cat_after_store),
])
# ***** image load valid simplification *****
@@ -140,6 +137,64 @@ def get_late_rewrite_patterns(ops, force_transcendental=False):
if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
return PatternMatcher(pat)
# *** correct load/store ***
def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
if (sz:=ls.src[0].dtype.count) == 1: return None
lengths = []
buf = idx.src[0]
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
pass
elif isinstance(buf.dtype, ImageDType):
lengths = [4]
elif ctx is not None and ctx.supports_float4:
# TODO: a better way to get this than ctx
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2])
lengths.append(1) # worst case, it's not folded
ptrdtype = cast(PtrDType, buf.dtype)
global_offset = 0
ret = []
while global_offset < sz:
for fold_length in lengths:
if global_offset+fold_length > sz: continue
oidx = idx.src[1] + global_offset
if oidx.simplify().divides(fold_length) is None: continue
lidx = buf.index(oidx, idx.src[2] if len(idx.src) > 2 else None)
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, local=ptrdtype.local))
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
global_offset += fold_length
break
if len(ret) == 1: return None
return UOp(Ops.CAT, ls.dtype, tuple(ret))
def image_fixup(ls:UOp):
# normal image load or store, with the CAST from expand_index
if ls.src[0].op is Ops.CAST and isinstance(image_dtype:=ls.src[0].src[0].dtype, ImageDType):
assert ls.src[0].dtype.count == 4, "image must be casted to 4"
idx = ls.src[0].src[0]
oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
return ls.replace(src=(idx,)+ls.src[1:])
# this is an unprocessed image without a cast, aka unfoldable image load. this doesn't work for stores
if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].dtype != dtypes.int.vec(2):
assert ls.op is Ops.LOAD, "if an image store isn't upcasted to 4, we can't store it"
idx = ls.src[0]
id4 = idx.src[1] % 4
oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:])
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), ls.const_like(float('nan')))
return None
correct_load_store = PatternMatcher([
# split LOAD/STORE
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(Ops.CAST, src=(UPat(Ops.INDEX, name="idx"),)),), name="ls", allow_any_len=True), split_load_store),
# image indexing, including unfoldable images
(UPat((Ops.LOAD, Ops.STORE), name="ls"), image_fixup),
])
# *** uop expander ***
@@ -160,13 +215,6 @@ def no_vectorized_alu(alu):
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
return UOp(Ops.VECTORIZE, alu.dtype, alus)
def no_vectorized_load_store(ls:UOp):
idx = ls.src[0]
assert isinstance(idx.dtype, PtrDType)
if idx.dtype.v == 1: return None
tv = [UOp(ls.op, ls.dtype.scalar(), tuple(j.gep(i) for j in ls.src)) for i in range(idx.dtype.v)]
return UOp(Ops.VECTORIZE, ls.dtype, tuple(tv))
def no_vectorized_acc(acc:UOp):
if acc.dtype.count == 1: return None
alus = tuple(UOp(acc.op, acc.dtype.scalar(),
@@ -175,16 +223,9 @@ def no_vectorized_acc(acc:UOp):
devectorize = PatternMatcher([
# no ALU on vectorized dtypes
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu),
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN), name="alu"), no_vectorized_alu),
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
(UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc),
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
])
devectorize_load_store = PatternMatcher([
# TODO: add vectorized support to transcendental
(UPat((Ops.INDEX), name="alu"), no_vectorized_alu),
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
])
def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
@@ -193,8 +234,6 @@ def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|N
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val)
load_store_indexing = PatternMatcher([
# late fixup of unfoldable image loads
(UPat(Ops.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
# simplify valid
(UPat(Ops.AND, name="valid"), simplify_valid),
# image load valid idx simplification
@@ -231,12 +270,10 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
if DEVECTORIZE:
# devectorize + load_store_indexing
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing)
else:
# new devectorize only for load/store
sink = graph_rewrite(sink, sym+devectorize_load_store)
# devectorize is optional
if DEVECTORIZE >= 2: sink = graph_rewrite(sink, sym+load_store_folding+load_store_indexing, ctx=opts)
elif DEVECTORIZE: sink = graph_rewrite(sink, sym+devectorize+load_store_folding+correct_load_store+load_store_indexing, ctx=opts)
else: sink = graph_rewrite(sink, sym+load_store_folding+correct_load_store+load_store_indexing, ctx=opts)
# optional pre matcher
if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher)

View File

@@ -230,6 +230,17 @@ symbolic = symbolic_simple+PatternMatcher([
# ** mod **
# mod folding
(UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
# push all GEPs through ALUs (fix arange stuff)
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
if not isinstance(gep.dtype, PtrDType) else None),
])
symbolic_flat = symbolic+PatternMatcher([
@@ -399,17 +410,6 @@ sym = symbolic_flat+PatternMatcher([
# VECTORIZE void is SINK
(UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b),
(UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
# push all GEPs through ALUs (fix arange stuff)
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
if not isinstance(gep.dtype, PtrDType) else None),
# push some GEPs through WMMAs
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)

View File

@@ -156,7 +156,7 @@ if (env_default_float := getenv("DEFAULT_FLOAT", "")):
assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
DTypeLike = Union[str, DType]
def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype)
def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype.lower())
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
# we don't support weak type and complex type
@@ -180,7 +180,7 @@ def sum_acc_dtype(dt:DType):
# default acc dtype for sum
if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
return least_upper_dtype(dt, dtypes.float)
return least_upper_dtype(dt, to_dtype(getenv("SUM_DTYPE", "float32")))
def truncate_fp16(x):
try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]

View File

@@ -106,6 +106,9 @@ sym = symbolic_simple+PatternMatcher([
# put CAST after expanding BUFFER
(UPat(Ops.VIEW, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="v"), lambda x,v: x.view(x.st+v.st).cast(v.dtype) if getenv("CAST_AFTER_EXPAND")
and x.base.op is Ops.BUFFER and resolve(prod(v.shape) > prod(x.shape)) else None),
# remove CONST/BIND/VIEW from SINK
(UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src)
if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST,Ops.BIND}))) != x.src else None),
])
# **** UOp realization
@@ -259,11 +262,12 @@ create_kernels = merge_views+PatternMatcher([
lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None),
# walk back the local graph until we reach a buffer/assign parent
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
# remove CONST/BIND from SINK
(UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src)
if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST,Ops.BIND}))) != x.src else None),
# remove downstream reshapes from SINK
(UPat(Ops.SINK, name="x"), lambda x:x.replace(src=tuple(s.base for s in x.src)) if any(s.op is Ops.VIEW for s in x.src) else None),
])
DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK}
# **** fix kernel AST
# ** create buffer ops + enumerate buffers
@@ -273,8 +277,18 @@ add_buffer_ops = PatternMatcher([
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), x.st.to_uop()))),
# STORE (except for COPY/BUFFER_VIEW)
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
# partial assign can store to a non-contiguous ShapeTracker
(UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)),
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()),
# otherwise the store is contiguous
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
# if the last child is a VIEW we merge the ShapeTrackers and store the base
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))),
lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)),
# remove CONTIGUOUS/DEVICE from kernel AST
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
])
# ** push views to buffer ops
@@ -290,23 +304,24 @@ def swizzle_reduceop(r:UOp, src:UOp, view:UOp):
strides = strides_for_shape(rshape)
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
# update input_st and axis
# create a new reduceop for the swizzled input
new_input_st = tmp + ShapeTracker(tuple(nv))
new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg)))
return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
return UOp(Ops.REDUCE_AXIS, r.dtype, (apply_swizzle(src.view(src.arg+new_input_st if src.op is Ops.VIEW else new_input_st)),),
(r.arg[0], new_axis)).view(ShapeTracker.from_shape(st.shape))
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u)).view(ShapeTracker.from_shape(r.shape))
def elementwise_view_right(root:UOp) -> UOp|None:
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW]): return None
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in DONT_PUSH_VIEWS]): return None
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
# place view after applying the elementwise op
new_shape = swizzles[0].base.shape
ret = root.replace(src=tuple(x.base if x.base.shape == new_shape else apply_swizzle(x.view(ShapeTracker.from_shape(new_shape))) for x in root.src))
new_st = ShapeTracker.from_shape(swizzles[0].base.shape)
new_src = [x.base if x.base.shape==new_st.shape else apply_swizzle(x.view(x.arg+new_st) if x.op is Ops.VIEW else x.view(new_st)) for x in root.src]
# reshape to match downstream shapes
return ret.reshape(root.shape)
return root.replace(src=tuple(new_src)).reshape(root.shape)
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
@@ -315,17 +330,12 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
# push VIEW to children
view_right = merge_views+PatternMatcher([
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
# STORE is the last child, so we just merge the ShapeTrackers and store the base
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
# push a non contiguous ShapeTracker through reduceop
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
# apply view after reduceops
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat.var("src"),), name="v"),), name="r"), reduceop_view_right),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="src"),), name="v"),), name="r"), reduceop_view_right),
# apply view after elementwise ops
(UPat(GroupOp.All-GroupOp.Buffer, name="root"), elementwise_view_right),
(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="root"), elementwise_view_right),
# double reduce op collapses to a single reduce op
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
@@ -359,9 +369,6 @@ def check_load_st(glbl:UOp, view:UOp):
fix_kernel_ops = PatternMatcher([
# BIND in shapetracker becomes DEFINE_VAR
(UPat(Ops.VIEW, name="x"), unbind_shapetracker),
# remove CONTIGUOUS/DEVICE
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
# remove unmasked valid
(UPat.where(UPat(Ops.VALID, name="valid"), UPat.cvar("x"), UPat()), lambda valid,x: x if all(v.mask is None for v in valid.st.views) else None),
# no ImageDType after load
@@ -378,15 +385,15 @@ def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
if s.op is Ops.ASSIGN:
for out in s.src[1].arg.ast.src: parents_rep[out] = s.buf_uop.view(unwrap(out.st))
ast = k.arg.ast.substitute(parents_rep)
# add buffer ops
ast = graph_rewrite(ast, add_buffer_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True)
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
# unbind_vars + push views to edges
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
# fix_kernel_ops
ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
# add buffer ops
ast = graph_rewrite(ast, view_left+add_buffer_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True)
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
# create subbuffer (TODO: this does not belong here)
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
# fix_kernel_ops
ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
return k.replace(arg=Kernel(ast, k.arg.metadata))
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}

View File

@@ -191,8 +191,8 @@ class ClangRenderer(CStyleLanguage):
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2]}),
Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"}
# LLVM legalizes double => half cast on systems that don't support it natively (like x86 cpus without AVX512-FP16) into a compiler-rt libcall.
extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16))]) + \
CStyleLanguage.extra_matcher
extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16)),
(UPat(Ops.SQRT, name="alu"), no_vectorized_alu),]) + CStyleLanguage.extra_matcher
if sys.platform == 'win32':
kernel_prefix = "__attribute__((ms_abi)) "

View File

@@ -163,6 +163,9 @@ class AM_GFX(AM_IP):
self.adev.regTCP_CNTL.write(self.adev.regTCP_CNTL.read() | 0x20000000)
self.adev.regRLC_SRM_CNTL.update(srm_enable=1, auto_incr_addr=1)
self.adev.regS2A_DOORBELL_ENTRY_0_CTRL.write(s2a_doorbell_port0_enable=1, s2a_doorbell_port0_awid=0x3, s2a_doorbell_port0_awaddr_31_28_value=0x3)
self.adev.regS2A_DOORBELL_ENTRY_3_CTRL.write(s2a_doorbell_port3_enable=1, s2a_doorbell_port3_awid=0x6, s2a_doorbell_port3_awaddr_31_28_value=0x3)
self.adev.regGRBM_CNTL.update(read_timeout=0xff)
for i in range(0, 16):
self._grbm_select(vmid=i)
@@ -297,6 +300,9 @@ class AM_IH(AM_IP):
for _, rwptr_vm, suf, ring_id in self.rings:
self.adev.reg(f"regIH_RB_CNTL{suf}").update(rb_enable=1, **({'enable_intr': 1} if ring_id == 0 else {}))
self.adev.regS2A_DOORBELL_ENTRY_1_CTRL.update(s2a_doorbell_port1_enable=1, s2a_doorbell_port1_awid=0x0, s2a_doorbell_port1_awaddr_31_28_value=0x0,
s2a_doorbell_port1_range_offset=am.AMDGPU_NAVI10_DOORBELL_IH*2, s2a_doorbell_port1_range_size=2)
class AM_SDMA(AM_IP):
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, doorbell:int, pipe:int, queue:int):
# Setup the ring
@@ -320,6 +326,8 @@ class AM_SDMA(AM_IP):
self.adev.reg(f"regSDMA{pipe}_UTCL1_PAGE").update(rd_l2_policy=0x2, wr_l2_policy=0x3, llc_noalloc=1) # rd=noa, wr=bypass
self.adev.reg(f"regSDMA{pipe}_F32_CNTL").update(halt=0, th1_reset=0)
self.adev.reg(f"regSDMA{pipe}_CNTL").update(ctxempty_int_enable=1, trap_enable=1)
self.adev.regS2A_DOORBELL_ENTRY_2_CTRL.update(s2a_doorbell_port2_enable=1, s2a_doorbell_port2_awid=0xe, s2a_doorbell_port2_awaddr_31_28_value=0x3,
s2a_doorbell_port2_range_offset=am.AMDGPU_NAVI10_DOORBELL_sDMA_ENGINE0*2, s2a_doorbell_port2_range_size=4)
def fini(self):
self.adev.regSDMA0_QUEUE0_RB_CNTL.update(rb_enable=0)

View File

@@ -246,11 +246,12 @@ class HCQSignal(Generic[DeviceType]):
Args:
value: The value to wait for.
timeout: Maximum time to wait in milliseconds. Defaults to 10s.
timeout: Maximum time to wait in milliseconds. Defaults to 30s.
"""
start_time = int(time.perf_counter() * 1000)
while self.value < value and (time_spent:=int(time.perf_counter() * 1000) - start_time) < timeout:
while (prev_value:=self.value) < value and (time_spent:=int(time.perf_counter() * 1000) - start_time) < timeout:
self._sleep(time_spent)
if self.value != prev_value: start_time = int(time.perf_counter() * 1000) # progress was made, reset timer
if self.value < value: raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
@contextlib.contextmanager