mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -298,7 +298,7 @@ jobs:
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2104 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
PYTHONPATH="." ALLOWED_KERNEL_COUNT=209 ALLOWED_READ_IMAGE=2105 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot alt model correctness (float32)
|
||||
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
|
||||
|
||||
@@ -27,7 +27,7 @@ AM binds compute queues directly to MEC (bypassing MES). Tinygrad uses only one
|
||||
|
||||
The GPU being passed can be in one of several states:
|
||||
1. Not initialized
|
||||
2. Initialized by AMDGPU
|
||||
2. Initialized by amdgpu
|
||||
3. Initialized by AM
|
||||
|
||||
The first and second states require a full GPU setup since their states are unknown. The second state also requires a mode1 reset to reinitialize all components.
|
||||
@@ -36,4 +36,4 @@ The third state can be set up partially to optimize boot time. In this case, onl
|
||||
|
||||
### VM Management
|
||||
|
||||
Each AM device sets up only a single `VMID=0` and one page directory. The page directory used is 3-level and thus supports up to 512TB of virtual addresses. All AM devices are located in one virtual address space.
|
||||
Each AM device sets up only a single `VMID=0` and one page directory. The page directory used is 3-level and thus supports up to 512GB of virtual addresses. All AM devices are located in one virtual address space.
|
||||
@@ -1,14 +1,12 @@
|
||||
import sys, onnx, time
|
||||
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, fetch
|
||||
from tinygrad.tensor import _from_np_dtype
|
||||
from extra.onnx import get_run_onnx
|
||||
from extra.onnx import OnnxRunner
|
||||
|
||||
def load_onnx_model(fn):
|
||||
onnx_file = fetch(fn)
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
Tensor.no_grad = True
|
||||
Tensor.training = False
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
|
||||
# find preinitted tensors and ignore them
|
||||
initted_tensors = {inp.name:None for inp in onnx_model.graph.initializer}
|
||||
|
||||
@@ -8,7 +8,7 @@ import numpy as np
|
||||
import subprocess
|
||||
import tensorflow as tf
|
||||
import tf2onnx
|
||||
from extra.onnx import get_run_onnx
|
||||
from extra.onnx import OnnxRunner
|
||||
from tinygrad.tensor import Tensor
|
||||
from extra.export_model import export_model_clang, compile_net, jit_model
|
||||
|
||||
@@ -25,7 +25,7 @@ class TinyOnnx:
|
||||
def __init__(self, keras_model):
|
||||
input_signature = [tf.TensorSpec([1,32], tf.float32, name='x')]
|
||||
onnx_model, _ = tf2onnx.convert.from_keras(keras_model, input_signature, opset=13)
|
||||
self.run_onnx = get_run_onnx(onnx_model)
|
||||
self.run_onnx = OnnxRunner(onnx_model)
|
||||
|
||||
def forward(self, x):
|
||||
return self.run_onnx({"x": x}, debug=False)['predictions']
|
||||
|
||||
@@ -848,9 +848,9 @@ def train_bert():
|
||||
|
||||
model = get_mlperf_bert_model(init_ckpt if RUNMLPERF else None)
|
||||
|
||||
for _, x in get_state_dict(model).items():
|
||||
x.realize().to_(GPUS)
|
||||
parameters = get_parameters(model)
|
||||
for p in parameters:
|
||||
p.to_(GPUS)
|
||||
|
||||
# ** Log run config **
|
||||
for key, value in config.items(): print(f'HParam: "{key}": {value}')
|
||||
|
||||
@@ -12,17 +12,14 @@ from tinygrad.engine.realize import CompiledRunner
|
||||
|
||||
import onnx
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
from extra.onnx import get_run_onnx # TODO: port to main tinygrad
|
||||
from extra.onnx import OnnxRunner # TODO: port to main tinygrad
|
||||
|
||||
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx"
|
||||
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl"
|
||||
|
||||
def compile(onnx_file):
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
Tensor.no_grad = True
|
||||
Tensor.training = False
|
||||
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
print("loaded model")
|
||||
|
||||
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
from ultralytics import YOLO
|
||||
import onnx
|
||||
from pathlib import Path
|
||||
from extra.onnx import get_run_onnx
|
||||
from extra.onnx import OnnxRunner
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
os.chdir("/tmp")
|
||||
@@ -14,5 +14,5 @@ onnx_model = onnx.load(open("yolov8n-seg.onnx", "rb"))
|
||||
# TODO: move get example inputs to onnx
|
||||
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
|
||||
print(input_shapes)
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
run_onnx({"images": Tensor.zeros(1,3,480,640)}, debug=True)
|
||||
|
||||
290
extra/onnx.py
290
extra/onnx.py
@@ -1,53 +1,42 @@
|
||||
from typing import Callable, Any, Sequence
|
||||
import importlib, functools
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, dtypes
|
||||
import importlib, functools, dataclasses
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv, DEBUG, all_same
|
||||
from tinygrad.dtype import DType, ConstType
|
||||
from tinygrad.dtype import DType, ConstType, dtypes
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from onnx import AttributeProto, ModelProto, TensorProto, ValueInfoProto, helper
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
cache_misses = 0
|
||||
@functools.lru_cache(None)
|
||||
def _cached_to_python_const(t:Tensor):
|
||||
if t.dtype is dtypes.uint8: return t.data().tobytes()
|
||||
if 0 in t.shape: return []
|
||||
return t.tolist()
|
||||
# ***** protobuf parsing ******
|
||||
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto, helper
|
||||
import numpy as np
|
||||
|
||||
# Tensor -> python value cache for parameters
|
||||
def to_python_const(t) -> list[ConstType]|ConstType|bytes:
|
||||
if not isinstance(t, Tensor): return t
|
||||
global cache_misses
|
||||
ret = _cached_to_python_const(t)
|
||||
if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3:
|
||||
print(f"Cache miss for {t}")
|
||||
cache_misses = info.misses
|
||||
return ret
|
||||
|
||||
# TODO: use real float16
|
||||
# src: onnx/mapping.py
|
||||
DTYPE_MAP: dict[int, DType] = {
|
||||
TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8,
|
||||
TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64,
|
||||
TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32,
|
||||
TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16,
|
||||
}
|
||||
def dtype_parse(onnx_dtype: int) -> DType:
|
||||
if onnx_dtype not in DTYPE_MAP: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported")
|
||||
return DTYPE_MAP[onnx_dtype] if is_dtype_supported(DTYPE_MAP[onnx_dtype]) else dtypes.float
|
||||
supported: dict[int, DType] = {
|
||||
TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8,
|
||||
TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64,
|
||||
TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32,
|
||||
TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16,
|
||||
}
|
||||
unsupported = {
|
||||
TensorProto.UNDEFINED, TensorProto.STRING, TensorProto.COMPLEX64, TensorProto.COMPLEX128, TensorProto.FLOAT8E4M3FN, TensorProto.FLOAT8E4M3FNUZ,
|
||||
TensorProto.FLOAT8E5M2, TensorProto.FLOAT8E5M2FNUZ, TensorProto.UINT4, TensorProto.INT4
|
||||
}
|
||||
if onnx_dtype in unsupported: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported")
|
||||
return supported[onnx_dtype] if is_dtype_supported(supported[onnx_dtype]) else dtypes.float
|
||||
|
||||
# src: onnx/onnx_ml_pb2.pyi
|
||||
ATTRIBUTE_MAP: dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = {
|
||||
AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i),
|
||||
AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t),
|
||||
AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints),
|
||||
AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings)
|
||||
}
|
||||
def attribute_parse(onnx_attribute: AttributeProto):
|
||||
if onnx_attribute.type not in ATTRIBUTE_MAP:
|
||||
supported: dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = {
|
||||
AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i),
|
||||
AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t),
|
||||
AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints),
|
||||
AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings)
|
||||
}
|
||||
unsupported = {
|
||||
AttributeProto.UNDEFINED, AttributeProto.GRAPH, AttributeProto.SPARSE_TENSOR, AttributeProto.TYPE_PROTO, AttributeProto.TENSORS,
|
||||
AttributeProto.GRAPHS, AttributeProto.SPARSE_TENSORS, AttributeProto.TYPE_PROTOS
|
||||
}
|
||||
if onnx_attribute.type in unsupported:
|
||||
raise NotImplementedError(f"attribute with type {AttributeProto.AttributeType.Name(onnx_attribute.type)} is not supported")
|
||||
return ATTRIBUTE_MAP[onnx_attribute.type](onnx_attribute)
|
||||
return supported[onnx_attribute.type](onnx_attribute)
|
||||
|
||||
def buffer_parse(onnx_tensor: TensorProto) -> Tensor:
|
||||
if onnx_tensor.string_data: raise NotImplementedError("Parsing for buffer with string data is not implemented.")
|
||||
@@ -62,116 +51,137 @@ def buffer_parse(onnx_tensor: TensorProto) -> Tensor:
|
||||
return Tensor(np_buffer, dtype=dtype)
|
||||
return Tensor(None)
|
||||
|
||||
onnx_ops = importlib.import_module('extra.onnx_ops')
|
||||
ONNXLIMIT = getenv("ONNXLIMIT", -1)
|
||||
def get_run_onnx(onnx_model: ModelProto):
|
||||
# model initialization data
|
||||
model_tensors = {inp.name:buffer_parse(inp) for inp in onnx_model.graph.initializer}
|
||||
model_expected_inputs = {inp.name:inp for inp in onnx_model.graph.input if inp.name not in model_tensors}
|
||||
model_attributes = {num:{x.name:attribute_parse(x) for x in n.attribute} for num,n in enumerate(onnx_model.graph.node)}
|
||||
def type_parse(onnx_type: TypeProto):
|
||||
elem_type = onnx_type
|
||||
if elem_type.HasField("map_type") or elem_type.HasField("sparse_tensor_type") or elem_type.HasField("opaque_type"):
|
||||
raise NotImplementedError("parsing for map_type, sparse_tensor_type and opaque_type are not implemented")
|
||||
if is_optional := elem_type.HasField("optional_type"): elem_type = elem_type.optional_type.elem_type
|
||||
if is_sequence := elem_type.HasField("sequence_type"): elem_type = elem_type.sequence_type.elem_type
|
||||
if elem_type.HasField("tensor_type"):
|
||||
shape = tuple(d.dim_param or d.dim_value for d in elem_type.tensor_type.shape.dim)
|
||||
dtype = dtype_parse(elem_type.tensor_type.elem_type)
|
||||
return OnnxValue(shape, dtype, is_optional, is_sequence)
|
||||
raise RuntimeError(f"TypeProto was not parsed properly: {onnx_type=}")
|
||||
|
||||
# model descriptions
|
||||
# TODO: need a better way of controlling training vs non-training
|
||||
is_onnx_preview_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in onnx_model.graph.node)
|
||||
onnx_model_version = onnx_model.opset_import[0].version
|
||||
# ***** onnx spec *****
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class OnnxValue:
|
||||
shape: tuple[str|int]
|
||||
dtype: DType
|
||||
is_optional: bool
|
||||
is_sequence: bool
|
||||
|
||||
# used to check validity of user_input according to their dimension variables
|
||||
variable_dims = {}
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class OnnxNode:
|
||||
num: int
|
||||
op: str
|
||||
inputs: tuple[str]
|
||||
outputs: tuple[str]
|
||||
opts: dict[str, Any]
|
||||
|
||||
# mapping from onnx ops to tensor.py ops
|
||||
tensor_methods = {
|
||||
op:op.lower() for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan",
|
||||
"Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh",
|
||||
"Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")
|
||||
}
|
||||
# ***** python const *****
|
||||
required_input_python_consts: dict[str, tuple[int, ...]] = {
|
||||
"Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,),
|
||||
"CumSum": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,),
|
||||
"ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4),
|
||||
**{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")},
|
||||
**{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")}
|
||||
}
|
||||
|
||||
# these values are expected to be python consts
|
||||
required_input_python_consts: dict[str, tuple[int, ...]] = {
|
||||
"Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,),
|
||||
"CumSum": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,),
|
||||
"ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4),
|
||||
**{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")},
|
||||
**{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")}
|
||||
}
|
||||
cache_misses = 0
|
||||
@functools.lru_cache(None)
|
||||
def _cached_to_python_const(t:Tensor):
|
||||
if t.dtype is dtypes.uint8: return t.data().tobytes()
|
||||
if 0 in t.shape: return []
|
||||
return t.tolist()
|
||||
|
||||
# src: https://onnx.ai/onnx/repo-docs/IR.html#input-output-data-types
|
||||
# parses and validates inputs based on their shape and dtype specified by model
|
||||
def prepare_input(user_input:Any, model_input:ValueInfoProto):
|
||||
type_proto = model_input.type
|
||||
if type_proto.HasField("optional_type"):
|
||||
if user_input is None: return None
|
||||
type_proto = type_proto.optional_type.elem_type
|
||||
if type_proto.HasField("sequence_type"):
|
||||
if not isinstance(user_input, Sequence): raise RuntimeError(f"{model_input.name} received {user_input}, expected sequence type")
|
||||
dtype = dtype_parse(type_proto.sequence_type.elem_type.tensor_type.elem_type)
|
||||
sequence = [Tensor(i, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(i, Tensor) else i for i in user_input]
|
||||
if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"shapes for {model_input.name} must be homogeneous")
|
||||
# TODO: need true float16 for dtype checking
|
||||
# if not all(t.dtype is dtype for t in sequence):
|
||||
# raise RuntimeError(f"{model_input.name} has dtype mismatch for sequence type. Expected {dtype}, received {tensor.dtype}.")
|
||||
# Tensor -> python value cache for parameters
|
||||
def to_python_const(t:Any, op:str, idx:int) -> list[ConstType]|ConstType|bytes:
|
||||
if idx not in required_input_python_consts.get(op, ()) or not isinstance(t, Tensor): return t
|
||||
global cache_misses
|
||||
ret = _cached_to_python_const(t)
|
||||
if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3:
|
||||
print(f"Cache miss for {t}")
|
||||
cache_misses = info.misses
|
||||
return ret
|
||||
|
||||
# ***** runner ******
|
||||
debug = int(getenv("DEBUGONNX", "0"))
|
||||
limit = int(getenv("ONNXLIMIT", "-1"))
|
||||
class OnnxRunner:
|
||||
def __init__(self, model: ModelProto):
|
||||
# parse model protobuf
|
||||
self.is_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in model.graph.node)
|
||||
self.old_training, self.old_no_grad = Tensor.training, Tensor.no_grad
|
||||
Tensor.training = True if self.is_training else False
|
||||
Tensor.no_grad = False if self.is_training else True
|
||||
self.graph_values = {x.name:buffer_parse(x) for x in model.graph.initializer}
|
||||
self.graph_inputs = {x.name:type_parse(x.type) for x in model.graph.input if x.name not in self.graph_values}
|
||||
self.graph_outputs = {x.name:type_parse(x.type) for x in model.graph.output}
|
||||
self.graph_nodes = tuple(OnnxNode(num, n.op_type, tuple(n.input), tuple(n.output), {x.name:attribute_parse(x) for x in n.attribute})
|
||||
for num,n in enumerate(model.graph.node))
|
||||
self.opset_version = model.opset_import[0].version
|
||||
self.variable_dims: dict[str, int] = {}
|
||||
|
||||
# TODO: move extra.onnx_ops here so we don't have to deal with annoying circular import
|
||||
# TODO: clean up opset stuff after moving extra.onnx_ops here
|
||||
self.onnx_ops_module = importlib.import_module('extra.onnx_ops')
|
||||
self.onnx_ops = {
|
||||
**{op: getattr(Tensor, op.lower()) for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan",
|
||||
"Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh",
|
||||
"Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")},
|
||||
}
|
||||
|
||||
def _parse_input(self, name: str, value: Any, spec: OnnxValue):
|
||||
if spec.is_optional and value is None: return None
|
||||
# TODO: need true float16 for dtype checking
|
||||
if spec.is_sequence:
|
||||
if not isinstance(value, Sequence): raise RuntimeError(f"{name} received {value}, expected a sequence type")
|
||||
sequence = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value]
|
||||
if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"Shapes for {name} sequence must be homogeneous")
|
||||
return sequence
|
||||
if type_proto.HasField("tensor_type"):
|
||||
dtype = dtype_parse(type_proto.tensor_type.elem_type)
|
||||
tensor = Tensor(user_input, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(user_input, Tensor) else user_input
|
||||
# TODO: need true float16 for dtype checking
|
||||
# if dtype is not tensor.dtype: raise RuntimeError(f"{model_input.name} has mismatch for dtype. Expected {dtype}, received {tensor.dtype}.")
|
||||
for dim, onnx_dim in enumerate(type_proto.tensor_type.shape.dim):
|
||||
dim_param, dim_value = onnx_dim.dim_param, onnx_dim.dim_value
|
||||
user_dim_input = tensor.shape[dim]
|
||||
if dim_param: dim_value = variable_dims[dim_param] if dim_param in variable_dims else variable_dims.setdefault(dim_param, user_dim_input)
|
||||
if user_dim_input != dim_value:
|
||||
raise RuntimeError(f"{model_input.name} has mismatch for dim={dim_param or dim}. Expected {dim_value}, received {user_dim_input}.")
|
||||
return tensor
|
||||
type_field_names = [field.name for field,_ in type_proto.ListFields()]
|
||||
raise NotImplementedError(f"{model_input.name} with {type_field_names=} is not supported")
|
||||
tensor = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(value, Tensor) else value
|
||||
for dim, (onnx_dim, user_dim_input) in enumerate(zip(spec.shape, tensor.shape, strict=True)):
|
||||
if isinstance(onnx_dim, str):
|
||||
onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input))
|
||||
if user_dim_input != onnx_dim: raise RuntimeError(f"{name} has mismatch on {dim=}. Expected {onnx_dim}, received {user_dim_input}.")
|
||||
return tensor
|
||||
|
||||
def run_onnx(inputs={}, debug=0):
|
||||
debug = getenv("DEBUGONNX") or debug
|
||||
if debug >= 3: print("Model initialization data:\n" + "\n".join(f"\t{i.name} - {model_tensors[i.name]}" for i in onnx_model.graph.initializer))
|
||||
def _dispatch_op(self, op, inps, opts):
|
||||
if op in self.onnx_ops: return self.onnx_ops[op](*inps, **opts)
|
||||
if hasattr(self.onnx_ops_module, op):
|
||||
fxn = getattr(self.onnx_ops_module, op)
|
||||
if isinstance(fxn, dict):
|
||||
for k in sorted(fxn.keys()):
|
||||
if k <= self.opset_version:
|
||||
real_fxn = fxn[k]
|
||||
else: real_fxn = fxn
|
||||
return real_fxn(*inps, **opts)
|
||||
raise NotImplementedError(f"{op=} not supported")
|
||||
|
||||
if debug >= 1: print("Model input:")
|
||||
for name, value_info in model_expected_inputs.items():
|
||||
def __call__(self, inputs:dict[str, Any], debug=debug):
|
||||
for name, input_spec in self.graph_inputs.items():
|
||||
if name not in inputs: raise RuntimeError(f"Please provide input data for {name}")
|
||||
model_tensors[name] = prepare_input(inputs[name], value_info)
|
||||
if debug >= 1: print(f"\t{name} - {model_tensors[name]}")
|
||||
if debug >= 2: print(f"\t\t{MessageToDict(value_info.type)}")
|
||||
self.graph_values[name] = self._parse_input(name, inputs[name], input_spec)
|
||||
|
||||
for num,n in enumerate(onnx_model.graph.node):
|
||||
inp_tensors = [model_tensors.get(x) for x in n.input]
|
||||
required_consts = required_input_python_consts.get(n.op_type, ())
|
||||
inp = [to_python_const(t) if i in required_consts else t for i,t in enumerate(inp_tensors)]
|
||||
opt = model_attributes[num]
|
||||
for node in self.graph_nodes:
|
||||
inps = [to_python_const(self.graph_values.get(name), node.op, i) for i,name in enumerate(node.inputs)]
|
||||
opts = node.opts
|
||||
|
||||
if debug >= 1: print(f"{num}: op \"{n.op_type}\" input shapes {[x.shape if isinstance(x, Tensor) else x for x in inp_tensors]} opt {opt}")
|
||||
if debug >= 3:
|
||||
print("\tinputs:")
|
||||
print("\n".join(f"\t\t{x} - {t!r}" + (" (to_python_const)" if i in required_consts else "") for i,(x,t) in enumerate(zip(n.input, inp))))
|
||||
# provide additional opts
|
||||
if node.op == "Split" and 'num_outputs' not in opts: opts['num_outputs'] = len(node.outputs)
|
||||
if node.op == "Gradient": opts['intermediate_tensors'] = self.graph_values
|
||||
|
||||
# provide additional arguments
|
||||
if n.op_type == "Split" and 'num_outputs' not in opt: opt['num_outputs'] = len(n.output)
|
||||
if n.op_type == "Gradient": opt['intermediate_tensors'] = model_tensors
|
||||
if debug >= 1: print(f"{node.num}: op '{node.op}' opt {opts}")
|
||||
if debug >= 2 and node.inputs: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {i!r}" for x,i in zip(node.inputs, inps)))
|
||||
ret = self._dispatch_op(node.op, inps, opts)
|
||||
ret = ret if isinstance(ret, tuple) else (ret,)
|
||||
if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{x} - {o!r}" for x,o in zip(node.outputs, ret)))
|
||||
|
||||
# run op
|
||||
if n.op_type in tensor_methods: ret = getattr(Tensor, tensor_methods[n.op_type])(*inp, **opt)
|
||||
elif hasattr(onnx_ops, n.op_type):
|
||||
fxn = getattr(onnx_ops, n.op_type)
|
||||
if isinstance(fxn, dict):
|
||||
for k in sorted(fxn.keys()):
|
||||
if k <= onnx_model_version:
|
||||
real_fxn = fxn[k]
|
||||
else:
|
||||
real_fxn = fxn
|
||||
ret = real_fxn(*inp, **opt)
|
||||
else:
|
||||
print("UNSUPPORTED", n.op_type, n.input, n.output)
|
||||
raise NotImplementedError(f"op_type {n.op_type} not supported")
|
||||
self.graph_values.update(dict(zip(node.outputs, ret[:len(node.outputs)], strict=True)))
|
||||
|
||||
# finalization after running the op
|
||||
if not isinstance(ret, tuple): ret = (ret, )
|
||||
if len(n.output) > len(ret): raise RuntimeError(f"expected output size must be less than {len(ret)}, it's {n.output}")
|
||||
for i in range(len(n.output)): model_tensors[n.output[i]] = ret[i]
|
||||
if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{n.output[i]} - {ret[i]}" for i in range(len(n.output))))
|
||||
|
||||
if num == ONNXLIMIT: return {name:model_tensors[name] for name in n.output}
|
||||
return {x.name:model_tensors[x.name] for x in onnx_model.graph.output}
|
||||
return run_onnx
|
||||
if node.num == limit:
|
||||
Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad
|
||||
return {name:self.graph_values[name] for name in node.outputs}
|
||||
Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad
|
||||
return {name:self.graph_values[name] for name in self.graph_outputs}
|
||||
@@ -3,7 +3,7 @@ from typing import cast, Literal
|
||||
from tinygrad.tensor import Tensor, _broadcast_shape, ConstType, ReductionStr
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.helpers import prod, flatten, make_tuple
|
||||
from extra.onnx import dtype_parse, to_python_const
|
||||
from extra.onnx import dtype_parse, _cached_to_python_const
|
||||
import numpy as np
|
||||
|
||||
# **************** Free Ops ****************
|
||||
@@ -282,7 +282,7 @@ def Gather(x:Tensor, indices:Tensor, axis:int=0):
|
||||
x_sh = list(x.shape)
|
||||
ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:]
|
||||
if indices.ndim > 1: indices = indices.flatten()
|
||||
indices = [to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in to_python_const(indices)] # type: ignore
|
||||
indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)]
|
||||
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore
|
||||
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
|
||||
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
|
||||
@@ -575,12 +575,9 @@ from tinygrad.nn.optim import SGD
|
||||
def onnx_training(input_group_size):
|
||||
def _decorator(func):
|
||||
def __wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs):
|
||||
old_training = Tensor.training
|
||||
Tensor.training = True
|
||||
R = R.detach()
|
||||
groups = len(inputs) // input_group_size
|
||||
ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))]
|
||||
Tensor.training = old_training
|
||||
return tuple(flatten(zip(*ret)))
|
||||
return __wrapper
|
||||
return _decorator
|
||||
|
||||
@@ -2,7 +2,7 @@ import time, sys, hashlib
|
||||
from pathlib import Path
|
||||
import onnx
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
from extra.onnx import get_run_onnx
|
||||
from extra.onnx import OnnxRunner
|
||||
from tinygrad import Tensor, dtypes, TinyJit
|
||||
from tinygrad.helpers import IMAGE, GlobalCounters, fetch, colored, getenv, trange
|
||||
from tinygrad.tensor import _from_np_dtype
|
||||
@@ -11,11 +11,8 @@ import numpy as np
|
||||
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
Tensor.training = False
|
||||
|
||||
onnx_model = onnx.load(onnx_path := fetch(OPENPILOT_MODEL))
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
|
||||
Tensor.manual_seed(100)
|
||||
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
|
||||
|
||||
39
test/external/external_fuzz_am_interrupts.py
vendored
Normal file
39
test/external/external_fuzz_am_interrupts.py
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
import subprocess
|
||||
import random
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
def run_test(i, full_run=False):
|
||||
print(f"\rRunning iteration {i}...", end=" ", flush=True)
|
||||
|
||||
p = subprocess.Popen(['python3', 'test/test_tiny.py', 'TestTiny.test_plus'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
if not full_run:
|
||||
time.sleep(random.uniform(0, 1200) / 1000)
|
||||
p.kill()
|
||||
_, stderr = p.communicate()
|
||||
else:
|
||||
_, stderr = p.communicate()
|
||||
|
||||
if full_run:
|
||||
stderr_text = stderr.decode()
|
||||
print(stderr_text)
|
||||
assert "Ran 1 test in" in stderr_text and "OK" in stderr_text
|
||||
|
||||
max_workers = 4
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = []
|
||||
for i in range(1000000):
|
||||
if i % 100 == 0:
|
||||
for future in as_completed(futures):
|
||||
try: future.result()
|
||||
except Exception as e:
|
||||
print(f"\nError in iteration: {e}")
|
||||
futures = []
|
||||
|
||||
run_test(i, True)
|
||||
else:
|
||||
future = executor.submit(run_test, i, False)
|
||||
futures.append(future)
|
||||
|
||||
if len(futures) > max_workers * 2: futures = [f for f in futures if not f.done()]
|
||||
6
test/external/external_model_benchmark.py
vendored
6
test/external/external_model_benchmark.py
vendored
@@ -6,7 +6,7 @@ import onnx
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
import onnxruntime as ort
|
||||
from onnx2torch import convert
|
||||
from extra.onnx import get_run_onnx
|
||||
from extra.onnx import OnnxRunner
|
||||
from tinygrad.helpers import OSX, DEBUG, fetch
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.device import CompileError
|
||||
@@ -65,7 +65,7 @@ def benchmark_model(m, devices, validate_outs=False):
|
||||
try:
|
||||
Device.DEFAULT = device
|
||||
inputs = {k:Tensor(inp) for k,inp in np_inputs.items()}
|
||||
tinygrad_model = get_run_onnx(onnx_model)
|
||||
tinygrad_model = OnnxRunner(onnx_model)
|
||||
benchmark(m, f"tinygrad_{device.lower()}_jitless", lambda: {k:v.numpy() for k,v in tinygrad_model(inputs).items()})
|
||||
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
@@ -115,7 +115,7 @@ def benchmark_model(m, devices, validate_outs=False):
|
||||
rtol, atol = 2e-3, 2e-3 # tolerance for fp16 models
|
||||
Device.DEFAULT = device
|
||||
inputs = {k:Tensor(inp) for k,inp in np_inputs.items()}
|
||||
tinygrad_model = get_run_onnx(onnx_model)
|
||||
tinygrad_model = OnnxRunner(onnx_model)
|
||||
tinygrad_out = tinygrad_model(inputs)
|
||||
|
||||
ort_sess = ort.InferenceSession(str(fn), ort_options, ["CPUExecutionProvider"])
|
||||
|
||||
3
test/external/external_test_am.py
vendored
3
test/external/external_test_am.py
vendored
@@ -1,5 +1,6 @@
|
||||
import unittest
|
||||
from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableTraverseContext
|
||||
from tinygrad.helpers import mv_address
|
||||
|
||||
class FakeGMC:
|
||||
def __init__(self): self.vm_base = 0x0
|
||||
@@ -19,6 +20,8 @@ class FakeAM:
|
||||
self.gmc = FakeGMC()
|
||||
self.mm = AMMemoryManager(self, vram_size=4 << 30)
|
||||
self.is_booting = False
|
||||
def paddr2cpu(self, paddr:int) -> int: return paddr + mv_address(self.vram)
|
||||
def paddr2mc(self, paddr:int) -> int: return paddr
|
||||
|
||||
# * PTE format:
|
||||
# * 63:59 reserved
|
||||
|
||||
6
test/external/external_test_onnx_backend.py
vendored
6
test/external/external_test_onnx_backend.py
vendored
@@ -10,7 +10,7 @@ from tinygrad.device import is_dtype_supported
|
||||
# pip3 install tabulate
|
||||
pytest_plugins = 'onnx.backend.test.report',
|
||||
|
||||
from extra.onnx import get_run_onnx
|
||||
from extra.onnx import OnnxRunner
|
||||
|
||||
class TinygradModel(BackendRep):
|
||||
def __init__(self, run_onnx, input_names):
|
||||
@@ -20,7 +20,7 @@ class TinygradModel(BackendRep):
|
||||
|
||||
def run(self, inputs: Any, **kwargs: Any) -> Tuple[Any, ...]:
|
||||
real_inputs = dict(zip(self.input_names, inputs))
|
||||
ret = self.fxn(real_inputs, debug=True)
|
||||
ret = self.fxn(real_inputs, debug=2)
|
||||
return tuple(x.numpy() if isinstance(x, Tensor) else [i.numpy() for i in x] if isinstance(x, list) else np.array(x) for x in ret.values())
|
||||
|
||||
class TinygradBackend(Backend):
|
||||
@@ -30,7 +30,7 @@ class TinygradBackend(Backend):
|
||||
input_initializer = [x.name for x in model.graph.initializer]
|
||||
net_feed_input = [x for x in input_all if x not in input_initializer]
|
||||
print("prepare", cls, device, net_feed_input)
|
||||
run_onnx = get_run_onnx(model)
|
||||
run_onnx = OnnxRunner(model)
|
||||
return TinygradModel(run_onnx, net_feed_input)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -7,7 +7,7 @@ try:
|
||||
import onnx
|
||||
except ModuleNotFoundError:
|
||||
raise unittest.SkipTest("onnx not installed, skipping onnx test")
|
||||
from extra.onnx import get_run_onnx
|
||||
from extra.onnx import OnnxRunner
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import CI, fetch, temp
|
||||
|
||||
@@ -26,7 +26,7 @@ np.random.seed(1337)
|
||||
class TestOnnxModel(unittest.TestCase):
|
||||
def test_benchmark_openpilot_model(self):
|
||||
onnx_model = onnx.load(fetch(OPENPILOT_MODEL))
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
def get_inputs():
|
||||
np_inputs = {
|
||||
"input_imgs": np.random.randn(*(1, 12, 128, 256)),
|
||||
@@ -70,7 +70,7 @@ class TestOnnxModel(unittest.TestCase):
|
||||
|
||||
def test_openpilot_model(self):
|
||||
onnx_model = onnx.load(fetch(OPENPILOT_MODEL))
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
print("got run_onnx")
|
||||
inputs = {
|
||||
"input_imgs": np.random.randn(*(1, 12, 128, 256)),
|
||||
@@ -124,7 +124,7 @@ class TestOnnxModel(unittest.TestCase):
|
||||
onnx_model = onnx.load(fn)
|
||||
print("onnx loaded")
|
||||
from test.models.test_efficientnet import chicken_img, car_img, preprocess, _LABELS
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
|
||||
def run(img):
|
||||
inputs = {input_name: preprocess(img, new=input_new)}
|
||||
|
||||
@@ -166,7 +166,7 @@ class TestIndexing(unittest.TestCase):
|
||||
GlobalCounters.reset()
|
||||
z = emb(x).realize()
|
||||
self.assertLessEqual(GlobalCounters.global_ops, op_limit)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 3)
|
||||
if getenv("CHECK", 1):
|
||||
import torch
|
||||
with torch.no_grad():
|
||||
|
||||
@@ -220,7 +220,9 @@ class TestMultiConstFolding(unittest.TestCase):
|
||||
t = Tensor.arange(16).float().realize().to(ds)
|
||||
|
||||
# non const folding case creates one ast on each shard
|
||||
_check_ast_count(4, t + 1)
|
||||
# NOTE: there's extra contiguous kernels here since it's realizing both the CONTIGUOUS and its parent COPY
|
||||
# why does multi call contiguous on a COPY?
|
||||
_check_ast_count(7, t + 1)
|
||||
_check_ast_count(4, 1 + t)
|
||||
_check_ast_count(4, t * 2)
|
||||
_check_ast_count(4, 2 * t)
|
||||
|
||||
@@ -113,6 +113,8 @@ class TestImageDType(unittest.TestCase):
|
||||
assert it.lazydata.base.realized._buf != b1
|
||||
|
||||
# issue caused by: don't realize image to image casts. this is part of a larger problem
|
||||
#@unittest.expectedFailure
|
||||
# update: passing after tensor_map
|
||||
def test_lil_model(self):
|
||||
with Context(IMAGE=2):
|
||||
x = Tensor.zeros(1, 1)
|
||||
@@ -121,7 +123,10 @@ class TestImageDType(unittest.TestCase):
|
||||
loss = x.image_dot(w1).image_dot(w2).float().max()
|
||||
loss.backward()
|
||||
sched = unwrap(w1.grad).schedule()
|
||||
self.assertEqual(len(sched), 9)
|
||||
# NOTE: the w1 grad must realize to a seperate kernel
|
||||
assert w1.grad.lazydata.is_realized, f"never realized {w1.grad}"
|
||||
self.assertEqual(w1.grad.lazydata.base.buffer.dtype, dtypes.float32)
|
||||
self.assertEqual(len(sched), 10)
|
||||
for s,ei in zip(sched, lower_schedule(sched[:])):
|
||||
ei.run()
|
||||
if s.outputs[0].dtype == dtypes.float:
|
||||
|
||||
@@ -318,6 +318,7 @@ class TestJit(unittest.TestCase):
|
||||
assert len(res3) == 10, "All values should be different, rand works in jit."
|
||||
assert res3 != res2, "Jit rand is diff with diff seeds"
|
||||
|
||||
@unittest.expectedFailure # requires contiguous folding
|
||||
def test_jit_random_after_unrealized_random(self):
|
||||
@TinyJit
|
||||
def f(): return Tensor.rand()
|
||||
|
||||
@@ -63,7 +63,11 @@ def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, d
|
||||
|
||||
class TestLinearizer(unittest.TestCase):
|
||||
def test_arg_dedup(self):
|
||||
a, b = Tensor.randn(4), Tensor.randn(4)
|
||||
# NOTE: this realize exists because Tensor.numpy calls .contiguous() internally
|
||||
# without contiguous folding, rand.to("CLANG") and rand.contiguous().to("CLANG") are different UOps.
|
||||
# this test asserts they are the identical Buffer
|
||||
# having different buffers is fine for correctness, because the outputs match.
|
||||
a, b = Tensor.randn(4).realize(), Tensor.randn(4).realize()
|
||||
np_a, np_b = a.numpy(), b.numpy()
|
||||
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),))))
|
||||
lowered = list(lower_schedule(c.schedule()))
|
||||
@@ -1690,6 +1694,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
# should upcast the two Tensor.stacks
|
||||
assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2
|
||||
|
||||
@unittest.expectedFailure # requires contiguous folding
|
||||
def test_masked_upcast_wino_full(self):
|
||||
with Context(WINO=1):
|
||||
x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()
|
||||
|
||||
@@ -734,7 +734,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
zeros = Tensor.zeros(3).realize()
|
||||
b = a.to(devices_2)*zeros.to(devices_2)
|
||||
sched = b.schedule()
|
||||
self.assertEqual(len(sched), 6)
|
||||
self.assertEqual(len(sched), 8)
|
||||
# notably, only two copies (for the arange) - vs 4 copies if we didn't fold the const copy
|
||||
self.assertEqual(len([x for x in sched if any(u.op is Ops.COPY for u in x.ast.toposort)]), 2)
|
||||
# all these kernels are just because multi calls contiguous on every single shard
|
||||
|
||||
@@ -16,7 +16,7 @@ from tinygrad.shape.view import View
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same
|
||||
from tinygrad.codegen.kernel import verify_ast
|
||||
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, ops_folding
|
||||
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
@@ -67,7 +67,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
||||
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext())
|
||||
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, ScheduleContext())
|
||||
|
||||
class TestSchedule(unittest.TestCase):
|
||||
def test_basic_binop_fusion(self):
|
||||
@@ -220,7 +220,7 @@ class TestSchedule(unittest.TestCase):
|
||||
GlobalCounters.reset()
|
||||
expr = (a*b)/b
|
||||
expr.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0) # the scheduler can fold divs now!
|
||||
self.assertEqual(GlobalCounters.global_ops, 0)
|
||||
np.testing.assert_allclose(expr.numpy(), np.full((4,), 4.0))
|
||||
|
||||
@@ -229,7 +229,7 @@ class TestSchedule(unittest.TestCase):
|
||||
GlobalCounters.reset()
|
||||
expr = a/a
|
||||
expr.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
self.assertEqual(GlobalCounters.global_ops, 0)
|
||||
np.testing.assert_allclose(expr.numpy(), np.full((4,), 1.0))
|
||||
|
||||
@@ -972,6 +972,26 @@ class TestSchedule(unittest.TestCase):
|
||||
expected = (x_exp:=np.exp(x.numpy()-x.numpy().max(-1, keepdims=True)))/x_exp.sum(-1, keepdims=True)
|
||||
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_softmax_upcast(self):
|
||||
# input half, softmax in float
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.half).realize()
|
||||
out = x.softmax(dtype=dtypes.float)
|
||||
sched = out.schedule()
|
||||
self.assertEqual(len(sched), 3)
|
||||
self.assertEqual(len(sched[0].outputs), 1)
|
||||
self.assertEqual(sched[0].outputs[0].dtype, dtypes.half)
|
||||
|
||||
# input float, softmax in float
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.float).realize()
|
||||
out = x.softmax(dtype=dtypes.float)
|
||||
sched = out.schedule()
|
||||
self.assertEqual(len(sched), 3)
|
||||
self.assertEqual(len(sched[0].outputs), 1)
|
||||
self.assertEqual(sched[0].outputs[0].dtype, dtypes.float)
|
||||
|
||||
def test_softmax_backward(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4, 12, 64, 64, requires_grad=True).realize()
|
||||
@@ -1804,7 +1824,7 @@ def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.
|
||||
|
||||
# these pattern matchers should move to engine/schedule.py
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
ops_folding = symbolic_simple+PatternMatcher([
|
||||
(UPat(Ops.DETACH, name="x"), lambda x:x.src[0]),
|
||||
])
|
||||
|
||||
@@ -1822,8 +1842,8 @@ def run_tensor_ast(r:Tensor):
|
||||
output = UOp.new_buffer(r.device, r.lazydata.size, r.dtype)
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, output.dtype.ptr(size=output.size), (), 0)
|
||||
sink = UOp(Ops.STORE, src=(glbl, ShapeTracker.from_shape(r.lazydata.base.shape).to_uop(), r.lazydata.base)).sink()
|
||||
sink = graph_rewrite(sink, remove_movement_ops+sym+load_buffers+view_left, bufs:=[output])
|
||||
sink = graph_rewrite(sink, remove_movement_ops+sym+view_right)
|
||||
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+load_buffers+view_left, bufs:=[output])
|
||||
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+view_right)
|
||||
si = ScheduleItem(sink, tuple(x.buffer for x in bufs), (), ())
|
||||
run_schedule([si])
|
||||
return output.realized.as_buffer().cast(output.dtype.fmt, r.shape).tolist()
|
||||
@@ -2184,7 +2204,7 @@ class TestConst(unittest.TestCase):
|
||||
sched = add.schedule()
|
||||
self.assertEqual(len(sched), 0)
|
||||
# b+0 and b share the same underlying device memory
|
||||
self.assertIs(add.lazydata.realized, b.lazydata.realized)
|
||||
self.assertIs(add.lazydata.buffer, b.lazydata.buffer)
|
||||
self.assertListEqual(add.tolist(), [2, 2, 2, 2])
|
||||
|
||||
def test_src_masked_const_folding(self):
|
||||
@@ -2238,6 +2258,17 @@ class TestCopyFolding(unittest.TestCase):
|
||||
add = schedule_graph_rewrite(add)
|
||||
assert all_same([x.device for x in add.src]), f"ALU has different devices! {[x.device for x in add.src]}"
|
||||
|
||||
def test_copy_to_same_device(self):
|
||||
a = Tensor.empty(4).lazydata
|
||||
b = a.copy_to_device(a.device)
|
||||
check_schedule(b, 0, filter_sink=False)
|
||||
b = schedule_graph_rewrite(b)
|
||||
self.assertIs(b, a)
|
||||
|
||||
def test_clone(self):
|
||||
a = Tensor.empty(4).lazydata
|
||||
check_schedule(a.clone(), 1, filter_sink=False)
|
||||
|
||||
class TestTensorUOpSpec(unittest.TestCase):
|
||||
def test_const_must_be_unmasked(self):
|
||||
a = Tensor.ones((4, 4)).pad((2, 2))
|
||||
@@ -2253,6 +2284,12 @@ class TestTensorUOpSpec(unittest.TestCase):
|
||||
t = graph_rewrite(a.lazydata.sink(), remove_movement_ops+merge_views)
|
||||
create_schedule_with_vars(t)
|
||||
|
||||
def test_symbolic_shape_ok(self):
|
||||
a = Tensor.ones(4)
|
||||
vi = UOp.variable("i", 1, 10).bind(4)
|
||||
t = graph_rewrite(a.reshape(vi).sum().lazydata, remove_movement_ops+merge_views)
|
||||
create_schedule_with_vars(t)
|
||||
|
||||
class TestBufferUOp(unittest.TestCase):
|
||||
# BUFFER has a ShapeTracker of shape=(n,) and stride=(1,)
|
||||
def test_buffer_has_buffer(self):
|
||||
@@ -2316,34 +2353,80 @@ class TestBufferUOp(unittest.TestCase):
|
||||
|
||||
class TestContiguous(unittest.TestCase):
|
||||
def test_contiguous_buffer(self):
|
||||
a = Tensor.empty(4).lazydata
|
||||
b = a.alu(Ops.CONTIGUOUS)
|
||||
b = schedule_graph_rewrite(b)
|
||||
self.assertIs(b, a)
|
||||
a = Tensor.empty(4)
|
||||
b = a.contiguous()
|
||||
check_schedule(b, 0)
|
||||
|
||||
def test_contiguous_buffer_view(self):
|
||||
a = Tensor.empty(4).lazydata
|
||||
b = a.reshape((2, 2)).alu(Ops.CONTIGUOUS)
|
||||
b = schedule_graph_rewrite(b)
|
||||
self.assertIs(b, a.buf_uop.view(unwrap(b.st)))
|
||||
a = Tensor.empty(4)
|
||||
b = a.reshape((2, 2)).contiguous()
|
||||
check_schedule(b, 0)
|
||||
|
||||
def test_non_contiguous_buffer_view(self):
|
||||
a = Tensor.empty(4, 1).lazydata
|
||||
b = a.expand((4, 4)).alu(Ops.CONTIGUOUS)
|
||||
b = schedule_graph_rewrite(b)
|
||||
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
|
||||
a = Tensor.empty(4, 1)
|
||||
b = a.expand((4, 4)).contiguous()
|
||||
check_schedule(b, 1)
|
||||
|
||||
def test_size_change_buffer_view(self):
|
||||
a = Tensor.empty(4).lazydata
|
||||
b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).alu(Ops.CONTIGUOUS)
|
||||
b = schedule_graph_rewrite(b)
|
||||
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
|
||||
a = Tensor.empty(4)
|
||||
b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).contiguous()
|
||||
check_schedule(b, 1)
|
||||
|
||||
def test_double_contiguous_realizes_once(self):
|
||||
a = Tensor.empty(4, 1).lazydata
|
||||
b = a.expand((4, 4)).alu(Ops.CONTIGUOUS).alu(Ops.CONTIGUOUS)
|
||||
b = schedule_graph_rewrite(b)
|
||||
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
|
||||
a = Tensor.empty(4, 1)
|
||||
b = a.expand((4, 4)).contiguous().contiguous()
|
||||
check_schedule(b, 1)
|
||||
|
||||
|
||||
class TestUOpBecome(unittest.TestCase):
|
||||
# the simplest case, if we create a new BUFFER for this UOp
|
||||
def test_new_buffer(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = Tensor.empty(4, 4)
|
||||
add = a+b
|
||||
check_schedule(add, 1)
|
||||
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {})
|
||||
|
||||
def test_new_buffer_view(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = Tensor.empty(4, 4)
|
||||
add = (a+b).reshape(8, 2)
|
||||
check_schedule(add, 1)
|
||||
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {})
|
||||
# VIEW is preserverd after the becomes rewrite.
|
||||
self.assertEqual(add.lazydata.shape, (8, 2))
|
||||
assert add.lazydata is not add.lazydata.base
|
||||
|
||||
def test_become_existing_buffer(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = a*1
|
||||
assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul
|
||||
check_schedule(b, 0)
|
||||
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER)
|
||||
self.assertIs(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
|
||||
def test_become_const_in_base(self):
|
||||
a = Tensor.empty(4)
|
||||
b = a*0
|
||||
assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul
|
||||
check_schedule(b, 0)
|
||||
assert UPat(Ops.CONST, arg=0).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER)
|
||||
|
||||
def test_become_const_in_view(self):
|
||||
# if we shrink the base down to a size 0, only the VIEW becomes CONST, base is unchanged.
|
||||
add = Tensor.empty(2, 2)+Tensor.empty(2, 2)
|
||||
b = add.shrink(((0, 1), (0, 0)))
|
||||
check_schedule(b, 0)
|
||||
assert UPat(Ops.CONST, arg=0).match(b.lazydata, {})
|
||||
self.assertEqual(b.shape, (1, 0))
|
||||
# the base is untouched.
|
||||
assert UPat(Ops.ADD).match(add.lazydata, {})
|
||||
|
||||
def test_become_const_from_const(self):
|
||||
const_add = Tensor(1)+Tensor(2)
|
||||
assert UPat(Ops.ADD).match(const_add.lazydata, {})
|
||||
check_schedule(const_add, 0)
|
||||
assert UPat(Ops.CONST, arg=3).match(const_add.lazydata.base, {})
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -69,7 +69,8 @@ class TestSetitem(unittest.TestCase):
|
||||
t[1] ^= 5
|
||||
np.testing.assert_allclose(t.numpy(), [[0, 1], [7, 6]])
|
||||
|
||||
@unittest.expectedFailure
|
||||
#@unittest.expectedFailure
|
||||
# update: passing after delete_forced_realize
|
||||
def test_setitem_consecutive_inplace_operator(self):
|
||||
t = Tensor.arange(4).reshape(2, 2).contiguous()
|
||||
t[1] += 2
|
||||
|
||||
@@ -2,6 +2,7 @@ import unittest
|
||||
from tinygrad import Device, dtypes, Tensor
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.ops import view_supported_devices
|
||||
from tinygrad.helpers import Context
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported")
|
||||
class TestSubBuffer(unittest.TestCase):
|
||||
@@ -47,5 +48,22 @@ class TestSubBuffer(unittest.TestCase):
|
||||
out = vt.to(f"{Device.DEFAULT}:1").realize().tolist()
|
||||
assert out == [2, 3, 4]
|
||||
|
||||
def test_subbuffer_deallocate(self):
|
||||
with Context(LRU=0):
|
||||
vbuf = self.buf.view(2, dtypes.uint8, offset=3).ensure_allocated()
|
||||
self.buf.deallocate()
|
||||
vbuf.deallocate()
|
||||
|
||||
# Allocate a fake one on the same place
|
||||
_ = Buffer(Device.DEFAULT, 10, dtypes.uint8).ensure_allocated()
|
||||
|
||||
self.buf.ensure_allocated()
|
||||
self.buf.copyin(memoryview(bytearray(range(10, 20))))
|
||||
|
||||
vbuf.ensure_allocated()
|
||||
|
||||
tst = vbuf.as_buffer().tolist()
|
||||
assert tst == [13, 14]
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -14,6 +14,7 @@ from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_ker
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.rewriter import full_graph_rewrite, sym
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
|
||||
def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)
|
||||
|
||||
@@ -365,6 +366,17 @@ class TestAssembly(unittest.TestCase):
|
||||
self.assertIn(Ops.SHR, ops)
|
||||
self.assertIn(Ops.IDIV, ops)
|
||||
|
||||
def test_mulacc_unrolled(self):
|
||||
# test that acc = acc + a0*b0 + a1*b1 + a2*b2 + a3*b3
|
||||
# is not acc = acc + (a0*b0 + a1*b1 + a2*b2 + a3*b3)
|
||||
a = Tensor.empty(1024)
|
||||
b = Tensor.empty(1024)
|
||||
c = (a*b).sum()
|
||||
k = Kernel(c.schedule()[-1].ast)
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
|
||||
uops = k.linearize().uops
|
||||
self.assertEqual(len([x.op for x in uops if x.op is Ops.MULACC]), 4)
|
||||
|
||||
class TestUOpMethod(unittest.TestCase):
|
||||
@unittest.skip("uops lt no longer ordered")
|
||||
def test_compare_alu_same_src_different_arg(self):
|
||||
|
||||
@@ -164,6 +164,7 @@ class TestSafetensors(unittest.TestCase):
|
||||
def test_save_all_dtypes(self):
|
||||
for dtype in dtypes.fields().values():
|
||||
if dtype in [dtypes.bfloat16]: continue # not supported in numpy
|
||||
if dtype in [dtypes.double] and Device.DEFAULT == "METAL": continue # not supported on METAL
|
||||
path = temp(f"ones.{dtype}.safetensors")
|
||||
ones = Tensor(np.random.rand(10,10), dtype=dtype)
|
||||
safe_save(get_state_dict(ones), path)
|
||||
|
||||
@@ -104,7 +104,8 @@ class TestRealizeMeansRealize(unittest.TestCase):
|
||||
x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize()
|
||||
self.assertEqual(x.lazydata.op, Ops.VIEW)
|
||||
|
||||
@unittest.expectedFailure
|
||||
#@unittest.expectedFailure
|
||||
# update: passing after delete_forced_realize
|
||||
def test_uniform_realizes(self):
|
||||
x = Tensor.uniform(16, 3, 3, 3, requires_grad=True).realize()
|
||||
print(x.lazydata)
|
||||
|
||||
@@ -2,7 +2,7 @@ import unittest, math
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.helpers import all_same
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, exec_alu
|
||||
from tinygrad.codegen.rewriter import full_graph_rewrite
|
||||
from tinygrad.codegen.rewriter import full_graph_rewrite, mulacc_unrolled
|
||||
|
||||
# Helper function to apply the graph rewrite
|
||||
def apply_rewrite(expr):
|
||||
@@ -274,5 +274,41 @@ class TestSubstitute(unittest.TestCase):
|
||||
ret = substitute(ret, {a.sin():a.sqrt(), n1.sin():n1.sqrt()})
|
||||
self.assertIs(ret, a.sqrt().sqrt())
|
||||
|
||||
class TestMulaccUnrolledAcc(unittest.TestCase):
|
||||
def test_unrolled2(self):
|
||||
acc_range = (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))
|
||||
acc = UOp(Ops.DEFINE_ACC, dtypes.int, (UOp.const(dtypes.int, 0),) + acc_range, (0,))
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = UOp.variable('b', 0, 10)
|
||||
expr = acc.assign(acc + (a*2 + b*3))
|
||||
expr_with_mulacc = graph_rewrite(expr, mulacc_unrolled)
|
||||
self.assertIs(expr_with_mulacc, acc.assign(acc + a*2 + b*3))
|
||||
|
||||
def test_unrolled4_float(self):
|
||||
acc_range = (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 3))
|
||||
acc = UOp(Ops.DEFINE_ACC, dtypes.float32, (UOp.const(dtypes.int, 0),)+acc_range, (0,))
|
||||
|
||||
a = [UOp.variable(f'a{i}', float("-inf"), float("inf"), dtype=dtypes.float32) for i in range(4)]
|
||||
b = [UOp.variable(f'b{i}', float("-inf"), float("inf"), dtype=dtypes.float32) for i in range(4)]
|
||||
|
||||
expr = acc.assign(acc + (a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3]))
|
||||
expr_with_mulacc = graph_rewrite(expr, mulacc_unrolled)
|
||||
|
||||
# Verify it unrolls into individual multiply-accumulate operations
|
||||
expected = acc.assign(acc + a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3])
|
||||
self.assertIs(expr_with_mulacc, expected)
|
||||
|
||||
def test_unrolled4_float_const(self):
|
||||
acc_range = (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 3))
|
||||
acc = UOp(Ops.DEFINE_ACC, dtypes.float32, (UOp.const(dtypes.int, 0),)+acc_range, (0,))
|
||||
|
||||
a = [UOp.variable(f'a{i}', float("-inf"), float("inf"), dtype=dtypes.float32) for i in range(4)]
|
||||
expr = acc.assign(acc + (a[0]*3.0 + a[1]*4.0 + a[2]*5.0 + a[3]*6.0))
|
||||
expr_with_mulacc = graph_rewrite(expr, mulacc_unrolled)
|
||||
|
||||
# Verify it unrolls into individual multiply-accumulate operations
|
||||
expected = acc.assign(acc + a[0]*3.0 + a[1]*4.0 + a[2]*5.0 + a[3]*6.0)
|
||||
self.assertIs(expr_with_mulacc, expected)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -53,7 +53,8 @@ class TestTensorUopRepresentation(unittest.TestCase):
|
||||
b = Tensor([4.,5,6]).realize()
|
||||
c = a+b
|
||||
print(c.lazydata)
|
||||
is_pattern(c, UPat(Ops.ADD, src=(UPat(Ops.VIEW, src=(realized_pattern,)), UPat(Ops.VIEW, src=(realized_pattern,)))))
|
||||
is_pattern(c, UPat(Ops.ADD, src=(realized_pattern, realized_pattern)))
|
||||
#is_pattern(c, UPat(Ops.ADD, src=(UPat(Ops.VIEW, src=(realized_pattern,)), UPat(Ops.VIEW, src=(realized_pattern,)))))
|
||||
|
||||
def test_const_pattern(self):
|
||||
a = Tensor(1)
|
||||
@@ -71,9 +72,9 @@ class TestTensorUopRepresentation(unittest.TestCase):
|
||||
def test_viewed_consts_do_not_realize(self):
|
||||
a = Tensor.ones(10, 10)
|
||||
print(a.lazydata)
|
||||
pre_realize = a.lazydata
|
||||
a.realize()
|
||||
assert a.lazydata is pre_realize
|
||||
is_pattern(a, const_pattern)
|
||||
self.assertEqual(a.lazydata.shape, (10, 10))
|
||||
|
||||
# currently, CONSTs have a "fake" BUFFER. this should be fixed
|
||||
# current:
|
||||
@@ -111,7 +112,8 @@ class TestTensorUopRepresentation(unittest.TestCase):
|
||||
c = a.to("TEST") # NOTE: this isn't checked
|
||||
print(c.lazydata)
|
||||
# TODO: COPY on a Tensor becomes a VIEW(COPY), this should be done in the scheduler not in ops
|
||||
is_pattern(c, UPat(Ops.VIEW, src=(UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)),)))
|
||||
is_pattern(c, UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)))
|
||||
#is_pattern(c, UPat(Ops.VIEW, src=(UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)),)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -239,6 +239,9 @@ index_load = UPat.var("buf").index(rng_aug).load(name="ld")
|
||||
arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug))
|
||||
arange_m = ((arange_augrng<UPat.cvar("compval"))!=UPat(Ops.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0))
|
||||
|
||||
# this moves the accumulation variable down an unrolled add chain which allows for more efficient accumulation using mulacc
|
||||
mulacc_unrolled = PatternMatcher([(UPat.var("x")+UPat.var("y")+acc_pat, lambda x,y,acc: (acc+x)+y if y.op is not Ops.DEFINE_ACC else None)])
|
||||
|
||||
# this is symbolic 2.0
|
||||
sym = symbolic_flat+PatternMatcher([
|
||||
# self ASSIGN is just self
|
||||
@@ -504,8 +507,9 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
# expand
|
||||
sink = graph_rewrite(sink, sym+expander)
|
||||
|
||||
# devectorize + load_store_indexing
|
||||
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing)
|
||||
# devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse
|
||||
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing+
|
||||
mulacc_unrolled)
|
||||
|
||||
# final rules for the renderer (without sym)
|
||||
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher)
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections import defaultdict
|
||||
from typing import Optional, Any, Iterator, Generator
|
||||
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
|
||||
from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE
|
||||
from tinygrad.helpers import CI, OSX, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
|
||||
cpu_time_execution
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -129,7 +129,7 @@ class Buffer:
|
||||
if self._base is None and (self.options is None or self.options.external_ptr is None):
|
||||
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
|
||||
self.allocator.free(self._buf, self.nbytes, self.options)
|
||||
del self._buf
|
||||
del self._buf
|
||||
def __reduce__(self):
|
||||
buf = None
|
||||
if self._base is not None:
|
||||
@@ -202,7 +202,7 @@ class LRUAllocator(Allocator):
|
||||
for opaque in opaques: super().free(opaque, sz, options)
|
||||
opaques.clear()
|
||||
def free(self, opaque:Any, size:int, options:Optional[BufferSpec]=None):
|
||||
if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
|
||||
if LRU and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
|
||||
else: super().free(opaque, size, options)
|
||||
|
||||
class _MallocAllocator(LRUAllocator):
|
||||
@@ -310,7 +310,7 @@ if PROFILE:
|
||||
for dev in devs: dev.synchronize()
|
||||
for dev in devs: dev._at_profile_finalize()
|
||||
|
||||
with open(temp("profile.pkl"), "wb") as f: pickle.dump(Compiled.profile_events, f)
|
||||
with open(fn:=temp("profile.pkl", append_user=True), "wb") as f: pickle.dump(Compiled.profile_events, f)
|
||||
|
||||
from tinygrad.ops import launch_viz
|
||||
launch_viz("PROFILE", temp("profile.pkl"))
|
||||
launch_viz("PROFILE", fn)
|
||||
|
||||
@@ -2,7 +2,7 @@ import sys, atexit, functools, pickle
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
|
||||
from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify
|
||||
from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify, graph_rewrite_map
|
||||
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, ContextVar
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes
|
||||
@@ -31,9 +31,9 @@ tensor_uop_spec = PatternMatcher([
|
||||
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
|
||||
(UPat(Ops.DEFINE_VAR, src=(UPat(Ops.VIEW, arg=ShapeTracker.from_shape(()))), arg=None), lambda: True),
|
||||
|
||||
# Tensor const has an unmasked ShapeTracker of stride 0 and a device
|
||||
# Tensor const has a device and an unmasked ShapeTracker of stride 0 or a ShapeTracker with symbolic shape
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)),
|
||||
lambda st: len(st.st.views) == 1 and all(s == 0 for s in st.st.views[0].strides) and st.st.views[0].mask is None),
|
||||
lambda st: st.st.views[0].mask is None and ((len(st.st.views) == 1 and all(s == 0 for s in st.st.views[0].strides)) or not all_int(st.shape))),
|
||||
|
||||
# DETACH and CONTIGUOUS change how we interpret the source UOp
|
||||
# CONTIGUOUS ensures the source UOp realizes
|
||||
@@ -88,15 +88,15 @@ class ScheduleContext:
|
||||
|
||||
# wrap tensor uops around a VIEW(BUFFER, <uop>)
|
||||
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
|
||||
def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
|
||||
def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
# SINK is passthrough
|
||||
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, ctx, cache) for x in buf.src))
|
||||
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src))
|
||||
# skip creating buffers for CONST/BIND/DEVICE/BUFFER
|
||||
if buf.base.is_realized or buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf
|
||||
# VIEW is passthrough
|
||||
if buf is not buf.base:
|
||||
cache[buf] = ret = add_buffers(buf.base, ctx, cache).view(unwrap(buf.st))
|
||||
cache[buf] = ret = add_buffers(buf.base, tensor_map, ctx, cache).view(unwrap(buf.st))
|
||||
return ret
|
||||
# make things that can't be images not images
|
||||
dtype = buf.dtype
|
||||
@@ -105,11 +105,11 @@ def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
|
||||
dtype = buf.dtype.base
|
||||
# ASSIGN already has a target buffer, otherwise we create a new one
|
||||
buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, ctx, cache) for x in buf.src))
|
||||
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src))
|
||||
# track the underlying tensor uop for this buffer
|
||||
ctx.tensor_uops[buf_uop] = [buf]
|
||||
ctx.tensor_uops[buf_uop] = tensor_map[buf]
|
||||
# (early) bufferize
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st)
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
|
||||
return ret
|
||||
|
||||
# **** AST graph rewrite
|
||||
@@ -242,7 +242,7 @@ if CAPTURE_PROCESS_REPLAY:
|
||||
def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER
|
||||
def uval(u:UOp) -> UOp:
|
||||
assert is_scheduled(u), f"must be a scheduled op {u}"
|
||||
return r.src[0] if (r:=u.src[1]).op is Ops.CONTIGUOUS and not (r.src[0].base.op is Ops.VIEW and len(r.src[0].base.src) == 2) else r
|
||||
return u.src[1]
|
||||
|
||||
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp], realizes:dict[UOp, UOp],
|
||||
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
|
||||
@@ -329,7 +329,7 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]:
|
||||
# maybe fuse arange with its children
|
||||
for rbuf in reduce_of_const:
|
||||
group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
|
||||
if any(luop.forced_realize for tr in group for luop in ctx.tensor_uops[tr]): continue
|
||||
if any(luop.op is Ops.CONTIGUOUS for tr in group for luop in ctx.tensor_uops[tr]): continue
|
||||
kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}}
|
||||
if len(kernel_children) == 0: continue
|
||||
for tr in group: del ctx.realizes[tr]
|
||||
@@ -340,10 +340,6 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]:
|
||||
|
||||
# **** Schedule creation and BFS toposort
|
||||
|
||||
class UPatScheduled(UPat):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
|
||||
|
||||
# ** this is schedule level const folding
|
||||
|
||||
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
|
||||
@@ -358,22 +354,18 @@ def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
|
||||
case _: return None
|
||||
return reduce.const_like(ret)
|
||||
|
||||
def found_contiguous(ctx:ScheduleContext, contig:UOp, base:UOp, b:UOp):
|
||||
if contig.src[0].op is Ops.VIEW and len(contig.src[0].src):
|
||||
old_base = contig.src[0].src[0]
|
||||
if old_base.op is Ops.VIEW and (sti:=unwrap(contig.src[0].st).invert(old_base.shape)) is not None: ctx.contiguous[old_base] = base.view(sti)
|
||||
def found_contiguous(ctx:ScheduleContext, contig:UOp, src:UOp):
|
||||
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx.contiguous[src.base] = contig.view(sti)
|
||||
def replace_contiguous(ctx:ScheduleContext, alu:UOp):
|
||||
new_src = list(alu.src)
|
||||
for i,s in enumerate(alu.src):
|
||||
if (replace_src:=ctx.contiguous.get(s, None)) is not None: new_src[i] = replace_src
|
||||
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
|
||||
|
||||
ops_folding = symbolic_simple+PatternMatcher([
|
||||
# op with size 0 is zero
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
# UOp with size 0 is zero
|
||||
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
|
||||
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
|
||||
# if the uop folded to a CONST we can delete the BUFFER
|
||||
(UPatScheduled(Ops.CONST, name="const"), lambda b,base,const: base.const_like(const.const_arg)),
|
||||
# DETACH is a NOOP here
|
||||
(UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]),
|
||||
# reduce of size 0 is the identity element
|
||||
@@ -386,13 +378,16 @@ ops_folding = symbolic_simple+PatternMatcher([
|
||||
# no COPY to same device, except clone (arg is True)
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
|
||||
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
|
||||
# remove cast to image when it's already a contiguous image
|
||||
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)),
|
||||
lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
||||
# remove contiguous if we can just view the buffer
|
||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
|
||||
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
|
||||
# double contiguous is one contiguous
|
||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.CONTIGUOUS),)), lambda root: root.src[0]),
|
||||
# support for using a contiguous permuted view instead of the parent view if one exists
|
||||
(UPatScheduled(Ops.CONTIGUOUS, name="contig"), found_contiguous),
|
||||
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
|
||||
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
|
||||
# remove CONST/BIND/BUFFER/VIEW from SINK
|
||||
(UPat(Ops.SINK, name="root"),
|
||||
@@ -400,36 +395,12 @@ ops_folding = symbolic_simple+PatternMatcher([
|
||||
if (new_src:=tuple(x.base for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
|
||||
])
|
||||
|
||||
# ** buffer merging
|
||||
|
||||
def merge(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp) -> UOp:
|
||||
assert v1.st is not None and v2.st is not None and v1.st == v2.st, f"implicit movementop {v1.st} {v2.st}"
|
||||
# if b2 is realized also realize b1
|
||||
if b2 in ctx.realizes:
|
||||
ctx.realizes[b1] = b1
|
||||
del ctx.realizes[b2]
|
||||
# ops referring to b2 now ref to b1
|
||||
ctx.tensor_uops[b1] += ctx.tensor_uops[b2]
|
||||
del ctx.tensor_uops[b2]
|
||||
# merge
|
||||
return v1
|
||||
|
||||
def merge_realized(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp):
|
||||
# early become
|
||||
for luop in ctx.tensor_uops.get(b1, [])+ctx.tensor_uops.get(b2, []): ctx.becomes_map[luop] = b1.view(unwrap(luop.st))
|
||||
return v1
|
||||
|
||||
merge_bufs = PatternMatcher([
|
||||
# merge base
|
||||
(UPat(Ops.VIEW, name="v2", src=(UPat(Ops.BUFFER, name="b2"), UPat(Ops.VIEW, name="v1", src=(UPat(Ops.BUFFER, name="b1"), UPat())))), merge),
|
||||
(UPat(Ops.VIEW, name="v2", src=(UPat(Ops.BUFFER, name="b2"), UPat(Ops.VIEW, name="v1", src=(UPat(Ops.BUFFER, name="b1"),)))), merge_realized),
|
||||
# collapse view
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat())).view(name="mv"))), lambda mv:mv),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).view(name="mv"))), lambda mv:mv),
|
||||
])
|
||||
|
||||
# ** this decides which ops get realized
|
||||
|
||||
class UPatScheduled(UPat):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
|
||||
|
||||
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
|
||||
|
||||
def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
|
||||
@@ -448,8 +419,7 @@ def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs)
|
||||
return x.view(unwrap(view.st))
|
||||
|
||||
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
|
||||
if not root.device.startswith("DISK"): return None
|
||||
if x.op is not Ops.VIEW: x = x.src[-1] # TODO: remove this once forced_realize is gone
|
||||
if not b.device.startswith("DISK"): return None
|
||||
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
|
||||
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
|
||||
|
||||
@@ -482,7 +452,7 @@ def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
|
||||
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
|
||||
|
||||
def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
|
||||
if (m:=ctx.tensor_uops[b][0].metadata) is not None: ctx.ops_metadata[x] = m
|
||||
if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m
|
||||
if b not in ctx.realizes: return x # collapse BUFFER
|
||||
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
|
||||
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
|
||||
@@ -523,28 +493,36 @@ remove_movement_ops = PatternMatcher([
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
||||
# if using VIZ, do a graph rewrite to vizualize the Tensor graph
|
||||
if getenv("VIZ"): graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext())
|
||||
if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec)
|
||||
# to_uop is removing (many) of the movement ops
|
||||
sink = add_buffers(big_sink, ctx:=ScheduleContext(), cache={})
|
||||
# const folding and fusion
|
||||
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+do_realize, ctx)
|
||||
sink = graph_rewrite(sink, merge_bufs, ctx)
|
||||
# create the scheduler context
|
||||
graph_rewrite(sink, create_ctx, ctx)
|
||||
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx:=ScheduleContext())
|
||||
rev_tensor_map: dict[UOp, list[UOp]] = {}
|
||||
for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k)
|
||||
# add BUFFER uops
|
||||
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx, cache={})
|
||||
# add realizes
|
||||
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
|
||||
# group realizes into kernels
|
||||
store_groups = group_realizes(ctx)
|
||||
graph_rewrite(sink, break_sched, ctx)
|
||||
# preschedule realize groups
|
||||
prescheduled: list[ScheduleItem] = []
|
||||
for store_uops in store_groups:
|
||||
if len(stores:=[ctx.realizes[u] for u in store_uops if ctx.realizes[u].op is Ops.STORE]) == 0: continue
|
||||
prescheduled.append(schedule_uop(UOp.sink(*stores), ctx))
|
||||
small_sink = UOp.sink(*[ctx.realizes[u] for u in store_uops])
|
||||
if not all(x.op is Ops.STORE for x in small_sink.src): raise RuntimeError(f"expected all realized BUFFERs to get a STORE {sink}")
|
||||
prescheduled.append(schedule_uop(small_sink, ctx))
|
||||
# can only schedule once
|
||||
for buf_uop in store_uops:
|
||||
for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st))
|
||||
|
||||
# tensors can become an existing buffer or simplify to a const, no ScheduleItem needed
|
||||
for k,v in tensor_map.items():
|
||||
# NOOP
|
||||
if k.base is v.base: continue
|
||||
# NOTE: only the base tensors get a BUFFER UOp
|
||||
if v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st))
|
||||
# otherwise if it simplified to a CONST the UOp just becomes that CONST
|
||||
elif v.op is Ops.CONST: ctx.becomes_map[k] = v
|
||||
|
||||
# add kernel children
|
||||
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
|
||||
graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list)
|
||||
|
||||
@@ -78,7 +78,8 @@ def polyN(x:T, p:list[float]) -> T: return functools.reduce(lambda acc,c: acc*x+
|
||||
def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
|
||||
def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
|
||||
def temp(x:str, append_user:bool=False) -> str:
|
||||
return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{os.getenv('USERNAME', os.getlogin())}" if append_user else x)).as_posix()
|
||||
|
||||
class Context(contextlib.ContextDecorator):
|
||||
def __init__(self, **kwargs): self.kwargs = kwargs
|
||||
@@ -107,7 +108,7 @@ WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1),
|
||||
USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1)
|
||||
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
|
||||
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
|
||||
PICKLE_BUFFERS, PROFILE = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ"))
|
||||
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
||||
@@ -233,7 +233,6 @@ class UOpMetaClass(type):
|
||||
# some uops map to other stuff
|
||||
buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers
|
||||
all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary()
|
||||
forced_realize:weakref.WeakSet[UOp] = weakref.WeakSet()
|
||||
|
||||
# NOTE: this should be frozen, but frozen is slower
|
||||
@dataclass(eq=False, slots=True)
|
||||
@@ -409,11 +408,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
|
||||
return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
||||
def contiguous(self):
|
||||
if not unwrap(self.st).contiguous or self.size != self.base.size or self.base.op is Ops.CONST:
|
||||
return self.alu(Ops.CONTIGUOUS)
|
||||
forced_realize.add(self.base)
|
||||
return self
|
||||
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
|
||||
|
||||
# *** from LazyBuffer ***
|
||||
|
||||
@@ -432,19 +427,22 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
# otherwise it's just a VIEW(BUFFER)
|
||||
return UOp(Ops.VIEW, dtype, (UOp.new_buffer(device, (st:=ShapeTracker.from_shape(shape)).size, dtype),), st)
|
||||
def copy_to_device(self, device:str, clone:bool=False) -> UOp:
|
||||
# no COPY
|
||||
if self.device == device and not clone: return self
|
||||
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
|
||||
if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device)
|
||||
# COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st)
|
||||
return UOp(Ops.COPY, self.base.dtype, (UOp(Ops.DEVICE, arg=device), self.base), clone).view(unwrap(self.st))
|
||||
ret = UOp(Ops.COPY, self.base.dtype, (UOp(Ops.DEVICE, arg=device), self.base), clone)
|
||||
op_arg = []
|
||||
mop = self
|
||||
while mop is not self.base:
|
||||
op_arg.append((mop.op, mop.arg))
|
||||
mop = mop.src[0]
|
||||
for op,arg in reversed(op_arg): ret = UOp(op, ret.dtype, (ret,), arg)
|
||||
return ret
|
||||
def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True)
|
||||
@property
|
||||
def lbs(self): return [self]
|
||||
@property
|
||||
def metadata(self): return all_metadata.get(self, None)
|
||||
@property
|
||||
def forced_realize(self): return self in forced_realize
|
||||
|
||||
# *** uop movement ops ***
|
||||
|
||||
@@ -822,10 +820,10 @@ if TRACK_MATCH_STATS:
|
||||
@atexit.register
|
||||
def print_match_stats():
|
||||
if TRACK_MATCH_STATS >= 2:
|
||||
with open(fn:=temp("rewrites.pkl"), "wb") as f:
|
||||
with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f:
|
||||
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
|
||||
with Context(PICKLE_BUFFERS=0): pickle.dump((tracked_keys, tracked_ctxs), f)
|
||||
if getenv("VIZ"): launch_viz("VIZ", temp("rewrites.pkl"))
|
||||
if getenv("VIZ"): launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
|
||||
if getenv("PRINT_MATCH_STATS", 1):
|
||||
ret = [0,0,0.0,0.0]
|
||||
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import ctypes, collections, time, dataclasses, pathlib, fcntl, os, signal
|
||||
import ctypes, collections, time, dataclasses, pathlib, fcntl, os
|
||||
from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp
|
||||
from tinygrad.runtime.autogen.am import am, mp_11_0, mp_13_0_0, nbio_4_3_0, mmhub_3_0_0, gc_11_0_0, osssys_6_0_0
|
||||
from tinygrad.runtime.support.allocator import TLSFAllocator
|
||||
@@ -98,20 +98,14 @@ class AMFirmware:
|
||||
|
||||
def desc(self, typ:int, blob:memoryview, offset:int, size:int) -> tuple[int, memoryview]: return (typ, blob[offset:offset+size])
|
||||
|
||||
class AMPhysicalMemoryBlock:
|
||||
def __init__(self, adev:AMDev, paddr:int, size:int): self.adev, self.paddr, self.size = adev, paddr, size
|
||||
def mc_addr(self): return self.adev.gmc.mc_base + self.paddr
|
||||
def cpu_addr(self): return mv_address(self.adev.vram) + self.paddr
|
||||
def cpu_view(self): return to_mv(self.cpu_addr(), self.size)
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class AMMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702
|
||||
|
||||
class AMPageTableEntry:
|
||||
def __init__(self, pm, lv): self.pm, self.view, self.lv = pm, pm.cpu_view().cast('Q'), lv
|
||||
def __init__(self, adev, paddr, lv): self.paddr, self.view, self.lv = paddr, to_mv(adev.paddr2cpu(paddr), 0x1000).cast('Q'), lv
|
||||
|
||||
def set_table(self, entry_id, pte:AMPageTableEntry, valid=True):
|
||||
self.view[entry_id] = (pte.pm.paddr & 0x0000FFFFFFFFF000) | (am.AMDGPU_PTE_VALID if valid else 0)
|
||||
self.view[entry_id] = (pte.paddr & 0x0000FFFFFFFFF000) | (am.AMDGPU_PTE_VALID if valid else 0)
|
||||
|
||||
def set_page(self, entry_id, paddr, uncached=False, system=False, snooped=False, frag=0, valid=True):
|
||||
f = (am.AMDGPU_PTE_VALID if valid else 0) | am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE \
|
||||
@@ -133,11 +127,11 @@ class AMPageTableTraverseContext:
|
||||
def level_down(self):
|
||||
pt, pte_idx, _ = self.pt_stack[-1]
|
||||
if (entry:=pt.get_entry(pte_idx)) & am.AMDGPU_PTE_VALID:
|
||||
assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.pm.paddr:#x}, {pte_idx=} {entry=:#x}"
|
||||
child_page_table = AMPageTableEntry(AMPhysicalMemoryBlock(pt.pm.adev, entry & 0x0000FFFFFFFFF000, 0x1000), lv=pt.lv+1)
|
||||
assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}"
|
||||
child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1)
|
||||
else:
|
||||
assert self.create_pts, "Not allowed to create new page table"
|
||||
pt.set_table(pte_idx, child_page_table:=AMPageTableEntry(self.adev.mm.palloc(0x1000, zero=True), lv=pt.lv+1))
|
||||
pt.set_table(pte_idx, child_page_table:=AMPageTableEntry(self.adev, self.adev.mm.palloc(0x1000, zero=True), lv=pt.lv+1))
|
||||
|
||||
self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table)))
|
||||
return self.pt_stack[-1]
|
||||
@@ -145,7 +139,7 @@ class AMPageTableTraverseContext:
|
||||
def _try_free_pt(self) -> bool:
|
||||
pt, _, _ = self.pt_stack[-1]
|
||||
if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.get_entry(i) & am.AMDGPU_PTE_VALID == 0 for i in range(512)):
|
||||
self.adev.mm.pfree(AMPhysicalMemoryBlock(self.adev, pt.pm.paddr, 0x1000))
|
||||
self.adev.mm.pfree(pt.paddr)
|
||||
parent_pt, parent_pte_idx, _ = self.pt_stack[-2]
|
||||
parent_pt.set_page(parent_pte_idx, 0x0, valid=False)
|
||||
return True
|
||||
@@ -179,7 +173,7 @@ class AMMemoryManager:
|
||||
self.adev, self.vram_size = adev, vram_size
|
||||
self.boot_allocator = TLSFAllocator(32 << 20, base=vram_size - (64 << 20)) # per device
|
||||
self.pa_allocator = TLSFAllocator(vram_size - (64 << 20)) # per device
|
||||
self.root_page_table = AMPageTableEntry(self.palloc(0x1000, zero=True, boot=True), lv=am.AMDGPU_VM_PDB1)
|
||||
self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=True, boot=True), lv=am.AMDGPU_VM_PDB1)
|
||||
|
||||
def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False) -> AMMapping:
|
||||
assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}"
|
||||
@@ -213,12 +207,12 @@ class AMMemoryManager:
|
||||
# Alloc physical memory and map it to the virtual address
|
||||
va = self.alloc_vaddr(size, align)
|
||||
|
||||
if contigous: paddrs = [(self.palloc(size, zero=True).paddr, size)]
|
||||
if contigous: paddrs = [(self.palloc(size, zero=True), size)]
|
||||
else:
|
||||
paddrs = []
|
||||
try:
|
||||
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, va, create_pts=True)
|
||||
for _, _, _, seg_cnt, seg_size in ctx.next(size): paddrs += [(self.palloc(seg_size, zero=False).paddr, seg_size) for _ in range(seg_cnt)]
|
||||
for _, _, _, seg_cnt, seg_size in ctx.next(size): paddrs += [(self.palloc(seg_size, zero=False), seg_size) for _ in range(seg_cnt)]
|
||||
except MemoryError:
|
||||
for paddr, _ in paddrs: self.pa_allocator.free(paddr)
|
||||
raise
|
||||
@@ -230,13 +224,13 @@ class AMMemoryManager:
|
||||
self.va_allocator.free(vm.va_addr)
|
||||
for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr)
|
||||
|
||||
def palloc(self, size, align=0x1000, zero=True, boot=False) -> AMPhysicalMemoryBlock:
|
||||
def palloc(self, size:int, align:int=0x1000, zero=True, boot=False) -> int:
|
||||
assert self.adev.is_booting == boot, "During booting, only boot memory can be allocated"
|
||||
pm = AMPhysicalMemoryBlock(self.adev, (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align), size)
|
||||
if zero: ctypes.memset(pm.cpu_addr(), 0, pm.size)
|
||||
return pm
|
||||
paddr = (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align)
|
||||
if zero: ctypes.memset(self.adev.paddr2cpu(paddr), 0, size)
|
||||
return paddr
|
||||
|
||||
def pfree(self, pm:AMPhysicalMemoryBlock): self.pa_allocator.free(pm.paddr)
|
||||
def pfree(self, paddr:int): self.pa_allocator.free(paddr)
|
||||
|
||||
class AMDev:
|
||||
def __init__(self, pcidev, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview):
|
||||
@@ -285,13 +279,10 @@ class AMDev:
|
||||
self.partial_boot = False
|
||||
|
||||
if not self.partial_boot:
|
||||
try: # do not interrupt the boot process
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
if self.psp.is_sos_alive(): self.smu.mode1_reset()
|
||||
for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]:
|
||||
ip.init()
|
||||
if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
|
||||
finally: signal.signal(signal.SIGINT, signal.default_int_handler)
|
||||
if self.psp.is_sos_alive() and self.smu.is_smu_alive(): self.smu.mode1_reset()
|
||||
for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]:
|
||||
ip.init()
|
||||
if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
|
||||
|
||||
# Booting done
|
||||
self.is_booting = False
|
||||
@@ -309,6 +300,7 @@ class AMDev:
|
||||
for ip in [self.sdma, self.gfx]: ip.fini()
|
||||
|
||||
def paddr2cpu(self, paddr:int) -> int: return mv_address(self.vram) + paddr
|
||||
def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr
|
||||
|
||||
def ip_base(self, ip:str, inst:int, seg:int) -> int: return self.regs_offset[am.__dict__[f"{ip}_HWIP"]][inst][seg]
|
||||
|
||||
@@ -337,8 +329,8 @@ class AMDev:
|
||||
self.reg("regBIF_BX_PF0_RSMU_INDEX").write(reg)
|
||||
self.reg("regBIF_BX_PF0_RSMU_DATA").write(val)
|
||||
|
||||
def wait_reg(self, reg:AMRegister, value:int, mask=0xffffffff) -> int:
|
||||
for _ in range(10000):
|
||||
def wait_reg(self, reg:AMRegister, value:int, mask=0xffffffff, timeout=10000) -> int:
|
||||
for _ in range(timeout):
|
||||
if ((rval:=reg.read()) & mask) == value: return rval
|
||||
time.sleep(0.001)
|
||||
raise RuntimeError(f'wait_reg timeout reg=0x{reg.reg_off:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}')
|
||||
@@ -348,9 +340,8 @@ class AMDev:
|
||||
# The table is located at the end of VRAM - 64KB and is 10KB in size.
|
||||
mmRCC_CONFIG_MEMSIZE = 0xde3
|
||||
self.vram_size = self.rreg(mmRCC_CONFIG_MEMSIZE) << 20
|
||||
self.discovery_pm = AMPhysicalMemoryBlock(self, self.vram_size - (64 << 10), 10 << 10)
|
||||
|
||||
bhdr = am.struct_binary_header.from_address(self.discovery_pm.cpu_addr())
|
||||
bhdr = am.struct_binary_header.from_address(self.paddr2cpu(self.vram_size - (64 << 10)))
|
||||
ihdr = am.struct_ip_discovery_header.from_address(ctypes.addressof(bhdr) + bhdr.table_list[am.IP_DISCOVERY].offset)
|
||||
assert ihdr.signature == am.DISCOVERY_TABLE_SIGNATURE and not ihdr.base_addr_64_bit, f"0x{ihdr.signature:X} != 0x{am.DISCOVERY_TABLE_SIGNATURE:X}"
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import ctypes, time
|
||||
import ctypes, time, contextlib
|
||||
from typing import Literal
|
||||
from tinygrad.runtime.autogen.am import am, smu_v13_0_0
|
||||
from tinygrad.helpers import to_mv, data64, lo32, hi32, DEBUG
|
||||
@@ -25,8 +25,8 @@ class AM_GMC(AM_IP):
|
||||
self.vm_base = self.adev.mm.va_allocator.base
|
||||
self.vm_end = self.vm_base + self.adev.mm.va_allocator.size - 1
|
||||
|
||||
self.memscratch_pm = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)
|
||||
self.dummy_page_pm = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)
|
||||
self.memscratch_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)
|
||||
self.dummy_page_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)
|
||||
self.hub_initted = {"MM": False, "GC": False}
|
||||
|
||||
def init(self): self.init_hub("MM")
|
||||
@@ -55,7 +55,7 @@ class AM_GMC(AM_IP):
|
||||
def enable_vm_addressing(self, page_table, ip:Literal["MM", "GC"], vmid):
|
||||
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_START_ADDR", "_LO32", "_HI32", self.vm_base >> 12)
|
||||
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_END_ADDR", "_LO32", "_HI32", self.vm_end >> 12)
|
||||
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_BASE_ADDR", "_LO32", "_HI32", page_table.pm.paddr | 1)
|
||||
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_BASE_ADDR", "_LO32", "_HI32", page_table.paddr | 1)
|
||||
self.adev.reg(f"reg{ip}VM_CONTEXT{vmid}_CNTL").write(0x1fffe00, enable_context=1, page_table_depth=(3 - page_table.lv))
|
||||
|
||||
def init_hub(self, ip:Literal["MM", "GC"]):
|
||||
@@ -66,8 +66,8 @@ class AM_GMC(AM_IP):
|
||||
|
||||
self.adev.reg(f"reg{ip}MC_VM_SYSTEM_APERTURE_LOW_ADDR").write(self.mc_base >> 18)
|
||||
self.adev.reg(f"reg{ip}MC_VM_SYSTEM_APERTURE_HIGH_ADDR").write(self.mc_end >> 18)
|
||||
self.adev.wreg_pair(f"reg{ip}MC_VM_SYSTEM_APERTURE_DEFAULT_ADDR", "_LSB", "_MSB", self.memscratch_pm.paddr >> 12)
|
||||
self.adev.wreg_pair(f"reg{ip}VM_L2_PROTECTION_FAULT_DEFAULT_ADDR", "_LO32", "_HI32", self.dummy_page_pm.paddr >> 12)
|
||||
self.adev.wreg_pair(f"reg{ip}MC_VM_SYSTEM_APERTURE_DEFAULT_ADDR", "_LSB", "_MSB", self.memscratch_paddr >> 12)
|
||||
self.adev.wreg_pair(f"reg{ip}VM_L2_PROTECTION_FAULT_DEFAULT_ADDR", "_LO32", "_HI32", self.dummy_page_paddr >> 12)
|
||||
|
||||
self.adev.reg(f"reg{ip}VM_L2_PROTECTION_FAULT_CNTL2").update(active_page_migration_pte_read_retry=1)
|
||||
|
||||
@@ -106,22 +106,26 @@ class AM_SMU(AM_IP):
|
||||
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetSoftMinByFreq, clck, poll=True)
|
||||
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetSoftMaxByFreq, clck, poll=True)
|
||||
|
||||
def is_smu_alive(self):
|
||||
with contextlib.suppress(RuntimeError): self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_GetSmuVersion, 0, timeout=100)
|
||||
return self.adev.mmMP1_SMN_C2PMSG_90.read() != 0
|
||||
|
||||
def mode1_reset(self):
|
||||
if DEBUG >= 2: print(f"am {self.adev.devfmt}: mode1 reset")
|
||||
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_Mode1Reset, 0, poll=True)
|
||||
time.sleep(0.5) # 500ms
|
||||
|
||||
def _smu_cmn_poll_stat(self): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1)
|
||||
def _smu_cmn_poll_stat(self, timeout=10000): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1, timeout=timeout)
|
||||
def _smu_cmn_send_msg(self, msg, param=0):
|
||||
self.adev.mmMP1_SMN_C2PMSG_90.write(0) # resp reg
|
||||
self.adev.mmMP1_SMN_C2PMSG_82.write(param)
|
||||
self.adev.mmMP1_SMN_C2PMSG_66.write(msg)
|
||||
|
||||
def _smu_cmn_send_smc_msg_with_param(self, msg, param, poll=True, read_back_arg=False):
|
||||
if poll: self._smu_cmn_poll_stat()
|
||||
def _smu_cmn_send_smc_msg_with_param(self, msg, param, poll=True, read_back_arg=False, timeout=10000): # 10s
|
||||
if poll: self._smu_cmn_poll_stat(timeout=timeout)
|
||||
|
||||
self._smu_cmn_send_msg(msg, param)
|
||||
self._smu_cmn_poll_stat()
|
||||
self._smu_cmn_poll_stat(timeout=timeout)
|
||||
return self.adev.rreg(self.adev.mmMP1_SMN_C2PMSG_82) if read_back_arg else None
|
||||
|
||||
class AM_GFX(AM_IP):
|
||||
@@ -232,27 +236,28 @@ class AM_GFX(AM_IP):
|
||||
class AM_IH(AM_IP):
|
||||
def __init__(self, adev):
|
||||
super().__init__(adev)
|
||||
self.rings = [(self.adev.mm.palloc(512 << 10, boot=True), self.adev.mm.palloc(0x1000, boot=True), "", 0),
|
||||
(self.adev.mm.palloc(512 << 10, boot=True), self.adev.mm.palloc(0x1000, boot=True), "_RING1", 1)]
|
||||
self.ring_size = 512 << 10
|
||||
self.rings = [(self.adev.mm.palloc(self.ring_size, boot=True), self.adev.mm.palloc(0x1000, boot=True), "", 0),
|
||||
(self.adev.mm.palloc(self.ring_size, boot=True), self.adev.mm.palloc(0x1000, boot=True), "_RING1", 1)]
|
||||
|
||||
def interrupt_handler(self):
|
||||
ring_vm, rwptr_vm, suf, _ = self.rings[0]
|
||||
wptr = to_mv(rwptr_vm.cpu_addr(), 8).cast('Q')[0]
|
||||
_, rwptr_vm, suf, _ = self.rings[0]
|
||||
wptr = to_mv(self.adev.paddr2cpu(rwptr_vm), 8).cast('Q')[0]
|
||||
|
||||
if self.adev.reg(f"regIH_RB_WPTR{suf}").read(rb_overflow=1):
|
||||
self.adev.reg(f"regIH_RB_WPTR{suf}").update(rb_overflow=0)
|
||||
self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=1)
|
||||
self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=0)
|
||||
self.adev.regIH_RB_RPTR.write(wptr % ring_vm.size)
|
||||
self.adev.regIH_RB_RPTR.write(wptr % self.ring_size)
|
||||
|
||||
def init(self):
|
||||
for ring_vm, rwptr_vm, suf, ring_id in self.rings:
|
||||
self.adev.wreg_pair("regIH_RB_BASE", suf, f"_HI{suf}", ring_vm.mc_addr() >> 8)
|
||||
self.adev.wreg_pair("regIH_RB_BASE", suf, f"_HI{suf}", self.adev.paddr2mc(ring_vm) >> 8)
|
||||
|
||||
self.adev.reg(f"regIH_RB_CNTL{suf}").write(mc_space=4, wptr_overflow_clear=1, rb_size=(ring_vm.size//4).bit_length(),
|
||||
self.adev.reg(f"regIH_RB_CNTL{suf}").write(mc_space=4, wptr_overflow_clear=1, rb_size=(self.ring_size//4).bit_length(),
|
||||
mc_snoop=1, mc_ro=0, mc_vmid=0, **({'wptr_overflow_enable': 1, 'rptr_rearm': 1} if ring_id == 0 else {'rb_full_drain_enable': 1}))
|
||||
|
||||
if ring_id == 0: self.adev.wreg_pair("regIH_RB_WPTR_ADDR", "_LO", "_HI", rwptr_vm.mc_addr())
|
||||
if ring_id == 0: self.adev.wreg_pair("regIH_RB_WPTR_ADDR", "_LO", "_HI", self.adev.paddr2mc(rwptr_vm))
|
||||
|
||||
self.adev.reg(f"regIH_RB_WPTR{suf}").write(0)
|
||||
self.adev.reg(f"regIH_RB_RPTR{suf}").write(0)
|
||||
@@ -303,10 +308,12 @@ class AM_PSP(AM_IP):
|
||||
def __init__(self, adev):
|
||||
super().__init__(adev)
|
||||
|
||||
self.msg1_pm = self.adev.mm.palloc(am.PSP_1_MEG, align=am.PSP_1_MEG, zero=not self.adev.partial_boot, boot=True)
|
||||
self.cmd_pm = self.adev.mm.palloc(am.PSP_CMD_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True)
|
||||
self.fence_pm = self.adev.mm.palloc(am.PSP_FENCE_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True)
|
||||
self.ring_pm = self.adev.mm.palloc(0x10000, zero=not self.adev.partial_boot, boot=True)
|
||||
self.msg1_paddr = self.adev.mm.palloc(am.PSP_1_MEG, align=am.PSP_1_MEG, zero=not self.adev.partial_boot, boot=True)
|
||||
self.cmd_paddr = self.adev.mm.palloc(am.PSP_CMD_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True)
|
||||
self.fence_paddr = self.adev.mm.palloc(am.PSP_FENCE_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True)
|
||||
|
||||
self.ring_size = 0x10000
|
||||
self.ring_paddr = self.adev.mm.palloc(self.ring_size, zero=not self.adev.partial_boot, boot=True)
|
||||
|
||||
def is_sos_alive(self): return self.adev.regMP0_SMN_C2PMSG_81.read() != 0x0
|
||||
def init(self):
|
||||
@@ -316,8 +323,9 @@ class AM_PSP(AM_IP):
|
||||
(am.PSP_FW_TYPE_PSP_INTF_DRV, am.PSP_BL__LOAD_INTFDRV), (am.PSP_FW_TYPE_PSP_DBG_DRV, am.PSP_BL__LOAD_DBGDRV),
|
||||
(am.PSP_FW_TYPE_PSP_RAS_DRV, am.PSP_BL__LOAD_RASDRV), (am.PSP_FW_TYPE_PSP_SOS, am.PSP_BL__LOAD_SOSDRV)]
|
||||
|
||||
for fw, compid in sos_components_load_order: self._bootloader_load_component(fw, compid)
|
||||
while not self.is_sos_alive(): time.sleep(0.01)
|
||||
if not self.is_sos_alive():
|
||||
for fw, compid in sos_components_load_order: self._bootloader_load_component(fw, compid)
|
||||
while not self.is_sos_alive(): time.sleep(0.01)
|
||||
|
||||
self._ring_create()
|
||||
self._tmr_init()
|
||||
@@ -332,8 +340,8 @@ class AM_PSP(AM_IP):
|
||||
def _wait_for_bootloader(self): self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_35, mask=0xFFFFFFFF, value=0x80000000)
|
||||
|
||||
def _prep_msg1(self, data):
|
||||
ctypes.memset(self.msg1_pm.cpu_addr(), 0, self.msg1_pm.size)
|
||||
self.msg1_pm.cpu_view()[:len(data)] = data
|
||||
ctypes.memset(cpu_addr:=self.adev.paddr2cpu(self.msg1_paddr), 0, am.PSP_1_MEG)
|
||||
to_mv(cpu_addr, len(data))[:] = data
|
||||
self.adev.gmc.flush_hdp()
|
||||
|
||||
def _bootloader_load_component(self, fw, compid):
|
||||
@@ -342,7 +350,7 @@ class AM_PSP(AM_IP):
|
||||
self._wait_for_bootloader()
|
||||
|
||||
self._prep_msg1(self.adev.fw.sos_fw[fw])
|
||||
self.adev.regMP0_SMN_C2PMSG_36.write(self.msg1_pm.mc_addr() >> 20)
|
||||
self.adev.regMP0_SMN_C2PMSG_36.write(self.adev.paddr2mc(self.msg1_paddr) >> 20)
|
||||
self.adev.regMP0_SMN_C2PMSG_35.write(compid)
|
||||
|
||||
return self._wait_for_bootloader()
|
||||
@@ -350,16 +358,22 @@ class AM_PSP(AM_IP):
|
||||
def _tmr_init(self):
|
||||
# Load TOC and calculate TMR size
|
||||
self._prep_msg1(fwm:=self.adev.fw.sos_fw[am.PSP_FW_TYPE_PSP_TOC])
|
||||
resp = self._load_toc_cmd(len(fwm))
|
||||
|
||||
self.tmr_pm = self.adev.mm.palloc(resp.resp.tmr_size, align=am.PSP_TMR_ALIGNMENT, boot=True)
|
||||
self.tmr_size = self._load_toc_cmd(len(fwm)).resp.tmr_size
|
||||
self.tmr_paddr = self.adev.mm.palloc(self.tmr_size, align=am.PSP_TMR_ALIGNMENT, boot=True)
|
||||
|
||||
def _ring_create(self):
|
||||
# If the ring is already created, destroy it
|
||||
if self.adev.regMP0_SMN_C2PMSG_71.read() != 0:
|
||||
self.adev.regMP0_SMN_C2PMSG_64.write(am.GFX_CTRL_CMD_ID_DESTROY_RINGS)
|
||||
|
||||
# There might be handshake issue with hardware which needs delay
|
||||
time.sleep(0.02)
|
||||
|
||||
# Wait until the sOS is ready
|
||||
self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_64, mask=0x80000000, value=0x80000000)
|
||||
|
||||
self.adev.wreg_pair("regMP0_SMN_C2PMSG", "_69", "_70", self.ring_pm.mc_addr())
|
||||
self.adev.regMP0_SMN_C2PMSG_71.write(self.ring_pm.size)
|
||||
self.adev.wreg_pair("regMP0_SMN_C2PMSG", "_69", "_70", self.adev.paddr2mc(self.ring_paddr))
|
||||
self.adev.regMP0_SMN_C2PMSG_71.write(self.ring_size)
|
||||
self.adev.regMP0_SMN_C2PMSG_64.write(am.PSP_RING_TYPE__KM << 16)
|
||||
|
||||
# There might be handshake issue with hardware which needs delay
|
||||
@@ -369,28 +383,28 @@ class AM_PSP(AM_IP):
|
||||
|
||||
def _ring_submit(self):
|
||||
prev_wptr = self.adev.regMP0_SMN_C2PMSG_67.read()
|
||||
ring_entry_addr = self.ring_pm.cpu_addr() + prev_wptr * 4
|
||||
ring_entry_addr = self.adev.paddr2cpu(self.ring_paddr) + prev_wptr * 4
|
||||
|
||||
ctypes.memset(ring_entry_addr, 0, ctypes.sizeof(am.struct_psp_gfx_rb_frame))
|
||||
write_loc = am.struct_psp_gfx_rb_frame.from_address(ring_entry_addr)
|
||||
write_loc.cmd_buf_addr_hi, write_loc.cmd_buf_addr_lo = data64(self.cmd_pm.mc_addr())
|
||||
write_loc.fence_addr_hi, write_loc.fence_addr_lo = data64(self.fence_pm.mc_addr())
|
||||
write_loc.cmd_buf_addr_hi, write_loc.cmd_buf_addr_lo = data64(self.adev.paddr2mc(self.cmd_paddr))
|
||||
write_loc.fence_addr_hi, write_loc.fence_addr_lo = data64(self.adev.paddr2mc(self.fence_paddr))
|
||||
write_loc.fence_value = prev_wptr
|
||||
|
||||
# Move the wptr
|
||||
self.adev.regMP0_SMN_C2PMSG_67.write(prev_wptr + ctypes.sizeof(am.struct_psp_gfx_rb_frame) // 4)
|
||||
|
||||
while self.fence_pm.cpu_view().cast('I')[0] != prev_wptr: pass
|
||||
while to_mv(self.adev.paddr2cpu(self.fence_paddr), 4).cast('I')[0] != prev_wptr: pass
|
||||
time.sleep(0.005)
|
||||
|
||||
resp = am.struct_psp_gfx_cmd_resp.from_address(self.cmd_pm.cpu_addr())
|
||||
resp = am.struct_psp_gfx_cmd_resp.from_address(self.adev.paddr2cpu(self.cmd_paddr))
|
||||
if resp.resp.status != 0: raise RuntimeError(f"PSP command failed {resp.cmd_id} {resp.resp.status}")
|
||||
|
||||
return resp
|
||||
|
||||
def _prep_ring_cmd(self, hdr):
|
||||
ctypes.memset(self.cmd_pm.cpu_addr(), 0, 0x1000)
|
||||
cmd = am.struct_psp_gfx_cmd_resp.from_address(self.cmd_pm.cpu_addr())
|
||||
ctypes.memset(self.adev.paddr2cpu(self.cmd_paddr), 0, 0x1000)
|
||||
cmd = am.struct_psp_gfx_cmd_resp.from_address(self.adev.paddr2cpu(self.cmd_paddr))
|
||||
cmd.cmd_id = hdr
|
||||
return cmd
|
||||
|
||||
@@ -400,22 +414,22 @@ class AM_PSP(AM_IP):
|
||||
|
||||
self._prep_msg1(fw_bytes)
|
||||
cmd = self._prep_ring_cmd(am.GFX_CMD_ID_LOAD_IP_FW)
|
||||
cmd.cmd.cmd_load_ip_fw.fw_phy_addr_hi, cmd.cmd.cmd_load_ip_fw.fw_phy_addr_lo = data64(self.msg1_pm.mc_addr())
|
||||
cmd.cmd.cmd_load_ip_fw.fw_phy_addr_hi, cmd.cmd.cmd_load_ip_fw.fw_phy_addr_lo = data64(self.adev.paddr2mc(self.msg1_paddr))
|
||||
cmd.cmd.cmd_load_ip_fw.fw_size = len(fw_bytes)
|
||||
cmd.cmd.cmd_load_ip_fw.fw_type = fw_type
|
||||
return self._ring_submit()
|
||||
|
||||
def _tmr_load_cmd(self):
|
||||
cmd = self._prep_ring_cmd(am.GFX_CMD_ID_SETUP_TMR)
|
||||
cmd.cmd.cmd_setup_tmr.buf_phy_addr_hi, cmd.cmd.cmd_setup_tmr.buf_phy_addr_lo = data64(self.tmr_pm.mc_addr())
|
||||
cmd.cmd.cmd_setup_tmr.system_phy_addr_hi, cmd.cmd.cmd_setup_tmr.system_phy_addr_lo = data64(self.tmr_pm.paddr)
|
||||
cmd.cmd.cmd_setup_tmr.buf_phy_addr_hi, cmd.cmd.cmd_setup_tmr.buf_phy_addr_lo = data64(self.adev.paddr2mc(self.tmr_paddr))
|
||||
cmd.cmd.cmd_setup_tmr.system_phy_addr_hi, cmd.cmd.cmd_setup_tmr.system_phy_addr_lo = data64(self.tmr_paddr)
|
||||
cmd.cmd.cmd_setup_tmr.bitfield.virt_phy_addr = 1
|
||||
cmd.cmd.cmd_setup_tmr.buf_size = self.tmr_pm.size
|
||||
cmd.cmd.cmd_setup_tmr.buf_size = self.tmr_size
|
||||
return self._ring_submit()
|
||||
|
||||
def _load_toc_cmd(self, toc_size):
|
||||
cmd = self._prep_ring_cmd(am.GFX_CMD_ID_LOAD_TOC)
|
||||
cmd.cmd.cmd_load_toc.toc_phy_addr_hi, cmd.cmd.cmd_load_toc.toc_phy_addr_lo = data64(self.msg1_pm.mc_addr())
|
||||
cmd.cmd.cmd_load_toc.toc_phy_addr_hi, cmd.cmd.cmd_load_toc.toc_phy_addr_lo = data64(self.adev.paddr2mc(self.msg1_paddr))
|
||||
cmd.cmd.cmd_load_toc.toc_size = toc_size
|
||||
return self._ring_submit()
|
||||
|
||||
|
||||
@@ -1856,8 +1856,8 @@ class Tensor(SimpleMathTrait):
|
||||
return self.std(axis, keepdim, correction), self.mean(axis, keepdim)
|
||||
|
||||
def _softmax(self, axis, dtype:Optional[DTypeLike]=None):
|
||||
x = self.cast(dtype) if dtype is not None else self
|
||||
m = x - x.max(axis=axis, keepdim=True).detach()
|
||||
m = self - self.max(axis=axis, keepdim=True).detach()
|
||||
if dtype is not None: m = m.cast(dtype)
|
||||
e = m.exp()
|
||||
return m, e, e.sum(axis=axis, keepdim=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user