mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
2
.github/workflows/benchmark.yml
vendored
2
.github/workflows/benchmark.yml
vendored
@@ -447,7 +447,7 @@ jobs:
|
||||
testmoreamdbenchmark:
|
||||
name: tinybox red Training Benchmark
|
||||
runs-on: [self-hosted, Linux, tinybox]
|
||||
timeout-minutes: 20
|
||||
timeout-minutes: 30
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -o pipefail {0}
|
||||
|
||||
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
@@ -180,6 +180,8 @@ jobs:
|
||||
run: PYTHONPATH=. LLVM=1 LLVMOPT=0 TINY_BACKEND=1 python3 -m pytest -n auto test/test_ops.py --durations=20 || true
|
||||
- name: Test in-place operations on views
|
||||
run: PYTHONPATH=. TORCH_DEBUG=1 python3 extra/torch_backend/test_inplace.py
|
||||
- name: Test multi-gpu
|
||||
run: PYTHONPATH=. LLVM=1 GPUS=4 TORCH_DEBUG=1 python3 extra/torch_backend/test_multigpu.py
|
||||
|
||||
torchbackendmore:
|
||||
name: Torch Backend Tests More
|
||||
@@ -328,8 +330,8 @@ jobs:
|
||||
run: awk '/```python/{flag=1;next}/```/{flag=0}flag' README.md > README.py && PYTHONPATH=. python README.py
|
||||
- name: Run unit tests
|
||||
run: PYTHONPATH="." python -m pytest -n=auto test/unit/
|
||||
- name: Repo line count < 11500 lines
|
||||
run: MAX_LINE_COUNT=11500 python sz.py
|
||||
- name: Repo line count < 12000 lines
|
||||
run: MAX_LINE_COUNT=12000 python sz.py
|
||||
|
||||
fuzzing:
|
||||
name: Fuzzing
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -34,6 +34,8 @@ extra/datasets/open-images-v6-mlperf
|
||||
extra/datasets/kits/
|
||||
extra/datasets/COCO/
|
||||
extra/datasets/audio*
|
||||
extra/huggingface_onnx/models/*
|
||||
extra/huggingface_onnx/*.yaml
|
||||
extra/weights
|
||||
venv
|
||||
examples/**/net.*[js,json]
|
||||
@@ -56,5 +58,5 @@ weights
|
||||
comgr_*
|
||||
*.pkl
|
||||
site/
|
||||
master_schedule.py
|
||||
profile_stats
|
||||
*.log
|
||||
|
||||
@@ -925,7 +925,7 @@ def train_bert():
|
||||
# ** hyperparameters **
|
||||
BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
|
||||
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
|
||||
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.00011 * math.sqrt(BS/66))
|
||||
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.0002 * math.sqrt(BS/96))
|
||||
|
||||
train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3630000 // BS)
|
||||
warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", 1)
|
||||
@@ -936,7 +936,7 @@ def train_bert():
|
||||
save_ckpt_dir = config["SAVE_CKPT_DIR"] = getenv("SAVE_CKPT_DIR", "./ckpts")
|
||||
init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR)
|
||||
|
||||
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**10 if dtypes.default_float == dtypes.float16 else 1.0)
|
||||
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**11 if dtypes.default_float == dtypes.float16 else 1.0)
|
||||
decay = config["DECAY"] = getenv("DECAY", 0.01)
|
||||
epsilon = config["EPSILON"] = getenv("EPSILON", 1e-6)
|
||||
poly_power = config["POLY_POWER"] = getenv("POLY_POWER", 1.0)
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78
|
||||
export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
|
||||
|
||||
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
|
||||
export BEAM=4 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78
|
||||
export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
|
||||
|
||||
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
|
||||
export BEAM=4 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export SUBMISSION_PLATFORM="tinybox_green"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78
|
||||
export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
|
||||
|
||||
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
|
||||
export BEAM=4 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78
|
||||
export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
|
||||
|
||||
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
|
||||
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78
|
||||
export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
|
||||
|
||||
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
|
||||
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export SUBMISSION_PLATFORM="tinybox_red"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78
|
||||
export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
|
||||
|
||||
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
|
||||
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
|
||||
85
extra/huggingface_onnx/collect_metadata.py
Normal file
85
extra/huggingface_onnx/collect_metadata.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import yaml, time, requests, argparse
|
||||
from pathlib import Path
|
||||
from huggingface_hub import list_models, HfApi
|
||||
from tinygrad.helpers import tqdm
|
||||
|
||||
HUGGINGFACE_URL = "https://huggingface.co"
|
||||
SKIPPED_FILES = [
|
||||
"fp16", "int8", "uint8", "quantized", # numerical accuracy issues
|
||||
"avx2", "arm64", "avx512", "avx512_vnni", # numerical accuracy issues
|
||||
"q4", "q4f16", "bnb4", # unimplemented quantization
|
||||
"model_O4", # requires non cpu ort runner and MemcpyFromHost op
|
||||
"merged", # TODO implement attribute with graph type and Loop op
|
||||
]
|
||||
SKIPPED_REPO_PATHS = [
|
||||
# Invalid model-index
|
||||
"AdamCodd/vit-base-nsfw-detector",
|
||||
# TODO: implement attribute with graph type and Loop op
|
||||
"minishlab/potion-base-8M", "minishlab/M2V_base_output", "minishlab/potion-retrieval-32M",
|
||||
# TODO: implement SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, GroupQueryAttention
|
||||
"HuggingFaceTB/SmolLM2-360M-Instruct",
|
||||
# TODO: implement SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, RotaryEmbedding, MultiHeadAttention
|
||||
"HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
||||
# TODO: implmement RandomNormalLike
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", "stabilityai/sdxl-turbo", 'SimianLuo/LCM_Dreamshaper_v7',
|
||||
# TODO: implement NonZero
|
||||
"mangoapps/fb_zeroshot_mnli_onnx",
|
||||
# TODO huge Concat in here with 1024 (1, 3, 32, 32) Tensors, and maybe a MOD bug with const folding
|
||||
"briaai/RMBG-2.0",
|
||||
]
|
||||
|
||||
def get_top_repos(n: int, sort: str) -> list[str]: # list["FacebookAI/xlm-roberta-large", ...]
|
||||
print(f"** Getting top {n} models sorted by {sort} **")
|
||||
repos = []
|
||||
i = 0
|
||||
for model in list_models(filter="onnx", sort=sort):
|
||||
if model.id in SKIPPED_REPO_PATHS: continue
|
||||
print(f"{i+1}/{n}: {model.id} ({getattr(model, sort)})")
|
||||
repos.append(model.id)
|
||||
i += 1
|
||||
if i == n: break
|
||||
return repos
|
||||
|
||||
def get_metadata(repos:list[str]) -> dict:
|
||||
api = HfApi()
|
||||
repos_metadata = {"repositories": {}}
|
||||
total_size = 0
|
||||
|
||||
# TODO: speed head requests up with async?
|
||||
for repo in tqdm(repos, desc="Getting metadata"):
|
||||
files_metadata = []
|
||||
model_info = api.model_info(repo)
|
||||
|
||||
for file in model_info.siblings:
|
||||
filename = file.rfilename
|
||||
if not (filename.endswith('.onnx') or filename.endswith('.onnx_data')): continue
|
||||
if any(skip_str in filename for skip_str in SKIPPED_FILES): continue
|
||||
head = requests.head(f"{HUGGINGFACE_URL}/{repo}/resolve/main/{filename}", allow_redirects=True)
|
||||
file_size = file.size or int(head.headers.get('Content-Length', 0))
|
||||
files_metadata.append({"file": filename, "size": f"{file_size/1e6:.2f}MB"})
|
||||
total_size += file_size
|
||||
|
||||
repos_metadata["repositories"][repo] = {
|
||||
"url": f"{HUGGINGFACE_URL}/{repo}",
|
||||
"download_path": None,
|
||||
"files": files_metadata,
|
||||
}
|
||||
repos_metadata['total_size'] = f"{total_size/1e9:.2f}GB"
|
||||
repos_metadata['created_at'] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
return repos_metadata
|
||||
|
||||
if __name__ == "__main__":
|
||||
sort = "downloads" # recent 30 days downloads
|
||||
huggingface_onnx_dir = Path(__file__).parent
|
||||
|
||||
parser = argparse.ArgumentParser(description="Produces a YAML file with metadata of top huggingface onnx models")
|
||||
parser.add_argument("--limit", type=int, required=True, help="Number of top repositories to process (e.g., 100)")
|
||||
parser.add_argument("--output", type=str, default="huggingface_repos.yaml", help="Output YAML file name to save the report")
|
||||
args = parser.parse_args()
|
||||
|
||||
top_repos = get_top_repos(args.limit, sort)
|
||||
metadata = get_metadata(top_repos)
|
||||
yaml_path = huggingface_onnx_dir / args.output
|
||||
with open(yaml_path, 'w') as f:
|
||||
yaml.dump(metadata, f, sort_keys=False)
|
||||
print(f"YAML saved to: {str(yaml_path)}")
|
||||
29
extra/huggingface_onnx/download_models.py
Normal file
29
extra/huggingface_onnx/download_models.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import yaml, argparse
|
||||
from pathlib import Path
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
def download_models(yaml_file: str, download_dir: str) -> None:
|
||||
with open(yaml_file, 'r') as f: metadata = yaml.safe_load(f)
|
||||
n = len(metadata["repositories"])
|
||||
|
||||
for i, (model_id, model_data) in enumerate(metadata["repositories"].items()):
|
||||
print(f"Downloading {i+1}/{n}: {model_id}...")
|
||||
allow_patterns = [file_info["file"] for file_info in model_data["files"]]
|
||||
root_path = Path(snapshot_download(repo_id=model_id, allow_patterns=allow_patterns, cache_dir=download_dir))
|
||||
# download configs too (the sizes are small)
|
||||
snapshot_download(repo_id=model_id, allow_patterns=["*config.json"], cache_dir=download_dir)
|
||||
print(f"Downloaded model files to: {root_path}")
|
||||
model_data["download_path"] = str(root_path)
|
||||
|
||||
# Save the updated metadata back to the YAML file
|
||||
with open(yaml_file, 'w') as f: yaml.dump(metadata, f, sort_keys=False)
|
||||
print("Download completed according to YAML file.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Download models from Huggingface Hub based on a YAML configuration file.")
|
||||
parser.add_argument("input", type=str, help="Path to the input YAML configuration file containing model information.")
|
||||
args = parser.parse_args()
|
||||
|
||||
models_folder = Path(__file__).parent / "models"
|
||||
models_folder.mkdir(parents=True, exist_ok=True)
|
||||
download_models(args.input, str(models_folder))
|
||||
136
extra/huggingface_onnx/run_models.py
Normal file
136
extra/huggingface_onnx/run_models.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import onnx, yaml, tempfile, time, collections, pprint, argparse, json
|
||||
from pathlib import Path
|
||||
from extra.onnx import OnnxRunner, get_onnx_ops
|
||||
from extra.onnx_helpers import validate, get_example_inputs
|
||||
|
||||
def get_config(root_path: Path):
|
||||
ret = {}
|
||||
for path in root_path.rglob("*config.json"):
|
||||
config = json.load(path.open())
|
||||
if isinstance(config, dict):
|
||||
ret.update(config)
|
||||
return ret
|
||||
|
||||
def run_huggingface_validate(onnx_model_path, config, rtol, atol):
|
||||
onnx_model = onnx.load(onnx_model_path)
|
||||
onnx_runner = OnnxRunner(onnx_model)
|
||||
inputs = get_example_inputs(onnx_runner.graph_inputs, config)
|
||||
validate(onnx_model_path, inputs, rtol=rtol, atol=atol)
|
||||
|
||||
def get_tolerances(file_name): # -> rtol, atol
|
||||
# TODO very high rtol atol
|
||||
if "fp16" in file_name: return 9e-2, 9e-2
|
||||
if any(q in file_name for q in ["int8", "uint8", "quantized"]): return 4, 4
|
||||
return 4e-3, 3e-2
|
||||
|
||||
def validate_repos(models:dict[str, tuple[Path, Path]]):
|
||||
print(f"** Validating {len(model_paths)} models **")
|
||||
for model_id, (root_path, relative_path) in models.items():
|
||||
print(f"validating model {model_id}")
|
||||
model_path = root_path / relative_path
|
||||
onnx_file_name = model_path.stem
|
||||
config = get_config(root_path)
|
||||
rtol, atol = get_tolerances(onnx_file_name)
|
||||
st = time.time()
|
||||
run_huggingface_validate(model_path, config, rtol, atol)
|
||||
et = time.time() - st
|
||||
print(f"passed, took {et:.2f}s")
|
||||
|
||||
def retrieve_op_stats(models:dict[str, tuple[Path, Path]]) -> dict:
|
||||
ret = {}
|
||||
op_counter = collections.Counter()
|
||||
unsupported_ops = collections.defaultdict(set)
|
||||
supported_ops = get_onnx_ops()
|
||||
print(f"** Retrieving stats from {len(model_paths)} models **")
|
||||
for model_id, (root_path, relative_path) in models.items():
|
||||
print(f"examining {model_id}")
|
||||
model_path = root_path / relative_path
|
||||
onnx_runner = OnnxRunner(onnx.load(model_path))
|
||||
for node in onnx_runner.graph_nodes:
|
||||
op_counter[node.op] += 1
|
||||
if node.op not in supported_ops:
|
||||
unsupported_ops[node.op].add(model_id)
|
||||
del onnx_runner
|
||||
ret["unsupported_ops"] = {k:list(v) for k, v in unsupported_ops.items()}
|
||||
ret["op_counter"] = op_counter.most_common()
|
||||
return ret
|
||||
|
||||
def debug_run(model_path, truncate, config, rtol, atol):
|
||||
if truncate != -1:
|
||||
model = onnx.load(model_path)
|
||||
nodes_up_to_limit = list(model.graph.node)[:truncate + 1]
|
||||
new_output_values = [onnx.helper.make_empty_tensor_value_info(output_name) for output_name in nodes_up_to_limit[-1].output]
|
||||
model.graph.ClearField("node")
|
||||
model.graph.node.extend(nodes_up_to_limit)
|
||||
model.graph.ClearField("output")
|
||||
model.graph.output.extend(new_output_values)
|
||||
with tempfile.NamedTemporaryFile(suffix=model_path.suffix) as tmp:
|
||||
onnx.save(model, tmp.name)
|
||||
run_huggingface_validate(tmp.name, config, rtol, atol)
|
||||
else:
|
||||
run_huggingface_validate(model_path, config, rtol, atol)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Huggingface ONNX Model Validator and Ops Checker")
|
||||
parser.add_argument("input", type=str, help="Path to the input YAML configuration file containing model information.")
|
||||
parser.add_argument("--check_ops", action="store_true", default=False,
|
||||
help="Check support for ONNX operations in models from the YAML file")
|
||||
parser.add_argument("--validate", action="store_true", default=False,
|
||||
help="Validate correctness of models from the YAML file")
|
||||
parser.add_argument("--debug", type=str, default="",
|
||||
help="""Validates without explicitly needing a YAML or models pre-installed.
|
||||
provide repo id (e.g. "minishlab/potion-base-8M") to validate all onnx models inside the repo
|
||||
provide onnx model path (e.g. "minishlab/potion-base-8M/onnx/model.onnx") to validate only that one model
|
||||
""")
|
||||
parser.add_argument("--truncate", type=int, default=-1, help="Truncate the ONNX model so intermediate results can be validated")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not (args.check_ops or args.validate or args.debug):
|
||||
parser.error("Please provide either --validate, --check_ops, or --debug.")
|
||||
if args.truncate != -1 and not args.debug:
|
||||
parser.error("--truncate and --debug should be used together for debugging")
|
||||
|
||||
if args.check_ops or args.validate:
|
||||
with open(args.input, 'r') as f:
|
||||
data = yaml.safe_load(f)
|
||||
assert all(repo["download_path"] is not None for repo in data["repositories"].values()), "please run `download_models.py` for this yaml"
|
||||
model_paths = {
|
||||
model_id + "/" + model["file"]: (Path(repo["download_path"]), Path(model["file"]))
|
||||
for model_id, repo in data["repositories"].items()
|
||||
for model in repo["files"]
|
||||
if model["file"].endswith(".onnx")
|
||||
}
|
||||
|
||||
if args.check_ops:
|
||||
pprint.pprint(retrieve_op_stats(model_paths))
|
||||
|
||||
if args.validate:
|
||||
validate_repos(model_paths)
|
||||
|
||||
if args.debug:
|
||||
from huggingface_hub import snapshot_download
|
||||
download_dir = Path(__file__).parent / "models"
|
||||
path:list[str] = args.debug.split("/")
|
||||
if len(path) == 2:
|
||||
# repo id
|
||||
# validates all onnx models inside repo
|
||||
repo_id = "/".join(path)
|
||||
root_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=["*.onnx", ".onnx_data"], cache_dir=download_dir))
|
||||
snapshot_download(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=download_dir)
|
||||
config = get_config(root_path)
|
||||
for onnx_model in root_path.rglob("*.onnx"):
|
||||
rtol, atol = get_tolerances(onnx_model.name)
|
||||
print(f"validating {onnx_model.relative_to(root_path)} with truncate={args.truncate}, {rtol=}, {atol=}")
|
||||
debug_run(onnx_model, -1, config, rtol, atol)
|
||||
else:
|
||||
# model id
|
||||
# only validate the specified onnx model
|
||||
onnx_model = path[-1]
|
||||
assert path[-1].endswith(".onnx")
|
||||
repo_id, relative_path = "/".join(path[:2]), "/".join(path[2:])
|
||||
root_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=[relative_path], cache_dir=download_dir))
|
||||
snapshot_download(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=download_dir)
|
||||
config = get_config(root_path)
|
||||
rtol, atol = get_tolerances(onnx_model)
|
||||
print(f"validating {relative_path} with truncate={args.truncate}, {rtol=}, {atol=}")
|
||||
debug_run(root_path / relative_path, args.truncate, config, rtol, atol)
|
||||
@@ -5,12 +5,43 @@ import onnx
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
def get_example_inputs(graph_inputs:dict[str, OnnxValue]):
|
||||
def get_example_inputs(graph_inputs:dict[str, OnnxValue], config={}):
|
||||
def _get_shape(onnx_shape: tuple[str|int]):
|
||||
shape = []
|
||||
for onnx_dim in onnx_shape:
|
||||
match onnx_dim:
|
||||
case int(): shape.append(onnx_dim)
|
||||
case "width" | "height":
|
||||
size = config.get("size", {})
|
||||
shape.append(size) if isinstance(size, int) else shape.append(size.get(onnx_dim, 224))
|
||||
case "sequence" | "sequence_length" | "decoder_sequence_length": shape.append(64)
|
||||
case "encoder_sequence_length": shape.append(config.get("nb_max_frames", 64))
|
||||
case "past_decoder_sequence_length" | "encoder_sequence_length_out": shape.append(64)
|
||||
case "encoder_sequence_length / 2": shape.append(32)
|
||||
case "batch_size": shape.append(1)
|
||||
case "num_channels": shape.append(config.get("in_channels", 3))
|
||||
case "num_channels_latent": shape.append(config.get("latent_channels", 4))
|
||||
case "height_latent" | "width_latent": shape.append(config.get("sample_size", 1024) // 8)
|
||||
case "feature_size": shape.append(config.get("num_mel_bins", 128))
|
||||
case _: shape.append(1)
|
||||
return shape
|
||||
def _get_value(name, shape, dtype):
|
||||
match name:
|
||||
case "input_ids":
|
||||
vocab_size = config.get("text_config", {}).get("vocab_size") or config.get("vocab_size", 32)
|
||||
val = np.random.randint(0, vocab_size-1, shape)
|
||||
case "attention_mask": val = np.random.randint(0, 2, size=shape)
|
||||
case "token_type_ids": val = np.random.randint(0, config.get("type_vocab_size", 2), shape)
|
||||
case "image_tensor": val = np.random.randint(0, 256, shape)
|
||||
case "task_id": return Tensor(0, dtype=dtype)
|
||||
case _: val = np.random.uniform(size=shape) * 8
|
||||
return Tensor(val.astype(_to_np_dtype(dtype))).realize()
|
||||
|
||||
ret: dict[str, Tensor] = {}
|
||||
for name, spec in graph_inputs.items():
|
||||
assert not spec.is_optional and not spec.is_sequence, "only allow tensor input for now"
|
||||
shape = tuple(dim if isinstance(dim, int) else 1 for dim in spec.shape)
|
||||
value = Tensor(np.random.uniform(size=shape).astype(_to_np_dtype(spec.dtype)) * 8).realize()
|
||||
shape = _get_shape(spec.shape)
|
||||
value = _get_value(name, shape, spec.dtype)
|
||||
ret.update({name:value})
|
||||
return ret
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# A001 Variable `input` is shadowing a Python builtin
|
||||
# A002 Function argument `input` is shadowing a Python builtin
|
||||
# A006 Lambda argument `input` is shadowing a Python builtin
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.helpers import getenv, prod
|
||||
import torch.lib
|
||||
TORCH_DEBUG = getenv("TORCH_DEBUG")
|
||||
@@ -12,9 +12,12 @@ from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype
|
||||
|
||||
# https://pytorch.org/docs/stable/torch.compiler_ir.html
|
||||
|
||||
def _from_torch_device(device: torch.device): return f"{Device.DEFAULT}:{device.index or 0}"
|
||||
def _to_torch_device(device: str): return torch.device("tiny", int(device.partition(":")[2] or 0))
|
||||
|
||||
import torch.utils.cpp_extension
|
||||
mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[str(pathlib.Path(__file__).parent / "wrapped_tensor.cpp")])
|
||||
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, _to_torch_dtype(x.dtype))
|
||||
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, _to_torch_dtype(x.dtype), _to_torch_device(x.device).index)
|
||||
def unwrap(x:torch.Tensor) -> Tensor:
|
||||
assert isinstance(x, torch.Tensor), f"x isn't {type(x)}"
|
||||
return mod.unwrap(x)
|
||||
@@ -24,6 +27,7 @@ class TinyBackend:
|
||||
def current_device(self): return 0
|
||||
def _is_in_bad_fork(self): return False
|
||||
def manual_seed_all(self, seed: int): Tensor.manual_seed(seed)
|
||||
def device_count(self): return getenv("GPUS", 1) # TODO: device count in tiny?
|
||||
torch.utils.rename_privateuse1_backend("tiny")
|
||||
torch._register_device_module("tiny", TinyBackend())
|
||||
torch.utils.generate_methods_for_privateuse1_backend()
|
||||
@@ -31,7 +35,7 @@ torch.utils.generate_methods_for_privateuse1_backend()
|
||||
# in place operations with views
|
||||
def is_view(self: torch.Tensor) -> bool: return getattr(self, "_base", None) is not None
|
||||
def realize_with_views(self: torch.Tensor, views: list[torch.Tensor]):
|
||||
assert self.device.type == "tiny"
|
||||
assert self.is_tiny
|
||||
self = unwrap(self)
|
||||
if not self.lazydata.st.contiguous: raise ValueError("base of view must be contiguous") # TODO: support?
|
||||
self.replace(self.clone().realize())
|
||||
@@ -63,18 +67,18 @@ def inplace_fn(outvars: str|list[str]):
|
||||
@torch.library.impl("aten::masked_select", "privateuseone")
|
||||
def masked_select(self, mask):
|
||||
# err, bad
|
||||
return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()]))
|
||||
return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()], device=_from_torch_device(self.device)))
|
||||
|
||||
@torch.library.impl("aten::_index_put_impl_", "privateuseone")
|
||||
@inplace_fn("self")
|
||||
def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False):
|
||||
# TODO: move to tinygrad
|
||||
ret = aten._index_put_impl_(self.cpu(), [x.cpu() if isinstance(x, torch.Tensor) else None for x in indices], values.cpu(), accumulate, unsafe).tiny()
|
||||
ret = aten._index_put_impl_(self.cpu(), [x.cpu() if isinstance(x, torch.Tensor) else None for x in indices], values.cpu(), accumulate, unsafe).to(self.device)
|
||||
return wrap(unwrap(self).assign(unwrap(ret)))
|
||||
|
||||
@torch.library.impl("aten::index.Tensor", "privateuseone")
|
||||
def index_tensor(x, y):
|
||||
return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).tiny()
|
||||
return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).to(x.device)
|
||||
|
||||
@torch.library.impl("aten::randperm.generator_out", "privateuseone")
|
||||
def randperm_generator(n, generator=None, out=None): out.copy_(torch.randperm(n, generator=generator, device="cpu").tiny())
|
||||
@@ -121,13 +125,13 @@ def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
|
||||
@torch.library.impl("aten::empty_strided", "privateuseone")
|
||||
def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False):
|
||||
if TORCH_DEBUG: print(f"empty_strided {size=} {stride=} {dtype=} {layout=} {device=} {pin_memory=}")
|
||||
ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype))
|
||||
ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype), device=_from_torch_device(device))
|
||||
return wrap(ret)
|
||||
|
||||
@torch.library.impl("aten::empty.memory_format", "privateuseone")
|
||||
def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
|
||||
if TORCH_DEBUG: print(f"empty.memory_format {size=} {dtype=} {layout=} {device=} {pin_memory=} {memory_format=}")
|
||||
ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())).contiguous()
|
||||
ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype or torch.get_default_dtype()), device=_from_torch_device(device)).contiguous()
|
||||
return wrap(ret)
|
||||
|
||||
@torch.library.impl("aten::max_pool2d_with_indices", "privateuseone")
|
||||
@@ -170,7 +174,7 @@ def convolution_overrideable(input, weight, bias, stride, padding, dilation, tra
|
||||
def convolution_backward_overrideable(grad_out, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask):
|
||||
if TORCH_DEBUG >= 1:
|
||||
print(f"convolution_backward {input.shape=} {weight.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
|
||||
grad_out, input, weight, bias = unwrap(grad_out), unwrap(input), unwrap(weight), Tensor.zeros(weight.shape[0])
|
||||
grad_out, input, weight, bias = unwrap(grad_out), unwrap(input), unwrap(weight), Tensor.zeros(weight.shape[0], device=_from_torch_device(weight.device))
|
||||
out = Tensor.conv2d(input, weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding)
|
||||
grads = out.gradient(*[t for t,m in zip([input, weight, bias], output_mask) if m], gradient=grad_out)
|
||||
return tuple([wrap(grads.pop(0)) if m else None for m in output_mask])
|
||||
@@ -183,17 +187,19 @@ for i,pre in enumerate(["", "bi", "tri"]):
|
||||
|
||||
@torch.library.impl("aten::_copy_from", "privateuseone")
|
||||
def _copy_from(src: torch.Tensor, dest, non_blocking=False):
|
||||
realize = str(dest.device) == "tiny" and maybe_realize_storage(dest)
|
||||
realize = dest.is_tiny and maybe_realize_storage(dest)
|
||||
cast_dtype = _from_torch_dtype(dest.dtype)
|
||||
if str(src.device) == "tiny" and str(dest.device) == "tiny":
|
||||
unwrap(dest).assign(unwrap(src).cast(cast_dtype))
|
||||
if src.is_tiny and dest.is_tiny:
|
||||
to_device = _from_torch_device(dest.device)
|
||||
unwrap(dest).assign(unwrap(src).cast(cast_dtype).to(to_device))
|
||||
if realize: Tensor.realize(unwrap(dest))
|
||||
elif str(src.device) == "tiny" and str(dest.device) == "cpu":
|
||||
elif src.is_tiny and dest.is_cpu:
|
||||
# TODO: is there a better way?
|
||||
dest.resize_(src.numel()).resize_(src.shape)
|
||||
dest.copy_(torch.from_numpy(unwrap(src).cast(cast_dtype).numpy()))
|
||||
elif str(src.device) == "cpu" and str(dest.device) == "tiny":
|
||||
unwrap(dest).assign(Tensor(src.numpy()).cast(cast_dtype))
|
||||
elif src.is_cpu and dest.is_tiny:
|
||||
to_device = _from_torch_device(dest.device)
|
||||
unwrap(dest).assign(Tensor(src.numpy()).cast(cast_dtype).to(to_device))
|
||||
if realize: Tensor.realize(unwrap(dest))
|
||||
else:
|
||||
raise NotImplementedError(f"can't copy from {src.device} -> {dest.device}")
|
||||
@@ -341,6 +347,7 @@ def wrap_out(f):
|
||||
assigned = f(*args, **kwargs)
|
||||
if getenv("ALLOW_DTYPE_MISMATCH", 1): assigned = assigned.cast(out.dtype)
|
||||
assert out.shape == assigned.shape, f"shape mismatch: {assigned.shape} -> {out.shape}"
|
||||
assert out.device == assigned.device, f"device mismatch: {assigned.device} -> {out.device}"
|
||||
assert out.dtype == assigned.dtype, f"dtype mismatch: {assigned.dtype} -> {out.dtype}"
|
||||
if out.lazydata.is_realized: assigned = assigned.contiguous() # TODO: how does this map to torch's semantics
|
||||
return out.assign(assigned)
|
||||
@@ -462,7 +469,7 @@ def get_real_tinygrad_buffers():
|
||||
res = set()
|
||||
for mod in _torch_modules_with_buffers:
|
||||
for _,b in mod.named_buffers(recurse=False):
|
||||
if b is not None and str(b.device) == "tiny":
|
||||
if b is not None and b.is_tiny:
|
||||
res.add(unwrap(b))
|
||||
return res
|
||||
torch.nn.modules.module.register_module_buffer_registration_hook(register_torch_buffer)
|
||||
@@ -476,7 +483,7 @@ def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs):
|
||||
for state_dict in optimizer.state.values():
|
||||
for _, value in state_dict.items():
|
||||
if torch.is_tensor(value): tinygrad_tensors.append(value)
|
||||
real_tinygrad_tensors = [unwrap(x) for x in tinygrad_tensors if str(x.device) == "tiny"]
|
||||
real_tinygrad_tensors = [unwrap(x) for x in tinygrad_tensors if x.is_tiny]
|
||||
real_tinygrad_tensors += get_real_tinygrad_buffers()
|
||||
if len(real_tinygrad_tensors): Tensor.realize(*real_tinygrad_tensors)
|
||||
|
||||
|
||||
29
extra/torch_backend/test_multigpu.py
Normal file
29
extra/torch_backend/test_multigpu.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import unittest
|
||||
from tinygrad.helpers import getenv
|
||||
import torch
|
||||
import tinygrad.frontend.torch
|
||||
torch.set_default_device("tiny")
|
||||
import numpy as np
|
||||
|
||||
@unittest.skipIf(getenv("GPUS",1)<=1, "only single GPU")
|
||||
class TestTorchBackendMultiGPU(unittest.TestCase):
|
||||
def test_transfer(self):
|
||||
a = torch.Tensor([[1,2],[3,4]]).to("tiny:0")
|
||||
b = torch.Tensor([[3,2],[1,0]]).to("tiny:1")
|
||||
self.assertNotEqual(a.device, b.device)
|
||||
np.testing.assert_array_equal(a.cpu(), a.to("tiny:1").cpu())
|
||||
np.testing.assert_array_equal(b.cpu(), b.to("tiny:1").cpu())
|
||||
|
||||
def test_basic_ops(self):
|
||||
a = torch.Tensor([[1,2],[3,4]]).to("tiny:0")
|
||||
b = torch.Tensor([[3,2],[1,0]]).to("tiny:1")
|
||||
c1 = a + b.to("tiny:0")
|
||||
c2 = b + a.to("tiny:1")
|
||||
np.testing.assert_array_equal(c1.cpu(), torch.full((2,2),4).cpu())
|
||||
np.testing.assert_array_equal(c1.cpu(), c2.cpu())
|
||||
|
||||
# TODO: torch.distributed functions
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -109,7 +109,7 @@ int register_hook() {
|
||||
}
|
||||
int temp_register_hook = register_hook();
|
||||
|
||||
at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype) {
|
||||
at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype, c10::DeviceIndex device_index) {
|
||||
// TODO: we have to get the dtype and the shape from the tinygrad Tensor
|
||||
std::vector<int64_t> sizes = py_obj.attr("shape").cast<std::vector<int64_t>>();
|
||||
|
||||
@@ -127,7 +127,7 @@ at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype) {
|
||||
return at::detail::make_tensor<at::TinyOpaqueTensorImpl<std::shared_ptr<c10::SafePyObject>>>(
|
||||
at::DispatchKeySet(at::DispatchKey::PrivateUse1),
|
||||
c10::scalarTypeToTypeMeta(dtype),
|
||||
at::Device(at::kPrivateUse1),
|
||||
at::Device(at::kPrivateUse1, device_index),
|
||||
std::make_shared<c10::SafePyObject>(py_obj.release().ptr(), getPyInterpreter()),
|
||||
sizes, strides);
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ import numpy as np
|
||||
from typing import List, Callable
|
||||
import torch
|
||||
import warnings
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, DEVECTORIZE
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.device import is_dtype_supported
|
||||
@@ -1490,6 +1490,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=-1), lambda x: x.logcumsumexp(-1), atol=1e-7, grad_atol=1e-7)
|
||||
|
||||
@unittest.skipIf(not DEVECTORIZE, "broken without DEVECTORIZE. TODO: fix this")
|
||||
def test_logcumsumexp_numerical(self):
|
||||
helper_test_op(None, lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(), atol=1e-7, grad_atol=1e-7, vals=[[0.0, 100.0]])
|
||||
|
||||
|
||||
@@ -2527,5 +2527,12 @@ class TestUOpBecome(unittest.TestCase):
|
||||
assert b.lazydata.is_realized
|
||||
assert b.lazydata.base.buffer._base is None
|
||||
|
||||
def test_setitem_offset(self):
|
||||
a = Tensor.full((16,), 0.).contiguous().realize()
|
||||
b = Tensor.full((16,), 1.).contiguous().realize()
|
||||
a_view = a[4:].reshape(3, 4).shrink(((0,2),(0,2))).reshape((4,))
|
||||
b.shrink(((0,4),)).assign(a_view).realize()
|
||||
self.assertListEqual(b.tolist(), [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -414,7 +414,7 @@ class TestTinygrad(unittest.TestCase):
|
||||
|
||||
def test_tensor_dtype_errors(self):
|
||||
with self.assertRaises(AttributeError): Tensor([3], dtype="typo")
|
||||
with self.assertRaises(TypeError): Tensor([3], dtype=(dtypes.int,))
|
||||
with self.assertRaises(AttributeError): Tensor([3], dtype=(dtypes.int,))
|
||||
|
||||
def test_tensor_bytes(self):
|
||||
data = b"abc123"
|
||||
|
||||
@@ -1,86 +1,83 @@
|
||||
from typing import Optional, Any, Callable
|
||||
import functools, operator
|
||||
from typing import Optional, Any, Callable, cast
|
||||
import functools, operator, itertools
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, ImageDType, PtrDType
|
||||
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve
|
||||
from tinygrad.ops import graph_rewrite, GroupOp
|
||||
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym
|
||||
from tinygrad.helpers import getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
|
||||
from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# ***** float4/image store handling *****
|
||||
|
||||
def fold_expanded(ex, buf):
|
||||
new_srcs = dedup(list(ex.src))
|
||||
old_new_srcs = new_srcs[:]
|
||||
is_load, is_image = new_srcs[0].op is Ops.LOAD, isinstance(buf.dtype, ImageDType)
|
||||
|
||||
# TODO: get the device from the buffer somehow
|
||||
# NOTE: this can't be Device.DEFAULT because it opens devices
|
||||
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None
|
||||
lengths = [4] if is_image else ([8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))
|
||||
# ***** load/store grouping *****
|
||||
|
||||
def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||
# first, extract all the relevant offsets
|
||||
offsets_rootsrc: defaultdict[Any, dict] = defaultdict(dict)
|
||||
for i,s in enumerate(new_srcs):
|
||||
idx = s.src[0].src[1]
|
||||
if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue
|
||||
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
|
||||
for i in range(vec.dtype.count):
|
||||
idx = vec.gep(i).simplify()
|
||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
||||
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
|
||||
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
||||
else: root_src, arg = idx, 0
|
||||
# add gates for gated
|
||||
if len(s.src[0].src) == 3: root_src = (s.src[0].src[2], root_src)
|
||||
assert arg not in offsets_rootsrc[root_src], f"{offsets_rootsrc[root_src][arg]} != {i} with {len(s.src)} sources"
|
||||
offsets_rootsrc[root_src][arg] = i
|
||||
if mask is not None: root_src = (mask.gep(i).simplify(), root_src)
|
||||
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
|
||||
|
||||
# then rewrite everything we can
|
||||
used: set[tuple[UOp, UOp]] = set()
|
||||
# the buf.dtype is always a pointer
|
||||
ptrdtype = cast(PtrDType, buf.dtype)
|
||||
|
||||
# then rewrite everything we can into groups
|
||||
ret = []
|
||||
idxs: list[int|None] = [None]*vec.dtype.count
|
||||
global_offset = 0
|
||||
for rootsrc, offsets in offsets_rootsrc.items():
|
||||
for o in offsets:
|
||||
for fold_length in lengths:
|
||||
if all((rootsrc,o+i) not in used and o+i in offsets for i in range(fold_length)):
|
||||
load_1 = new_srcs[offsets[o]]
|
||||
new_src = list(load_1.src)
|
||||
oidx = new_src[0].src[1]
|
||||
if oidx.divides(fold_length) is None: continue
|
||||
if is_image:
|
||||
# for images, we rewrite the index. it must evenly divide 4 from the above check
|
||||
new_src[0] = buf.index(
|
||||
UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
|
||||
rootsrc[0] if isinstance(rootsrc, tuple) else None)
|
||||
else:
|
||||
# for non image, we upcast the index pointer
|
||||
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(size=new_src[0].dtype.size, local=new_src[0].dtype.local))
|
||||
# generate the folded new_srcs
|
||||
if is_load:
|
||||
new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
|
||||
for i in range(fold_length): new_srcs[offsets[o+i]] = new_load.gep(i)
|
||||
else: # vectorize the store
|
||||
new_src[1] = UOp(Ops.VECTORIZE, new_src[1].dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[1] for i in range(fold_length)))
|
||||
for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(Ops.STORE, dtypes.void, tuple(new_src)) if i == 0 else None
|
||||
used.update((rootsrc,o+i) for i in range(fold_length))
|
||||
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
|
||||
for grp in grouped_offsets:
|
||||
# get the index offset for this element. using [0] is okay, because they are the same
|
||||
oidx = vec.gep(offsets[grp[0]][0])
|
||||
lidx = UOp(Ops.INDEX, buf.dtype, (buf, oidx, rootsrc[0]) if mask is not None else (buf, oidx))
|
||||
if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, local=ptrdtype.local))
|
||||
# set the idxs of the output
|
||||
for i,g in enumerate(grp):
|
||||
for oo in offsets[g]: idxs[oo] = global_offset+i
|
||||
# add this lidx to the CAT
|
||||
ret.append(lidx)
|
||||
global_offset += len(grp)
|
||||
assert None not in idxs, f"some idxs are missing {idxs}"
|
||||
# this base thing is for image, we want the CAT to be a normal pointer
|
||||
return UOp(Ops.CAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret)).gep(tuple(cast(list[int], idxs)))
|
||||
|
||||
# dedup expand for LOAD
|
||||
if is_load and len(old_new_srcs) != len(ex.src): new_srcs = [new_srcs[old_new_srcs.index(s)] for s in ex.src]
|
||||
# remove Nones for STORE
|
||||
return UOp(ex.op, ex.dtype, tuple(x for x in new_srcs if x is not None), ex.arg) if len(used) else None
|
||||
def cat_after_store(cat:UOp, data:UOp):
|
||||
# TODO: this is written in many places
|
||||
offset = 0
|
||||
ret = []
|
||||
for s in cat.src:
|
||||
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count)))))
|
||||
offset += s.dtype.count
|
||||
return UOp.sink(ret[0], *ret[1:])
|
||||
|
||||
def fix_unfoldable_image_load(load:UOp, buf:UOp):
|
||||
if not isinstance(buf.dtype, ImageDType) or (oidx:=load.src[0].src[1]).dtype.count == 2: return None
|
||||
id4 = oidx % 4
|
||||
new_src = list(load.src)
|
||||
# TODO: copied logic from above
|
||||
new_src[0] = load.src[0].src[0].index(
|
||||
UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
|
||||
load.src[0].src[2] if len(load.src[0].src) == 3 else None)
|
||||
vec_load = UOp(Ops.LOAD, load.dtype.vec(4), tuple(new_src))
|
||||
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), load.const_like(float('nan')))
|
||||
def gep_on_store(gep:UOp, st:UOp):
|
||||
# NOTE: we need to invert the gep here, but it may be an expanding gep
|
||||
# fake argsort. TODO: handle duplicates
|
||||
a = {}
|
||||
for i,x in enumerate(gep.arg): a[x] = i
|
||||
new_arg = tuple(x[1] for x in sorted(a.items()))
|
||||
return UOp(Ops.STORE, src=(gep.src[0], st.gep(new_arg)))
|
||||
|
||||
buf_idx_pat = UPat(Ops.INDEX, src=(UPat.var("buf"),), allow_any_len=True)
|
||||
float4_folding = PatternMatcher([
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
|
||||
(UPat((Ops.BARRIER, Ops.SINK), src=UPat(Ops.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
|
||||
load_store_folding = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL), name="buf")), UPat.var("vec"))), expand_index),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL), name="buf")), UPat.var("vec"),
|
||||
UPat.var("mask"))), expand_index),
|
||||
# GEP after LOAD
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
|
||||
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
|
||||
# GEP on data of STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st"))), gep_on_store),
|
||||
# put CAT after LOAD
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.CAT, name="cat"),), name="ld", allow_any_len=True),
|
||||
lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
|
||||
# put CAT after STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.CAT, name="cat"), UPat(name="data"))), cat_after_store),
|
||||
])
|
||||
|
||||
# ***** image load valid simplification *****
|
||||
@@ -140,6 +137,64 @@ def get_late_rewrite_patterns(ops, force_transcendental=False):
|
||||
if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
|
||||
return PatternMatcher(pat)
|
||||
|
||||
# *** correct load/store ***
|
||||
|
||||
def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
if (sz:=ls.src[0].dtype.count) == 1: return None
|
||||
lengths = []
|
||||
buf = idx.src[0]
|
||||
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
|
||||
pass
|
||||
elif isinstance(buf.dtype, ImageDType):
|
||||
lengths = [4]
|
||||
elif ctx is not None and ctx.supports_float4:
|
||||
# TODO: a better way to get this than ctx
|
||||
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2])
|
||||
lengths.append(1) # worst case, it's not folded
|
||||
ptrdtype = cast(PtrDType, buf.dtype)
|
||||
global_offset = 0
|
||||
ret = []
|
||||
while global_offset < sz:
|
||||
for fold_length in lengths:
|
||||
if global_offset+fold_length > sz: continue
|
||||
oidx = idx.src[1] + global_offset
|
||||
if oidx.simplify().divides(fold_length) is None: continue
|
||||
lidx = buf.index(oidx, idx.src[2] if len(idx.src) > 2 else None)
|
||||
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, local=ptrdtype.local))
|
||||
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
|
||||
else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
|
||||
global_offset += fold_length
|
||||
break
|
||||
if len(ret) == 1: return None
|
||||
return UOp(Ops.CAT, ls.dtype, tuple(ret))
|
||||
|
||||
def image_fixup(ls:UOp):
|
||||
# normal image load or store, with the CAST from expand_index
|
||||
if ls.src[0].op is Ops.CAST and isinstance(image_dtype:=ls.src[0].src[0].dtype, ImageDType):
|
||||
assert ls.src[0].dtype.count == 4, "image must be casted to 4"
|
||||
idx = ls.src[0].src[0]
|
||||
oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
|
||||
idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
|
||||
return ls.replace(src=(idx,)+ls.src[1:])
|
||||
|
||||
# this is an unprocessed image without a cast, aka unfoldable image load. this doesn't work for stores
|
||||
if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].dtype != dtypes.int.vec(2):
|
||||
assert ls.op is Ops.LOAD, "if an image store isn't upcasted to 4, we can't store it"
|
||||
idx = ls.src[0]
|
||||
id4 = idx.src[1] % 4
|
||||
oidx = UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((idx.src[1] // 4) % image_dtype.shape[1], (idx.src[1] // (4*image_dtype.shape[1]))))
|
||||
idx = idx.replace(src=(idx.src[0], oidx)+idx.src[2:])
|
||||
vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:])
|
||||
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), ls.const_like(float('nan')))
|
||||
|
||||
return None
|
||||
|
||||
correct_load_store = PatternMatcher([
|
||||
# split LOAD/STORE
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(Ops.CAST, src=(UPat(Ops.INDEX, name="idx"),)),), name="ls", allow_any_len=True), split_load_store),
|
||||
# image indexing, including unfoldable images
|
||||
(UPat((Ops.LOAD, Ops.STORE), name="ls"), image_fixup),
|
||||
])
|
||||
|
||||
# *** uop expander ***
|
||||
|
||||
@@ -160,13 +215,6 @@ def no_vectorized_alu(alu):
|
||||
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
|
||||
return UOp(Ops.VECTORIZE, alu.dtype, alus)
|
||||
|
||||
def no_vectorized_load_store(ls:UOp):
|
||||
idx = ls.src[0]
|
||||
assert isinstance(idx.dtype, PtrDType)
|
||||
if idx.dtype.v == 1: return None
|
||||
tv = [UOp(ls.op, ls.dtype.scalar(), tuple(j.gep(i) for j in ls.src)) for i in range(idx.dtype.v)]
|
||||
return UOp(Ops.VECTORIZE, ls.dtype, tuple(tv))
|
||||
|
||||
def no_vectorized_acc(acc:UOp):
|
||||
if acc.dtype.count == 1: return None
|
||||
alus = tuple(UOp(acc.op, acc.dtype.scalar(),
|
||||
@@ -175,16 +223,9 @@ def no_vectorized_acc(acc:UOp):
|
||||
|
||||
devectorize = PatternMatcher([
|
||||
# no ALU on vectorized dtypes
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu),
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN), name="alu"), no_vectorized_alu),
|
||||
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
|
||||
(UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc),
|
||||
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
|
||||
])
|
||||
|
||||
devectorize_load_store = PatternMatcher([
|
||||
# TODO: add vectorized support to transcendental
|
||||
(UPat((Ops.INDEX), name="alu"), no_vectorized_alu),
|
||||
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
|
||||
])
|
||||
|
||||
def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
|
||||
@@ -193,8 +234,6 @@ def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|N
|
||||
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val)
|
||||
|
||||
load_store_indexing = PatternMatcher([
|
||||
# late fixup of unfoldable image loads
|
||||
(UPat(Ops.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
|
||||
# simplify valid
|
||||
(UPat(Ops.AND, name="valid"), simplify_valid),
|
||||
# image load valid idx simplification
|
||||
@@ -231,12 +270,10 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
|
||||
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
|
||||
|
||||
if DEVECTORIZE:
|
||||
# devectorize + load_store_indexing
|
||||
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing)
|
||||
else:
|
||||
# new devectorize only for load/store
|
||||
sink = graph_rewrite(sink, sym+devectorize_load_store)
|
||||
# devectorize is optional
|
||||
if DEVECTORIZE >= 2: sink = graph_rewrite(sink, sym+load_store_folding+load_store_indexing, ctx=opts)
|
||||
elif DEVECTORIZE: sink = graph_rewrite(sink, sym+devectorize+load_store_folding+correct_load_store+load_store_indexing, ctx=opts)
|
||||
else: sink = graph_rewrite(sink, sym+load_store_folding+correct_load_store+load_store_indexing, ctx=opts)
|
||||
|
||||
# optional pre matcher
|
||||
if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher)
|
||||
|
||||
@@ -230,6 +230,17 @@ symbolic = symbolic_simple+PatternMatcher([
|
||||
# ** mod **
|
||||
# mod folding
|
||||
(UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
|
||||
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
|
||||
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
|
||||
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
|
||||
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
|
||||
# push all GEPs through ALUs (fix arange stuff)
|
||||
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
|
||||
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
|
||||
if not isinstance(gep.dtype, PtrDType) else None),
|
||||
])
|
||||
|
||||
symbolic_flat = symbolic+PatternMatcher([
|
||||
@@ -399,17 +410,6 @@ sym = symbolic_flat+PatternMatcher([
|
||||
# VECTORIZE void is SINK
|
||||
(UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b),
|
||||
(UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
|
||||
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
|
||||
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
|
||||
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
|
||||
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
|
||||
# push all GEPs through ALUs (fix arange stuff)
|
||||
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
|
||||
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
|
||||
if not isinstance(gep.dtype, PtrDType) else None),
|
||||
# push some GEPs through WMMAs
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
|
||||
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
|
||||
|
||||
@@ -156,7 +156,7 @@ if (env_default_float := getenv("DEFAULT_FLOAT", "")):
|
||||
assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
|
||||
|
||||
DTypeLike = Union[str, DType]
|
||||
def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype)
|
||||
def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype.lower())
|
||||
|
||||
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
|
||||
# we don't support weak type and complex type
|
||||
@@ -180,7 +180,7 @@ def sum_acc_dtype(dt:DType):
|
||||
# default acc dtype for sum
|
||||
if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
|
||||
if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
|
||||
return least_upper_dtype(dt, dtypes.float)
|
||||
return least_upper_dtype(dt, to_dtype(getenv("SUM_DTYPE", "float32")))
|
||||
|
||||
def truncate_fp16(x):
|
||||
try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
|
||||
|
||||
@@ -106,6 +106,9 @@ sym = symbolic_simple+PatternMatcher([
|
||||
# put CAST after expanding BUFFER
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="v"), lambda x,v: x.view(x.st+v.st).cast(v.dtype) if getenv("CAST_AFTER_EXPAND")
|
||||
and x.base.op is Ops.BUFFER and resolve(prod(v.shape) > prod(x.shape)) else None),
|
||||
# remove CONST/BIND/VIEW from SINK
|
||||
(UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src)
|
||||
if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST,Ops.BIND}))) != x.src else None),
|
||||
])
|
||||
|
||||
# **** UOp realization
|
||||
@@ -259,11 +262,12 @@ create_kernels = merge_views+PatternMatcher([
|
||||
lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None),
|
||||
# walk back the local graph until we reach a buffer/assign parent
|
||||
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
|
||||
# remove CONST/BIND from SINK
|
||||
(UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src)
|
||||
if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST,Ops.BIND}))) != x.src else None),
|
||||
# remove downstream reshapes from SINK
|
||||
(UPat(Ops.SINK, name="x"), lambda x:x.replace(src=tuple(s.base for s in x.src)) if any(s.op is Ops.VIEW for s in x.src) else None),
|
||||
])
|
||||
|
||||
DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK}
|
||||
|
||||
# **** fix kernel AST
|
||||
|
||||
# ** create buffer ops + enumerate buffers
|
||||
@@ -273,8 +277,18 @@ add_buffer_ops = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), x.st.to_uop()))),
|
||||
# STORE (except for COPY/BUFFER_VIEW)
|
||||
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
|
||||
# partial assign can store to a non-contiguous ShapeTracker
|
||||
(UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)),
|
||||
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()),
|
||||
# otherwise the store is contiguous
|
||||
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
|
||||
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
|
||||
# if the last child is a VIEW we merge the ShapeTrackers and store the base
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))),
|
||||
lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)),
|
||||
# remove CONTIGUOUS/DEVICE from kernel AST
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
||||
])
|
||||
|
||||
# ** push views to buffer ops
|
||||
@@ -290,23 +304,24 @@ def swizzle_reduceop(r:UOp, src:UOp, view:UOp):
|
||||
strides = strides_for_shape(rshape)
|
||||
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
||||
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
|
||||
# update input_st and axis
|
||||
# create a new reduceop for the swizzled input
|
||||
new_input_st = tmp + ShapeTracker(tuple(nv))
|
||||
new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg)))
|
||||
return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
|
||||
return UOp(Ops.REDUCE_AXIS, r.dtype, (apply_swizzle(src.view(src.arg+new_input_st if src.op is Ops.VIEW else new_input_st)),),
|
||||
(r.arg[0], new_axis)).view(ShapeTracker.from_shape(st.shape))
|
||||
|
||||
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
|
||||
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
|
||||
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u)).view(ShapeTracker.from_shape(r.shape))
|
||||
|
||||
def elementwise_view_right(root:UOp) -> UOp|None:
|
||||
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW]): return None
|
||||
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in DONT_PUSH_VIEWS]): return None
|
||||
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
|
||||
# place view after applying the elementwise op
|
||||
new_shape = swizzles[0].base.shape
|
||||
ret = root.replace(src=tuple(x.base if x.base.shape == new_shape else apply_swizzle(x.view(ShapeTracker.from_shape(new_shape))) for x in root.src))
|
||||
new_st = ShapeTracker.from_shape(swizzles[0].base.shape)
|
||||
new_src = [x.base if x.base.shape==new_st.shape else apply_swizzle(x.view(x.arg+new_st) if x.op is Ops.VIEW else x.view(new_st)) for x in root.src]
|
||||
# reshape to match downstream shapes
|
||||
return ret.reshape(root.shape)
|
||||
return root.replace(src=tuple(new_src)).reshape(root.shape)
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
||||
@@ -315,17 +330,12 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
|
||||
# push VIEW to children
|
||||
view_right = merge_views+PatternMatcher([
|
||||
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
|
||||
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
|
||||
# STORE is the last child, so we just merge the ShapeTrackers and store the base
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
|
||||
# push a non contiguous ShapeTracker through reduceop
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
|
||||
# apply view after reduceops
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat.var("src"),), name="v"),), name="r"), reduceop_view_right),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="src"),), name="v"),), name="r"), reduceop_view_right),
|
||||
# apply view after elementwise ops
|
||||
(UPat(GroupOp.All-GroupOp.Buffer, name="root"), elementwise_view_right),
|
||||
(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="root"), elementwise_view_right),
|
||||
# double reduce op collapses to a single reduce op
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
@@ -359,9 +369,6 @@ def check_load_st(glbl:UOp, view:UOp):
|
||||
fix_kernel_ops = PatternMatcher([
|
||||
# BIND in shapetracker becomes DEFINE_VAR
|
||||
(UPat(Ops.VIEW, name="x"), unbind_shapetracker),
|
||||
# remove CONTIGUOUS/DEVICE
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
|
||||
# remove unmasked valid
|
||||
(UPat.where(UPat(Ops.VALID, name="valid"), UPat.cvar("x"), UPat()), lambda valid,x: x if all(v.mask is None for v in valid.st.views) else None),
|
||||
# no ImageDType after load
|
||||
@@ -378,15 +385,15 @@ def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
|
||||
if s.op is Ops.ASSIGN:
|
||||
for out in s.src[1].arg.ast.src: parents_rep[out] = s.buf_uop.view(unwrap(out.st))
|
||||
ast = k.arg.ast.substitute(parents_rep)
|
||||
# add buffer ops
|
||||
ast = graph_rewrite(ast, add_buffer_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True)
|
||||
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
|
||||
# unbind_vars + push views to edges
|
||||
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
|
||||
# fix_kernel_ops
|
||||
ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
|
||||
# add buffer ops
|
||||
ast = graph_rewrite(ast, view_left+add_buffer_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True)
|
||||
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
|
||||
# create subbuffer (TODO: this does not belong here)
|
||||
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
# fix_kernel_ops
|
||||
ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
|
||||
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
||||
|
||||
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
|
||||
|
||||
@@ -191,8 +191,8 @@ class ClangRenderer(CStyleLanguage):
|
||||
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2]}),
|
||||
Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"}
|
||||
# LLVM legalizes double => half cast on systems that don't support it natively (like x86 cpus without AVX512-FP16) into a compiler-rt libcall.
|
||||
extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16))]) + \
|
||||
CStyleLanguage.extra_matcher
|
||||
extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16)),
|
||||
(UPat(Ops.SQRT, name="alu"), no_vectorized_alu),]) + CStyleLanguage.extra_matcher
|
||||
|
||||
if sys.platform == 'win32':
|
||||
kernel_prefix = "__attribute__((ms_abi)) "
|
||||
|
||||
@@ -163,6 +163,9 @@ class AM_GFX(AM_IP):
|
||||
self.adev.regTCP_CNTL.write(self.adev.regTCP_CNTL.read() | 0x20000000)
|
||||
self.adev.regRLC_SRM_CNTL.update(srm_enable=1, auto_incr_addr=1)
|
||||
|
||||
self.adev.regS2A_DOORBELL_ENTRY_0_CTRL.write(s2a_doorbell_port0_enable=1, s2a_doorbell_port0_awid=0x3, s2a_doorbell_port0_awaddr_31_28_value=0x3)
|
||||
self.adev.regS2A_DOORBELL_ENTRY_3_CTRL.write(s2a_doorbell_port3_enable=1, s2a_doorbell_port3_awid=0x6, s2a_doorbell_port3_awaddr_31_28_value=0x3)
|
||||
|
||||
self.adev.regGRBM_CNTL.update(read_timeout=0xff)
|
||||
for i in range(0, 16):
|
||||
self._grbm_select(vmid=i)
|
||||
@@ -297,6 +300,9 @@ class AM_IH(AM_IP):
|
||||
for _, rwptr_vm, suf, ring_id in self.rings:
|
||||
self.adev.reg(f"regIH_RB_CNTL{suf}").update(rb_enable=1, **({'enable_intr': 1} if ring_id == 0 else {}))
|
||||
|
||||
self.adev.regS2A_DOORBELL_ENTRY_1_CTRL.update(s2a_doorbell_port1_enable=1, s2a_doorbell_port1_awid=0x0, s2a_doorbell_port1_awaddr_31_28_value=0x0,
|
||||
s2a_doorbell_port1_range_offset=am.AMDGPU_NAVI10_DOORBELL_IH*2, s2a_doorbell_port1_range_size=2)
|
||||
|
||||
class AM_SDMA(AM_IP):
|
||||
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, doorbell:int, pipe:int, queue:int):
|
||||
# Setup the ring
|
||||
@@ -320,6 +326,8 @@ class AM_SDMA(AM_IP):
|
||||
self.adev.reg(f"regSDMA{pipe}_UTCL1_PAGE").update(rd_l2_policy=0x2, wr_l2_policy=0x3, llc_noalloc=1) # rd=noa, wr=bypass
|
||||
self.adev.reg(f"regSDMA{pipe}_F32_CNTL").update(halt=0, th1_reset=0)
|
||||
self.adev.reg(f"regSDMA{pipe}_CNTL").update(ctxempty_int_enable=1, trap_enable=1)
|
||||
self.adev.regS2A_DOORBELL_ENTRY_2_CTRL.update(s2a_doorbell_port2_enable=1, s2a_doorbell_port2_awid=0xe, s2a_doorbell_port2_awaddr_31_28_value=0x3,
|
||||
s2a_doorbell_port2_range_offset=am.AMDGPU_NAVI10_DOORBELL_sDMA_ENGINE0*2, s2a_doorbell_port2_range_size=4)
|
||||
|
||||
def fini(self):
|
||||
self.adev.regSDMA0_QUEUE0_RB_CNTL.update(rb_enable=0)
|
||||
|
||||
@@ -246,11 +246,12 @@ class HCQSignal(Generic[DeviceType]):
|
||||
|
||||
Args:
|
||||
value: The value to wait for.
|
||||
timeout: Maximum time to wait in milliseconds. Defaults to 10s.
|
||||
timeout: Maximum time to wait in milliseconds. Defaults to 30s.
|
||||
"""
|
||||
start_time = int(time.perf_counter() * 1000)
|
||||
while self.value < value and (time_spent:=int(time.perf_counter() * 1000) - start_time) < timeout:
|
||||
while (prev_value:=self.value) < value and (time_spent:=int(time.perf_counter() * 1000) - start_time) < timeout:
|
||||
self._sleep(time_spent)
|
||||
if self.value != prev_value: start_time = int(time.perf_counter() * 1000) # progress was made, reset timer
|
||||
if self.value < value: raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
||||
Reference in New Issue
Block a user