Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2025-01-21 13:06:04 -08:00
37 changed files with 578 additions and 384 deletions

View File

@@ -298,7 +298,7 @@ jobs:
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model kernel count and gate usage
run: |
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2104 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
PYTHONPATH="." ALLOWED_KERNEL_COUNT=209 ALLOWED_READ_IMAGE=2105 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot alt model correctness (float32)
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx

View File

@@ -27,7 +27,7 @@ AM binds compute queues directly to MEC (bypassing MES). Tinygrad uses only one
The GPU being passed can be in one of several states:
1. Not initialized
2. Initialized by AMDGPU
2. Initialized by amdgpu
3. Initialized by AM
The first and second states require a full GPU setup since their states are unknown. The second state also requires a mode1 reset to reinitialize all components.
@@ -36,4 +36,4 @@ The third state can be set up partially to optimize boot time. In this case, onl
### VM Management
Each AM device sets up only a single `VMID=0` and one page directory. The page directory used is 3-level and thus supports up to 512TB of virtual addresses. All AM devices are located in one virtual address space.
Each AM device sets up only a single `VMID=0` and one page directory. The page directory used is 3-level and thus supports up to 512GB of virtual addresses. All AM devices are located in one virtual address space.

View File

@@ -1,14 +1,12 @@
import sys, onnx, time
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, fetch
from tinygrad.tensor import _from_np_dtype
from extra.onnx import get_run_onnx
from extra.onnx import OnnxRunner
def load_onnx_model(fn):
onnx_file = fetch(fn)
onnx_model = onnx.load(onnx_file)
Tensor.no_grad = True
Tensor.training = False
run_onnx = get_run_onnx(onnx_model)
run_onnx = OnnxRunner(onnx_model)
# find preinitted tensors and ignore them
initted_tensors = {inp.name:None for inp in onnx_model.graph.initializer}

View File

@@ -8,7 +8,7 @@ import numpy as np
import subprocess
import tensorflow as tf
import tf2onnx
from extra.onnx import get_run_onnx
from extra.onnx import OnnxRunner
from tinygrad.tensor import Tensor
from extra.export_model import export_model_clang, compile_net, jit_model
@@ -25,7 +25,7 @@ class TinyOnnx:
def __init__(self, keras_model):
input_signature = [tf.TensorSpec([1,32], tf.float32, name='x')]
onnx_model, _ = tf2onnx.convert.from_keras(keras_model, input_signature, opset=13)
self.run_onnx = get_run_onnx(onnx_model)
self.run_onnx = OnnxRunner(onnx_model)
def forward(self, x):
return self.run_onnx({"x": x}, debug=False)['predictions']

View File

@@ -848,9 +848,9 @@ def train_bert():
model = get_mlperf_bert_model(init_ckpt if RUNMLPERF else None)
for _, x in get_state_dict(model).items():
x.realize().to_(GPUS)
parameters = get_parameters(model)
for p in parameters:
p.to_(GPUS)
# ** Log run config **
for key, value in config.items(): print(f'HParam: "{key}": {value}')

View File

@@ -12,17 +12,14 @@ from tinygrad.engine.realize import CompiledRunner
import onnx
from onnx.helper import tensor_dtype_to_np_dtype
from extra.onnx import get_run_onnx # TODO: port to main tinygrad
from extra.onnx import OnnxRunner # TODO: port to main tinygrad
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx"
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl"
def compile(onnx_file):
onnx_model = onnx.load(onnx_file)
Tensor.no_grad = True
Tensor.training = False
run_onnx = get_run_onnx(onnx_model)
run_onnx = OnnxRunner(onnx_model)
print("loaded model")
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}

View File

@@ -3,7 +3,7 @@ import os
from ultralytics import YOLO
import onnx
from pathlib import Path
from extra.onnx import get_run_onnx
from extra.onnx import OnnxRunner
from tinygrad.tensor import Tensor
os.chdir("/tmp")
@@ -14,5 +14,5 @@ onnx_model = onnx.load(open("yolov8n-seg.onnx", "rb"))
# TODO: move get example inputs to onnx
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
print(input_shapes)
run_onnx = get_run_onnx(onnx_model)
run_onnx = OnnxRunner(onnx_model)
run_onnx({"images": Tensor.zeros(1,3,480,640)}, debug=True)

View File

@@ -1,53 +1,42 @@
from typing import Callable, Any, Sequence
import importlib, functools
import numpy as np
from tinygrad import Tensor, dtypes
import importlib, functools, dataclasses
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, DEBUG, all_same
from tinygrad.dtype import DType, ConstType
from tinygrad.dtype import DType, ConstType, dtypes
from tinygrad.device import is_dtype_supported
from onnx import AttributeProto, ModelProto, TensorProto, ValueInfoProto, helper
from google.protobuf.json_format import MessageToDict
cache_misses = 0
@functools.lru_cache(None)
def _cached_to_python_const(t:Tensor):
if t.dtype is dtypes.uint8: return t.data().tobytes()
if 0 in t.shape: return []
return t.tolist()
# ***** protobuf parsing ******
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto, helper
import numpy as np
# Tensor -> python value cache for parameters
def to_python_const(t) -> list[ConstType]|ConstType|bytes:
if not isinstance(t, Tensor): return t
global cache_misses
ret = _cached_to_python_const(t)
if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3:
print(f"Cache miss for {t}")
cache_misses = info.misses
return ret
# TODO: use real float16
# src: onnx/mapping.py
DTYPE_MAP: dict[int, DType] = {
TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8,
TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64,
TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32,
TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16,
}
def dtype_parse(onnx_dtype: int) -> DType:
if onnx_dtype not in DTYPE_MAP: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported")
return DTYPE_MAP[onnx_dtype] if is_dtype_supported(DTYPE_MAP[onnx_dtype]) else dtypes.float
supported: dict[int, DType] = {
TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8,
TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64,
TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32,
TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16,
}
unsupported = {
TensorProto.UNDEFINED, TensorProto.STRING, TensorProto.COMPLEX64, TensorProto.COMPLEX128, TensorProto.FLOAT8E4M3FN, TensorProto.FLOAT8E4M3FNUZ,
TensorProto.FLOAT8E5M2, TensorProto.FLOAT8E5M2FNUZ, TensorProto.UINT4, TensorProto.INT4
}
if onnx_dtype in unsupported: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported")
return supported[onnx_dtype] if is_dtype_supported(supported[onnx_dtype]) else dtypes.float
# src: onnx/onnx_ml_pb2.pyi
ATTRIBUTE_MAP: dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = {
AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i),
AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t),
AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints),
AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings)
}
def attribute_parse(onnx_attribute: AttributeProto):
if onnx_attribute.type not in ATTRIBUTE_MAP:
supported: dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = {
AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i),
AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t),
AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints),
AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings)
}
unsupported = {
AttributeProto.UNDEFINED, AttributeProto.GRAPH, AttributeProto.SPARSE_TENSOR, AttributeProto.TYPE_PROTO, AttributeProto.TENSORS,
AttributeProto.GRAPHS, AttributeProto.SPARSE_TENSORS, AttributeProto.TYPE_PROTOS
}
if onnx_attribute.type in unsupported:
raise NotImplementedError(f"attribute with type {AttributeProto.AttributeType.Name(onnx_attribute.type)} is not supported")
return ATTRIBUTE_MAP[onnx_attribute.type](onnx_attribute)
return supported[onnx_attribute.type](onnx_attribute)
def buffer_parse(onnx_tensor: TensorProto) -> Tensor:
if onnx_tensor.string_data: raise NotImplementedError("Parsing for buffer with string data is not implemented.")
@@ -62,116 +51,137 @@ def buffer_parse(onnx_tensor: TensorProto) -> Tensor:
return Tensor(np_buffer, dtype=dtype)
return Tensor(None)
onnx_ops = importlib.import_module('extra.onnx_ops')
ONNXLIMIT = getenv("ONNXLIMIT", -1)
def get_run_onnx(onnx_model: ModelProto):
# model initialization data
model_tensors = {inp.name:buffer_parse(inp) for inp in onnx_model.graph.initializer}
model_expected_inputs = {inp.name:inp for inp in onnx_model.graph.input if inp.name not in model_tensors}
model_attributes = {num:{x.name:attribute_parse(x) for x in n.attribute} for num,n in enumerate(onnx_model.graph.node)}
def type_parse(onnx_type: TypeProto):
elem_type = onnx_type
if elem_type.HasField("map_type") or elem_type.HasField("sparse_tensor_type") or elem_type.HasField("opaque_type"):
raise NotImplementedError("parsing for map_type, sparse_tensor_type and opaque_type are not implemented")
if is_optional := elem_type.HasField("optional_type"): elem_type = elem_type.optional_type.elem_type
if is_sequence := elem_type.HasField("sequence_type"): elem_type = elem_type.sequence_type.elem_type
if elem_type.HasField("tensor_type"):
shape = tuple(d.dim_param or d.dim_value for d in elem_type.tensor_type.shape.dim)
dtype = dtype_parse(elem_type.tensor_type.elem_type)
return OnnxValue(shape, dtype, is_optional, is_sequence)
raise RuntimeError(f"TypeProto was not parsed properly: {onnx_type=}")
# model descriptions
# TODO: need a better way of controlling training vs non-training
is_onnx_preview_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in onnx_model.graph.node)
onnx_model_version = onnx_model.opset_import[0].version
# ***** onnx spec *****
@dataclasses.dataclass(frozen=True)
class OnnxValue:
shape: tuple[str|int]
dtype: DType
is_optional: bool
is_sequence: bool
# used to check validity of user_input according to their dimension variables
variable_dims = {}
@dataclasses.dataclass(frozen=True)
class OnnxNode:
num: int
op: str
inputs: tuple[str]
outputs: tuple[str]
opts: dict[str, Any]
# mapping from onnx ops to tensor.py ops
tensor_methods = {
op:op.lower() for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan",
"Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh",
"Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")
}
# ***** python const *****
required_input_python_consts: dict[str, tuple[int, ...]] = {
"Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,),
"CumSum": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,),
"ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4),
**{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")},
**{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")}
}
# these values are expected to be python consts
required_input_python_consts: dict[str, tuple[int, ...]] = {
"Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,),
"CumSum": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,),
"ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4),
**{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")},
**{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")}
}
cache_misses = 0
@functools.lru_cache(None)
def _cached_to_python_const(t:Tensor):
if t.dtype is dtypes.uint8: return t.data().tobytes()
if 0 in t.shape: return []
return t.tolist()
# src: https://onnx.ai/onnx/repo-docs/IR.html#input-output-data-types
# parses and validates inputs based on their shape and dtype specified by model
def prepare_input(user_input:Any, model_input:ValueInfoProto):
type_proto = model_input.type
if type_proto.HasField("optional_type"):
if user_input is None: return None
type_proto = type_proto.optional_type.elem_type
if type_proto.HasField("sequence_type"):
if not isinstance(user_input, Sequence): raise RuntimeError(f"{model_input.name} received {user_input}, expected sequence type")
dtype = dtype_parse(type_proto.sequence_type.elem_type.tensor_type.elem_type)
sequence = [Tensor(i, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(i, Tensor) else i for i in user_input]
if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"shapes for {model_input.name} must be homogeneous")
# TODO: need true float16 for dtype checking
# if not all(t.dtype is dtype for t in sequence):
# raise RuntimeError(f"{model_input.name} has dtype mismatch for sequence type. Expected {dtype}, received {tensor.dtype}.")
# Tensor -> python value cache for parameters
def to_python_const(t:Any, op:str, idx:int) -> list[ConstType]|ConstType|bytes:
if idx not in required_input_python_consts.get(op, ()) or not isinstance(t, Tensor): return t
global cache_misses
ret = _cached_to_python_const(t)
if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3:
print(f"Cache miss for {t}")
cache_misses = info.misses
return ret
# ***** runner ******
debug = int(getenv("DEBUGONNX", "0"))
limit = int(getenv("ONNXLIMIT", "-1"))
class OnnxRunner:
def __init__(self, model: ModelProto):
# parse model protobuf
self.is_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in model.graph.node)
self.old_training, self.old_no_grad = Tensor.training, Tensor.no_grad
Tensor.training = True if self.is_training else False
Tensor.no_grad = False if self.is_training else True
self.graph_values = {x.name:buffer_parse(x) for x in model.graph.initializer}
self.graph_inputs = {x.name:type_parse(x.type) for x in model.graph.input if x.name not in self.graph_values}
self.graph_outputs = {x.name:type_parse(x.type) for x in model.graph.output}
self.graph_nodes = tuple(OnnxNode(num, n.op_type, tuple(n.input), tuple(n.output), {x.name:attribute_parse(x) for x in n.attribute})
for num,n in enumerate(model.graph.node))
self.opset_version = model.opset_import[0].version
self.variable_dims: dict[str, int] = {}
# TODO: move extra.onnx_ops here so we don't have to deal with annoying circular import
# TODO: clean up opset stuff after moving extra.onnx_ops here
self.onnx_ops_module = importlib.import_module('extra.onnx_ops')
self.onnx_ops = {
**{op: getattr(Tensor, op.lower()) for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan",
"Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh",
"Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")},
}
def _parse_input(self, name: str, value: Any, spec: OnnxValue):
if spec.is_optional and value is None: return None
# TODO: need true float16 for dtype checking
if spec.is_sequence:
if not isinstance(value, Sequence): raise RuntimeError(f"{name} received {value}, expected a sequence type")
sequence = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value]
if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"Shapes for {name} sequence must be homogeneous")
return sequence
if type_proto.HasField("tensor_type"):
dtype = dtype_parse(type_proto.tensor_type.elem_type)
tensor = Tensor(user_input, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(user_input, Tensor) else user_input
# TODO: need true float16 for dtype checking
# if dtype is not tensor.dtype: raise RuntimeError(f"{model_input.name} has mismatch for dtype. Expected {dtype}, received {tensor.dtype}.")
for dim, onnx_dim in enumerate(type_proto.tensor_type.shape.dim):
dim_param, dim_value = onnx_dim.dim_param, onnx_dim.dim_value
user_dim_input = tensor.shape[dim]
if dim_param: dim_value = variable_dims[dim_param] if dim_param in variable_dims else variable_dims.setdefault(dim_param, user_dim_input)
if user_dim_input != dim_value:
raise RuntimeError(f"{model_input.name} has mismatch for dim={dim_param or dim}. Expected {dim_value}, received {user_dim_input}.")
return tensor
type_field_names = [field.name for field,_ in type_proto.ListFields()]
raise NotImplementedError(f"{model_input.name} with {type_field_names=} is not supported")
tensor = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(value, Tensor) else value
for dim, (onnx_dim, user_dim_input) in enumerate(zip(spec.shape, tensor.shape, strict=True)):
if isinstance(onnx_dim, str):
onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input))
if user_dim_input != onnx_dim: raise RuntimeError(f"{name} has mismatch on {dim=}. Expected {onnx_dim}, received {user_dim_input}.")
return tensor
def run_onnx(inputs={}, debug=0):
debug = getenv("DEBUGONNX") or debug
if debug >= 3: print("Model initialization data:\n" + "\n".join(f"\t{i.name} - {model_tensors[i.name]}" for i in onnx_model.graph.initializer))
def _dispatch_op(self, op, inps, opts):
if op in self.onnx_ops: return self.onnx_ops[op](*inps, **opts)
if hasattr(self.onnx_ops_module, op):
fxn = getattr(self.onnx_ops_module, op)
if isinstance(fxn, dict):
for k in sorted(fxn.keys()):
if k <= self.opset_version:
real_fxn = fxn[k]
else: real_fxn = fxn
return real_fxn(*inps, **opts)
raise NotImplementedError(f"{op=} not supported")
if debug >= 1: print("Model input:")
for name, value_info in model_expected_inputs.items():
def __call__(self, inputs:dict[str, Any], debug=debug):
for name, input_spec in self.graph_inputs.items():
if name not in inputs: raise RuntimeError(f"Please provide input data for {name}")
model_tensors[name] = prepare_input(inputs[name], value_info)
if debug >= 1: print(f"\t{name} - {model_tensors[name]}")
if debug >= 2: print(f"\t\t{MessageToDict(value_info.type)}")
self.graph_values[name] = self._parse_input(name, inputs[name], input_spec)
for num,n in enumerate(onnx_model.graph.node):
inp_tensors = [model_tensors.get(x) for x in n.input]
required_consts = required_input_python_consts.get(n.op_type, ())
inp = [to_python_const(t) if i in required_consts else t for i,t in enumerate(inp_tensors)]
opt = model_attributes[num]
for node in self.graph_nodes:
inps = [to_python_const(self.graph_values.get(name), node.op, i) for i,name in enumerate(node.inputs)]
opts = node.opts
if debug >= 1: print(f"{num}: op \"{n.op_type}\" input shapes {[x.shape if isinstance(x, Tensor) else x for x in inp_tensors]} opt {opt}")
if debug >= 3:
print("\tinputs:")
print("\n".join(f"\t\t{x} - {t!r}" + (" (to_python_const)" if i in required_consts else "") for i,(x,t) in enumerate(zip(n.input, inp))))
# provide additional opts
if node.op == "Split" and 'num_outputs' not in opts: opts['num_outputs'] = len(node.outputs)
if node.op == "Gradient": opts['intermediate_tensors'] = self.graph_values
# provide additional arguments
if n.op_type == "Split" and 'num_outputs' not in opt: opt['num_outputs'] = len(n.output)
if n.op_type == "Gradient": opt['intermediate_tensors'] = model_tensors
if debug >= 1: print(f"{node.num}: op '{node.op}' opt {opts}")
if debug >= 2 and node.inputs: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {i!r}" for x,i in zip(node.inputs, inps)))
ret = self._dispatch_op(node.op, inps, opts)
ret = ret if isinstance(ret, tuple) else (ret,)
if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{x} - {o!r}" for x,o in zip(node.outputs, ret)))
# run op
if n.op_type in tensor_methods: ret = getattr(Tensor, tensor_methods[n.op_type])(*inp, **opt)
elif hasattr(onnx_ops, n.op_type):
fxn = getattr(onnx_ops, n.op_type)
if isinstance(fxn, dict):
for k in sorted(fxn.keys()):
if k <= onnx_model_version:
real_fxn = fxn[k]
else:
real_fxn = fxn
ret = real_fxn(*inp, **opt)
else:
print("UNSUPPORTED", n.op_type, n.input, n.output)
raise NotImplementedError(f"op_type {n.op_type} not supported")
self.graph_values.update(dict(zip(node.outputs, ret[:len(node.outputs)], strict=True)))
# finalization after running the op
if not isinstance(ret, tuple): ret = (ret, )
if len(n.output) > len(ret): raise RuntimeError(f"expected output size must be less than {len(ret)}, it's {n.output}")
for i in range(len(n.output)): model_tensors[n.output[i]] = ret[i]
if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{n.output[i]} - {ret[i]}" for i in range(len(n.output))))
if num == ONNXLIMIT: return {name:model_tensors[name] for name in n.output}
return {x.name:model_tensors[x.name] for x in onnx_model.graph.output}
return run_onnx
if node.num == limit:
Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad
return {name:self.graph_values[name] for name in node.outputs}
Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad
return {name:self.graph_values[name] for name in self.graph_outputs}

View File

@@ -3,7 +3,7 @@ from typing import cast, Literal
from tinygrad.tensor import Tensor, _broadcast_shape, ConstType, ReductionStr
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.helpers import prod, flatten, make_tuple
from extra.onnx import dtype_parse, to_python_const
from extra.onnx import dtype_parse, _cached_to_python_const
import numpy as np
# **************** Free Ops ****************
@@ -282,7 +282,7 @@ def Gather(x:Tensor, indices:Tensor, axis:int=0):
x_sh = list(x.shape)
ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:]
if indices.ndim > 1: indices = indices.flatten()
indices = [to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in to_python_const(indices)] # type: ignore
indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)]
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
@@ -575,12 +575,9 @@ from tinygrad.nn.optim import SGD
def onnx_training(input_group_size):
def _decorator(func):
def __wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs):
old_training = Tensor.training
Tensor.training = True
R = R.detach()
groups = len(inputs) // input_group_size
ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))]
Tensor.training = old_training
return tuple(flatten(zip(*ret)))
return __wrapper
return _decorator

View File

@@ -2,7 +2,7 @@ import time, sys, hashlib
from pathlib import Path
import onnx
from onnx.helper import tensor_dtype_to_np_dtype
from extra.onnx import get_run_onnx
from extra.onnx import OnnxRunner
from tinygrad import Tensor, dtypes, TinyJit
from tinygrad.helpers import IMAGE, GlobalCounters, fetch, colored, getenv, trange
from tinygrad.tensor import _from_np_dtype
@@ -11,11 +11,8 @@ import numpy as np
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
if __name__ == "__main__":
Tensor.no_grad = True
Tensor.training = False
onnx_model = onnx.load(onnx_path := fetch(OPENPILOT_MODEL))
run_onnx = get_run_onnx(onnx_model)
run_onnx = OnnxRunner(onnx_model)
Tensor.manual_seed(100)
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}

View File

@@ -0,0 +1,39 @@
import subprocess
import random
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
def run_test(i, full_run=False):
print(f"\rRunning iteration {i}...", end=" ", flush=True)
p = subprocess.Popen(['python3', 'test/test_tiny.py', 'TestTiny.test_plus'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if not full_run:
time.sleep(random.uniform(0, 1200) / 1000)
p.kill()
_, stderr = p.communicate()
else:
_, stderr = p.communicate()
if full_run:
stderr_text = stderr.decode()
print(stderr_text)
assert "Ran 1 test in" in stderr_text and "OK" in stderr_text
max_workers = 4
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for i in range(1000000):
if i % 100 == 0:
for future in as_completed(futures):
try: future.result()
except Exception as e:
print(f"\nError in iteration: {e}")
futures = []
run_test(i, True)
else:
future = executor.submit(run_test, i, False)
futures.append(future)
if len(futures) > max_workers * 2: futures = [f for f in futures if not f.done()]

View File

@@ -6,7 +6,7 @@ import onnx
from onnx.helper import tensor_dtype_to_np_dtype
import onnxruntime as ort
from onnx2torch import convert
from extra.onnx import get_run_onnx
from extra.onnx import OnnxRunner
from tinygrad.helpers import OSX, DEBUG, fetch
from tinygrad import Tensor, Device
from tinygrad.device import CompileError
@@ -65,7 +65,7 @@ def benchmark_model(m, devices, validate_outs=False):
try:
Device.DEFAULT = device
inputs = {k:Tensor(inp) for k,inp in np_inputs.items()}
tinygrad_model = get_run_onnx(onnx_model)
tinygrad_model = OnnxRunner(onnx_model)
benchmark(m, f"tinygrad_{device.lower()}_jitless", lambda: {k:v.numpy() for k,v in tinygrad_model(inputs).items()})
from tinygrad.engine.jit import TinyJit
@@ -115,7 +115,7 @@ def benchmark_model(m, devices, validate_outs=False):
rtol, atol = 2e-3, 2e-3 # tolerance for fp16 models
Device.DEFAULT = device
inputs = {k:Tensor(inp) for k,inp in np_inputs.items()}
tinygrad_model = get_run_onnx(onnx_model)
tinygrad_model = OnnxRunner(onnx_model)
tinygrad_out = tinygrad_model(inputs)
ort_sess = ort.InferenceSession(str(fn), ort_options, ["CPUExecutionProvider"])

View File

@@ -1,5 +1,6 @@
import unittest
from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableTraverseContext
from tinygrad.helpers import mv_address
class FakeGMC:
def __init__(self): self.vm_base = 0x0
@@ -19,6 +20,8 @@ class FakeAM:
self.gmc = FakeGMC()
self.mm = AMMemoryManager(self, vram_size=4 << 30)
self.is_booting = False
def paddr2cpu(self, paddr:int) -> int: return paddr + mv_address(self.vram)
def paddr2mc(self, paddr:int) -> int: return paddr
# * PTE format:
# * 63:59 reserved

View File

@@ -10,7 +10,7 @@ from tinygrad.device import is_dtype_supported
# pip3 install tabulate
pytest_plugins = 'onnx.backend.test.report',
from extra.onnx import get_run_onnx
from extra.onnx import OnnxRunner
class TinygradModel(BackendRep):
def __init__(self, run_onnx, input_names):
@@ -20,7 +20,7 @@ class TinygradModel(BackendRep):
def run(self, inputs: Any, **kwargs: Any) -> Tuple[Any, ...]:
real_inputs = dict(zip(self.input_names, inputs))
ret = self.fxn(real_inputs, debug=True)
ret = self.fxn(real_inputs, debug=2)
return tuple(x.numpy() if isinstance(x, Tensor) else [i.numpy() for i in x] if isinstance(x, list) else np.array(x) for x in ret.values())
class TinygradBackend(Backend):
@@ -30,7 +30,7 @@ class TinygradBackend(Backend):
input_initializer = [x.name for x in model.graph.initializer]
net_feed_input = [x for x in input_all if x not in input_initializer]
print("prepare", cls, device, net_feed_input)
run_onnx = get_run_onnx(model)
run_onnx = OnnxRunner(model)
return TinygradModel(run_onnx, net_feed_input)
@classmethod

View File

@@ -7,7 +7,7 @@ try:
import onnx
except ModuleNotFoundError:
raise unittest.SkipTest("onnx not installed, skipping onnx test")
from extra.onnx import get_run_onnx
from extra.onnx import OnnxRunner
from tinygrad.tensor import Tensor
from tinygrad.helpers import CI, fetch, temp
@@ -26,7 +26,7 @@ np.random.seed(1337)
class TestOnnxModel(unittest.TestCase):
def test_benchmark_openpilot_model(self):
onnx_model = onnx.load(fetch(OPENPILOT_MODEL))
run_onnx = get_run_onnx(onnx_model)
run_onnx = OnnxRunner(onnx_model)
def get_inputs():
np_inputs = {
"input_imgs": np.random.randn(*(1, 12, 128, 256)),
@@ -70,7 +70,7 @@ class TestOnnxModel(unittest.TestCase):
def test_openpilot_model(self):
onnx_model = onnx.load(fetch(OPENPILOT_MODEL))
run_onnx = get_run_onnx(onnx_model)
run_onnx = OnnxRunner(onnx_model)
print("got run_onnx")
inputs = {
"input_imgs": np.random.randn(*(1, 12, 128, 256)),
@@ -124,7 +124,7 @@ class TestOnnxModel(unittest.TestCase):
onnx_model = onnx.load(fn)
print("onnx loaded")
from test.models.test_efficientnet import chicken_img, car_img, preprocess, _LABELS
run_onnx = get_run_onnx(onnx_model)
run_onnx = OnnxRunner(onnx_model)
def run(img):
inputs = {input_name: preprocess(img, new=input_new)}

View File

@@ -166,7 +166,7 @@ class TestIndexing(unittest.TestCase):
GlobalCounters.reset()
z = emb(x).realize()
self.assertLessEqual(GlobalCounters.global_ops, op_limit)
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertEqual(GlobalCounters.kernel_count, 3)
if getenv("CHECK", 1):
import torch
with torch.no_grad():

View File

@@ -220,7 +220,9 @@ class TestMultiConstFolding(unittest.TestCase):
t = Tensor.arange(16).float().realize().to(ds)
# non const folding case creates one ast on each shard
_check_ast_count(4, t + 1)
# NOTE: there's extra contiguous kernels here since it's realizing both the CONTIGUOUS and its parent COPY
# why does multi call contiguous on a COPY?
_check_ast_count(7, t + 1)
_check_ast_count(4, 1 + t)
_check_ast_count(4, t * 2)
_check_ast_count(4, 2 * t)

View File

@@ -113,6 +113,8 @@ class TestImageDType(unittest.TestCase):
assert it.lazydata.base.realized._buf != b1
# issue caused by: don't realize image to image casts. this is part of a larger problem
#@unittest.expectedFailure
# update: passing after tensor_map
def test_lil_model(self):
with Context(IMAGE=2):
x = Tensor.zeros(1, 1)
@@ -121,7 +123,10 @@ class TestImageDType(unittest.TestCase):
loss = x.image_dot(w1).image_dot(w2).float().max()
loss.backward()
sched = unwrap(w1.grad).schedule()
self.assertEqual(len(sched), 9)
# NOTE: the w1 grad must realize to a seperate kernel
assert w1.grad.lazydata.is_realized, f"never realized {w1.grad}"
self.assertEqual(w1.grad.lazydata.base.buffer.dtype, dtypes.float32)
self.assertEqual(len(sched), 10)
for s,ei in zip(sched, lower_schedule(sched[:])):
ei.run()
if s.outputs[0].dtype == dtypes.float:

View File

@@ -318,6 +318,7 @@ class TestJit(unittest.TestCase):
assert len(res3) == 10, "All values should be different, rand works in jit."
assert res3 != res2, "Jit rand is diff with diff seeds"
@unittest.expectedFailure # requires contiguous folding
def test_jit_random_after_unrealized_random(self):
@TinyJit
def f(): return Tensor.rand()

View File

@@ -63,7 +63,11 @@ def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, d
class TestLinearizer(unittest.TestCase):
def test_arg_dedup(self):
a, b = Tensor.randn(4), Tensor.randn(4)
# NOTE: this realize exists because Tensor.numpy calls .contiguous() internally
# without contiguous folding, rand.to("CLANG") and rand.contiguous().to("CLANG") are different UOps.
# this test asserts they are the identical Buffer
# having different buffers is fine for correctness, because the outputs match.
a, b = Tensor.randn(4).realize(), Tensor.randn(4).realize()
np_a, np_b = a.numpy(), b.numpy()
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),))))
lowered = list(lower_schedule(c.schedule()))
@@ -1690,6 +1694,7 @@ class TestHandCodedOpts(unittest.TestCase):
# should upcast the two Tensor.stacks
assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2
@unittest.expectedFailure # requires contiguous folding
def test_masked_upcast_wino_full(self):
with Context(WINO=1):
x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()

View File

@@ -734,7 +734,7 @@ class TestMultiTensor(unittest.TestCase):
zeros = Tensor.zeros(3).realize()
b = a.to(devices_2)*zeros.to(devices_2)
sched = b.schedule()
self.assertEqual(len(sched), 6)
self.assertEqual(len(sched), 8)
# notably, only two copies (for the arange) - vs 4 copies if we didn't fold the const copy
self.assertEqual(len([x for x in sched if any(u.op is Ops.COPY for u in x.ast.toposort)]), 2)
# all these kernels are just because multi calls contiguous on every single shard

View File

@@ -16,7 +16,7 @@ from tinygrad.shape.view import View
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same
from tinygrad.codegen.kernel import verify_ast
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, ops_folding
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from extra.models.llama import precompute_freqs_cis
@@ -67,7 +67,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
@track_rewrites(named=True)
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext())
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, ScheduleContext())
class TestSchedule(unittest.TestCase):
def test_basic_binop_fusion(self):
@@ -220,7 +220,7 @@ class TestSchedule(unittest.TestCase):
GlobalCounters.reset()
expr = (a*b)/b
expr.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.kernel_count, 0) # the scheduler can fold divs now!
self.assertEqual(GlobalCounters.global_ops, 0)
np.testing.assert_allclose(expr.numpy(), np.full((4,), 4.0))
@@ -229,7 +229,7 @@ class TestSchedule(unittest.TestCase):
GlobalCounters.reset()
expr = a/a
expr.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.kernel_count, 0)
self.assertEqual(GlobalCounters.global_ops, 0)
np.testing.assert_allclose(expr.numpy(), np.full((4,), 1.0))
@@ -972,6 +972,26 @@ class TestSchedule(unittest.TestCase):
expected = (x_exp:=np.exp(x.numpy()-x.numpy().max(-1, keepdims=True)))/x_exp.sum(-1, keepdims=True)
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_softmax_upcast(self):
# input half, softmax in float
Tensor.manual_seed(0)
x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.half).realize()
out = x.softmax(dtype=dtypes.float)
sched = out.schedule()
self.assertEqual(len(sched), 3)
self.assertEqual(len(sched[0].outputs), 1)
self.assertEqual(sched[0].outputs[0].dtype, dtypes.half)
# input float, softmax in float
Tensor.manual_seed(0)
x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.float).realize()
out = x.softmax(dtype=dtypes.float)
sched = out.schedule()
self.assertEqual(len(sched), 3)
self.assertEqual(len(sched[0].outputs), 1)
self.assertEqual(sched[0].outputs[0].dtype, dtypes.float)
def test_softmax_backward(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 12, 64, 64, requires_grad=True).realize()
@@ -1804,7 +1824,7 @@ def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.
# these pattern matchers should move to engine/schedule.py
sym = symbolic_simple+PatternMatcher([
ops_folding = symbolic_simple+PatternMatcher([
(UPat(Ops.DETACH, name="x"), lambda x:x.src[0]),
])
@@ -1822,8 +1842,8 @@ def run_tensor_ast(r:Tensor):
output = UOp.new_buffer(r.device, r.lazydata.size, r.dtype)
glbl = UOp(Ops.DEFINE_GLOBAL, output.dtype.ptr(size=output.size), (), 0)
sink = UOp(Ops.STORE, src=(glbl, ShapeTracker.from_shape(r.lazydata.base.shape).to_uop(), r.lazydata.base)).sink()
sink = graph_rewrite(sink, remove_movement_ops+sym+load_buffers+view_left, bufs:=[output])
sink = graph_rewrite(sink, remove_movement_ops+sym+view_right)
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+load_buffers+view_left, bufs:=[output])
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+view_right)
si = ScheduleItem(sink, tuple(x.buffer for x in bufs), (), ())
run_schedule([si])
return output.realized.as_buffer().cast(output.dtype.fmt, r.shape).tolist()
@@ -2184,7 +2204,7 @@ class TestConst(unittest.TestCase):
sched = add.schedule()
self.assertEqual(len(sched), 0)
# b+0 and b share the same underlying device memory
self.assertIs(add.lazydata.realized, b.lazydata.realized)
self.assertIs(add.lazydata.buffer, b.lazydata.buffer)
self.assertListEqual(add.tolist(), [2, 2, 2, 2])
def test_src_masked_const_folding(self):
@@ -2238,6 +2258,17 @@ class TestCopyFolding(unittest.TestCase):
add = schedule_graph_rewrite(add)
assert all_same([x.device for x in add.src]), f"ALU has different devices! {[x.device for x in add.src]}"
def test_copy_to_same_device(self):
a = Tensor.empty(4).lazydata
b = a.copy_to_device(a.device)
check_schedule(b, 0, filter_sink=False)
b = schedule_graph_rewrite(b)
self.assertIs(b, a)
def test_clone(self):
a = Tensor.empty(4).lazydata
check_schedule(a.clone(), 1, filter_sink=False)
class TestTensorUOpSpec(unittest.TestCase):
def test_const_must_be_unmasked(self):
a = Tensor.ones((4, 4)).pad((2, 2))
@@ -2253,6 +2284,12 @@ class TestTensorUOpSpec(unittest.TestCase):
t = graph_rewrite(a.lazydata.sink(), remove_movement_ops+merge_views)
create_schedule_with_vars(t)
def test_symbolic_shape_ok(self):
a = Tensor.ones(4)
vi = UOp.variable("i", 1, 10).bind(4)
t = graph_rewrite(a.reshape(vi).sum().lazydata, remove_movement_ops+merge_views)
create_schedule_with_vars(t)
class TestBufferUOp(unittest.TestCase):
# BUFFER has a ShapeTracker of shape=(n,) and stride=(1,)
def test_buffer_has_buffer(self):
@@ -2316,34 +2353,80 @@ class TestBufferUOp(unittest.TestCase):
class TestContiguous(unittest.TestCase):
def test_contiguous_buffer(self):
a = Tensor.empty(4).lazydata
b = a.alu(Ops.CONTIGUOUS)
b = schedule_graph_rewrite(b)
self.assertIs(b, a)
a = Tensor.empty(4)
b = a.contiguous()
check_schedule(b, 0)
def test_contiguous_buffer_view(self):
a = Tensor.empty(4).lazydata
b = a.reshape((2, 2)).alu(Ops.CONTIGUOUS)
b = schedule_graph_rewrite(b)
self.assertIs(b, a.buf_uop.view(unwrap(b.st)))
a = Tensor.empty(4)
b = a.reshape((2, 2)).contiguous()
check_schedule(b, 0)
def test_non_contiguous_buffer_view(self):
a = Tensor.empty(4, 1).lazydata
b = a.expand((4, 4)).alu(Ops.CONTIGUOUS)
b = schedule_graph_rewrite(b)
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
a = Tensor.empty(4, 1)
b = a.expand((4, 4)).contiguous()
check_schedule(b, 1)
def test_size_change_buffer_view(self):
a = Tensor.empty(4).lazydata
b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).alu(Ops.CONTIGUOUS)
b = schedule_graph_rewrite(b)
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
a = Tensor.empty(4)
b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).contiguous()
check_schedule(b, 1)
def test_double_contiguous_realizes_once(self):
a = Tensor.empty(4, 1).lazydata
b = a.expand((4, 4)).alu(Ops.CONTIGUOUS).alu(Ops.CONTIGUOUS)
b = schedule_graph_rewrite(b)
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
a = Tensor.empty(4, 1)
b = a.expand((4, 4)).contiguous().contiguous()
check_schedule(b, 1)
class TestUOpBecome(unittest.TestCase):
# the simplest case, if we create a new BUFFER for this UOp
def test_new_buffer(self):
a = Tensor.empty(4, 4)
b = Tensor.empty(4, 4)
add = a+b
check_schedule(add, 1)
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {})
def test_new_buffer_view(self):
a = Tensor.empty(4, 4)
b = Tensor.empty(4, 4)
add = (a+b).reshape(8, 2)
check_schedule(add, 1)
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {})
# VIEW is preserverd after the becomes rewrite.
self.assertEqual(add.lazydata.shape, (8, 2))
assert add.lazydata is not add.lazydata.base
def test_become_existing_buffer(self):
a = Tensor.empty(4, 4)
b = a*1
assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul
check_schedule(b, 0)
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER)
self.assertIs(a.lazydata.base.buffer, b.lazydata.base.buffer)
def test_become_const_in_base(self):
a = Tensor.empty(4)
b = a*0
assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul
check_schedule(b, 0)
assert UPat(Ops.CONST, arg=0).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER)
def test_become_const_in_view(self):
# if we shrink the base down to a size 0, only the VIEW becomes CONST, base is unchanged.
add = Tensor.empty(2, 2)+Tensor.empty(2, 2)
b = add.shrink(((0, 1), (0, 0)))
check_schedule(b, 0)
assert UPat(Ops.CONST, arg=0).match(b.lazydata, {})
self.assertEqual(b.shape, (1, 0))
# the base is untouched.
assert UPat(Ops.ADD).match(add.lazydata, {})
def test_become_const_from_const(self):
const_add = Tensor(1)+Tensor(2)
assert UPat(Ops.ADD).match(const_add.lazydata, {})
check_schedule(const_add, 0)
assert UPat(Ops.CONST, arg=3).match(const_add.lazydata.base, {})
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -69,7 +69,8 @@ class TestSetitem(unittest.TestCase):
t[1] ^= 5
np.testing.assert_allclose(t.numpy(), [[0, 1], [7, 6]])
@unittest.expectedFailure
#@unittest.expectedFailure
# update: passing after delete_forced_realize
def test_setitem_consecutive_inplace_operator(self):
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] += 2

View File

@@ -2,6 +2,7 @@ import unittest
from tinygrad import Device, dtypes, Tensor
from tinygrad.device import Buffer
from tinygrad.ops import view_supported_devices
from tinygrad.helpers import Context
@unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported")
class TestSubBuffer(unittest.TestCase):
@@ -47,5 +48,22 @@ class TestSubBuffer(unittest.TestCase):
out = vt.to(f"{Device.DEFAULT}:1").realize().tolist()
assert out == [2, 3, 4]
def test_subbuffer_deallocate(self):
with Context(LRU=0):
vbuf = self.buf.view(2, dtypes.uint8, offset=3).ensure_allocated()
self.buf.deallocate()
vbuf.deallocate()
# Allocate a fake one on the same place
_ = Buffer(Device.DEFAULT, 10, dtypes.uint8).ensure_allocated()
self.buf.ensure_allocated()
self.buf.copyin(memoryview(bytearray(range(10, 20))))
vbuf.ensure_allocated()
tst = vbuf.as_buffer().tolist()
assert tst == [13, 14]
if __name__ == '__main__':
unittest.main()

View File

@@ -14,6 +14,7 @@ from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_ker
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.rewriter import full_graph_rewrite, sym
from tinygrad.device import is_dtype_supported
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)
@@ -365,6 +366,17 @@ class TestAssembly(unittest.TestCase):
self.assertIn(Ops.SHR, ops)
self.assertIn(Ops.IDIV, ops)
def test_mulacc_unrolled(self):
# test that acc = acc + a0*b0 + a1*b1 + a2*b2 + a3*b3
# is not acc = acc + (a0*b0 + a1*b1 + a2*b2 + a3*b3)
a = Tensor.empty(1024)
b = Tensor.empty(1024)
c = (a*b).sum()
k = Kernel(c.schedule()[-1].ast)
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
uops = k.linearize().uops
self.assertEqual(len([x.op for x in uops if x.op is Ops.MULACC]), 4)
class TestUOpMethod(unittest.TestCase):
@unittest.skip("uops lt no longer ordered")
def test_compare_alu_same_src_different_arg(self):

View File

@@ -164,6 +164,7 @@ class TestSafetensors(unittest.TestCase):
def test_save_all_dtypes(self):
for dtype in dtypes.fields().values():
if dtype in [dtypes.bfloat16]: continue # not supported in numpy
if dtype in [dtypes.double] and Device.DEFAULT == "METAL": continue # not supported on METAL
path = temp(f"ones.{dtype}.safetensors")
ones = Tensor(np.random.rand(10,10), dtype=dtype)
safe_save(get_state_dict(ones), path)

View File

@@ -104,7 +104,8 @@ class TestRealizeMeansRealize(unittest.TestCase):
x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize()
self.assertEqual(x.lazydata.op, Ops.VIEW)
@unittest.expectedFailure
#@unittest.expectedFailure
# update: passing after delete_forced_realize
def test_uniform_realizes(self):
x = Tensor.uniform(16, 3, 3, 3, requires_grad=True).realize()
print(x.lazydata)

View File

@@ -2,7 +2,7 @@ import unittest, math
from tinygrad import dtypes
from tinygrad.helpers import all_same
from tinygrad.ops import GroupOp, UOp, Ops, exec_alu
from tinygrad.codegen.rewriter import full_graph_rewrite
from tinygrad.codegen.rewriter import full_graph_rewrite, mulacc_unrolled
# Helper function to apply the graph rewrite
def apply_rewrite(expr):
@@ -274,5 +274,41 @@ class TestSubstitute(unittest.TestCase):
ret = substitute(ret, {a.sin():a.sqrt(), n1.sin():n1.sqrt()})
self.assertIs(ret, a.sqrt().sqrt())
class TestMulaccUnrolledAcc(unittest.TestCase):
def test_unrolled2(self):
acc_range = (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))
acc = UOp(Ops.DEFINE_ACC, dtypes.int, (UOp.const(dtypes.int, 0),) + acc_range, (0,))
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
expr = acc.assign(acc + (a*2 + b*3))
expr_with_mulacc = graph_rewrite(expr, mulacc_unrolled)
self.assertIs(expr_with_mulacc, acc.assign(acc + a*2 + b*3))
def test_unrolled4_float(self):
acc_range = (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 3))
acc = UOp(Ops.DEFINE_ACC, dtypes.float32, (UOp.const(dtypes.int, 0),)+acc_range, (0,))
a = [UOp.variable(f'a{i}', float("-inf"), float("inf"), dtype=dtypes.float32) for i in range(4)]
b = [UOp.variable(f'b{i}', float("-inf"), float("inf"), dtype=dtypes.float32) for i in range(4)]
expr = acc.assign(acc + (a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3]))
expr_with_mulacc = graph_rewrite(expr, mulacc_unrolled)
# Verify it unrolls into individual multiply-accumulate operations
expected = acc.assign(acc + a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3])
self.assertIs(expr_with_mulacc, expected)
def test_unrolled4_float_const(self):
acc_range = (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 3))
acc = UOp(Ops.DEFINE_ACC, dtypes.float32, (UOp.const(dtypes.int, 0),)+acc_range, (0,))
a = [UOp.variable(f'a{i}', float("-inf"), float("inf"), dtype=dtypes.float32) for i in range(4)]
expr = acc.assign(acc + (a[0]*3.0 + a[1]*4.0 + a[2]*5.0 + a[3]*6.0))
expr_with_mulacc = graph_rewrite(expr, mulacc_unrolled)
# Verify it unrolls into individual multiply-accumulate operations
expected = acc.assign(acc + a[0]*3.0 + a[1]*4.0 + a[2]*5.0 + a[3]*6.0)
self.assertIs(expr_with_mulacc, expected)
if __name__ == '__main__':
unittest.main()

View File

@@ -53,7 +53,8 @@ class TestTensorUopRepresentation(unittest.TestCase):
b = Tensor([4.,5,6]).realize()
c = a+b
print(c.lazydata)
is_pattern(c, UPat(Ops.ADD, src=(UPat(Ops.VIEW, src=(realized_pattern,)), UPat(Ops.VIEW, src=(realized_pattern,)))))
is_pattern(c, UPat(Ops.ADD, src=(realized_pattern, realized_pattern)))
#is_pattern(c, UPat(Ops.ADD, src=(UPat(Ops.VIEW, src=(realized_pattern,)), UPat(Ops.VIEW, src=(realized_pattern,)))))
def test_const_pattern(self):
a = Tensor(1)
@@ -71,9 +72,9 @@ class TestTensorUopRepresentation(unittest.TestCase):
def test_viewed_consts_do_not_realize(self):
a = Tensor.ones(10, 10)
print(a.lazydata)
pre_realize = a.lazydata
a.realize()
assert a.lazydata is pre_realize
is_pattern(a, const_pattern)
self.assertEqual(a.lazydata.shape, (10, 10))
# currently, CONSTs have a "fake" BUFFER. this should be fixed
# current:
@@ -111,7 +112,8 @@ class TestTensorUopRepresentation(unittest.TestCase):
c = a.to("TEST") # NOTE: this isn't checked
print(c.lazydata)
# TODO: COPY on a Tensor becomes a VIEW(COPY), this should be done in the scheduler not in ops
is_pattern(c, UPat(Ops.VIEW, src=(UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)),)))
is_pattern(c, UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)))
#is_pattern(c, UPat(Ops.VIEW, src=(UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)),)))
if __name__ == '__main__':
unittest.main()

View File

@@ -239,6 +239,9 @@ index_load = UPat.var("buf").index(rng_aug).load(name="ld")
arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug))
arange_m = ((arange_augrng<UPat.cvar("compval"))!=UPat(Ops.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0))
# this moves the accumulation variable down an unrolled add chain which allows for more efficient accumulation using mulacc
mulacc_unrolled = PatternMatcher([(UPat.var("x")+UPat.var("y")+acc_pat, lambda x,y,acc: (acc+x)+y if y.op is not Ops.DEFINE_ACC else None)])
# this is symbolic 2.0
sym = symbolic_flat+PatternMatcher([
# self ASSIGN is just self
@@ -504,8 +507,9 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
# expand
sink = graph_rewrite(sink, sym+expander)
# devectorize + load_store_indexing
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing)
# devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing+
mulacc_unrolled)
# final rules for the renderer (without sym)
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher)

View File

@@ -4,7 +4,7 @@ from collections import defaultdict
from typing import Optional, Any, Iterator, Generator
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE
from tinygrad.helpers import CI, OSX, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
cpu_time_execution
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes
from tinygrad.renderer import Renderer
@@ -129,7 +129,7 @@ class Buffer:
if self._base is None and (self.options is None or self.options.external_ptr is None):
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
self.allocator.free(self._buf, self.nbytes, self.options)
del self._buf
del self._buf
def __reduce__(self):
buf = None
if self._base is not None:
@@ -202,7 +202,7 @@ class LRUAllocator(Allocator):
for opaque in opaques: super().free(opaque, sz, options)
opaques.clear()
def free(self, opaque:Any, size:int, options:Optional[BufferSpec]=None):
if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
if LRU and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
else: super().free(opaque, size, options)
class _MallocAllocator(LRUAllocator):
@@ -310,7 +310,7 @@ if PROFILE:
for dev in devs: dev.synchronize()
for dev in devs: dev._at_profile_finalize()
with open(temp("profile.pkl"), "wb") as f: pickle.dump(Compiled.profile_events, f)
with open(fn:=temp("profile.pkl", append_user=True), "wb") as f: pickle.dump(Compiled.profile_events, f)
from tinygrad.ops import launch_viz
launch_viz("PROFILE", temp("profile.pkl"))
launch_viz("PROFILE", fn)

View File

@@ -2,7 +2,7 @@ import sys, atexit, functools, pickle
from collections import defaultdict, deque
from dataclasses import dataclass, field
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify
from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify, graph_rewrite_map
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, ContextVar
from tinygrad.dtype import DType, ImageDType, dtypes
@@ -31,9 +31,9 @@ tensor_uop_spec = PatternMatcher([
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
(UPat(Ops.DEFINE_VAR, src=(UPat(Ops.VIEW, arg=ShapeTracker.from_shape(()))), arg=None), lambda: True),
# Tensor const has an unmasked ShapeTracker of stride 0 and a device
# Tensor const has a device and an unmasked ShapeTracker of stride 0 or a ShapeTracker with symbolic shape
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)),
lambda st: len(st.st.views) == 1 and all(s == 0 for s in st.st.views[0].strides) and st.st.views[0].mask is None),
lambda st: st.st.views[0].mask is None and ((len(st.st.views) == 1 and all(s == 0 for s in st.st.views[0].strides)) or not all_int(st.shape))),
# DETACH and CONTIGUOUS change how we interpret the source UOp
# CONTIGUOUS ensures the source UOp realizes
@@ -88,15 +88,15 @@ class ScheduleContext:
# wrap tensor uops around a VIEW(BUFFER, <uop>)
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
if (r:=cache.get(buf)) is not None: return r
# SINK is passthrough
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, ctx, cache) for x in buf.src))
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src))
# skip creating buffers for CONST/BIND/DEVICE/BUFFER
if buf.base.is_realized or buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf
# VIEW is passthrough
if buf is not buf.base:
cache[buf] = ret = add_buffers(buf.base, ctx, cache).view(unwrap(buf.st))
cache[buf] = ret = add_buffers(buf.base, tensor_map, ctx, cache).view(unwrap(buf.st))
return ret
# make things that can't be images not images
dtype = buf.dtype
@@ -105,11 +105,11 @@ def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
dtype = buf.dtype.base
# ASSIGN already has a target buffer, otherwise we create a new one
buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, ctx, cache) for x in buf.src))
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src))
# track the underlying tensor uop for this buffer
ctx.tensor_uops[buf_uop] = [buf]
ctx.tensor_uops[buf_uop] = tensor_map[buf]
# (early) bufferize
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st)
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
return ret
# **** AST graph rewrite
@@ -242,7 +242,7 @@ if CAPTURE_PROCESS_REPLAY:
def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER
def uval(u:UOp) -> UOp:
assert is_scheduled(u), f"must be a scheduled op {u}"
return r.src[0] if (r:=u.src[1]).op is Ops.CONTIGUOUS and not (r.src[0].base.op is Ops.VIEW and len(r.src[0].base.src) == 2) else r
return u.src[1]
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp], realizes:dict[UOp, UOp],
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
@@ -329,7 +329,7 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]:
# maybe fuse arange with its children
for rbuf in reduce_of_const:
group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
if any(luop.forced_realize for tr in group for luop in ctx.tensor_uops[tr]): continue
if any(luop.op is Ops.CONTIGUOUS for tr in group for luop in ctx.tensor_uops[tr]): continue
kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}}
if len(kernel_children) == 0: continue
for tr in group: del ctx.realizes[tr]
@@ -340,10 +340,6 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]:
# **** Schedule creation and BFS toposort
class UPatScheduled(UPat):
def __init__(self, *args, **kwargs):
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
# ** this is schedule level const folding
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
@@ -358,22 +354,18 @@ def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
case _: return None
return reduce.const_like(ret)
def found_contiguous(ctx:ScheduleContext, contig:UOp, base:UOp, b:UOp):
if contig.src[0].op is Ops.VIEW and len(contig.src[0].src):
old_base = contig.src[0].src[0]
if old_base.op is Ops.VIEW and (sti:=unwrap(contig.src[0].st).invert(old_base.shape)) is not None: ctx.contiguous[old_base] = base.view(sti)
def found_contiguous(ctx:ScheduleContext, contig:UOp, src:UOp):
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx.contiguous[src.base] = contig.view(sti)
def replace_contiguous(ctx:ScheduleContext, alu:UOp):
new_src = list(alu.src)
for i,s in enumerate(alu.src):
if (replace_src:=ctx.contiguous.get(s, None)) is not None: new_src[i] = replace_src
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
ops_folding = symbolic_simple+PatternMatcher([
# op with size 0 is zero
sym = symbolic_simple+PatternMatcher([
# UOp with size 0 is zero
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
# if the uop folded to a CONST we can delete the BUFFER
(UPatScheduled(Ops.CONST, name="const"), lambda b,base,const: base.const_like(const.const_arg)),
# DETACH is a NOOP here
(UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]),
# reduce of size 0 is the identity element
@@ -386,13 +378,16 @@ ops_folding = symbolic_simple+PatternMatcher([
# no COPY to same device, except clone (arg is True)
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
# remove cast to image when it's already a contiguous image
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)),
lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
# remove contiguous if we can just view the buffer
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
# double contiguous is one contiguous
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.CONTIGUOUS),)), lambda root: root.src[0]),
# support for using a contiguous permuted view instead of the parent view if one exists
(UPatScheduled(Ops.CONTIGUOUS, name="contig"), found_contiguous),
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
# remove CONST/BIND/BUFFER/VIEW from SINK
(UPat(Ops.SINK, name="root"),
@@ -400,36 +395,12 @@ ops_folding = symbolic_simple+PatternMatcher([
if (new_src:=tuple(x.base for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
])
# ** buffer merging
def merge(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp) -> UOp:
assert v1.st is not None and v2.st is not None and v1.st == v2.st, f"implicit movementop {v1.st} {v2.st}"
# if b2 is realized also realize b1
if b2 in ctx.realizes:
ctx.realizes[b1] = b1
del ctx.realizes[b2]
# ops referring to b2 now ref to b1
ctx.tensor_uops[b1] += ctx.tensor_uops[b2]
del ctx.tensor_uops[b2]
# merge
return v1
def merge_realized(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp):
# early become
for luop in ctx.tensor_uops.get(b1, [])+ctx.tensor_uops.get(b2, []): ctx.becomes_map[luop] = b1.view(unwrap(luop.st))
return v1
merge_bufs = PatternMatcher([
# merge base
(UPat(Ops.VIEW, name="v2", src=(UPat(Ops.BUFFER, name="b2"), UPat(Ops.VIEW, name="v1", src=(UPat(Ops.BUFFER, name="b1"), UPat())))), merge),
(UPat(Ops.VIEW, name="v2", src=(UPat(Ops.BUFFER, name="b2"), UPat(Ops.VIEW, name="v1", src=(UPat(Ops.BUFFER, name="b1"),)))), merge_realized),
# collapse view
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat())).view(name="mv"))), lambda mv:mv),
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).view(name="mv"))), lambda mv:mv),
])
# ** this decides which ops get realized
class UPatScheduled(UPat):
def __init__(self, *args, **kwargs):
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
@@ -448,8 +419,7 @@ def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs)
return x.view(unwrap(view.st))
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
if not root.device.startswith("DISK"): return None
if x.op is not Ops.VIEW: x = x.src[-1] # TODO: remove this once forced_realize is gone
if not b.device.startswith("DISK"): return None
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
@@ -482,7 +452,7 @@ def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
if (m:=ctx.tensor_uops[b][0].metadata) is not None: ctx.ops_metadata[x] = m
if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m
if b not in ctx.realizes: return x # collapse BUFFER
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
@@ -523,28 +493,36 @@ remove_movement_ops = PatternMatcher([
@track_rewrites(named=True)
def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
# if using VIZ, do a graph rewrite to vizualize the Tensor graph
if getenv("VIZ"): graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext())
if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec)
# to_uop is removing (many) of the movement ops
sink = add_buffers(big_sink, ctx:=ScheduleContext(), cache={})
# const folding and fusion
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+do_realize, ctx)
sink = graph_rewrite(sink, merge_bufs, ctx)
# create the scheduler context
graph_rewrite(sink, create_ctx, ctx)
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx:=ScheduleContext())
rev_tensor_map: dict[UOp, list[UOp]] = {}
for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k)
# add BUFFER uops
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx, cache={})
# add realizes
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
# group realizes into kernels
store_groups = group_realizes(ctx)
graph_rewrite(sink, break_sched, ctx)
# preschedule realize groups
prescheduled: list[ScheduleItem] = []
for store_uops in store_groups:
if len(stores:=[ctx.realizes[u] for u in store_uops if ctx.realizes[u].op is Ops.STORE]) == 0: continue
prescheduled.append(schedule_uop(UOp.sink(*stores), ctx))
small_sink = UOp.sink(*[ctx.realizes[u] for u in store_uops])
if not all(x.op is Ops.STORE for x in small_sink.src): raise RuntimeError(f"expected all realized BUFFERs to get a STORE {sink}")
prescheduled.append(schedule_uop(small_sink, ctx))
# can only schedule once
for buf_uop in store_uops:
for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st))
# tensors can become an existing buffer or simplify to a const, no ScheduleItem needed
for k,v in tensor_map.items():
# NOOP
if k.base is v.base: continue
# NOTE: only the base tensors get a BUFFER UOp
if v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st))
# otherwise if it simplified to a CONST the UOp just becomes that CONST
elif v.op is Ops.CONST: ctx.becomes_map[k] = v
# add kernel children
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list)

View File

@@ -78,7 +78,8 @@ def polyN(x:T, p:list[float]) -> T: return functools.reduce(lambda acc,c: acc*x+
def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
@functools.lru_cache(maxsize=None)
def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
def temp(x:str, append_user:bool=False) -> str:
return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{os.getenv('USERNAME', os.getlogin())}" if append_user else x)).as_posix()
class Context(contextlib.ContextDecorator):
def __init__(self, **kwargs): self.kwargs = kwargs
@@ -107,7 +108,7 @@ WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1),
USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1)
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
PICKLE_BUFFERS, PROFILE = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ"))
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
@dataclass(frozen=True)
class Metadata:

View File

@@ -233,7 +233,6 @@ class UOpMetaClass(type):
# some uops map to other stuff
buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers
all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary()
forced_realize:weakref.WeakSet[UOp] = weakref.WeakSet()
# NOTE: this should be frozen, but frozen is slower
@dataclass(eq=False, slots=True)
@@ -409,11 +408,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
def contiguous(self):
if not unwrap(self.st).contiguous or self.size != self.base.size or self.base.op is Ops.CONST:
return self.alu(Ops.CONTIGUOUS)
forced_realize.add(self.base)
return self
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
# *** from LazyBuffer ***
@@ -432,19 +427,22 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# otherwise it's just a VIEW(BUFFER)
return UOp(Ops.VIEW, dtype, (UOp.new_buffer(device, (st:=ShapeTracker.from_shape(shape)).size, dtype),), st)
def copy_to_device(self, device:str, clone:bool=False) -> UOp:
# no COPY
if self.device == device and not clone: return self
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device)
# COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st)
return UOp(Ops.COPY, self.base.dtype, (UOp(Ops.DEVICE, arg=device), self.base), clone).view(unwrap(self.st))
ret = UOp(Ops.COPY, self.base.dtype, (UOp(Ops.DEVICE, arg=device), self.base), clone)
op_arg = []
mop = self
while mop is not self.base:
op_arg.append((mop.op, mop.arg))
mop = mop.src[0]
for op,arg in reversed(op_arg): ret = UOp(op, ret.dtype, (ret,), arg)
return ret
def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True)
@property
def lbs(self): return [self]
@property
def metadata(self): return all_metadata.get(self, None)
@property
def forced_realize(self): return self in forced_realize
# *** uop movement ops ***
@@ -822,10 +820,10 @@ if TRACK_MATCH_STATS:
@atexit.register
def print_match_stats():
if TRACK_MATCH_STATS >= 2:
with open(fn:=temp("rewrites.pkl"), "wb") as f:
with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f:
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
with Context(PICKLE_BUFFERS=0): pickle.dump((tracked_keys, tracked_ctxs), f)
if getenv("VIZ"): launch_viz("VIZ", temp("rewrites.pkl"))
if getenv("VIZ"): launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
if getenv("PRINT_MATCH_STATS", 1):
ret = [0,0,0.0,0.0]
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import ctypes, collections, time, dataclasses, pathlib, fcntl, os, signal
import ctypes, collections, time, dataclasses, pathlib, fcntl, os
from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp
from tinygrad.runtime.autogen.am import am, mp_11_0, mp_13_0_0, nbio_4_3_0, mmhub_3_0_0, gc_11_0_0, osssys_6_0_0
from tinygrad.runtime.support.allocator import TLSFAllocator
@@ -98,20 +98,14 @@ class AMFirmware:
def desc(self, typ:int, blob:memoryview, offset:int, size:int) -> tuple[int, memoryview]: return (typ, blob[offset:offset+size])
class AMPhysicalMemoryBlock:
def __init__(self, adev:AMDev, paddr:int, size:int): self.adev, self.paddr, self.size = adev, paddr, size
def mc_addr(self): return self.adev.gmc.mc_base + self.paddr
def cpu_addr(self): return mv_address(self.adev.vram) + self.paddr
def cpu_view(self): return to_mv(self.cpu_addr(), self.size)
@dataclasses.dataclass(frozen=True)
class AMMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702
class AMPageTableEntry:
def __init__(self, pm, lv): self.pm, self.view, self.lv = pm, pm.cpu_view().cast('Q'), lv
def __init__(self, adev, paddr, lv): self.paddr, self.view, self.lv = paddr, to_mv(adev.paddr2cpu(paddr), 0x1000).cast('Q'), lv
def set_table(self, entry_id, pte:AMPageTableEntry, valid=True):
self.view[entry_id] = (pte.pm.paddr & 0x0000FFFFFFFFF000) | (am.AMDGPU_PTE_VALID if valid else 0)
self.view[entry_id] = (pte.paddr & 0x0000FFFFFFFFF000) | (am.AMDGPU_PTE_VALID if valid else 0)
def set_page(self, entry_id, paddr, uncached=False, system=False, snooped=False, frag=0, valid=True):
f = (am.AMDGPU_PTE_VALID if valid else 0) | am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE \
@@ -133,11 +127,11 @@ class AMPageTableTraverseContext:
def level_down(self):
pt, pte_idx, _ = self.pt_stack[-1]
if (entry:=pt.get_entry(pte_idx)) & am.AMDGPU_PTE_VALID:
assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.pm.paddr:#x}, {pte_idx=} {entry=:#x}"
child_page_table = AMPageTableEntry(AMPhysicalMemoryBlock(pt.pm.adev, entry & 0x0000FFFFFFFFF000, 0x1000), lv=pt.lv+1)
assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}"
child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1)
else:
assert self.create_pts, "Not allowed to create new page table"
pt.set_table(pte_idx, child_page_table:=AMPageTableEntry(self.adev.mm.palloc(0x1000, zero=True), lv=pt.lv+1))
pt.set_table(pte_idx, child_page_table:=AMPageTableEntry(self.adev, self.adev.mm.palloc(0x1000, zero=True), lv=pt.lv+1))
self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table)))
return self.pt_stack[-1]
@@ -145,7 +139,7 @@ class AMPageTableTraverseContext:
def _try_free_pt(self) -> bool:
pt, _, _ = self.pt_stack[-1]
if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.get_entry(i) & am.AMDGPU_PTE_VALID == 0 for i in range(512)):
self.adev.mm.pfree(AMPhysicalMemoryBlock(self.adev, pt.pm.paddr, 0x1000))
self.adev.mm.pfree(pt.paddr)
parent_pt, parent_pte_idx, _ = self.pt_stack[-2]
parent_pt.set_page(parent_pte_idx, 0x0, valid=False)
return True
@@ -179,7 +173,7 @@ class AMMemoryManager:
self.adev, self.vram_size = adev, vram_size
self.boot_allocator = TLSFAllocator(32 << 20, base=vram_size - (64 << 20)) # per device
self.pa_allocator = TLSFAllocator(vram_size - (64 << 20)) # per device
self.root_page_table = AMPageTableEntry(self.palloc(0x1000, zero=True, boot=True), lv=am.AMDGPU_VM_PDB1)
self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=True, boot=True), lv=am.AMDGPU_VM_PDB1)
def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False) -> AMMapping:
assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}"
@@ -213,12 +207,12 @@ class AMMemoryManager:
# Alloc physical memory and map it to the virtual address
va = self.alloc_vaddr(size, align)
if contigous: paddrs = [(self.palloc(size, zero=True).paddr, size)]
if contigous: paddrs = [(self.palloc(size, zero=True), size)]
else:
paddrs = []
try:
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, va, create_pts=True)
for _, _, _, seg_cnt, seg_size in ctx.next(size): paddrs += [(self.palloc(seg_size, zero=False).paddr, seg_size) for _ in range(seg_cnt)]
for _, _, _, seg_cnt, seg_size in ctx.next(size): paddrs += [(self.palloc(seg_size, zero=False), seg_size) for _ in range(seg_cnt)]
except MemoryError:
for paddr, _ in paddrs: self.pa_allocator.free(paddr)
raise
@@ -230,13 +224,13 @@ class AMMemoryManager:
self.va_allocator.free(vm.va_addr)
for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr)
def palloc(self, size, align=0x1000, zero=True, boot=False) -> AMPhysicalMemoryBlock:
def palloc(self, size:int, align:int=0x1000, zero=True, boot=False) -> int:
assert self.adev.is_booting == boot, "During booting, only boot memory can be allocated"
pm = AMPhysicalMemoryBlock(self.adev, (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align), size)
if zero: ctypes.memset(pm.cpu_addr(), 0, pm.size)
return pm
paddr = (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align)
if zero: ctypes.memset(self.adev.paddr2cpu(paddr), 0, size)
return paddr
def pfree(self, pm:AMPhysicalMemoryBlock): self.pa_allocator.free(pm.paddr)
def pfree(self, paddr:int): self.pa_allocator.free(paddr)
class AMDev:
def __init__(self, pcidev, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview):
@@ -285,13 +279,10 @@ class AMDev:
self.partial_boot = False
if not self.partial_boot:
try: # do not interrupt the boot process
signal.signal(signal.SIGINT, signal.SIG_IGN)
if self.psp.is_sos_alive(): self.smu.mode1_reset()
for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]:
ip.init()
if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
finally: signal.signal(signal.SIGINT, signal.default_int_handler)
if self.psp.is_sos_alive() and self.smu.is_smu_alive(): self.smu.mode1_reset()
for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]:
ip.init()
if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
# Booting done
self.is_booting = False
@@ -309,6 +300,7 @@ class AMDev:
for ip in [self.sdma, self.gfx]: ip.fini()
def paddr2cpu(self, paddr:int) -> int: return mv_address(self.vram) + paddr
def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr
def ip_base(self, ip:str, inst:int, seg:int) -> int: return self.regs_offset[am.__dict__[f"{ip}_HWIP"]][inst][seg]
@@ -337,8 +329,8 @@ class AMDev:
self.reg("regBIF_BX_PF0_RSMU_INDEX").write(reg)
self.reg("regBIF_BX_PF0_RSMU_DATA").write(val)
def wait_reg(self, reg:AMRegister, value:int, mask=0xffffffff) -> int:
for _ in range(10000):
def wait_reg(self, reg:AMRegister, value:int, mask=0xffffffff, timeout=10000) -> int:
for _ in range(timeout):
if ((rval:=reg.read()) & mask) == value: return rval
time.sleep(0.001)
raise RuntimeError(f'wait_reg timeout reg=0x{reg.reg_off:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}')
@@ -348,9 +340,8 @@ class AMDev:
# The table is located at the end of VRAM - 64KB and is 10KB in size.
mmRCC_CONFIG_MEMSIZE = 0xde3
self.vram_size = self.rreg(mmRCC_CONFIG_MEMSIZE) << 20
self.discovery_pm = AMPhysicalMemoryBlock(self, self.vram_size - (64 << 10), 10 << 10)
bhdr = am.struct_binary_header.from_address(self.discovery_pm.cpu_addr())
bhdr = am.struct_binary_header.from_address(self.paddr2cpu(self.vram_size - (64 << 10)))
ihdr = am.struct_ip_discovery_header.from_address(ctypes.addressof(bhdr) + bhdr.table_list[am.IP_DISCOVERY].offset)
assert ihdr.signature == am.DISCOVERY_TABLE_SIGNATURE and not ihdr.base_addr_64_bit, f"0x{ihdr.signature:X} != 0x{am.DISCOVERY_TABLE_SIGNATURE:X}"

View File

@@ -1,4 +1,4 @@
import ctypes, time
import ctypes, time, contextlib
from typing import Literal
from tinygrad.runtime.autogen.am import am, smu_v13_0_0
from tinygrad.helpers import to_mv, data64, lo32, hi32, DEBUG
@@ -25,8 +25,8 @@ class AM_GMC(AM_IP):
self.vm_base = self.adev.mm.va_allocator.base
self.vm_end = self.vm_base + self.adev.mm.va_allocator.size - 1
self.memscratch_pm = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)
self.dummy_page_pm = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)
self.memscratch_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)
self.dummy_page_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)
self.hub_initted = {"MM": False, "GC": False}
def init(self): self.init_hub("MM")
@@ -55,7 +55,7 @@ class AM_GMC(AM_IP):
def enable_vm_addressing(self, page_table, ip:Literal["MM", "GC"], vmid):
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_START_ADDR", "_LO32", "_HI32", self.vm_base >> 12)
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_END_ADDR", "_LO32", "_HI32", self.vm_end >> 12)
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_BASE_ADDR", "_LO32", "_HI32", page_table.pm.paddr | 1)
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_BASE_ADDR", "_LO32", "_HI32", page_table.paddr | 1)
self.adev.reg(f"reg{ip}VM_CONTEXT{vmid}_CNTL").write(0x1fffe00, enable_context=1, page_table_depth=(3 - page_table.lv))
def init_hub(self, ip:Literal["MM", "GC"]):
@@ -66,8 +66,8 @@ class AM_GMC(AM_IP):
self.adev.reg(f"reg{ip}MC_VM_SYSTEM_APERTURE_LOW_ADDR").write(self.mc_base >> 18)
self.adev.reg(f"reg{ip}MC_VM_SYSTEM_APERTURE_HIGH_ADDR").write(self.mc_end >> 18)
self.adev.wreg_pair(f"reg{ip}MC_VM_SYSTEM_APERTURE_DEFAULT_ADDR", "_LSB", "_MSB", self.memscratch_pm.paddr >> 12)
self.adev.wreg_pair(f"reg{ip}VM_L2_PROTECTION_FAULT_DEFAULT_ADDR", "_LO32", "_HI32", self.dummy_page_pm.paddr >> 12)
self.adev.wreg_pair(f"reg{ip}MC_VM_SYSTEM_APERTURE_DEFAULT_ADDR", "_LSB", "_MSB", self.memscratch_paddr >> 12)
self.adev.wreg_pair(f"reg{ip}VM_L2_PROTECTION_FAULT_DEFAULT_ADDR", "_LO32", "_HI32", self.dummy_page_paddr >> 12)
self.adev.reg(f"reg{ip}VM_L2_PROTECTION_FAULT_CNTL2").update(active_page_migration_pte_read_retry=1)
@@ -106,22 +106,26 @@ class AM_SMU(AM_IP):
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetSoftMinByFreq, clck, poll=True)
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetSoftMaxByFreq, clck, poll=True)
def is_smu_alive(self):
with contextlib.suppress(RuntimeError): self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_GetSmuVersion, 0, timeout=100)
return self.adev.mmMP1_SMN_C2PMSG_90.read() != 0
def mode1_reset(self):
if DEBUG >= 2: print(f"am {self.adev.devfmt}: mode1 reset")
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_Mode1Reset, 0, poll=True)
time.sleep(0.5) # 500ms
def _smu_cmn_poll_stat(self): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1)
def _smu_cmn_poll_stat(self, timeout=10000): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1, timeout=timeout)
def _smu_cmn_send_msg(self, msg, param=0):
self.adev.mmMP1_SMN_C2PMSG_90.write(0) # resp reg
self.adev.mmMP1_SMN_C2PMSG_82.write(param)
self.adev.mmMP1_SMN_C2PMSG_66.write(msg)
def _smu_cmn_send_smc_msg_with_param(self, msg, param, poll=True, read_back_arg=False):
if poll: self._smu_cmn_poll_stat()
def _smu_cmn_send_smc_msg_with_param(self, msg, param, poll=True, read_back_arg=False, timeout=10000): # 10s
if poll: self._smu_cmn_poll_stat(timeout=timeout)
self._smu_cmn_send_msg(msg, param)
self._smu_cmn_poll_stat()
self._smu_cmn_poll_stat(timeout=timeout)
return self.adev.rreg(self.adev.mmMP1_SMN_C2PMSG_82) if read_back_arg else None
class AM_GFX(AM_IP):
@@ -232,27 +236,28 @@ class AM_GFX(AM_IP):
class AM_IH(AM_IP):
def __init__(self, adev):
super().__init__(adev)
self.rings = [(self.adev.mm.palloc(512 << 10, boot=True), self.adev.mm.palloc(0x1000, boot=True), "", 0),
(self.adev.mm.palloc(512 << 10, boot=True), self.adev.mm.palloc(0x1000, boot=True), "_RING1", 1)]
self.ring_size = 512 << 10
self.rings = [(self.adev.mm.palloc(self.ring_size, boot=True), self.adev.mm.palloc(0x1000, boot=True), "", 0),
(self.adev.mm.palloc(self.ring_size, boot=True), self.adev.mm.palloc(0x1000, boot=True), "_RING1", 1)]
def interrupt_handler(self):
ring_vm, rwptr_vm, suf, _ = self.rings[0]
wptr = to_mv(rwptr_vm.cpu_addr(), 8).cast('Q')[0]
_, rwptr_vm, suf, _ = self.rings[0]
wptr = to_mv(self.adev.paddr2cpu(rwptr_vm), 8).cast('Q')[0]
if self.adev.reg(f"regIH_RB_WPTR{suf}").read(rb_overflow=1):
self.adev.reg(f"regIH_RB_WPTR{suf}").update(rb_overflow=0)
self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=1)
self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=0)
self.adev.regIH_RB_RPTR.write(wptr % ring_vm.size)
self.adev.regIH_RB_RPTR.write(wptr % self.ring_size)
def init(self):
for ring_vm, rwptr_vm, suf, ring_id in self.rings:
self.adev.wreg_pair("regIH_RB_BASE", suf, f"_HI{suf}", ring_vm.mc_addr() >> 8)
self.adev.wreg_pair("regIH_RB_BASE", suf, f"_HI{suf}", self.adev.paddr2mc(ring_vm) >> 8)
self.adev.reg(f"regIH_RB_CNTL{suf}").write(mc_space=4, wptr_overflow_clear=1, rb_size=(ring_vm.size//4).bit_length(),
self.adev.reg(f"regIH_RB_CNTL{suf}").write(mc_space=4, wptr_overflow_clear=1, rb_size=(self.ring_size//4).bit_length(),
mc_snoop=1, mc_ro=0, mc_vmid=0, **({'wptr_overflow_enable': 1, 'rptr_rearm': 1} if ring_id == 0 else {'rb_full_drain_enable': 1}))
if ring_id == 0: self.adev.wreg_pair("regIH_RB_WPTR_ADDR", "_LO", "_HI", rwptr_vm.mc_addr())
if ring_id == 0: self.adev.wreg_pair("regIH_RB_WPTR_ADDR", "_LO", "_HI", self.adev.paddr2mc(rwptr_vm))
self.adev.reg(f"regIH_RB_WPTR{suf}").write(0)
self.adev.reg(f"regIH_RB_RPTR{suf}").write(0)
@@ -303,10 +308,12 @@ class AM_PSP(AM_IP):
def __init__(self, adev):
super().__init__(adev)
self.msg1_pm = self.adev.mm.palloc(am.PSP_1_MEG, align=am.PSP_1_MEG, zero=not self.adev.partial_boot, boot=True)
self.cmd_pm = self.adev.mm.palloc(am.PSP_CMD_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True)
self.fence_pm = self.adev.mm.palloc(am.PSP_FENCE_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True)
self.ring_pm = self.adev.mm.palloc(0x10000, zero=not self.adev.partial_boot, boot=True)
self.msg1_paddr = self.adev.mm.palloc(am.PSP_1_MEG, align=am.PSP_1_MEG, zero=not self.adev.partial_boot, boot=True)
self.cmd_paddr = self.adev.mm.palloc(am.PSP_CMD_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True)
self.fence_paddr = self.adev.mm.palloc(am.PSP_FENCE_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True)
self.ring_size = 0x10000
self.ring_paddr = self.adev.mm.palloc(self.ring_size, zero=not self.adev.partial_boot, boot=True)
def is_sos_alive(self): return self.adev.regMP0_SMN_C2PMSG_81.read() != 0x0
def init(self):
@@ -316,8 +323,9 @@ class AM_PSP(AM_IP):
(am.PSP_FW_TYPE_PSP_INTF_DRV, am.PSP_BL__LOAD_INTFDRV), (am.PSP_FW_TYPE_PSP_DBG_DRV, am.PSP_BL__LOAD_DBGDRV),
(am.PSP_FW_TYPE_PSP_RAS_DRV, am.PSP_BL__LOAD_RASDRV), (am.PSP_FW_TYPE_PSP_SOS, am.PSP_BL__LOAD_SOSDRV)]
for fw, compid in sos_components_load_order: self._bootloader_load_component(fw, compid)
while not self.is_sos_alive(): time.sleep(0.01)
if not self.is_sos_alive():
for fw, compid in sos_components_load_order: self._bootloader_load_component(fw, compid)
while not self.is_sos_alive(): time.sleep(0.01)
self._ring_create()
self._tmr_init()
@@ -332,8 +340,8 @@ class AM_PSP(AM_IP):
def _wait_for_bootloader(self): self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_35, mask=0xFFFFFFFF, value=0x80000000)
def _prep_msg1(self, data):
ctypes.memset(self.msg1_pm.cpu_addr(), 0, self.msg1_pm.size)
self.msg1_pm.cpu_view()[:len(data)] = data
ctypes.memset(cpu_addr:=self.adev.paddr2cpu(self.msg1_paddr), 0, am.PSP_1_MEG)
to_mv(cpu_addr, len(data))[:] = data
self.adev.gmc.flush_hdp()
def _bootloader_load_component(self, fw, compid):
@@ -342,7 +350,7 @@ class AM_PSP(AM_IP):
self._wait_for_bootloader()
self._prep_msg1(self.adev.fw.sos_fw[fw])
self.adev.regMP0_SMN_C2PMSG_36.write(self.msg1_pm.mc_addr() >> 20)
self.adev.regMP0_SMN_C2PMSG_36.write(self.adev.paddr2mc(self.msg1_paddr) >> 20)
self.adev.regMP0_SMN_C2PMSG_35.write(compid)
return self._wait_for_bootloader()
@@ -350,16 +358,22 @@ class AM_PSP(AM_IP):
def _tmr_init(self):
# Load TOC and calculate TMR size
self._prep_msg1(fwm:=self.adev.fw.sos_fw[am.PSP_FW_TYPE_PSP_TOC])
resp = self._load_toc_cmd(len(fwm))
self.tmr_pm = self.adev.mm.palloc(resp.resp.tmr_size, align=am.PSP_TMR_ALIGNMENT, boot=True)
self.tmr_size = self._load_toc_cmd(len(fwm)).resp.tmr_size
self.tmr_paddr = self.adev.mm.palloc(self.tmr_size, align=am.PSP_TMR_ALIGNMENT, boot=True)
def _ring_create(self):
# If the ring is already created, destroy it
if self.adev.regMP0_SMN_C2PMSG_71.read() != 0:
self.adev.regMP0_SMN_C2PMSG_64.write(am.GFX_CTRL_CMD_ID_DESTROY_RINGS)
# There might be handshake issue with hardware which needs delay
time.sleep(0.02)
# Wait until the sOS is ready
self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_64, mask=0x80000000, value=0x80000000)
self.adev.wreg_pair("regMP0_SMN_C2PMSG", "_69", "_70", self.ring_pm.mc_addr())
self.adev.regMP0_SMN_C2PMSG_71.write(self.ring_pm.size)
self.adev.wreg_pair("regMP0_SMN_C2PMSG", "_69", "_70", self.adev.paddr2mc(self.ring_paddr))
self.adev.regMP0_SMN_C2PMSG_71.write(self.ring_size)
self.adev.regMP0_SMN_C2PMSG_64.write(am.PSP_RING_TYPE__KM << 16)
# There might be handshake issue with hardware which needs delay
@@ -369,28 +383,28 @@ class AM_PSP(AM_IP):
def _ring_submit(self):
prev_wptr = self.adev.regMP0_SMN_C2PMSG_67.read()
ring_entry_addr = self.ring_pm.cpu_addr() + prev_wptr * 4
ring_entry_addr = self.adev.paddr2cpu(self.ring_paddr) + prev_wptr * 4
ctypes.memset(ring_entry_addr, 0, ctypes.sizeof(am.struct_psp_gfx_rb_frame))
write_loc = am.struct_psp_gfx_rb_frame.from_address(ring_entry_addr)
write_loc.cmd_buf_addr_hi, write_loc.cmd_buf_addr_lo = data64(self.cmd_pm.mc_addr())
write_loc.fence_addr_hi, write_loc.fence_addr_lo = data64(self.fence_pm.mc_addr())
write_loc.cmd_buf_addr_hi, write_loc.cmd_buf_addr_lo = data64(self.adev.paddr2mc(self.cmd_paddr))
write_loc.fence_addr_hi, write_loc.fence_addr_lo = data64(self.adev.paddr2mc(self.fence_paddr))
write_loc.fence_value = prev_wptr
# Move the wptr
self.adev.regMP0_SMN_C2PMSG_67.write(prev_wptr + ctypes.sizeof(am.struct_psp_gfx_rb_frame) // 4)
while self.fence_pm.cpu_view().cast('I')[0] != prev_wptr: pass
while to_mv(self.adev.paddr2cpu(self.fence_paddr), 4).cast('I')[0] != prev_wptr: pass
time.sleep(0.005)
resp = am.struct_psp_gfx_cmd_resp.from_address(self.cmd_pm.cpu_addr())
resp = am.struct_psp_gfx_cmd_resp.from_address(self.adev.paddr2cpu(self.cmd_paddr))
if resp.resp.status != 0: raise RuntimeError(f"PSP command failed {resp.cmd_id} {resp.resp.status}")
return resp
def _prep_ring_cmd(self, hdr):
ctypes.memset(self.cmd_pm.cpu_addr(), 0, 0x1000)
cmd = am.struct_psp_gfx_cmd_resp.from_address(self.cmd_pm.cpu_addr())
ctypes.memset(self.adev.paddr2cpu(self.cmd_paddr), 0, 0x1000)
cmd = am.struct_psp_gfx_cmd_resp.from_address(self.adev.paddr2cpu(self.cmd_paddr))
cmd.cmd_id = hdr
return cmd
@@ -400,22 +414,22 @@ class AM_PSP(AM_IP):
self._prep_msg1(fw_bytes)
cmd = self._prep_ring_cmd(am.GFX_CMD_ID_LOAD_IP_FW)
cmd.cmd.cmd_load_ip_fw.fw_phy_addr_hi, cmd.cmd.cmd_load_ip_fw.fw_phy_addr_lo = data64(self.msg1_pm.mc_addr())
cmd.cmd.cmd_load_ip_fw.fw_phy_addr_hi, cmd.cmd.cmd_load_ip_fw.fw_phy_addr_lo = data64(self.adev.paddr2mc(self.msg1_paddr))
cmd.cmd.cmd_load_ip_fw.fw_size = len(fw_bytes)
cmd.cmd.cmd_load_ip_fw.fw_type = fw_type
return self._ring_submit()
def _tmr_load_cmd(self):
cmd = self._prep_ring_cmd(am.GFX_CMD_ID_SETUP_TMR)
cmd.cmd.cmd_setup_tmr.buf_phy_addr_hi, cmd.cmd.cmd_setup_tmr.buf_phy_addr_lo = data64(self.tmr_pm.mc_addr())
cmd.cmd.cmd_setup_tmr.system_phy_addr_hi, cmd.cmd.cmd_setup_tmr.system_phy_addr_lo = data64(self.tmr_pm.paddr)
cmd.cmd.cmd_setup_tmr.buf_phy_addr_hi, cmd.cmd.cmd_setup_tmr.buf_phy_addr_lo = data64(self.adev.paddr2mc(self.tmr_paddr))
cmd.cmd.cmd_setup_tmr.system_phy_addr_hi, cmd.cmd.cmd_setup_tmr.system_phy_addr_lo = data64(self.tmr_paddr)
cmd.cmd.cmd_setup_tmr.bitfield.virt_phy_addr = 1
cmd.cmd.cmd_setup_tmr.buf_size = self.tmr_pm.size
cmd.cmd.cmd_setup_tmr.buf_size = self.tmr_size
return self._ring_submit()
def _load_toc_cmd(self, toc_size):
cmd = self._prep_ring_cmd(am.GFX_CMD_ID_LOAD_TOC)
cmd.cmd.cmd_load_toc.toc_phy_addr_hi, cmd.cmd.cmd_load_toc.toc_phy_addr_lo = data64(self.msg1_pm.mc_addr())
cmd.cmd.cmd_load_toc.toc_phy_addr_hi, cmd.cmd.cmd_load_toc.toc_phy_addr_lo = data64(self.adev.paddr2mc(self.msg1_paddr))
cmd.cmd.cmd_load_toc.toc_size = toc_size
return self._ring_submit()

View File

@@ -1856,8 +1856,8 @@ class Tensor(SimpleMathTrait):
return self.std(axis, keepdim, correction), self.mean(axis, keepdim)
def _softmax(self, axis, dtype:Optional[DTypeLike]=None):
x = self.cast(dtype) if dtype is not None else self
m = x - x.max(axis=axis, keepdim=True).detach()
m = self - self.max(axis=axis, keepdim=True).detach()
if dtype is not None: m = m.cast(dtype)
e = m.exp()
return m, e, e.sum(axis=axis, keepdim=True)