diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 4865ffbd2d..053b66f4d8 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -447,7 +447,7 @@ jobs: testmoreamdbenchmark: name: tinybox red Training Benchmark runs-on: [self-hosted, Linux, tinybox] - timeout-minutes: 20 + timeout-minutes: 30 defaults: run: shell: bash -o pipefail {0} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 337cbf6422..a0b4038507 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -180,6 +180,8 @@ jobs: run: PYTHONPATH=. LLVM=1 LLVMOPT=0 TINY_BACKEND=1 python3 -m pytest -n auto test/test_ops.py --durations=20 || true - name: Test in-place operations on views run: PYTHONPATH=. TORCH_DEBUG=1 python3 extra/torch_backend/test_inplace.py + - name: Test multi-gpu + run: PYTHONPATH=. LLVM=1 GPUS=4 TORCH_DEBUG=1 python3 extra/torch_backend/test_multigpu.py torchbackendmore: name: Torch Backend Tests More @@ -328,8 +330,8 @@ jobs: run: awk '/```python/{flag=1;next}/```/{flag=0}flag' README.md > README.py && PYTHONPATH=. python README.py - name: Run unit tests run: PYTHONPATH="." python -m pytest -n=auto test/unit/ - - name: Repo line count < 11500 lines - run: MAX_LINE_COUNT=11500 python sz.py + - name: Repo line count < 12000 lines + run: MAX_LINE_COUNT=12000 python sz.py fuzzing: name: Fuzzing diff --git a/.gitignore b/.gitignore index e94ef15d7e..b3e0505f49 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,8 @@ extra/datasets/open-images-v6-mlperf extra/datasets/kits/ extra/datasets/COCO/ extra/datasets/audio* +extra/huggingface_onnx/models/* +extra/huggingface_onnx/*.yaml extra/weights venv examples/**/net.*[js,json] @@ -56,5 +58,5 @@ weights comgr_* *.pkl site/ -master_schedule.py profile_stats +*.log diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index ad4d5887f6..7d16b46fe0 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -925,7 +925,7 @@ def train_bert(): # ** hyperparameters ** BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS)) EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS)) - max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.00011 * math.sqrt(BS/66)) + max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.0002 * math.sqrt(BS/96)) train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3630000 // BS) warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", 1) @@ -936,7 +936,7 @@ def train_bert(): save_ckpt_dir = config["SAVE_CKPT_DIR"] = getenv("SAVE_CKPT_DIR", "./ckpts") init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR) - loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**10 if dtypes.default_float == dtypes.float16 else 1.0) + loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**11 if dtypes.default_float == dtypes.float16 else 1.0) decay = config["DECAY"] = getenv("DECAY", 0.01) epsilon = config["EPSILON"] = getenv("EPSILON", 1e-6) poly_power = config["POLY_POWER"] = getenv("POLY_POWER", 1.0) diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh index 6f17109784..f2f1cb8e45 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh @@ -2,9 +2,9 @@ export PYTHONPATH="." export MODEL="bert" -export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78 +export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96 -export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 +export BEAM=4 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh index 05c5a75619..269f478428 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh @@ -2,9 +2,9 @@ export PYTHONPATH="." export MODEL="bert" -export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78 +export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96 -export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 +export BEAM=4 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh index 6a77928d89..d9aa9eddfe 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh @@ -3,9 +3,9 @@ export PYTHONPATH="." export MODEL="bert" export SUBMISSION_PLATFORM="tinybox_green" -export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78 +export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96 -export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 +export BEAM=4 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh index f72acd8942..bd32390b17 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh @@ -2,9 +2,9 @@ export PYTHONPATH="." export MODEL="bert" -export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78 +export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96 -export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 +export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh index 503c91aa93..6b2c4e6925 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh @@ -2,9 +2,9 @@ export PYTHONPATH="." export MODEL="bert" -export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78 +export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96 -export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 +export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh index ec2554b25f..caa380fc19 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh @@ -3,9 +3,9 @@ export PYTHONPATH="." export MODEL="bert" export SUBMISSION_PLATFORM="tinybox_red" -export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78 +export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96 -export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 +export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/extra/huggingface_onnx/collect_metadata.py b/extra/huggingface_onnx/collect_metadata.py new file mode 100644 index 0000000000..0a3d1eda21 --- /dev/null +++ b/extra/huggingface_onnx/collect_metadata.py @@ -0,0 +1,85 @@ +import yaml, time, requests, argparse +from pathlib import Path +from huggingface_hub import list_models, HfApi +from tinygrad.helpers import tqdm + +HUGGINGFACE_URL = "https://huggingface.co" +SKIPPED_FILES = [ + "fp16", "int8", "uint8", "quantized", # numerical accuracy issues + "avx2", "arm64", "avx512", "avx512_vnni", # numerical accuracy issues + "q4", "q4f16", "bnb4", # unimplemented quantization + "model_O4", # requires non cpu ort runner and MemcpyFromHost op + "merged", # TODO implement attribute with graph type and Loop op +] +SKIPPED_REPO_PATHS = [ + # Invalid model-index + "AdamCodd/vit-base-nsfw-detector", + # TODO: implement attribute with graph type and Loop op + "minishlab/potion-base-8M", "minishlab/M2V_base_output", "minishlab/potion-retrieval-32M", + # TODO: implement SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, GroupQueryAttention + "HuggingFaceTB/SmolLM2-360M-Instruct", + # TODO: implement SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, RotaryEmbedding, MultiHeadAttention + "HuggingFaceTB/SmolLM2-1.7B-Instruct", + # TODO: implmement RandomNormalLike + "stabilityai/stable-diffusion-xl-base-1.0", "stabilityai/sdxl-turbo", 'SimianLuo/LCM_Dreamshaper_v7', + # TODO: implement NonZero + "mangoapps/fb_zeroshot_mnli_onnx", + # TODO huge Concat in here with 1024 (1, 3, 32, 32) Tensors, and maybe a MOD bug with const folding + "briaai/RMBG-2.0", +] + +def get_top_repos(n: int, sort: str) -> list[str]: # list["FacebookAI/xlm-roberta-large", ...] + print(f"** Getting top {n} models sorted by {sort} **") + repos = [] + i = 0 + for model in list_models(filter="onnx", sort=sort): + if model.id in SKIPPED_REPO_PATHS: continue + print(f"{i+1}/{n}: {model.id} ({getattr(model, sort)})") + repos.append(model.id) + i += 1 + if i == n: break + return repos + +def get_metadata(repos:list[str]) -> dict: + api = HfApi() + repos_metadata = {"repositories": {}} + total_size = 0 + + # TODO: speed head requests up with async? + for repo in tqdm(repos, desc="Getting metadata"): + files_metadata = [] + model_info = api.model_info(repo) + + for file in model_info.siblings: + filename = file.rfilename + if not (filename.endswith('.onnx') or filename.endswith('.onnx_data')): continue + if any(skip_str in filename for skip_str in SKIPPED_FILES): continue + head = requests.head(f"{HUGGINGFACE_URL}/{repo}/resolve/main/{filename}", allow_redirects=True) + file_size = file.size or int(head.headers.get('Content-Length', 0)) + files_metadata.append({"file": filename, "size": f"{file_size/1e6:.2f}MB"}) + total_size += file_size + + repos_metadata["repositories"][repo] = { + "url": f"{HUGGINGFACE_URL}/{repo}", + "download_path": None, + "files": files_metadata, + } + repos_metadata['total_size'] = f"{total_size/1e9:.2f}GB" + repos_metadata['created_at'] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + return repos_metadata + +if __name__ == "__main__": + sort = "downloads" # recent 30 days downloads + huggingface_onnx_dir = Path(__file__).parent + + parser = argparse.ArgumentParser(description="Produces a YAML file with metadata of top huggingface onnx models") + parser.add_argument("--limit", type=int, required=True, help="Number of top repositories to process (e.g., 100)") + parser.add_argument("--output", type=str, default="huggingface_repos.yaml", help="Output YAML file name to save the report") + args = parser.parse_args() + + top_repos = get_top_repos(args.limit, sort) + metadata = get_metadata(top_repos) + yaml_path = huggingface_onnx_dir / args.output + with open(yaml_path, 'w') as f: + yaml.dump(metadata, f, sort_keys=False) + print(f"YAML saved to: {str(yaml_path)}") diff --git a/extra/huggingface_onnx/download_models.py b/extra/huggingface_onnx/download_models.py new file mode 100644 index 0000000000..e79e0e85ff --- /dev/null +++ b/extra/huggingface_onnx/download_models.py @@ -0,0 +1,29 @@ +import yaml, argparse +from pathlib import Path +from huggingface_hub import snapshot_download + +def download_models(yaml_file: str, download_dir: str) -> None: + with open(yaml_file, 'r') as f: metadata = yaml.safe_load(f) + n = len(metadata["repositories"]) + + for i, (model_id, model_data) in enumerate(metadata["repositories"].items()): + print(f"Downloading {i+1}/{n}: {model_id}...") + allow_patterns = [file_info["file"] for file_info in model_data["files"]] + root_path = Path(snapshot_download(repo_id=model_id, allow_patterns=allow_patterns, cache_dir=download_dir)) + # download configs too (the sizes are small) + snapshot_download(repo_id=model_id, allow_patterns=["*config.json"], cache_dir=download_dir) + print(f"Downloaded model files to: {root_path}") + model_data["download_path"] = str(root_path) + + # Save the updated metadata back to the YAML file + with open(yaml_file, 'w') as f: yaml.dump(metadata, f, sort_keys=False) + print("Download completed according to YAML file.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Download models from Huggingface Hub based on a YAML configuration file.") + parser.add_argument("input", type=str, help="Path to the input YAML configuration file containing model information.") + args = parser.parse_args() + + models_folder = Path(__file__).parent / "models" + models_folder.mkdir(parents=True, exist_ok=True) + download_models(args.input, str(models_folder)) \ No newline at end of file diff --git a/extra/huggingface_onnx/run_models.py b/extra/huggingface_onnx/run_models.py new file mode 100644 index 0000000000..a0acc2af72 --- /dev/null +++ b/extra/huggingface_onnx/run_models.py @@ -0,0 +1,136 @@ +import onnx, yaml, tempfile, time, collections, pprint, argparse, json +from pathlib import Path +from extra.onnx import OnnxRunner, get_onnx_ops +from extra.onnx_helpers import validate, get_example_inputs + +def get_config(root_path: Path): + ret = {} + for path in root_path.rglob("*config.json"): + config = json.load(path.open()) + if isinstance(config, dict): + ret.update(config) + return ret + +def run_huggingface_validate(onnx_model_path, config, rtol, atol): + onnx_model = onnx.load(onnx_model_path) + onnx_runner = OnnxRunner(onnx_model) + inputs = get_example_inputs(onnx_runner.graph_inputs, config) + validate(onnx_model_path, inputs, rtol=rtol, atol=atol) + +def get_tolerances(file_name): # -> rtol, atol + # TODO very high rtol atol + if "fp16" in file_name: return 9e-2, 9e-2 + if any(q in file_name for q in ["int8", "uint8", "quantized"]): return 4, 4 + return 4e-3, 3e-2 + +def validate_repos(models:dict[str, tuple[Path, Path]]): + print(f"** Validating {len(model_paths)} models **") + for model_id, (root_path, relative_path) in models.items(): + print(f"validating model {model_id}") + model_path = root_path / relative_path + onnx_file_name = model_path.stem + config = get_config(root_path) + rtol, atol = get_tolerances(onnx_file_name) + st = time.time() + run_huggingface_validate(model_path, config, rtol, atol) + et = time.time() - st + print(f"passed, took {et:.2f}s") + +def retrieve_op_stats(models:dict[str, tuple[Path, Path]]) -> dict: + ret = {} + op_counter = collections.Counter() + unsupported_ops = collections.defaultdict(set) + supported_ops = get_onnx_ops() + print(f"** Retrieving stats from {len(model_paths)} models **") + for model_id, (root_path, relative_path) in models.items(): + print(f"examining {model_id}") + model_path = root_path / relative_path + onnx_runner = OnnxRunner(onnx.load(model_path)) + for node in onnx_runner.graph_nodes: + op_counter[node.op] += 1 + if node.op not in supported_ops: + unsupported_ops[node.op].add(model_id) + del onnx_runner + ret["unsupported_ops"] = {k:list(v) for k, v in unsupported_ops.items()} + ret["op_counter"] = op_counter.most_common() + return ret + +def debug_run(model_path, truncate, config, rtol, atol): + if truncate != -1: + model = onnx.load(model_path) + nodes_up_to_limit = list(model.graph.node)[:truncate + 1] + new_output_values = [onnx.helper.make_empty_tensor_value_info(output_name) for output_name in nodes_up_to_limit[-1].output] + model.graph.ClearField("node") + model.graph.node.extend(nodes_up_to_limit) + model.graph.ClearField("output") + model.graph.output.extend(new_output_values) + with tempfile.NamedTemporaryFile(suffix=model_path.suffix) as tmp: + onnx.save(model, tmp.name) + run_huggingface_validate(tmp.name, config, rtol, atol) + else: + run_huggingface_validate(model_path, config, rtol, atol) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Huggingface ONNX Model Validator and Ops Checker") + parser.add_argument("input", type=str, help="Path to the input YAML configuration file containing model information.") + parser.add_argument("--check_ops", action="store_true", default=False, + help="Check support for ONNX operations in models from the YAML file") + parser.add_argument("--validate", action="store_true", default=False, + help="Validate correctness of models from the YAML file") + parser.add_argument("--debug", type=str, default="", + help="""Validates without explicitly needing a YAML or models pre-installed. + provide repo id (e.g. "minishlab/potion-base-8M") to validate all onnx models inside the repo + provide onnx model path (e.g. "minishlab/potion-base-8M/onnx/model.onnx") to validate only that one model + """) + parser.add_argument("--truncate", type=int, default=-1, help="Truncate the ONNX model so intermediate results can be validated") + args = parser.parse_args() + + if not (args.check_ops or args.validate or args.debug): + parser.error("Please provide either --validate, --check_ops, or --debug.") + if args.truncate != -1 and not args.debug: + parser.error("--truncate and --debug should be used together for debugging") + + if args.check_ops or args.validate: + with open(args.input, 'r') as f: + data = yaml.safe_load(f) + assert all(repo["download_path"] is not None for repo in data["repositories"].values()), "please run `download_models.py` for this yaml" + model_paths = { + model_id + "/" + model["file"]: (Path(repo["download_path"]), Path(model["file"])) + for model_id, repo in data["repositories"].items() + for model in repo["files"] + if model["file"].endswith(".onnx") + } + + if args.check_ops: + pprint.pprint(retrieve_op_stats(model_paths)) + + if args.validate: + validate_repos(model_paths) + + if args.debug: + from huggingface_hub import snapshot_download + download_dir = Path(__file__).parent / "models" + path:list[str] = args.debug.split("/") + if len(path) == 2: + # repo id + # validates all onnx models inside repo + repo_id = "/".join(path) + root_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=["*.onnx", ".onnx_data"], cache_dir=download_dir)) + snapshot_download(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=download_dir) + config = get_config(root_path) + for onnx_model in root_path.rglob("*.onnx"): + rtol, atol = get_tolerances(onnx_model.name) + print(f"validating {onnx_model.relative_to(root_path)} with truncate={args.truncate}, {rtol=}, {atol=}") + debug_run(onnx_model, -1, config, rtol, atol) + else: + # model id + # only validate the specified onnx model + onnx_model = path[-1] + assert path[-1].endswith(".onnx") + repo_id, relative_path = "/".join(path[:2]), "/".join(path[2:]) + root_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=[relative_path], cache_dir=download_dir)) + snapshot_download(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=download_dir) + config = get_config(root_path) + rtol, atol = get_tolerances(onnx_model) + print(f"validating {relative_path} with truncate={args.truncate}, {rtol=}, {atol=}") + debug_run(root_path / relative_path, args.truncate, config, rtol, atol) \ No newline at end of file diff --git a/extra/onnx_helpers.py b/extra/onnx_helpers.py index 5c3b64bd38..cce053f059 100644 --- a/extra/onnx_helpers.py +++ b/extra/onnx_helpers.py @@ -5,12 +5,43 @@ import onnx import numpy as np import onnxruntime as ort -def get_example_inputs(graph_inputs:dict[str, OnnxValue]): +def get_example_inputs(graph_inputs:dict[str, OnnxValue], config={}): + def _get_shape(onnx_shape: tuple[str|int]): + shape = [] + for onnx_dim in onnx_shape: + match onnx_dim: + case int(): shape.append(onnx_dim) + case "width" | "height": + size = config.get("size", {}) + shape.append(size) if isinstance(size, int) else shape.append(size.get(onnx_dim, 224)) + case "sequence" | "sequence_length" | "decoder_sequence_length": shape.append(64) + case "encoder_sequence_length": shape.append(config.get("nb_max_frames", 64)) + case "past_decoder_sequence_length" | "encoder_sequence_length_out": shape.append(64) + case "encoder_sequence_length / 2": shape.append(32) + case "batch_size": shape.append(1) + case "num_channels": shape.append(config.get("in_channels", 3)) + case "num_channels_latent": shape.append(config.get("latent_channels", 4)) + case "height_latent" | "width_latent": shape.append(config.get("sample_size", 1024) // 8) + case "feature_size": shape.append(config.get("num_mel_bins", 128)) + case _: shape.append(1) + return shape + def _get_value(name, shape, dtype): + match name: + case "input_ids": + vocab_size = config.get("text_config", {}).get("vocab_size") or config.get("vocab_size", 32) + val = np.random.randint(0, vocab_size-1, shape) + case "attention_mask": val = np.random.randint(0, 2, size=shape) + case "token_type_ids": val = np.random.randint(0, config.get("type_vocab_size", 2), shape) + case "image_tensor": val = np.random.randint(0, 256, shape) + case "task_id": return Tensor(0, dtype=dtype) + case _: val = np.random.uniform(size=shape) * 8 + return Tensor(val.astype(_to_np_dtype(dtype))).realize() + ret: dict[str, Tensor] = {} for name, spec in graph_inputs.items(): assert not spec.is_optional and not spec.is_sequence, "only allow tensor input for now" - shape = tuple(dim if isinstance(dim, int) else 1 for dim in spec.shape) - value = Tensor(np.random.uniform(size=shape).astype(_to_np_dtype(spec.dtype)) * 8).realize() + shape = _get_shape(spec.shape) + value = _get_value(name, shape, spec.dtype) ret.update({name:value}) return ret diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index ae94c11d92..ac8f7c7c76 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -2,7 +2,7 @@ # A001 Variable `input` is shadowing a Python builtin # A002 Function argument `input` is shadowing a Python builtin # A006 Lambda argument `input` is shadowing a Python builtin -from tinygrad import Tensor, dtypes +from tinygrad import Tensor, dtypes, Device from tinygrad.helpers import getenv, prod import torch.lib TORCH_DEBUG = getenv("TORCH_DEBUG") @@ -12,9 +12,12 @@ from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype # https://pytorch.org/docs/stable/torch.compiler_ir.html +def _from_torch_device(device: torch.device): return f"{Device.DEFAULT}:{device.index or 0}" +def _to_torch_device(device: str): return torch.device("tiny", int(device.partition(":")[2] or 0)) + import torch.utils.cpp_extension mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[str(pathlib.Path(__file__).parent / "wrapped_tensor.cpp")]) -def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, _to_torch_dtype(x.dtype)) +def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, _to_torch_dtype(x.dtype), _to_torch_device(x.device).index) def unwrap(x:torch.Tensor) -> Tensor: assert isinstance(x, torch.Tensor), f"x isn't {type(x)}" return mod.unwrap(x) @@ -24,6 +27,7 @@ class TinyBackend: def current_device(self): return 0 def _is_in_bad_fork(self): return False def manual_seed_all(self, seed: int): Tensor.manual_seed(seed) + def device_count(self): return getenv("GPUS", 1) # TODO: device count in tiny? torch.utils.rename_privateuse1_backend("tiny") torch._register_device_module("tiny", TinyBackend()) torch.utils.generate_methods_for_privateuse1_backend() @@ -31,7 +35,7 @@ torch.utils.generate_methods_for_privateuse1_backend() # in place operations with views def is_view(self: torch.Tensor) -> bool: return getattr(self, "_base", None) is not None def realize_with_views(self: torch.Tensor, views: list[torch.Tensor]): - assert self.device.type == "tiny" + assert self.is_tiny self = unwrap(self) if not self.lazydata.st.contiguous: raise ValueError("base of view must be contiguous") # TODO: support? self.replace(self.clone().realize()) @@ -63,18 +67,18 @@ def inplace_fn(outvars: str|list[str]): @torch.library.impl("aten::masked_select", "privateuseone") def masked_select(self, mask): # err, bad - return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()])) + return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()], device=_from_torch_device(self.device))) @torch.library.impl("aten::_index_put_impl_", "privateuseone") @inplace_fn("self") def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False): # TODO: move to tinygrad - ret = aten._index_put_impl_(self.cpu(), [x.cpu() if isinstance(x, torch.Tensor) else None for x in indices], values.cpu(), accumulate, unsafe).tiny() + ret = aten._index_put_impl_(self.cpu(), [x.cpu() if isinstance(x, torch.Tensor) else None for x in indices], values.cpu(), accumulate, unsafe).to(self.device) return wrap(unwrap(self).assign(unwrap(ret))) @torch.library.impl("aten::index.Tensor", "privateuseone") def index_tensor(x, y): - return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).tiny() + return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).to(x.device) @torch.library.impl("aten::randperm.generator_out", "privateuseone") def randperm_generator(n, generator=None, out=None): out.copy_(torch.randperm(n, generator=generator, device="cpu").tiny()) @@ -121,13 +125,13 @@ def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None): @torch.library.impl("aten::empty_strided", "privateuseone") def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False): if TORCH_DEBUG: print(f"empty_strided {size=} {stride=} {dtype=} {layout=} {device=} {pin_memory=}") - ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype)) + ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype), device=_from_torch_device(device)) return wrap(ret) @torch.library.impl("aten::empty.memory_format", "privateuseone") def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None): if TORCH_DEBUG: print(f"empty.memory_format {size=} {dtype=} {layout=} {device=} {pin_memory=} {memory_format=}") - ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())).contiguous() + ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype or torch.get_default_dtype()), device=_from_torch_device(device)).contiguous() return wrap(ret) @torch.library.impl("aten::max_pool2d_with_indices", "privateuseone") @@ -170,7 +174,7 @@ def convolution_overrideable(input, weight, bias, stride, padding, dilation, tra def convolution_backward_overrideable(grad_out, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask): if TORCH_DEBUG >= 1: print(f"convolution_backward {input.shape=} {weight.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}") - grad_out, input, weight, bias = unwrap(grad_out), unwrap(input), unwrap(weight), Tensor.zeros(weight.shape[0]) + grad_out, input, weight, bias = unwrap(grad_out), unwrap(input), unwrap(weight), Tensor.zeros(weight.shape[0], device=_from_torch_device(weight.device)) out = Tensor.conv2d(input, weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding) grads = out.gradient(*[t for t,m in zip([input, weight, bias], output_mask) if m], gradient=grad_out) return tuple([wrap(grads.pop(0)) if m else None for m in output_mask]) @@ -183,17 +187,19 @@ for i,pre in enumerate(["", "bi", "tri"]): @torch.library.impl("aten::_copy_from", "privateuseone") def _copy_from(src: torch.Tensor, dest, non_blocking=False): - realize = str(dest.device) == "tiny" and maybe_realize_storage(dest) + realize = dest.is_tiny and maybe_realize_storage(dest) cast_dtype = _from_torch_dtype(dest.dtype) - if str(src.device) == "tiny" and str(dest.device) == "tiny": - unwrap(dest).assign(unwrap(src).cast(cast_dtype)) + if src.is_tiny and dest.is_tiny: + to_device = _from_torch_device(dest.device) + unwrap(dest).assign(unwrap(src).cast(cast_dtype).to(to_device)) if realize: Tensor.realize(unwrap(dest)) - elif str(src.device) == "tiny" and str(dest.device) == "cpu": + elif src.is_tiny and dest.is_cpu: # TODO: is there a better way? dest.resize_(src.numel()).resize_(src.shape) dest.copy_(torch.from_numpy(unwrap(src).cast(cast_dtype).numpy())) - elif str(src.device) == "cpu" and str(dest.device) == "tiny": - unwrap(dest).assign(Tensor(src.numpy()).cast(cast_dtype)) + elif src.is_cpu and dest.is_tiny: + to_device = _from_torch_device(dest.device) + unwrap(dest).assign(Tensor(src.numpy()).cast(cast_dtype).to(to_device)) if realize: Tensor.realize(unwrap(dest)) else: raise NotImplementedError(f"can't copy from {src.device} -> {dest.device}") @@ -341,6 +347,7 @@ def wrap_out(f): assigned = f(*args, **kwargs) if getenv("ALLOW_DTYPE_MISMATCH", 1): assigned = assigned.cast(out.dtype) assert out.shape == assigned.shape, f"shape mismatch: {assigned.shape} -> {out.shape}" + assert out.device == assigned.device, f"device mismatch: {assigned.device} -> {out.device}" assert out.dtype == assigned.dtype, f"dtype mismatch: {assigned.dtype} -> {out.dtype}" if out.lazydata.is_realized: assigned = assigned.contiguous() # TODO: how does this map to torch's semantics return out.assign(assigned) @@ -462,7 +469,7 @@ def get_real_tinygrad_buffers(): res = set() for mod in _torch_modules_with_buffers: for _,b in mod.named_buffers(recurse=False): - if b is not None and str(b.device) == "tiny": + if b is not None and b.is_tiny: res.add(unwrap(b)) return res torch.nn.modules.module.register_module_buffer_registration_hook(register_torch_buffer) @@ -476,7 +483,7 @@ def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs): for state_dict in optimizer.state.values(): for _, value in state_dict.items(): if torch.is_tensor(value): tinygrad_tensors.append(value) - real_tinygrad_tensors = [unwrap(x) for x in tinygrad_tensors if str(x.device) == "tiny"] + real_tinygrad_tensors = [unwrap(x) for x in tinygrad_tensors if x.is_tiny] real_tinygrad_tensors += get_real_tinygrad_buffers() if len(real_tinygrad_tensors): Tensor.realize(*real_tinygrad_tensors) diff --git a/extra/torch_backend/test_multigpu.py b/extra/torch_backend/test_multigpu.py new file mode 100644 index 0000000000..9a21898132 --- /dev/null +++ b/extra/torch_backend/test_multigpu.py @@ -0,0 +1,29 @@ +import unittest +from tinygrad.helpers import getenv +import torch +import tinygrad.frontend.torch +torch.set_default_device("tiny") +import numpy as np + +@unittest.skipIf(getenv("GPUS",1)<=1, "only single GPU") +class TestTorchBackendMultiGPU(unittest.TestCase): + def test_transfer(self): + a = torch.Tensor([[1,2],[3,4]]).to("tiny:0") + b = torch.Tensor([[3,2],[1,0]]).to("tiny:1") + self.assertNotEqual(a.device, b.device) + np.testing.assert_array_equal(a.cpu(), a.to("tiny:1").cpu()) + np.testing.assert_array_equal(b.cpu(), b.to("tiny:1").cpu()) + + def test_basic_ops(self): + a = torch.Tensor([[1,2],[3,4]]).to("tiny:0") + b = torch.Tensor([[3,2],[1,0]]).to("tiny:1") + c1 = a + b.to("tiny:0") + c2 = b + a.to("tiny:1") + np.testing.assert_array_equal(c1.cpu(), torch.full((2,2),4).cpu()) + np.testing.assert_array_equal(c1.cpu(), c2.cpu()) + + # TODO: torch.distributed functions + +if __name__ == "__main__": + unittest.main() + diff --git a/extra/torch_backend/wrapped_tensor.cpp b/extra/torch_backend/wrapped_tensor.cpp index 658dc41597..bdf3e926b6 100644 --- a/extra/torch_backend/wrapped_tensor.cpp +++ b/extra/torch_backend/wrapped_tensor.cpp @@ -109,7 +109,7 @@ int register_hook() { } int temp_register_hook = register_hook(); -at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype) { +at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype, c10::DeviceIndex device_index) { // TODO: we have to get the dtype and the shape from the tinygrad Tensor std::vector sizes = py_obj.attr("shape").cast>(); @@ -127,7 +127,7 @@ at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype) { return at::detail::make_tensor>>( at::DispatchKeySet(at::DispatchKey::PrivateUse1), c10::scalarTypeToTypeMeta(dtype), - at::Device(at::kPrivateUse1), + at::Device(at::kPrivateUse1, device_index), std::make_shared(py_obj.release().ptr(), getPyInterpreter()), sizes, strides); } diff --git a/test/test_ops.py b/test/test_ops.py index 61a86873a8..be1699391e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -3,7 +3,7 @@ import numpy as np from typing import List, Callable import torch import warnings -from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL +from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, DEVECTORIZE from tinygrad import Tensor, Device, dtypes from tinygrad.tensor import _to_np_dtype from tinygrad.device import is_dtype_supported @@ -1490,6 +1490,7 @@ class TestOps(unittest.TestCase): helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(), atol=1e-7, grad_atol=1e-7) helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=-1), lambda x: x.logcumsumexp(-1), atol=1e-7, grad_atol=1e-7) + @unittest.skipIf(not DEVECTORIZE, "broken without DEVECTORIZE. TODO: fix this") def test_logcumsumexp_numerical(self): helper_test_op(None, lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(), atol=1e-7, grad_atol=1e-7, vals=[[0.0, 100.0]]) diff --git a/test/test_schedule.py b/test/test_schedule.py index cb0ad9e3e8..3ac485dab5 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2527,5 +2527,12 @@ class TestUOpBecome(unittest.TestCase): assert b.lazydata.is_realized assert b.lazydata.base.buffer._base is None + def test_setitem_offset(self): + a = Tensor.full((16,), 0.).contiguous().realize() + b = Tensor.full((16,), 1.).contiguous().realize() + a_view = a[4:].reshape(3, 4).shrink(((0,2),(0,2))).reshape((4,)) + b.shrink(((0,4),)).assign(a_view).realize() + self.assertListEqual(b.tolist(), [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/test/test_tensor.py b/test/test_tensor.py index 2b4d758848..fb9bbc72cb 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -414,7 +414,7 @@ class TestTinygrad(unittest.TestCase): def test_tensor_dtype_errors(self): with self.assertRaises(AttributeError): Tensor([3], dtype="typo") - with self.assertRaises(TypeError): Tensor([3], dtype=(dtypes.int,)) + with self.assertRaises(AttributeError): Tensor([3], dtype=(dtypes.int,)) def test_tensor_bytes(self): data = b"abc123" diff --git a/tinygrad/codegen/devectorizer.py b/tinygrad/codegen/devectorizer.py index 39eb7e6852..a0a9d4e87c 100644 --- a/tinygrad/codegen/devectorizer.py +++ b/tinygrad/codegen/devectorizer.py @@ -1,86 +1,83 @@ -from typing import Optional, Any, Callable -import functools, operator +from typing import Optional, Any, Callable, cast +import functools, operator, itertools from collections import defaultdict from tinygrad.dtype import dtypes, ImageDType, PtrDType from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve from tinygrad.ops import graph_rewrite, GroupOp from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym -from tinygrad.helpers import getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, DEVECTORIZE +from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES from tinygrad.renderer import Renderer -# ***** float4/image store handling ***** - -def fold_expanded(ex, buf): - new_srcs = dedup(list(ex.src)) - old_new_srcs = new_srcs[:] - is_load, is_image = new_srcs[0].op is Ops.LOAD, isinstance(buf.dtype, ImageDType) - - # TODO: get the device from the buffer somehow - # NOTE: this can't be Device.DEFAULT because it opens devices - if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None - lengths = [4] if is_image else ([8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2])) +# ***** load/store grouping ***** +def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None): # first, extract all the relevant offsets - offsets_rootsrc: defaultdict[Any, dict] = defaultdict(dict) - for i,s in enumerate(new_srcs): - idx = s.src[0].src[1] - if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue + offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict) + for i in range(vec.dtype.count): + idx = vec.gep(i).simplify() if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg + elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg else: root_src, arg = idx, 0 - # add gates for gated - if len(s.src[0].src) == 3: root_src = (s.src[0].src[2], root_src) - assert arg not in offsets_rootsrc[root_src], f"{offsets_rootsrc[root_src][arg]} != {i} with {len(s.src)} sources" - offsets_rootsrc[root_src][arg] = i + if mask is not None: root_src = (mask.gep(i).simplify(), root_src) + offsets_rootsrc[root_src].setdefault(arg, []).append(i) - # then rewrite everything we can - used: set[tuple[UOp, UOp]] = set() + # the buf.dtype is always a pointer + ptrdtype = cast(PtrDType, buf.dtype) + + # then rewrite everything we can into groups + ret = [] + idxs: list[int|None] = [None]*vec.dtype.count + global_offset = 0 for rootsrc, offsets in offsets_rootsrc.items(): - for o in offsets: - for fold_length in lengths: - if all((rootsrc,o+i) not in used and o+i in offsets for i in range(fold_length)): - load_1 = new_srcs[offsets[o]] - new_src = list(load_1.src) - oidx = new_src[0].src[1] - if oidx.divides(fold_length) is None: continue - if is_image: - # for images, we rewrite the index. it must evenly divide 4 from the above check - new_src[0] = buf.index( - UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))), - rootsrc[0] if isinstance(rootsrc, tuple) else None) - else: - # for non image, we upcast the index pointer - new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(size=new_src[0].dtype.size, local=new_src[0].dtype.local)) - # generate the folded new_srcs - if is_load: - new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src)) - for i in range(fold_length): new_srcs[offsets[o+i]] = new_load.gep(i) - else: # vectorize the store - new_src[1] = UOp(Ops.VECTORIZE, new_src[1].dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[1] for i in range(fold_length))) - for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(Ops.STORE, dtypes.void, tuple(new_src)) if i == 0 else None - used.update((rootsrc,o+i) for i in range(fold_length)) + grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])] + for grp in grouped_offsets: + # get the index offset for this element. using [0] is okay, because they are the same + oidx = vec.gep(offsets[grp[0]][0]) + lidx = UOp(Ops.INDEX, buf.dtype, (buf, oidx, rootsrc[0]) if mask is not None else (buf, oidx)) + if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, local=ptrdtype.local)) + # set the idxs of the output + for i,g in enumerate(grp): + for oo in offsets[g]: idxs[oo] = global_offset+i + # add this lidx to the CAT + ret.append(lidx) + global_offset += len(grp) + assert None not in idxs, f"some idxs are missing {idxs}" + # this base thing is for image, we want the CAT to be a normal pointer + return UOp(Ops.CAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret)).gep(tuple(cast(list[int], idxs))) - # dedup expand for LOAD - if is_load and len(old_new_srcs) != len(ex.src): new_srcs = [new_srcs[old_new_srcs.index(s)] for s in ex.src] - # remove Nones for STORE - return UOp(ex.op, ex.dtype, tuple(x for x in new_srcs if x is not None), ex.arg) if len(used) else None +def cat_after_store(cat:UOp, data:UOp): + # TODO: this is written in many places + offset = 0 + ret = [] + for s in cat.src: + ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count))))) + offset += s.dtype.count + return UOp.sink(ret[0], *ret[1:]) -def fix_unfoldable_image_load(load:UOp, buf:UOp): - if not isinstance(buf.dtype, ImageDType) or (oidx:=load.src[0].src[1]).dtype.count == 2: return None - id4 = oidx % 4 - new_src = list(load.src) - # TODO: copied logic from above - new_src[0] = load.src[0].src[0].index( - UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))), - load.src[0].src[2] if len(load.src[0].src) == 3 else None) - vec_load = UOp(Ops.LOAD, load.dtype.vec(4), tuple(new_src)) - return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), load.const_like(float('nan'))) +def gep_on_store(gep:UOp, st:UOp): + # NOTE: we need to invert the gep here, but it may be an expanding gep + # fake argsort. TODO: handle duplicates + a = {} + for i,x in enumerate(gep.arg): a[x] = i + new_arg = tuple(x[1] for x in sorted(a.items())) + return UOp(Ops.STORE, src=(gep.src[0], st.gep(new_arg))) -buf_idx_pat = UPat(Ops.INDEX, src=(UPat.var("buf"),), allow_any_len=True) -float4_folding = PatternMatcher([ - (UPat(Ops.VECTORIZE, src=UPat(Ops.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded), - (UPat((Ops.BARRIER, Ops.SINK), src=UPat(Ops.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded), +load_store_folding = PatternMatcher([ + (UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL), name="buf")), UPat.var("vec"))), expand_index), + (UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL), name="buf")), UPat.var("vec"), + UPat.var("mask"))), expand_index), + # GEP after LOAD + (UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True), + lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)), + # GEP on data of STORE + (UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st"))), gep_on_store), + # put CAT after LOAD + (UPat(Ops.LOAD, src=(UPat(Ops.CAT, name="cat"),), name="ld", allow_any_len=True), + lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))), + # put CAT after STORE + (UPat(Ops.STORE, src=(UPat(Ops.CAT, name="cat"), UPat(name="data"))), cat_after_store), ]) # ***** image load valid simplification ***** @@ -140,6 +137,64 @@ def get_late_rewrite_patterns(ops, force_transcendental=False): if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))] return PatternMatcher(pat) +# *** correct load/store *** + +def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp): + if (sz:=ls.src[0].dtype.count) == 1: return None + lengths = [] + buf = idx.src[0] + if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): + pass + elif isinstance(buf.dtype, ImageDType): + lengths = [4] + elif ctx is not None and ctx.supports_float4: + # TODO: a better way to get this than ctx + lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]) + lengths.append(1) # worst case, it's not folded + ptrdtype = cast(PtrDType, buf.dtype) + global_offset = 0 + ret = [] + while global_offset < sz: + for fold_length in lengths: + if global_offset+fold_length > sz: continue + oidx = idx.src[1] + global_offset + if oidx.simplify().divides(fold_length) is None: continue + lidx = buf.index(oidx, idx.src[2] if len(idx.src) > 2 else None) + if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, local=ptrdtype.local)) + if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:])) + else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length))) + global_offset += fold_length + break + if len(ret) == 1: return None + return UOp(Ops.CAT, ls.dtype, tuple(ret)) + +def image_fixup(ls:UOp): + # normal image load or store, with the CAST from expand_index + if ls.src[0].op is Ops.CAST and isinstance(image_dtype:=ls.src[0].src[0].dtype, ImageDType): + assert ls.src[0].dtype.count == 4, "image must be casted to 4" + idx = ls.src[0].src[0] + oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1])))) + idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:]) + return ls.replace(src=(idx,)+ls.src[1:]) + + # this is an unprocessed image without a cast, aka unfoldable image load. this doesn't work for stores + if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].dtype != dtypes.int.vec(2): + assert ls.op is Ops.LOAD, "if an image store isn't upcasted to 4, we can't store it" + idx = ls.src[0] + id4 = idx.src[1] % 4 + oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1])))) + idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:]) + vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:]) + return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), ls.const_like(float('nan'))) + + return None + +correct_load_store = PatternMatcher([ + # split LOAD/STORE + (UPat((Ops.LOAD, Ops.STORE), src=(UPat(Ops.CAST, src=(UPat(Ops.INDEX, name="idx"),)),), name="ls", allow_any_len=True), split_load_store), + # image indexing, including unfoldable images + (UPat((Ops.LOAD, Ops.STORE), name="ls"), image_fixup), +]) # *** uop expander *** @@ -160,13 +215,6 @@ def no_vectorized_alu(alu): alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount)) return UOp(Ops.VECTORIZE, alu.dtype, alus) -def no_vectorized_load_store(ls:UOp): - idx = ls.src[0] - assert isinstance(idx.dtype, PtrDType) - if idx.dtype.v == 1: return None - tv = [UOp(ls.op, ls.dtype.scalar(), tuple(j.gep(i) for j in ls.src)) for i in range(idx.dtype.v)] - return UOp(Ops.VECTORIZE, ls.dtype, tuple(tv)) - def no_vectorized_acc(acc:UOp): if acc.dtype.count == 1: return None alus = tuple(UOp(acc.op, acc.dtype.scalar(), @@ -175,16 +223,9 @@ def no_vectorized_acc(acc:UOp): devectorize = PatternMatcher([ # no ALU on vectorized dtypes - (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu), + (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN), name="alu"), no_vectorized_alu), (UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma), (UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc), - (UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store), -]) - -devectorize_load_store = PatternMatcher([ - # TODO: add vectorized support to transcendental - (UPat((Ops.INDEX), name="alu"), no_vectorized_alu), - (UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store), ]) def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None: @@ -193,8 +234,6 @@ def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|N return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val) load_store_indexing = PatternMatcher([ - # late fixup of unfoldable image loads - (UPat(Ops.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load), # simplify valid (UPat(Ops.AND, name="valid"), simplify_valid), # image load valid idx simplification @@ -231,12 +270,10 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else () extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([]) - if DEVECTORIZE: - # devectorize + load_store_indexing - sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing) - else: - # new devectorize only for load/store - sink = graph_rewrite(sink, sym+devectorize_load_store) + # devectorize is optional + if DEVECTORIZE >= 2: sink = graph_rewrite(sink, sym+load_store_folding+load_store_indexing, ctx=opts) + elif DEVECTORIZE: sink = graph_rewrite(sink, sym+devectorize+load_store_folding+correct_load_store+load_store_indexing, ctx=opts) + else: sink = graph_rewrite(sink, sym+load_store_folding+correct_load_store+load_store_indexing, ctx=opts) # optional pre matcher if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher) diff --git a/tinygrad/codegen/symbolic.py b/tinygrad/codegen/symbolic.py index dec7be249b..82fb566ebc 100644 --- a/tinygrad/codegen/symbolic.py +++ b/tinygrad/codegen/symbolic.py @@ -230,6 +230,17 @@ symbolic = symbolic_simple+PatternMatcher([ # ** mod ** # mod folding (UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)), + # GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST + (UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'), + lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))), + (UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"), + lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]), + (UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)), + (UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))), + # push all GEPs through ALUs (fix arange stuff) + (UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'), + lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \ + if not isinstance(gep.dtype, PtrDType) else None), ]) symbolic_flat = symbolic+PatternMatcher([ @@ -399,17 +410,6 @@ sym = symbolic_flat+PatternMatcher([ # VECTORIZE void is SINK (UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b), (UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)), - # GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST - (UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'), - lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))), - (UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"), - lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]), - (UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)), - (UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))), - # push all GEPs through ALUs (fix arange stuff) - (UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'), - lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \ - if not isinstance(gep.dtype, PtrDType) else None), # push some GEPs through WMMAs (UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma), # CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later) diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 1ee9bfa8f4..c0b80d5528 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -156,7 +156,7 @@ if (env_default_float := getenv("DEFAULT_FLOAT", "")): assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype" DTypeLike = Union[str, DType] -def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype) +def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype.lower()) # https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html # we don't support weak type and complex type @@ -180,7 +180,7 @@ def sum_acc_dtype(dt:DType): # default acc dtype for sum if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint) if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int) - return least_upper_dtype(dt, dtypes.float) + return least_upper_dtype(dt, to_dtype(getenv("SUM_DTYPE", "float32"))) def truncate_fp16(x): try: return struct.unpack("@e", struct.pack("@e", float(x)))[0] diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 5e90c22b15..f80ebf7358 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -106,6 +106,9 @@ sym = symbolic_simple+PatternMatcher([ # put CAST after expanding BUFFER (UPat(Ops.VIEW, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="v"), lambda x,v: x.view(x.st+v.st).cast(v.dtype) if getenv("CAST_AFTER_EXPAND") and x.base.op is Ops.BUFFER and resolve(prod(v.shape) > prod(x.shape)) else None), + # remove CONST/BIND/VIEW from SINK + (UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src) + if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST,Ops.BIND}))) != x.src else None), ]) # **** UOp realization @@ -259,11 +262,12 @@ create_kernels = merge_views+PatternMatcher([ lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None), # walk back the local graph until we reach a buffer/assign parent (UPat(Ops.KERNEL, name="x"), append_to_kernel), - # remove CONST/BIND from SINK - (UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src) - if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST,Ops.BIND}))) != x.src else None), + # remove downstream reshapes from SINK + (UPat(Ops.SINK, name="x"), lambda x:x.replace(src=tuple(s.base for s in x.src)) if any(s.op is Ops.VIEW for s in x.src) else None), ]) +DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK} + # **** fix kernel AST # ** create buffer ops + enumerate buffers @@ -273,8 +277,18 @@ add_buffer_ops = PatternMatcher([ (UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), x.st.to_uop()))), # STORE (except for COPY/BUFFER_VIEW) (UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x), + # partial assign can store to a non-contiguous ShapeTracker + (UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)), + lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()), + # otherwise the store is contiguous (UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)), lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()), + # if the last child is a VIEW we merge the ShapeTrackers and store the base + (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))), + lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)), + # remove CONTIGUOUS/DEVICE from kernel AST + (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x), + (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())), ]) # ** push views to buffer ops @@ -290,23 +304,24 @@ def swizzle_reduceop(r:UOp, src:UOp, view:UOp): strides = strides_for_shape(rshape) nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides, v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views] - # update input_st and axis + # create a new reduceop for the swizzled input new_input_st = tmp + ShapeTracker(tuple(nv)) new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg))) - return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape)) + return UOp(Ops.REDUCE_AXIS, r.dtype, (apply_swizzle(src.view(src.arg+new_input_st if src.op is Ops.VIEW else new_input_st)),), + (r.arg[0], new_axis)).view(ShapeTracker.from_shape(st.shape)) def reduceop_view_right(src:UOp, v:UOp, r:UOp): assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}" return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u)).view(ShapeTracker.from_shape(r.shape)) def elementwise_view_right(root:UOp) -> UOp|None: - if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW]): return None + if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in DONT_PUSH_VIEWS]): return None assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}" # place view after applying the elementwise op - new_shape = swizzles[0].base.shape - ret = root.replace(src=tuple(x.base if x.base.shape == new_shape else apply_swizzle(x.view(ShapeTracker.from_shape(new_shape))) for x in root.src)) + new_st = ShapeTracker.from_shape(swizzles[0].base.shape) + new_src = [x.base if x.base.shape==new_st.shape else apply_swizzle(x.view(x.arg+new_st) if x.op is Ops.VIEW else x.view(new_st)) for x in root.src] # reshape to match downstream shapes - return ret.reshape(root.shape) + return root.replace(src=tuple(new_src)).reshape(root.shape) def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu" @@ -315,17 +330,12 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: # push VIEW to children view_right = merge_views+PatternMatcher([ - # STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val)) - (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))), - lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))), - # STORE is the last child, so we just merge the ShapeTrackers and store the base - (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)), # push a non contiguous ShapeTracker through reduceop (UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop), # apply view after reduceops - (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat.var("src"),), name="v"),), name="r"), reduceop_view_right), + (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="src"),), name="v"),), name="r"), reduceop_view_right), # apply view after elementwise ops - (UPat(GroupOp.All-GroupOp.Buffer, name="root"), elementwise_view_right), + (UPat(GroupOp.All-DONT_PUSH_VIEWS, name="root"), elementwise_view_right), # double reduce op collapses to a single reduce op (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) @@ -359,9 +369,6 @@ def check_load_st(glbl:UOp, view:UOp): fix_kernel_ops = PatternMatcher([ # BIND in shapetracker becomes DEFINE_VAR (UPat(Ops.VIEW, name="x"), unbind_shapetracker), - # remove CONTIGUOUS/DEVICE - (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x), - (UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())), # remove unmasked valid (UPat.where(UPat(Ops.VALID, name="valid"), UPat.cvar("x"), UPat()), lambda valid,x: x if all(v.mask is None for v in valid.st.views) else None), # no ImageDType after load @@ -378,15 +385,15 @@ def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp: if s.op is Ops.ASSIGN: for out in s.src[1].arg.ast.src: parents_rep[out] = s.buf_uop.view(unwrap(out.st)) ast = k.arg.ast.substitute(parents_rep) - # add buffer ops - ast = graph_rewrite(ast, add_buffer_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True) - if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}") # unbind_vars + push views to edges ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right) - # fix_kernel_ops - ast = graph_rewrite(ast, fix_kernel_ops, var_vals) + # add buffer ops + ast = graph_rewrite(ast, view_left+add_buffer_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True) + if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}") # create subbuffer (TODO: this does not belong here) if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize) + # fix_kernel_ops + ast = graph_rewrite(ast, fix_kernel_ops, var_vals) return k.replace(arg=Kernel(ast, k.arg.metadata)) PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {} diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index b090cb9ba4..1a805fbbca 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -191,8 +191,8 @@ class ClangRenderer(CStyleLanguage): code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2]}), Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"} # LLVM legalizes double => half cast on systems that don't support it natively (like x86 cpus without AVX512-FP16) into a compiler-rt libcall. - extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16))]) + \ - CStyleLanguage.extra_matcher + extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16)), + (UPat(Ops.SQRT, name="alu"), no_vectorized_alu),]) + CStyleLanguage.extra_matcher if sys.platform == 'win32': kernel_prefix = "__attribute__((ms_abi)) " diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 455a0f7332..9af5920bde 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -163,6 +163,9 @@ class AM_GFX(AM_IP): self.adev.regTCP_CNTL.write(self.adev.regTCP_CNTL.read() | 0x20000000) self.adev.regRLC_SRM_CNTL.update(srm_enable=1, auto_incr_addr=1) + self.adev.regS2A_DOORBELL_ENTRY_0_CTRL.write(s2a_doorbell_port0_enable=1, s2a_doorbell_port0_awid=0x3, s2a_doorbell_port0_awaddr_31_28_value=0x3) + self.adev.regS2A_DOORBELL_ENTRY_3_CTRL.write(s2a_doorbell_port3_enable=1, s2a_doorbell_port3_awid=0x6, s2a_doorbell_port3_awaddr_31_28_value=0x3) + self.adev.regGRBM_CNTL.update(read_timeout=0xff) for i in range(0, 16): self._grbm_select(vmid=i) @@ -297,6 +300,9 @@ class AM_IH(AM_IP): for _, rwptr_vm, suf, ring_id in self.rings: self.adev.reg(f"regIH_RB_CNTL{suf}").update(rb_enable=1, **({'enable_intr': 1} if ring_id == 0 else {})) + self.adev.regS2A_DOORBELL_ENTRY_1_CTRL.update(s2a_doorbell_port1_enable=1, s2a_doorbell_port1_awid=0x0, s2a_doorbell_port1_awaddr_31_28_value=0x0, + s2a_doorbell_port1_range_offset=am.AMDGPU_NAVI10_DOORBELL_IH*2, s2a_doorbell_port1_range_size=2) + class AM_SDMA(AM_IP): def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, doorbell:int, pipe:int, queue:int): # Setup the ring @@ -320,6 +326,8 @@ class AM_SDMA(AM_IP): self.adev.reg(f"regSDMA{pipe}_UTCL1_PAGE").update(rd_l2_policy=0x2, wr_l2_policy=0x3, llc_noalloc=1) # rd=noa, wr=bypass self.adev.reg(f"regSDMA{pipe}_F32_CNTL").update(halt=0, th1_reset=0) self.adev.reg(f"regSDMA{pipe}_CNTL").update(ctxempty_int_enable=1, trap_enable=1) + self.adev.regS2A_DOORBELL_ENTRY_2_CTRL.update(s2a_doorbell_port2_enable=1, s2a_doorbell_port2_awid=0xe, s2a_doorbell_port2_awaddr_31_28_value=0x3, + s2a_doorbell_port2_range_offset=am.AMDGPU_NAVI10_DOORBELL_sDMA_ENGINE0*2, s2a_doorbell_port2_range_size=4) def fini(self): self.adev.regSDMA0_QUEUE0_RB_CNTL.update(rb_enable=0) diff --git a/tinygrad/runtime/support/hcq.py b/tinygrad/runtime/support/hcq.py index 42678f55b1..7af43d4590 100644 --- a/tinygrad/runtime/support/hcq.py +++ b/tinygrad/runtime/support/hcq.py @@ -246,11 +246,12 @@ class HCQSignal(Generic[DeviceType]): Args: value: The value to wait for. - timeout: Maximum time to wait in milliseconds. Defaults to 10s. + timeout: Maximum time to wait in milliseconds. Defaults to 30s. """ start_time = int(time.perf_counter() * 1000) - while self.value < value and (time_spent:=int(time.perf_counter() * 1000) - start_time) < timeout: + while (prev_value:=self.value) < value and (time_spent:=int(time.perf_counter() * 1000) - start_time) < timeout: self._sleep(time_spent) + if self.value != prev_value: start_time = int(time.perf_counter() * 1000) # progress was made, reset timer if self.value < value: raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})") @contextlib.contextmanager