diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 47cffd0090..be82930b62 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -298,7 +298,7 @@ jobs: - if: ${{ matrix.task == 'optimage' }} name: Test openpilot model kernel count and gate usage run: | - PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2104 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx + PYTHONPATH="." ALLOWED_KERNEL_COUNT=209 ALLOWED_READ_IMAGE=2105 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx - if: ${{ matrix.task == 'optimage' }} name: Test openpilot alt model correctness (float32) run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx diff --git a/docs/developer/am.md b/docs/developer/am.md index 9435bc9537..67699422fb 100644 --- a/docs/developer/am.md +++ b/docs/developer/am.md @@ -27,7 +27,7 @@ AM binds compute queues directly to MEC (bypassing MES). Tinygrad uses only one The GPU being passed can be in one of several states: 1. Not initialized -2. Initialized by AMDGPU +2. Initialized by amdgpu 3. Initialized by AM The first and second states require a full GPU setup since their states are unknown. The second state also requires a mode1 reset to reinitialize all components. @@ -36,4 +36,4 @@ The third state can be set up partially to optimize boot time. In this case, onl ### VM Management -Each AM device sets up only a single `VMID=0` and one page directory. The page directory used is 3-level and thus supports up to 512TB of virtual addresses. All AM devices are located in one virtual address space. \ No newline at end of file +Each AM device sets up only a single `VMID=0` and one page directory. The page directory used is 3-level and thus supports up to 512GB of virtual addresses. All AM devices are located in one virtual address space. \ No newline at end of file diff --git a/examples/benchmark_onnx.py b/examples/benchmark_onnx.py index 092a03cd13..498e626aa6 100644 --- a/examples/benchmark_onnx.py +++ b/examples/benchmark_onnx.py @@ -1,14 +1,12 @@ import sys, onnx, time from tinygrad import Tensor, TinyJit, Device, GlobalCounters, fetch from tinygrad.tensor import _from_np_dtype -from extra.onnx import get_run_onnx +from extra.onnx import OnnxRunner def load_onnx_model(fn): onnx_file = fetch(fn) onnx_model = onnx.load(onnx_file) - Tensor.no_grad = True - Tensor.training = False - run_onnx = get_run_onnx(onnx_model) + run_onnx = OnnxRunner(onnx_model) # find preinitted tensors and ignore them initted_tensors = {inp.name:None for inp in onnx_model.graph.initializer} diff --git a/examples/compile_tensorflow.py b/examples/compile_tensorflow.py index e3def3bbdb..7733934880 100644 --- a/examples/compile_tensorflow.py +++ b/examples/compile_tensorflow.py @@ -8,7 +8,7 @@ import numpy as np import subprocess import tensorflow as tf import tf2onnx -from extra.onnx import get_run_onnx +from extra.onnx import OnnxRunner from tinygrad.tensor import Tensor from extra.export_model import export_model_clang, compile_net, jit_model @@ -25,7 +25,7 @@ class TinyOnnx: def __init__(self, keras_model): input_signature = [tf.TensorSpec([1,32], tf.float32, name='x')] onnx_model, _ = tf2onnx.convert.from_keras(keras_model, input_signature, opset=13) - self.run_onnx = get_run_onnx(onnx_model) + self.run_onnx = OnnxRunner(onnx_model) def forward(self, x): return self.run_onnx({"x": x}, debug=False)['predictions'] diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index d7d9c2feb7..3bf5541b56 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -848,9 +848,9 @@ def train_bert(): model = get_mlperf_bert_model(init_ckpt if RUNMLPERF else None) - for _, x in get_state_dict(model).items(): - x.realize().to_(GPUS) parameters = get_parameters(model) + for p in parameters: + p.to_(GPUS) # ** Log run config ** for key, value in config.items(): print(f'HParam: "{key}": {value}') diff --git a/examples/openpilot/compile3.py b/examples/openpilot/compile3.py index f168a36b45..3a0f4c8628 100644 --- a/examples/openpilot/compile3.py +++ b/examples/openpilot/compile3.py @@ -12,17 +12,14 @@ from tinygrad.engine.realize import CompiledRunner import onnx from onnx.helper import tensor_dtype_to_np_dtype -from extra.onnx import get_run_onnx # TODO: port to main tinygrad +from extra.onnx import OnnxRunner # TODO: port to main tinygrad OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx" OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl" def compile(onnx_file): onnx_model = onnx.load(onnx_file) - Tensor.no_grad = True - Tensor.training = False - - run_onnx = get_run_onnx(onnx_model) + run_onnx = OnnxRunner(onnx_model) print("loaded model") input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input} diff --git a/examples/yolov8-onnx.py b/examples/yolov8-onnx.py index f75b5cb333..9c440b5afa 100644 --- a/examples/yolov8-onnx.py +++ b/examples/yolov8-onnx.py @@ -3,7 +3,7 @@ import os from ultralytics import YOLO import onnx from pathlib import Path -from extra.onnx import get_run_onnx +from extra.onnx import OnnxRunner from tinygrad.tensor import Tensor os.chdir("/tmp") @@ -14,5 +14,5 @@ onnx_model = onnx.load(open("yolov8n-seg.onnx", "rb")) # TODO: move get example inputs to onnx input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input} print(input_shapes) -run_onnx = get_run_onnx(onnx_model) +run_onnx = OnnxRunner(onnx_model) run_onnx({"images": Tensor.zeros(1,3,480,640)}, debug=True) diff --git a/extra/onnx.py b/extra/onnx.py index 539151fb6e..fbf8e69904 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -1,53 +1,42 @@ from typing import Callable, Any, Sequence -import importlib, functools -import numpy as np -from tinygrad import Tensor, dtypes +import importlib, functools, dataclasses +from tinygrad.tensor import Tensor from tinygrad.helpers import getenv, DEBUG, all_same -from tinygrad.dtype import DType, ConstType +from tinygrad.dtype import DType, ConstType, dtypes from tinygrad.device import is_dtype_supported -from onnx import AttributeProto, ModelProto, TensorProto, ValueInfoProto, helper -from google.protobuf.json_format import MessageToDict -cache_misses = 0 -@functools.lru_cache(None) -def _cached_to_python_const(t:Tensor): - if t.dtype is dtypes.uint8: return t.data().tobytes() - if 0 in t.shape: return [] - return t.tolist() +# ***** protobuf parsing ****** +from onnx import AttributeProto, ModelProto, TensorProto, TypeProto, helper +import numpy as np -# Tensor -> python value cache for parameters -def to_python_const(t) -> list[ConstType]|ConstType|bytes: - if not isinstance(t, Tensor): return t - global cache_misses - ret = _cached_to_python_const(t) - if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3: - print(f"Cache miss for {t}") - cache_misses = info.misses - return ret - -# TODO: use real float16 -# src: onnx/mapping.py -DTYPE_MAP: dict[int, DType] = { - TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8, - TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64, - TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32, - TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16, -} def dtype_parse(onnx_dtype: int) -> DType: - if onnx_dtype not in DTYPE_MAP: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported") - return DTYPE_MAP[onnx_dtype] if is_dtype_supported(DTYPE_MAP[onnx_dtype]) else dtypes.float + supported: dict[int, DType] = { + TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8, + TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64, + TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32, + TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16, + } + unsupported = { + TensorProto.UNDEFINED, TensorProto.STRING, TensorProto.COMPLEX64, TensorProto.COMPLEX128, TensorProto.FLOAT8E4M3FN, TensorProto.FLOAT8E4M3FNUZ, + TensorProto.FLOAT8E5M2, TensorProto.FLOAT8E5M2FNUZ, TensorProto.UINT4, TensorProto.INT4 + } + if onnx_dtype in unsupported: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported") + return supported[onnx_dtype] if is_dtype_supported(supported[onnx_dtype]) else dtypes.float -# src: onnx/onnx_ml_pb2.pyi -ATTRIBUTE_MAP: dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = { - AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i), - AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t), - AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints), - AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings) -} def attribute_parse(onnx_attribute: AttributeProto): - if onnx_attribute.type not in ATTRIBUTE_MAP: + supported: dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = { + AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i), + AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t), + AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints), + AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings) + } + unsupported = { + AttributeProto.UNDEFINED, AttributeProto.GRAPH, AttributeProto.SPARSE_TENSOR, AttributeProto.TYPE_PROTO, AttributeProto.TENSORS, + AttributeProto.GRAPHS, AttributeProto.SPARSE_TENSORS, AttributeProto.TYPE_PROTOS + } + if onnx_attribute.type in unsupported: raise NotImplementedError(f"attribute with type {AttributeProto.AttributeType.Name(onnx_attribute.type)} is not supported") - return ATTRIBUTE_MAP[onnx_attribute.type](onnx_attribute) + return supported[onnx_attribute.type](onnx_attribute) def buffer_parse(onnx_tensor: TensorProto) -> Tensor: if onnx_tensor.string_data: raise NotImplementedError("Parsing for buffer with string data is not implemented.") @@ -62,116 +51,137 @@ def buffer_parse(onnx_tensor: TensorProto) -> Tensor: return Tensor(np_buffer, dtype=dtype) return Tensor(None) -onnx_ops = importlib.import_module('extra.onnx_ops') -ONNXLIMIT = getenv("ONNXLIMIT", -1) -def get_run_onnx(onnx_model: ModelProto): - # model initialization data - model_tensors = {inp.name:buffer_parse(inp) for inp in onnx_model.graph.initializer} - model_expected_inputs = {inp.name:inp for inp in onnx_model.graph.input if inp.name not in model_tensors} - model_attributes = {num:{x.name:attribute_parse(x) for x in n.attribute} for num,n in enumerate(onnx_model.graph.node)} +def type_parse(onnx_type: TypeProto): + elem_type = onnx_type + if elem_type.HasField("map_type") or elem_type.HasField("sparse_tensor_type") or elem_type.HasField("opaque_type"): + raise NotImplementedError("parsing for map_type, sparse_tensor_type and opaque_type are not implemented") + if is_optional := elem_type.HasField("optional_type"): elem_type = elem_type.optional_type.elem_type + if is_sequence := elem_type.HasField("sequence_type"): elem_type = elem_type.sequence_type.elem_type + if elem_type.HasField("tensor_type"): + shape = tuple(d.dim_param or d.dim_value for d in elem_type.tensor_type.shape.dim) + dtype = dtype_parse(elem_type.tensor_type.elem_type) + return OnnxValue(shape, dtype, is_optional, is_sequence) + raise RuntimeError(f"TypeProto was not parsed properly: {onnx_type=}") - # model descriptions - # TODO: need a better way of controlling training vs non-training - is_onnx_preview_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in onnx_model.graph.node) - onnx_model_version = onnx_model.opset_import[0].version +# ***** onnx spec ***** +@dataclasses.dataclass(frozen=True) +class OnnxValue: + shape: tuple[str|int] + dtype: DType + is_optional: bool + is_sequence: bool - # used to check validity of user_input according to their dimension variables - variable_dims = {} +@dataclasses.dataclass(frozen=True) +class OnnxNode: + num: int + op: str + inputs: tuple[str] + outputs: tuple[str] + opts: dict[str, Any] - # mapping from onnx ops to tensor.py ops - tensor_methods = { - op:op.lower() for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan", - "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", - "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod") - } +# ***** python const ***** +required_input_python_consts: dict[str, tuple[int, ...]] = { + "Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,), + "CumSum": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,), + "ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4), + **{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")}, + **{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")} +} - # these values are expected to be python consts - required_input_python_consts: dict[str, tuple[int, ...]] = { - "Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,), - "CumSum": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,), - "ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4), - **{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")}, - **{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")} - } +cache_misses = 0 +@functools.lru_cache(None) +def _cached_to_python_const(t:Tensor): + if t.dtype is dtypes.uint8: return t.data().tobytes() + if 0 in t.shape: return [] + return t.tolist() - # src: https://onnx.ai/onnx/repo-docs/IR.html#input-output-data-types - # parses and validates inputs based on their shape and dtype specified by model - def prepare_input(user_input:Any, model_input:ValueInfoProto): - type_proto = model_input.type - if type_proto.HasField("optional_type"): - if user_input is None: return None - type_proto = type_proto.optional_type.elem_type - if type_proto.HasField("sequence_type"): - if not isinstance(user_input, Sequence): raise RuntimeError(f"{model_input.name} received {user_input}, expected sequence type") - dtype = dtype_parse(type_proto.sequence_type.elem_type.tensor_type.elem_type) - sequence = [Tensor(i, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(i, Tensor) else i for i in user_input] - if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"shapes for {model_input.name} must be homogeneous") - # TODO: need true float16 for dtype checking - # if not all(t.dtype is dtype for t in sequence): - # raise RuntimeError(f"{model_input.name} has dtype mismatch for sequence type. Expected {dtype}, received {tensor.dtype}.") +# Tensor -> python value cache for parameters +def to_python_const(t:Any, op:str, idx:int) -> list[ConstType]|ConstType|bytes: + if idx not in required_input_python_consts.get(op, ()) or not isinstance(t, Tensor): return t + global cache_misses + ret = _cached_to_python_const(t) + if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3: + print(f"Cache miss for {t}") + cache_misses = info.misses + return ret + +# ***** runner ****** +debug = int(getenv("DEBUGONNX", "0")) +limit = int(getenv("ONNXLIMIT", "-1")) +class OnnxRunner: + def __init__(self, model: ModelProto): + # parse model protobuf + self.is_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in model.graph.node) + self.old_training, self.old_no_grad = Tensor.training, Tensor.no_grad + Tensor.training = True if self.is_training else False + Tensor.no_grad = False if self.is_training else True + self.graph_values = {x.name:buffer_parse(x) for x in model.graph.initializer} + self.graph_inputs = {x.name:type_parse(x.type) for x in model.graph.input if x.name not in self.graph_values} + self.graph_outputs = {x.name:type_parse(x.type) for x in model.graph.output} + self.graph_nodes = tuple(OnnxNode(num, n.op_type, tuple(n.input), tuple(n.output), {x.name:attribute_parse(x) for x in n.attribute}) + for num,n in enumerate(model.graph.node)) + self.opset_version = model.opset_import[0].version + self.variable_dims: dict[str, int] = {} + + # TODO: move extra.onnx_ops here so we don't have to deal with annoying circular import + # TODO: clean up opset stuff after moving extra.onnx_ops here + self.onnx_ops_module = importlib.import_module('extra.onnx_ops') + self.onnx_ops = { + **{op: getattr(Tensor, op.lower()) for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", + "Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", + "Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")}, + } + + def _parse_input(self, name: str, value: Any, spec: OnnxValue): + if spec.is_optional and value is None: return None + # TODO: need true float16 for dtype checking + if spec.is_sequence: + if not isinstance(value, Sequence): raise RuntimeError(f"{name} received {value}, expected a sequence type") + sequence = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value] + if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"Shapes for {name} sequence must be homogeneous") return sequence - if type_proto.HasField("tensor_type"): - dtype = dtype_parse(type_proto.tensor_type.elem_type) - tensor = Tensor(user_input, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(user_input, Tensor) else user_input - # TODO: need true float16 for dtype checking - # if dtype is not tensor.dtype: raise RuntimeError(f"{model_input.name} has mismatch for dtype. Expected {dtype}, received {tensor.dtype}.") - for dim, onnx_dim in enumerate(type_proto.tensor_type.shape.dim): - dim_param, dim_value = onnx_dim.dim_param, onnx_dim.dim_value - user_dim_input = tensor.shape[dim] - if dim_param: dim_value = variable_dims[dim_param] if dim_param in variable_dims else variable_dims.setdefault(dim_param, user_dim_input) - if user_dim_input != dim_value: - raise RuntimeError(f"{model_input.name} has mismatch for dim={dim_param or dim}. Expected {dim_value}, received {user_dim_input}.") - return tensor - type_field_names = [field.name for field,_ in type_proto.ListFields()] - raise NotImplementedError(f"{model_input.name} with {type_field_names=} is not supported") + tensor = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(value, Tensor) else value + for dim, (onnx_dim, user_dim_input) in enumerate(zip(spec.shape, tensor.shape, strict=True)): + if isinstance(onnx_dim, str): + onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input)) + if user_dim_input != onnx_dim: raise RuntimeError(f"{name} has mismatch on {dim=}. Expected {onnx_dim}, received {user_dim_input}.") + return tensor - def run_onnx(inputs={}, debug=0): - debug = getenv("DEBUGONNX") or debug - if debug >= 3: print("Model initialization data:\n" + "\n".join(f"\t{i.name} - {model_tensors[i.name]}" for i in onnx_model.graph.initializer)) + def _dispatch_op(self, op, inps, opts): + if op in self.onnx_ops: return self.onnx_ops[op](*inps, **opts) + if hasattr(self.onnx_ops_module, op): + fxn = getattr(self.onnx_ops_module, op) + if isinstance(fxn, dict): + for k in sorted(fxn.keys()): + if k <= self.opset_version: + real_fxn = fxn[k] + else: real_fxn = fxn + return real_fxn(*inps, **opts) + raise NotImplementedError(f"{op=} not supported") - if debug >= 1: print("Model input:") - for name, value_info in model_expected_inputs.items(): + def __call__(self, inputs:dict[str, Any], debug=debug): + for name, input_spec in self.graph_inputs.items(): if name not in inputs: raise RuntimeError(f"Please provide input data for {name}") - model_tensors[name] = prepare_input(inputs[name], value_info) - if debug >= 1: print(f"\t{name} - {model_tensors[name]}") - if debug >= 2: print(f"\t\t{MessageToDict(value_info.type)}") + self.graph_values[name] = self._parse_input(name, inputs[name], input_spec) - for num,n in enumerate(onnx_model.graph.node): - inp_tensors = [model_tensors.get(x) for x in n.input] - required_consts = required_input_python_consts.get(n.op_type, ()) - inp = [to_python_const(t) if i in required_consts else t for i,t in enumerate(inp_tensors)] - opt = model_attributes[num] + for node in self.graph_nodes: + inps = [to_python_const(self.graph_values.get(name), node.op, i) for i,name in enumerate(node.inputs)] + opts = node.opts - if debug >= 1: print(f"{num}: op \"{n.op_type}\" input shapes {[x.shape if isinstance(x, Tensor) else x for x in inp_tensors]} opt {opt}") - if debug >= 3: - print("\tinputs:") - print("\n".join(f"\t\t{x} - {t!r}" + (" (to_python_const)" if i in required_consts else "") for i,(x,t) in enumerate(zip(n.input, inp)))) + # provide additional opts + if node.op == "Split" and 'num_outputs' not in opts: opts['num_outputs'] = len(node.outputs) + if node.op == "Gradient": opts['intermediate_tensors'] = self.graph_values - # provide additional arguments - if n.op_type == "Split" and 'num_outputs' not in opt: opt['num_outputs'] = len(n.output) - if n.op_type == "Gradient": opt['intermediate_tensors'] = model_tensors + if debug >= 1: print(f"{node.num}: op '{node.op}' opt {opts}") + if debug >= 2 and node.inputs: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {i!r}" for x,i in zip(node.inputs, inps))) + ret = self._dispatch_op(node.op, inps, opts) + ret = ret if isinstance(ret, tuple) else (ret,) + if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{x} - {o!r}" for x,o in zip(node.outputs, ret))) - # run op - if n.op_type in tensor_methods: ret = getattr(Tensor, tensor_methods[n.op_type])(*inp, **opt) - elif hasattr(onnx_ops, n.op_type): - fxn = getattr(onnx_ops, n.op_type) - if isinstance(fxn, dict): - for k in sorted(fxn.keys()): - if k <= onnx_model_version: - real_fxn = fxn[k] - else: - real_fxn = fxn - ret = real_fxn(*inp, **opt) - else: - print("UNSUPPORTED", n.op_type, n.input, n.output) - raise NotImplementedError(f"op_type {n.op_type} not supported") + self.graph_values.update(dict(zip(node.outputs, ret[:len(node.outputs)], strict=True))) - # finalization after running the op - if not isinstance(ret, tuple): ret = (ret, ) - if len(n.output) > len(ret): raise RuntimeError(f"expected output size must be less than {len(ret)}, it's {n.output}") - for i in range(len(n.output)): model_tensors[n.output[i]] = ret[i] - if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{n.output[i]} - {ret[i]}" for i in range(len(n.output)))) - - if num == ONNXLIMIT: return {name:model_tensors[name] for name in n.output} - return {x.name:model_tensors[x.name] for x in onnx_model.graph.output} - return run_onnx + if node.num == limit: + Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad + return {name:self.graph_values[name] for name in node.outputs} + Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad + return {name:self.graph_values[name] for name in self.graph_outputs} \ No newline at end of file diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 4f01b3cb03..4f745e680b 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -3,7 +3,7 @@ from typing import cast, Literal from tinygrad.tensor import Tensor, _broadcast_shape, ConstType, ReductionStr from tinygrad.dtype import ImageDType, dtypes from tinygrad.helpers import prod, flatten, make_tuple -from extra.onnx import dtype_parse, to_python_const +from extra.onnx import dtype_parse, _cached_to_python_const import numpy as np # **************** Free Ops **************** @@ -282,7 +282,7 @@ def Gather(x:Tensor, indices:Tensor, axis:int=0): x_sh = list(x.shape) ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:] if indices.ndim > 1: indices = indices.flatten() - indices = [to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in to_python_const(indices)] # type: ignore + indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)] args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot @@ -575,12 +575,9 @@ from tinygrad.nn.optim import SGD def onnx_training(input_group_size): def _decorator(func): def __wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs): - old_training = Tensor.training - Tensor.training = True R = R.detach() groups = len(inputs) // input_group_size ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))] - Tensor.training = old_training return tuple(flatten(zip(*ret))) return __wrapper return _decorator diff --git a/test/external/external_benchmark_openpilot.py b/test/external/external_benchmark_openpilot.py index 2811ec891c..780539b1e7 100644 --- a/test/external/external_benchmark_openpilot.py +++ b/test/external/external_benchmark_openpilot.py @@ -2,7 +2,7 @@ import time, sys, hashlib from pathlib import Path import onnx from onnx.helper import tensor_dtype_to_np_dtype -from extra.onnx import get_run_onnx +from extra.onnx import OnnxRunner from tinygrad import Tensor, dtypes, TinyJit from tinygrad.helpers import IMAGE, GlobalCounters, fetch, colored, getenv, trange from tinygrad.tensor import _from_np_dtype @@ -11,11 +11,8 @@ import numpy as np OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx" if __name__ == "__main__": - Tensor.no_grad = True - Tensor.training = False - onnx_model = onnx.load(onnx_path := fetch(OPENPILOT_MODEL)) - run_onnx = get_run_onnx(onnx_model) + run_onnx = OnnxRunner(onnx_model) Tensor.manual_seed(100) input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input} diff --git a/test/external/external_fuzz_am_interrupts.py b/test/external/external_fuzz_am_interrupts.py new file mode 100644 index 0000000000..2ed5724288 --- /dev/null +++ b/test/external/external_fuzz_am_interrupts.py @@ -0,0 +1,39 @@ +import subprocess +import random +import time +from concurrent.futures import ThreadPoolExecutor, as_completed + +def run_test(i, full_run=False): + print(f"\rRunning iteration {i}...", end=" ", flush=True) + + p = subprocess.Popen(['python3', 'test/test_tiny.py', 'TestTiny.test_plus'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + if not full_run: + time.sleep(random.uniform(0, 1200) / 1000) + p.kill() + _, stderr = p.communicate() + else: + _, stderr = p.communicate() + + if full_run: + stderr_text = stderr.decode() + print(stderr_text) + assert "Ran 1 test in" in stderr_text and "OK" in stderr_text + +max_workers = 4 +with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for i in range(1000000): + if i % 100 == 0: + for future in as_completed(futures): + try: future.result() + except Exception as e: + print(f"\nError in iteration: {e}") + futures = [] + + run_test(i, True) + else: + future = executor.submit(run_test, i, False) + futures.append(future) + + if len(futures) > max_workers * 2: futures = [f for f in futures if not f.done()] \ No newline at end of file diff --git a/test/external/external_model_benchmark.py b/test/external/external_model_benchmark.py index 3ed7b82746..4c0b720df2 100644 --- a/test/external/external_model_benchmark.py +++ b/test/external/external_model_benchmark.py @@ -6,7 +6,7 @@ import onnx from onnx.helper import tensor_dtype_to_np_dtype import onnxruntime as ort from onnx2torch import convert -from extra.onnx import get_run_onnx +from extra.onnx import OnnxRunner from tinygrad.helpers import OSX, DEBUG, fetch from tinygrad import Tensor, Device from tinygrad.device import CompileError @@ -65,7 +65,7 @@ def benchmark_model(m, devices, validate_outs=False): try: Device.DEFAULT = device inputs = {k:Tensor(inp) for k,inp in np_inputs.items()} - tinygrad_model = get_run_onnx(onnx_model) + tinygrad_model = OnnxRunner(onnx_model) benchmark(m, f"tinygrad_{device.lower()}_jitless", lambda: {k:v.numpy() for k,v in tinygrad_model(inputs).items()}) from tinygrad.engine.jit import TinyJit @@ -115,7 +115,7 @@ def benchmark_model(m, devices, validate_outs=False): rtol, atol = 2e-3, 2e-3 # tolerance for fp16 models Device.DEFAULT = device inputs = {k:Tensor(inp) for k,inp in np_inputs.items()} - tinygrad_model = get_run_onnx(onnx_model) + tinygrad_model = OnnxRunner(onnx_model) tinygrad_out = tinygrad_model(inputs) ort_sess = ort.InferenceSession(str(fn), ort_options, ["CPUExecutionProvider"]) diff --git a/test/external/external_test_am.py b/test/external/external_test_am.py index ad5159d4ef..985554623e 100644 --- a/test/external/external_test_am.py +++ b/test/external/external_test_am.py @@ -1,5 +1,6 @@ import unittest from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableTraverseContext +from tinygrad.helpers import mv_address class FakeGMC: def __init__(self): self.vm_base = 0x0 @@ -19,6 +20,8 @@ class FakeAM: self.gmc = FakeGMC() self.mm = AMMemoryManager(self, vram_size=4 << 30) self.is_booting = False + def paddr2cpu(self, paddr:int) -> int: return paddr + mv_address(self.vram) + def paddr2mc(self, paddr:int) -> int: return paddr # * PTE format: # * 63:59 reserved diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index 7e93a3984e..b9a61b40f1 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -10,7 +10,7 @@ from tinygrad.device import is_dtype_supported # pip3 install tabulate pytest_plugins = 'onnx.backend.test.report', -from extra.onnx import get_run_onnx +from extra.onnx import OnnxRunner class TinygradModel(BackendRep): def __init__(self, run_onnx, input_names): @@ -20,7 +20,7 @@ class TinygradModel(BackendRep): def run(self, inputs: Any, **kwargs: Any) -> Tuple[Any, ...]: real_inputs = dict(zip(self.input_names, inputs)) - ret = self.fxn(real_inputs, debug=True) + ret = self.fxn(real_inputs, debug=2) return tuple(x.numpy() if isinstance(x, Tensor) else [i.numpy() for i in x] if isinstance(x, list) else np.array(x) for x in ret.values()) class TinygradBackend(Backend): @@ -30,7 +30,7 @@ class TinygradBackend(Backend): input_initializer = [x.name for x in model.graph.initializer] net_feed_input = [x for x in input_all if x not in input_initializer] print("prepare", cls, device, net_feed_input) - run_onnx = get_run_onnx(model) + run_onnx = OnnxRunner(model) return TinygradModel(run_onnx, net_feed_input) @classmethod diff --git a/test/models/test_onnx.py b/test/models/test_onnx.py index 13252040e0..e3d1868aed 100644 --- a/test/models/test_onnx.py +++ b/test/models/test_onnx.py @@ -7,7 +7,7 @@ try: import onnx except ModuleNotFoundError: raise unittest.SkipTest("onnx not installed, skipping onnx test") -from extra.onnx import get_run_onnx +from extra.onnx import OnnxRunner from tinygrad.tensor import Tensor from tinygrad.helpers import CI, fetch, temp @@ -26,7 +26,7 @@ np.random.seed(1337) class TestOnnxModel(unittest.TestCase): def test_benchmark_openpilot_model(self): onnx_model = onnx.load(fetch(OPENPILOT_MODEL)) - run_onnx = get_run_onnx(onnx_model) + run_onnx = OnnxRunner(onnx_model) def get_inputs(): np_inputs = { "input_imgs": np.random.randn(*(1, 12, 128, 256)), @@ -70,7 +70,7 @@ class TestOnnxModel(unittest.TestCase): def test_openpilot_model(self): onnx_model = onnx.load(fetch(OPENPILOT_MODEL)) - run_onnx = get_run_onnx(onnx_model) + run_onnx = OnnxRunner(onnx_model) print("got run_onnx") inputs = { "input_imgs": np.random.randn(*(1, 12, 128, 256)), @@ -124,7 +124,7 @@ class TestOnnxModel(unittest.TestCase): onnx_model = onnx.load(fn) print("onnx loaded") from test.models.test_efficientnet import chicken_img, car_img, preprocess, _LABELS - run_onnx = get_run_onnx(onnx_model) + run_onnx = OnnxRunner(onnx_model) def run(img): inputs = {input_name: preprocess(img, new=input_new)} diff --git a/test/test_arange.py b/test/test_arange.py index a5c8b535bb..07512ae1b6 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -166,7 +166,7 @@ class TestIndexing(unittest.TestCase): GlobalCounters.reset() z = emb(x).realize() self.assertLessEqual(GlobalCounters.global_ops, op_limit) - self.assertEqual(GlobalCounters.kernel_count, 2) + self.assertEqual(GlobalCounters.kernel_count, 3) if getenv("CHECK", 1): import torch with torch.no_grad(): diff --git a/test/test_const_folding.py b/test/test_const_folding.py index 4ca2359912..dfffca8989 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -220,7 +220,9 @@ class TestMultiConstFolding(unittest.TestCase): t = Tensor.arange(16).float().realize().to(ds) # non const folding case creates one ast on each shard - _check_ast_count(4, t + 1) + # NOTE: there's extra contiguous kernels here since it's realizing both the CONTIGUOUS and its parent COPY + # why does multi call contiguous on a COPY? + _check_ast_count(7, t + 1) _check_ast_count(4, 1 + t) _check_ast_count(4, t * 2) _check_ast_count(4, 2 * t) diff --git a/test/test_image_dtype.py b/test/test_image_dtype.py index 03ce8dac44..62fcb4a443 100644 --- a/test/test_image_dtype.py +++ b/test/test_image_dtype.py @@ -113,6 +113,8 @@ class TestImageDType(unittest.TestCase): assert it.lazydata.base.realized._buf != b1 # issue caused by: don't realize image to image casts. this is part of a larger problem + #@unittest.expectedFailure + # update: passing after tensor_map def test_lil_model(self): with Context(IMAGE=2): x = Tensor.zeros(1, 1) @@ -121,7 +123,10 @@ class TestImageDType(unittest.TestCase): loss = x.image_dot(w1).image_dot(w2).float().max() loss.backward() sched = unwrap(w1.grad).schedule() - self.assertEqual(len(sched), 9) + # NOTE: the w1 grad must realize to a seperate kernel + assert w1.grad.lazydata.is_realized, f"never realized {w1.grad}" + self.assertEqual(w1.grad.lazydata.base.buffer.dtype, dtypes.float32) + self.assertEqual(len(sched), 10) for s,ei in zip(sched, lower_schedule(sched[:])): ei.run() if s.outputs[0].dtype == dtypes.float: diff --git a/test/test_jit.py b/test/test_jit.py index 382c83a52a..7abb13100f 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -318,6 +318,7 @@ class TestJit(unittest.TestCase): assert len(res3) == 10, "All values should be different, rand works in jit." assert res3 != res2, "Jit rand is diff with diff seeds" + @unittest.expectedFailure # requires contiguous folding def test_jit_random_after_unrealized_random(self): @TinyJit def f(): return Tensor.rand() diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 58bf3d4e13..da2995d37d 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -63,7 +63,11 @@ def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, d class TestLinearizer(unittest.TestCase): def test_arg_dedup(self): - a, b = Tensor.randn(4), Tensor.randn(4) + # NOTE: this realize exists because Tensor.numpy calls .contiguous() internally + # without contiguous folding, rand.to("CLANG") and rand.contiguous().to("CLANG") are different UOps. + # this test asserts they are the identical Buffer + # having different buffers is fine for correctness, because the outputs match. + a, b = Tensor.randn(4).realize(), Tensor.randn(4).realize() np_a, np_b = a.numpy(), b.numpy() c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),)))) lowered = list(lower_schedule(c.schedule())) @@ -1690,6 +1694,7 @@ class TestHandCodedOpts(unittest.TestCase): # should upcast the two Tensor.stacks assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2 + @unittest.expectedFailure # requires contiguous folding def test_masked_upcast_wino_full(self): with Context(WINO=1): x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize() diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 34a3480c0d..b34baced75 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -734,7 +734,7 @@ class TestMultiTensor(unittest.TestCase): zeros = Tensor.zeros(3).realize() b = a.to(devices_2)*zeros.to(devices_2) sched = b.schedule() - self.assertEqual(len(sched), 6) + self.assertEqual(len(sched), 8) # notably, only two copies (for the arange) - vs 4 copies if we didn't fold the const copy self.assertEqual(len([x for x in sched if any(u.op is Ops.COPY for u in x.ast.toposort)]), 2) # all these kernels are just because multi calls contiguous on every single shard diff --git a/test/test_schedule.py b/test/test_schedule.py index f4c9240e04..ccf423d0a5 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -16,7 +16,7 @@ from tinygrad.shape.view import View from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same from tinygrad.codegen.kernel import verify_ast -from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, ops_folding +from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule from extra.models.llama import precompute_freqs_cis @@ -67,7 +67,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs): np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2) @track_rewrites(named=True) -def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext()) +def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, ScheduleContext()) class TestSchedule(unittest.TestCase): def test_basic_binop_fusion(self): @@ -220,7 +220,7 @@ class TestSchedule(unittest.TestCase): GlobalCounters.reset() expr = (a*b)/b expr.realize() - self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertEqual(GlobalCounters.kernel_count, 0) # the scheduler can fold divs now! self.assertEqual(GlobalCounters.global_ops, 0) np.testing.assert_allclose(expr.numpy(), np.full((4,), 4.0)) @@ -229,7 +229,7 @@ class TestSchedule(unittest.TestCase): GlobalCounters.reset() expr = a/a expr.realize() - self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertEqual(GlobalCounters.kernel_count, 0) self.assertEqual(GlobalCounters.global_ops, 0) np.testing.assert_allclose(expr.numpy(), np.full((4,), 1.0)) @@ -972,6 +972,26 @@ class TestSchedule(unittest.TestCase): expected = (x_exp:=np.exp(x.numpy()-x.numpy().max(-1, keepdims=True)))/x_exp.sum(-1, keepdims=True) np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4) + @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") + def test_softmax_upcast(self): + # input half, softmax in float + Tensor.manual_seed(0) + x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.half).realize() + out = x.softmax(dtype=dtypes.float) + sched = out.schedule() + self.assertEqual(len(sched), 3) + self.assertEqual(len(sched[0].outputs), 1) + self.assertEqual(sched[0].outputs[0].dtype, dtypes.half) + + # input float, softmax in float + Tensor.manual_seed(0) + x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.float).realize() + out = x.softmax(dtype=dtypes.float) + sched = out.schedule() + self.assertEqual(len(sched), 3) + self.assertEqual(len(sched[0].outputs), 1) + self.assertEqual(sched[0].outputs[0].dtype, dtypes.float) + def test_softmax_backward(self): Tensor.manual_seed(0) x = Tensor.randn(4, 12, 64, 64, requires_grad=True).realize() @@ -1804,7 +1824,7 @@ def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops. # these pattern matchers should move to engine/schedule.py -sym = symbolic_simple+PatternMatcher([ +ops_folding = symbolic_simple+PatternMatcher([ (UPat(Ops.DETACH, name="x"), lambda x:x.src[0]), ]) @@ -1822,8 +1842,8 @@ def run_tensor_ast(r:Tensor): output = UOp.new_buffer(r.device, r.lazydata.size, r.dtype) glbl = UOp(Ops.DEFINE_GLOBAL, output.dtype.ptr(size=output.size), (), 0) sink = UOp(Ops.STORE, src=(glbl, ShapeTracker.from_shape(r.lazydata.base.shape).to_uop(), r.lazydata.base)).sink() - sink = graph_rewrite(sink, remove_movement_ops+sym+load_buffers+view_left, bufs:=[output]) - sink = graph_rewrite(sink, remove_movement_ops+sym+view_right) + sink = graph_rewrite(sink, remove_movement_ops+ops_folding+load_buffers+view_left, bufs:=[output]) + sink = graph_rewrite(sink, remove_movement_ops+ops_folding+view_right) si = ScheduleItem(sink, tuple(x.buffer for x in bufs), (), ()) run_schedule([si]) return output.realized.as_buffer().cast(output.dtype.fmt, r.shape).tolist() @@ -2184,7 +2204,7 @@ class TestConst(unittest.TestCase): sched = add.schedule() self.assertEqual(len(sched), 0) # b+0 and b share the same underlying device memory - self.assertIs(add.lazydata.realized, b.lazydata.realized) + self.assertIs(add.lazydata.buffer, b.lazydata.buffer) self.assertListEqual(add.tolist(), [2, 2, 2, 2]) def test_src_masked_const_folding(self): @@ -2238,6 +2258,17 @@ class TestCopyFolding(unittest.TestCase): add = schedule_graph_rewrite(add) assert all_same([x.device for x in add.src]), f"ALU has different devices! {[x.device for x in add.src]}" + def test_copy_to_same_device(self): + a = Tensor.empty(4).lazydata + b = a.copy_to_device(a.device) + check_schedule(b, 0, filter_sink=False) + b = schedule_graph_rewrite(b) + self.assertIs(b, a) + + def test_clone(self): + a = Tensor.empty(4).lazydata + check_schedule(a.clone(), 1, filter_sink=False) + class TestTensorUOpSpec(unittest.TestCase): def test_const_must_be_unmasked(self): a = Tensor.ones((4, 4)).pad((2, 2)) @@ -2253,6 +2284,12 @@ class TestTensorUOpSpec(unittest.TestCase): t = graph_rewrite(a.lazydata.sink(), remove_movement_ops+merge_views) create_schedule_with_vars(t) + def test_symbolic_shape_ok(self): + a = Tensor.ones(4) + vi = UOp.variable("i", 1, 10).bind(4) + t = graph_rewrite(a.reshape(vi).sum().lazydata, remove_movement_ops+merge_views) + create_schedule_with_vars(t) + class TestBufferUOp(unittest.TestCase): # BUFFER has a ShapeTracker of shape=(n,) and stride=(1,) def test_buffer_has_buffer(self): @@ -2316,34 +2353,80 @@ class TestBufferUOp(unittest.TestCase): class TestContiguous(unittest.TestCase): def test_contiguous_buffer(self): - a = Tensor.empty(4).lazydata - b = a.alu(Ops.CONTIGUOUS) - b = schedule_graph_rewrite(b) - self.assertIs(b, a) + a = Tensor.empty(4) + b = a.contiguous() + check_schedule(b, 0) def test_contiguous_buffer_view(self): - a = Tensor.empty(4).lazydata - b = a.reshape((2, 2)).alu(Ops.CONTIGUOUS) - b = schedule_graph_rewrite(b) - self.assertIs(b, a.buf_uop.view(unwrap(b.st))) + a = Tensor.empty(4) + b = a.reshape((2, 2)).contiguous() + check_schedule(b, 0) def test_non_contiguous_buffer_view(self): - a = Tensor.empty(4, 1).lazydata - b = a.expand((4, 4)).alu(Ops.CONTIGUOUS) - b = schedule_graph_rewrite(b) - assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {}) + a = Tensor.empty(4, 1) + b = a.expand((4, 4)).contiguous() + check_schedule(b, 1) def test_size_change_buffer_view(self): - a = Tensor.empty(4).lazydata - b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).alu(Ops.CONTIGUOUS) - b = schedule_graph_rewrite(b) - assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {}) + a = Tensor.empty(4) + b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).contiguous() + check_schedule(b, 1) def test_double_contiguous_realizes_once(self): - a = Tensor.empty(4, 1).lazydata - b = a.expand((4, 4)).alu(Ops.CONTIGUOUS).alu(Ops.CONTIGUOUS) - b = schedule_graph_rewrite(b) - assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {}) + a = Tensor.empty(4, 1) + b = a.expand((4, 4)).contiguous().contiguous() + check_schedule(b, 1) + + +class TestUOpBecome(unittest.TestCase): + # the simplest case, if we create a new BUFFER for this UOp + def test_new_buffer(self): + a = Tensor.empty(4, 4) + b = Tensor.empty(4, 4) + add = a+b + check_schedule(add, 1) + assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {}) + + def test_new_buffer_view(self): + a = Tensor.empty(4, 4) + b = Tensor.empty(4, 4) + add = (a+b).reshape(8, 2) + check_schedule(add, 1) + assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {}) + # VIEW is preserverd after the becomes rewrite. + self.assertEqual(add.lazydata.shape, (8, 2)) + assert add.lazydata is not add.lazydata.base + + def test_become_existing_buffer(self): + a = Tensor.empty(4, 4) + b = a*1 + assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul + check_schedule(b, 0) + assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER) + self.assertIs(a.lazydata.base.buffer, b.lazydata.base.buffer) + + def test_become_const_in_base(self): + a = Tensor.empty(4) + b = a*0 + assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul + check_schedule(b, 0) + assert UPat(Ops.CONST, arg=0).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER) + + def test_become_const_in_view(self): + # if we shrink the base down to a size 0, only the VIEW becomes CONST, base is unchanged. + add = Tensor.empty(2, 2)+Tensor.empty(2, 2) + b = add.shrink(((0, 1), (0, 0))) + check_schedule(b, 0) + assert UPat(Ops.CONST, arg=0).match(b.lazydata, {}) + self.assertEqual(b.shape, (1, 0)) + # the base is untouched. + assert UPat(Ops.ADD).match(add.lazydata, {}) + + def test_become_const_from_const(self): + const_add = Tensor(1)+Tensor(2) + assert UPat(Ops.ADD).match(const_add.lazydata, {}) + check_schedule(const_add, 0) + assert UPat(Ops.CONST, arg=3).match(const_add.lazydata.base, {}) if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/test/test_setitem.py b/test/test_setitem.py index f1bb595ef2..5c7c14fb57 100644 --- a/test/test_setitem.py +++ b/test/test_setitem.py @@ -69,7 +69,8 @@ class TestSetitem(unittest.TestCase): t[1] ^= 5 np.testing.assert_allclose(t.numpy(), [[0, 1], [7, 6]]) - @unittest.expectedFailure + #@unittest.expectedFailure + # update: passing after delete_forced_realize def test_setitem_consecutive_inplace_operator(self): t = Tensor.arange(4).reshape(2, 2).contiguous() t[1] += 2 diff --git a/test/test_subbuffer.py b/test/test_subbuffer.py index 8b6e2043f4..40fb7ad3a3 100644 --- a/test/test_subbuffer.py +++ b/test/test_subbuffer.py @@ -2,6 +2,7 @@ import unittest from tinygrad import Device, dtypes, Tensor from tinygrad.device import Buffer from tinygrad.ops import view_supported_devices +from tinygrad.helpers import Context @unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported") class TestSubBuffer(unittest.TestCase): @@ -47,5 +48,22 @@ class TestSubBuffer(unittest.TestCase): out = vt.to(f"{Device.DEFAULT}:1").realize().tolist() assert out == [2, 3, 4] + def test_subbuffer_deallocate(self): + with Context(LRU=0): + vbuf = self.buf.view(2, dtypes.uint8, offset=3).ensure_allocated() + self.buf.deallocate() + vbuf.deallocate() + + # Allocate a fake one on the same place + _ = Buffer(Device.DEFAULT, 10, dtypes.uint8).ensure_allocated() + + self.buf.ensure_allocated() + self.buf.copyin(memoryview(bytearray(range(10, 20)))) + + vbuf.ensure_allocated() + + tst = vbuf.as_buffer().tolist() + assert tst == [13, 14] + if __name__ == '__main__': unittest.main() diff --git a/test/test_uops.py b/test/test_uops.py index 99c75e68a3..b7392e46aa 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -14,6 +14,7 @@ from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_ker from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.rewriter import full_graph_rewrite, sym from tinygrad.device import is_dtype_supported +from tinygrad.codegen.kernel import Kernel, Opt, OptOps def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check) @@ -365,6 +366,17 @@ class TestAssembly(unittest.TestCase): self.assertIn(Ops.SHR, ops) self.assertIn(Ops.IDIV, ops) + def test_mulacc_unrolled(self): + # test that acc = acc + a0*b0 + a1*b1 + a2*b2 + a3*b3 + # is not acc = acc + (a0*b0 + a1*b1 + a2*b2 + a3*b3) + a = Tensor.empty(1024) + b = Tensor.empty(1024) + c = (a*b).sum() + k = Kernel(c.schedule()[-1].ast) + k.apply_opt(Opt(OptOps.UNROLL, 0, 4)) + uops = k.linearize().uops + self.assertEqual(len([x.op for x in uops if x.op is Ops.MULACC]), 4) + class TestUOpMethod(unittest.TestCase): @unittest.skip("uops lt no longer ordered") def test_compare_alu_same_src_different_arg(self): diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index 7078a994f3..a76c194076 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -164,6 +164,7 @@ class TestSafetensors(unittest.TestCase): def test_save_all_dtypes(self): for dtype in dtypes.fields().values(): if dtype in [dtypes.bfloat16]: continue # not supported in numpy + if dtype in [dtypes.double] and Device.DEFAULT == "METAL": continue # not supported on METAL path = temp(f"ones.{dtype}.safetensors") ones = Tensor(np.random.rand(10,10), dtype=dtype) safe_save(get_state_dict(ones), path) diff --git a/test/unit/test_gradient.py b/test/unit/test_gradient.py index b36b81f243..a9a41eace0 100644 --- a/test/unit/test_gradient.py +++ b/test/unit/test_gradient.py @@ -104,7 +104,8 @@ class TestRealizeMeansRealize(unittest.TestCase): x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize() self.assertEqual(x.lazydata.op, Ops.VIEW) - @unittest.expectedFailure + #@unittest.expectedFailure + # update: passing after delete_forced_realize def test_uniform_realizes(self): x = Tensor.uniform(16, 3, 3, 3, requires_grad=True).realize() print(x.lazydata) diff --git a/test/unit/test_graph_rewrite.py b/test/unit/test_graph_rewrite.py index 669d2e3319..86364e87ec 100644 --- a/test/unit/test_graph_rewrite.py +++ b/test/unit/test_graph_rewrite.py @@ -2,7 +2,7 @@ import unittest, math from tinygrad import dtypes from tinygrad.helpers import all_same from tinygrad.ops import GroupOp, UOp, Ops, exec_alu -from tinygrad.codegen.rewriter import full_graph_rewrite +from tinygrad.codegen.rewriter import full_graph_rewrite, mulacc_unrolled # Helper function to apply the graph rewrite def apply_rewrite(expr): @@ -274,5 +274,41 @@ class TestSubstitute(unittest.TestCase): ret = substitute(ret, {a.sin():a.sqrt(), n1.sin():n1.sqrt()}) self.assertIs(ret, a.sqrt().sqrt()) +class TestMulaccUnrolledAcc(unittest.TestCase): + def test_unrolled2(self): + acc_range = (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)) + acc = UOp(Ops.DEFINE_ACC, dtypes.int, (UOp.const(dtypes.int, 0),) + acc_range, (0,)) + a = UOp.variable('a', 0, 10) + b = UOp.variable('b', 0, 10) + expr = acc.assign(acc + (a*2 + b*3)) + expr_with_mulacc = graph_rewrite(expr, mulacc_unrolled) + self.assertIs(expr_with_mulacc, acc.assign(acc + a*2 + b*3)) + + def test_unrolled4_float(self): + acc_range = (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 3)) + acc = UOp(Ops.DEFINE_ACC, dtypes.float32, (UOp.const(dtypes.int, 0),)+acc_range, (0,)) + + a = [UOp.variable(f'a{i}', float("-inf"), float("inf"), dtype=dtypes.float32) for i in range(4)] + b = [UOp.variable(f'b{i}', float("-inf"), float("inf"), dtype=dtypes.float32) for i in range(4)] + + expr = acc.assign(acc + (a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3])) + expr_with_mulacc = graph_rewrite(expr, mulacc_unrolled) + + # Verify it unrolls into individual multiply-accumulate operations + expected = acc.assign(acc + a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3]) + self.assertIs(expr_with_mulacc, expected) + + def test_unrolled4_float_const(self): + acc_range = (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 3)) + acc = UOp(Ops.DEFINE_ACC, dtypes.float32, (UOp.const(dtypes.int, 0),)+acc_range, (0,)) + + a = [UOp.variable(f'a{i}', float("-inf"), float("inf"), dtype=dtypes.float32) for i in range(4)] + expr = acc.assign(acc + (a[0]*3.0 + a[1]*4.0 + a[2]*5.0 + a[3]*6.0)) + expr_with_mulacc = graph_rewrite(expr, mulacc_unrolled) + + # Verify it unrolls into individual multiply-accumulate operations + expected = acc.assign(acc + a[0]*3.0 + a[1]*4.0 + a[2]*5.0 + a[3]*6.0) + self.assertIs(expr_with_mulacc, expected) + if __name__ == '__main__': unittest.main() diff --git a/test/unit/test_tensor_uop_representation.py b/test/unit/test_tensor_uop_representation.py index dc8d0b64aa..b4d391c2af 100644 --- a/test/unit/test_tensor_uop_representation.py +++ b/test/unit/test_tensor_uop_representation.py @@ -53,7 +53,8 @@ class TestTensorUopRepresentation(unittest.TestCase): b = Tensor([4.,5,6]).realize() c = a+b print(c.lazydata) - is_pattern(c, UPat(Ops.ADD, src=(UPat(Ops.VIEW, src=(realized_pattern,)), UPat(Ops.VIEW, src=(realized_pattern,))))) + is_pattern(c, UPat(Ops.ADD, src=(realized_pattern, realized_pattern))) + #is_pattern(c, UPat(Ops.ADD, src=(UPat(Ops.VIEW, src=(realized_pattern,)), UPat(Ops.VIEW, src=(realized_pattern,))))) def test_const_pattern(self): a = Tensor(1) @@ -71,9 +72,9 @@ class TestTensorUopRepresentation(unittest.TestCase): def test_viewed_consts_do_not_realize(self): a = Tensor.ones(10, 10) print(a.lazydata) - pre_realize = a.lazydata a.realize() - assert a.lazydata is pre_realize + is_pattern(a, const_pattern) + self.assertEqual(a.lazydata.shape, (10, 10)) # currently, CONSTs have a "fake" BUFFER. this should be fixed # current: @@ -111,7 +112,8 @@ class TestTensorUopRepresentation(unittest.TestCase): c = a.to("TEST") # NOTE: this isn't checked print(c.lazydata) # TODO: COPY on a Tensor becomes a VIEW(COPY), this should be done in the scheduler not in ops - is_pattern(c, UPat(Ops.VIEW, src=(UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)),))) + is_pattern(c, UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,))) + #is_pattern(c, UPat(Ops.VIEW, src=(UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)),))) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/rewriter.py b/tinygrad/codegen/rewriter.py index 3d45187cc8..b61de63329 100644 --- a/tinygrad/codegen/rewriter.py +++ b/tinygrad/codegen/rewriter.py @@ -239,6 +239,9 @@ index_load = UPat.var("buf").index(rng_aug).load(name="ld") arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug)) arange_m = ((arange_augrng UOp: # expand sink = graph_rewrite(sink, sym+expander) - # devectorize + load_store_indexing - sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing) + # devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse + sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing+ + mulacc_unrolled) # final rules for the renderer (without sym) sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher) diff --git a/tinygrad/device.py b/tinygrad/device.py index d4782120c1..2a20992f80 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -4,7 +4,7 @@ from collections import defaultdict from typing import Optional, Any, Iterator, Generator import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE -from tinygrad.helpers import CI, OSX, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \ +from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \ cpu_time_execution from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes from tinygrad.renderer import Renderer @@ -129,7 +129,7 @@ class Buffer: if self._base is None and (self.options is None or self.options.external_ptr is None): if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes self.allocator.free(self._buf, self.nbytes, self.options) - del self._buf + del self._buf def __reduce__(self): buf = None if self._base is not None: @@ -202,7 +202,7 @@ class LRUAllocator(Allocator): for opaque in opaques: super().free(opaque, sz, options) opaques.clear() def free(self, opaque:Any, size:int, options:Optional[BufferSpec]=None): - if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque) + if LRU and (options is None or not options.nolru): self.cache[(size, options)].append(opaque) else: super().free(opaque, size, options) class _MallocAllocator(LRUAllocator): @@ -310,7 +310,7 @@ if PROFILE: for dev in devs: dev.synchronize() for dev in devs: dev._at_profile_finalize() - with open(temp("profile.pkl"), "wb") as f: pickle.dump(Compiled.profile_events, f) + with open(fn:=temp("profile.pkl", append_user=True), "wb") as f: pickle.dump(Compiled.profile_events, f) from tinygrad.ops import launch_viz - launch_viz("PROFILE", temp("profile.pkl")) + launch_viz("PROFILE", fn) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 2375408dd2..71cc12492c 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -2,7 +2,7 @@ import sys, atexit, functools, pickle from collections import defaultdict, deque from dataclasses import dataclass, field from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views -from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify +from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify, graph_rewrite_map from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, ContextVar from tinygrad.dtype import DType, ImageDType, dtypes @@ -31,9 +31,9 @@ tensor_uop_spec = PatternMatcher([ (UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True), (UPat(Ops.DEFINE_VAR, src=(UPat(Ops.VIEW, arg=ShapeTracker.from_shape(()))), arg=None), lambda: True), - # Tensor const has an unmasked ShapeTracker of stride 0 and a device + # Tensor const has a device and an unmasked ShapeTracker of stride 0 or a ShapeTracker with symbolic shape (UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)), - lambda st: len(st.st.views) == 1 and all(s == 0 for s in st.st.views[0].strides) and st.st.views[0].mask is None), + lambda st: st.st.views[0].mask is None and ((len(st.st.views) == 1 and all(s == 0 for s in st.st.views[0].strides)) or not all_int(st.shape))), # DETACH and CONTIGUOUS change how we interpret the source UOp # CONTIGUOUS ensures the source UOp realizes @@ -88,15 +88,15 @@ class ScheduleContext: # wrap tensor uops around a VIEW(BUFFER, ) # this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it. -def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp: +def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r # SINK is passthrough - if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, ctx, cache) for x in buf.src)) + if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src)) # skip creating buffers for CONST/BIND/DEVICE/BUFFER if buf.base.is_realized or buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf # VIEW is passthrough if buf is not buf.base: - cache[buf] = ret = add_buffers(buf.base, ctx, cache).view(unwrap(buf.st)) + cache[buf] = ret = add_buffers(buf.base, tensor_map, ctx, cache).view(unwrap(buf.st)) return ret # make things that can't be images not images dtype = buf.dtype @@ -105,11 +105,11 @@ def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp: dtype = buf.dtype.base # ASSIGN already has a target buffer, otherwise we create a new one buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype) - op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, ctx, cache) for x in buf.src)) + op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src)) # track the underlying tensor uop for this buffer - ctx.tensor_uops[buf_uop] = [buf] + ctx.tensor_uops[buf_uop] = tensor_map[buf] # (early) bufferize - cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st) + cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st) return ret # **** AST graph rewrite @@ -242,7 +242,7 @@ if CAPTURE_PROCESS_REPLAY: def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER def uval(u:UOp) -> UOp: assert is_scheduled(u), f"must be a scheduled op {u}" - return r.src[0] if (r:=u.src[1]).op is Ops.CONTIGUOUS and not (r.src[0].base.op is Ops.VIEW and len(r.src[0].base.src) == 2) else r + return u.src[1] def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp], realizes:dict[UOp, UOp], reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None: @@ -329,7 +329,7 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]: # maybe fuse arange with its children for rbuf in reduce_of_const: group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf} - if any(luop.forced_realize for tr in group for luop in ctx.tensor_uops[tr]): continue + if any(luop.op is Ops.CONTIGUOUS for tr in group for luop in ctx.tensor_uops[tr]): continue kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}} if len(kernel_children) == 0: continue for tr in group: del ctx.realizes[tr] @@ -340,10 +340,6 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]: # **** Schedule creation and BFS toposort -class UPatScheduled(UPat): - def __init__(self, *args, **kwargs): - super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs}))) - # ** this is schedule level const folding def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None: @@ -358,22 +354,18 @@ def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None: case _: return None return reduce.const_like(ret) -def found_contiguous(ctx:ScheduleContext, contig:UOp, base:UOp, b:UOp): - if contig.src[0].op is Ops.VIEW and len(contig.src[0].src): - old_base = contig.src[0].src[0] - if old_base.op is Ops.VIEW and (sti:=unwrap(contig.src[0].st).invert(old_base.shape)) is not None: ctx.contiguous[old_base] = base.view(sti) +def found_contiguous(ctx:ScheduleContext, contig:UOp, src:UOp): + if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx.contiguous[src.base] = contig.view(sti) def replace_contiguous(ctx:ScheduleContext, alu:UOp): new_src = list(alu.src) for i,s in enumerate(alu.src): if (replace_src:=ctx.contiguous.get(s, None)) is not None: new_src[i] = replace_src if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src)) -ops_folding = symbolic_simple+PatternMatcher([ - # op with size 0 is zero +sym = symbolic_simple+PatternMatcher([ + # UOp with size 0 is zero (UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \ and not (root.base.op is Ops.CONST and root.base.arg == 0) else None), - # if the uop folded to a CONST we can delete the BUFFER - (UPatScheduled(Ops.CONST, name="const"), lambda b,base,const: base.const_like(const.const_arg)), # DETACH is a NOOP here (UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]), # reduce of size 0 is the identity element @@ -386,13 +378,16 @@ ops_folding = symbolic_simple+PatternMatcher([ # no COPY to same device, except clone (arg is True) (UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"), lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None), + # remove cast to image when it's already a contiguous image + (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)), + lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None), # remove contiguous if we can just view the buffer (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)), lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None), # double contiguous is one contiguous (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.CONTIGUOUS),)), lambda root: root.src[0]), # support for using a contiguous permuted view instead of the parent view if one exists - (UPatScheduled(Ops.CONTIGUOUS, name="contig"), found_contiguous), + (UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous), (UPat(GroupOp.ALU, name="alu"), replace_contiguous), # remove CONST/BIND/BUFFER/VIEW from SINK (UPat(Ops.SINK, name="root"), @@ -400,36 +395,12 @@ ops_folding = symbolic_simple+PatternMatcher([ if (new_src:=tuple(x.base for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), ]) -# ** buffer merging - -def merge(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp) -> UOp: - assert v1.st is not None and v2.st is not None and v1.st == v2.st, f"implicit movementop {v1.st} {v2.st}" - # if b2 is realized also realize b1 - if b2 in ctx.realizes: - ctx.realizes[b1] = b1 - del ctx.realizes[b2] - # ops referring to b2 now ref to b1 - ctx.tensor_uops[b1] += ctx.tensor_uops[b2] - del ctx.tensor_uops[b2] - # merge - return v1 - -def merge_realized(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp): - # early become - for luop in ctx.tensor_uops.get(b1, [])+ctx.tensor_uops.get(b2, []): ctx.becomes_map[luop] = b1.view(unwrap(luop.st)) - return v1 - -merge_bufs = PatternMatcher([ - # merge base - (UPat(Ops.VIEW, name="v2", src=(UPat(Ops.BUFFER, name="b2"), UPat(Ops.VIEW, name="v1", src=(UPat(Ops.BUFFER, name="b1"), UPat())))), merge), - (UPat(Ops.VIEW, name="v2", src=(UPat(Ops.BUFFER, name="b2"), UPat(Ops.VIEW, name="v1", src=(UPat(Ops.BUFFER, name="b1"),)))), merge_realized), - # collapse view - (UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat())).view(name="mv"))), lambda mv:mv), - (UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).view(name="mv"))), lambda mv:mv), -]) - # ** this decides which ops get realized +class UPatScheduled(UPat): + def __init__(self, *args, **kwargs): + super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs}))) + def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None: @@ -448,8 +419,7 @@ def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) return x.view(unwrap(view.st)) def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp): - if not root.device.startswith("DISK"): return None - if x.op is not Ops.VIEW: x = x.src[-1] # TODO: remove this once forced_realize is gone + if not b.device.startswith("DISK"): return None buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize) return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW))) @@ -482,7 +452,7 @@ def load_realized(ctx:ScheduleContext, b:UOp, st:UOp): return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop())) def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp): - if (m:=ctx.tensor_uops[b][0].metadata) is not None: ctx.ops_metadata[x] = m + if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m if b not in ctx.realizes: return x # collapse BUFFER ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x) return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop())) @@ -523,28 +493,36 @@ remove_movement_ops = PatternMatcher([ @track_rewrites(named=True) def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: - # if using VIZ, do a graph rewrite to vizualize the Tensor graph - if getenv("VIZ"): graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext()) if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec) - # to_uop is removing (many) of the movement ops - sink = add_buffers(big_sink, ctx:=ScheduleContext(), cache={}) - # const folding and fusion - sink = graph_rewrite(sink, remove_movement_ops+ops_folding+do_realize, ctx) - sink = graph_rewrite(sink, merge_bufs, ctx) - # create the scheduler context - graph_rewrite(sink, create_ctx, ctx) + tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx:=ScheduleContext()) + rev_tensor_map: dict[UOp, list[UOp]] = {} + for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k) + # add BUFFER uops + sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx, cache={}) + # add realizes + sink = graph_rewrite(sink, do_realize+create_ctx, ctx) # group realizes into kernels store_groups = group_realizes(ctx) graph_rewrite(sink, break_sched, ctx) # preschedule realize groups prescheduled: list[ScheduleItem] = [] for store_uops in store_groups: - if len(stores:=[ctx.realizes[u] for u in store_uops if ctx.realizes[u].op is Ops.STORE]) == 0: continue - prescheduled.append(schedule_uop(UOp.sink(*stores), ctx)) + small_sink = UOp.sink(*[ctx.realizes[u] for u in store_uops]) + if not all(x.op is Ops.STORE for x in small_sink.src): raise RuntimeError(f"expected all realized BUFFERs to get a STORE {sink}") + prescheduled.append(schedule_uop(small_sink, ctx)) # can only schedule once for buf_uop in store_uops: for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st)) + # tensors can become an existing buffer or simplify to a const, no ScheduleItem needed + for k,v in tensor_map.items(): + # NOOP + if k.base is v.base: continue + # NOTE: only the base tensors get a BUFFER UOp + if v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st)) + # otherwise if it simplified to a CONST the UOp just becomes that CONST + elif v.op is Ops.CONST: ctx.becomes_map[k] = v + # add kernel children schedule_targets = {out:si for si in prescheduled for out in si.outputs} graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index d47f62d7d7..090b9178cc 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -78,7 +78,8 @@ def polyN(x:T, p:list[float]) -> T: return functools.reduce(lambda acc,c: acc*x+ def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)]) @functools.lru_cache(maxsize=None) def getenv(key:str, default=0): return type(default)(os.getenv(key, default)) -def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix() +def temp(x:str, append_user:bool=False) -> str: + return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{os.getenv('USERNAME', os.getlogin())}" if append_user else x)).as_posix() class Context(contextlib.ContextDecorator): def __init__(self, **kwargs): self.kwargs = kwargs @@ -107,7 +108,7 @@ WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1) FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0) SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1) -PICKLE_BUFFERS, PROFILE = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")) +PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1) @dataclass(frozen=True) class Metadata: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 29fe063540..0225603cc4 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -233,7 +233,6 @@ class UOpMetaClass(type): # some uops map to other stuff buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary() -forced_realize:weakref.WeakSet[UOp] = weakref.WeakSet() # NOTE: this should be frozen, but frozen is slower @dataclass(eq=False, slots=True) @@ -409,11 +408,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}") return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x)) - def contiguous(self): - if not unwrap(self.st).contiguous or self.size != self.base.size or self.base.op is Ops.CONST: - return self.alu(Ops.CONTIGUOUS) - forced_realize.add(self.base) - return self + def contiguous(self): return self.alu(Ops.CONTIGUOUS) # *** from LazyBuffer *** @@ -432,19 +427,22 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # otherwise it's just a VIEW(BUFFER) return UOp(Ops.VIEW, dtype, (UOp.new_buffer(device, (st:=ShapeTracker.from_shape(shape)).size, dtype),), st) def copy_to_device(self, device:str, clone:bool=False) -> UOp: - # no COPY - if self.device == device and not clone: return self # if it's a shrink, do the shrink before the copy with CONTIGUOUS if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device) # COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st) - return UOp(Ops.COPY, self.base.dtype, (UOp(Ops.DEVICE, arg=device), self.base), clone).view(unwrap(self.st)) + ret = UOp(Ops.COPY, self.base.dtype, (UOp(Ops.DEVICE, arg=device), self.base), clone) + op_arg = [] + mop = self + while mop is not self.base: + op_arg.append((mop.op, mop.arg)) + mop = mop.src[0] + for op,arg in reversed(op_arg): ret = UOp(op, ret.dtype, (ret,), arg) + return ret def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True) @property def lbs(self): return [self] @property def metadata(self): return all_metadata.get(self, None) - @property - def forced_realize(self): return self in forced_realize # *** uop movement ops *** @@ -822,10 +820,10 @@ if TRACK_MATCH_STATS: @atexit.register def print_match_stats(): if TRACK_MATCH_STATS >= 2: - with open(fn:=temp("rewrites.pkl"), "wb") as f: + with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f: print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}") with Context(PICKLE_BUFFERS=0): pickle.dump((tracked_keys, tracked_ctxs), f) - if getenv("VIZ"): launch_viz("VIZ", temp("rewrites.pkl")) + if getenv("VIZ"): launch_viz("VIZ", temp("rewrites.pkl", append_user=True)) if getenv("PRINT_MATCH_STATS", 1): ret = [0,0,0.0,0.0] for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]): diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index b87ff02666..37b43d9c7b 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -1,5 +1,5 @@ from __future__ import annotations -import ctypes, collections, time, dataclasses, pathlib, fcntl, os, signal +import ctypes, collections, time, dataclasses, pathlib, fcntl, os from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp from tinygrad.runtime.autogen.am import am, mp_11_0, mp_13_0_0, nbio_4_3_0, mmhub_3_0_0, gc_11_0_0, osssys_6_0_0 from tinygrad.runtime.support.allocator import TLSFAllocator @@ -98,20 +98,14 @@ class AMFirmware: def desc(self, typ:int, blob:memoryview, offset:int, size:int) -> tuple[int, memoryview]: return (typ, blob[offset:offset+size]) -class AMPhysicalMemoryBlock: - def __init__(self, adev:AMDev, paddr:int, size:int): self.adev, self.paddr, self.size = adev, paddr, size - def mc_addr(self): return self.adev.gmc.mc_base + self.paddr - def cpu_addr(self): return mv_address(self.adev.vram) + self.paddr - def cpu_view(self): return to_mv(self.cpu_addr(), self.size) - @dataclasses.dataclass(frozen=True) class AMMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702 class AMPageTableEntry: - def __init__(self, pm, lv): self.pm, self.view, self.lv = pm, pm.cpu_view().cast('Q'), lv + def __init__(self, adev, paddr, lv): self.paddr, self.view, self.lv = paddr, to_mv(adev.paddr2cpu(paddr), 0x1000).cast('Q'), lv def set_table(self, entry_id, pte:AMPageTableEntry, valid=True): - self.view[entry_id] = (pte.pm.paddr & 0x0000FFFFFFFFF000) | (am.AMDGPU_PTE_VALID if valid else 0) + self.view[entry_id] = (pte.paddr & 0x0000FFFFFFFFF000) | (am.AMDGPU_PTE_VALID if valid else 0) def set_page(self, entry_id, paddr, uncached=False, system=False, snooped=False, frag=0, valid=True): f = (am.AMDGPU_PTE_VALID if valid else 0) | am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE \ @@ -133,11 +127,11 @@ class AMPageTableTraverseContext: def level_down(self): pt, pte_idx, _ = self.pt_stack[-1] if (entry:=pt.get_entry(pte_idx)) & am.AMDGPU_PTE_VALID: - assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.pm.paddr:#x}, {pte_idx=} {entry=:#x}" - child_page_table = AMPageTableEntry(AMPhysicalMemoryBlock(pt.pm.adev, entry & 0x0000FFFFFFFFF000, 0x1000), lv=pt.lv+1) + assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}" + child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1) else: assert self.create_pts, "Not allowed to create new page table" - pt.set_table(pte_idx, child_page_table:=AMPageTableEntry(self.adev.mm.palloc(0x1000, zero=True), lv=pt.lv+1)) + pt.set_table(pte_idx, child_page_table:=AMPageTableEntry(self.adev, self.adev.mm.palloc(0x1000, zero=True), lv=pt.lv+1)) self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table))) return self.pt_stack[-1] @@ -145,7 +139,7 @@ class AMPageTableTraverseContext: def _try_free_pt(self) -> bool: pt, _, _ = self.pt_stack[-1] if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.get_entry(i) & am.AMDGPU_PTE_VALID == 0 for i in range(512)): - self.adev.mm.pfree(AMPhysicalMemoryBlock(self.adev, pt.pm.paddr, 0x1000)) + self.adev.mm.pfree(pt.paddr) parent_pt, parent_pte_idx, _ = self.pt_stack[-2] parent_pt.set_page(parent_pte_idx, 0x0, valid=False) return True @@ -179,7 +173,7 @@ class AMMemoryManager: self.adev, self.vram_size = adev, vram_size self.boot_allocator = TLSFAllocator(32 << 20, base=vram_size - (64 << 20)) # per device self.pa_allocator = TLSFAllocator(vram_size - (64 << 20)) # per device - self.root_page_table = AMPageTableEntry(self.palloc(0x1000, zero=True, boot=True), lv=am.AMDGPU_VM_PDB1) + self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=True, boot=True), lv=am.AMDGPU_VM_PDB1) def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False) -> AMMapping: assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}" @@ -213,12 +207,12 @@ class AMMemoryManager: # Alloc physical memory and map it to the virtual address va = self.alloc_vaddr(size, align) - if contigous: paddrs = [(self.palloc(size, zero=True).paddr, size)] + if contigous: paddrs = [(self.palloc(size, zero=True), size)] else: paddrs = [] try: ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, va, create_pts=True) - for _, _, _, seg_cnt, seg_size in ctx.next(size): paddrs += [(self.palloc(seg_size, zero=False).paddr, seg_size) for _ in range(seg_cnt)] + for _, _, _, seg_cnt, seg_size in ctx.next(size): paddrs += [(self.palloc(seg_size, zero=False), seg_size) for _ in range(seg_cnt)] except MemoryError: for paddr, _ in paddrs: self.pa_allocator.free(paddr) raise @@ -230,13 +224,13 @@ class AMMemoryManager: self.va_allocator.free(vm.va_addr) for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr) - def palloc(self, size, align=0x1000, zero=True, boot=False) -> AMPhysicalMemoryBlock: + def palloc(self, size:int, align:int=0x1000, zero=True, boot=False) -> int: assert self.adev.is_booting == boot, "During booting, only boot memory can be allocated" - pm = AMPhysicalMemoryBlock(self.adev, (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align), size) - if zero: ctypes.memset(pm.cpu_addr(), 0, pm.size) - return pm + paddr = (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align) + if zero: ctypes.memset(self.adev.paddr2cpu(paddr), 0, size) + return paddr - def pfree(self, pm:AMPhysicalMemoryBlock): self.pa_allocator.free(pm.paddr) + def pfree(self, paddr:int): self.pa_allocator.free(paddr) class AMDev: def __init__(self, pcidev, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview): @@ -285,13 +279,10 @@ class AMDev: self.partial_boot = False if not self.partial_boot: - try: # do not interrupt the boot process - signal.signal(signal.SIGINT, signal.SIG_IGN) - if self.psp.is_sos_alive(): self.smu.mode1_reset() - for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]: - ip.init() - if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized") - finally: signal.signal(signal.SIGINT, signal.default_int_handler) + if self.psp.is_sos_alive() and self.smu.is_smu_alive(): self.smu.mode1_reset() + for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]: + ip.init() + if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized") # Booting done self.is_booting = False @@ -309,6 +300,7 @@ class AMDev: for ip in [self.sdma, self.gfx]: ip.fini() def paddr2cpu(self, paddr:int) -> int: return mv_address(self.vram) + paddr + def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr def ip_base(self, ip:str, inst:int, seg:int) -> int: return self.regs_offset[am.__dict__[f"{ip}_HWIP"]][inst][seg] @@ -337,8 +329,8 @@ class AMDev: self.reg("regBIF_BX_PF0_RSMU_INDEX").write(reg) self.reg("regBIF_BX_PF0_RSMU_DATA").write(val) - def wait_reg(self, reg:AMRegister, value:int, mask=0xffffffff) -> int: - for _ in range(10000): + def wait_reg(self, reg:AMRegister, value:int, mask=0xffffffff, timeout=10000) -> int: + for _ in range(timeout): if ((rval:=reg.read()) & mask) == value: return rval time.sleep(0.001) raise RuntimeError(f'wait_reg timeout reg=0x{reg.reg_off:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}') @@ -348,9 +340,8 @@ class AMDev: # The table is located at the end of VRAM - 64KB and is 10KB in size. mmRCC_CONFIG_MEMSIZE = 0xde3 self.vram_size = self.rreg(mmRCC_CONFIG_MEMSIZE) << 20 - self.discovery_pm = AMPhysicalMemoryBlock(self, self.vram_size - (64 << 10), 10 << 10) - bhdr = am.struct_binary_header.from_address(self.discovery_pm.cpu_addr()) + bhdr = am.struct_binary_header.from_address(self.paddr2cpu(self.vram_size - (64 << 10))) ihdr = am.struct_ip_discovery_header.from_address(ctypes.addressof(bhdr) + bhdr.table_list[am.IP_DISCOVERY].offset) assert ihdr.signature == am.DISCOVERY_TABLE_SIGNATURE and not ihdr.base_addr_64_bit, f"0x{ihdr.signature:X} != 0x{am.DISCOVERY_TABLE_SIGNATURE:X}" diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 3c831511b5..2b08798825 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -1,4 +1,4 @@ -import ctypes, time +import ctypes, time, contextlib from typing import Literal from tinygrad.runtime.autogen.am import am, smu_v13_0_0 from tinygrad.helpers import to_mv, data64, lo32, hi32, DEBUG @@ -25,8 +25,8 @@ class AM_GMC(AM_IP): self.vm_base = self.adev.mm.va_allocator.base self.vm_end = self.vm_base + self.adev.mm.va_allocator.size - 1 - self.memscratch_pm = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True) - self.dummy_page_pm = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True) + self.memscratch_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True) + self.dummy_page_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True) self.hub_initted = {"MM": False, "GC": False} def init(self): self.init_hub("MM") @@ -55,7 +55,7 @@ class AM_GMC(AM_IP): def enable_vm_addressing(self, page_table, ip:Literal["MM", "GC"], vmid): self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_START_ADDR", "_LO32", "_HI32", self.vm_base >> 12) self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_END_ADDR", "_LO32", "_HI32", self.vm_end >> 12) - self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_BASE_ADDR", "_LO32", "_HI32", page_table.pm.paddr | 1) + self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_BASE_ADDR", "_LO32", "_HI32", page_table.paddr | 1) self.adev.reg(f"reg{ip}VM_CONTEXT{vmid}_CNTL").write(0x1fffe00, enable_context=1, page_table_depth=(3 - page_table.lv)) def init_hub(self, ip:Literal["MM", "GC"]): @@ -66,8 +66,8 @@ class AM_GMC(AM_IP): self.adev.reg(f"reg{ip}MC_VM_SYSTEM_APERTURE_LOW_ADDR").write(self.mc_base >> 18) self.adev.reg(f"reg{ip}MC_VM_SYSTEM_APERTURE_HIGH_ADDR").write(self.mc_end >> 18) - self.adev.wreg_pair(f"reg{ip}MC_VM_SYSTEM_APERTURE_DEFAULT_ADDR", "_LSB", "_MSB", self.memscratch_pm.paddr >> 12) - self.adev.wreg_pair(f"reg{ip}VM_L2_PROTECTION_FAULT_DEFAULT_ADDR", "_LO32", "_HI32", self.dummy_page_pm.paddr >> 12) + self.adev.wreg_pair(f"reg{ip}MC_VM_SYSTEM_APERTURE_DEFAULT_ADDR", "_LSB", "_MSB", self.memscratch_paddr >> 12) + self.adev.wreg_pair(f"reg{ip}VM_L2_PROTECTION_FAULT_DEFAULT_ADDR", "_LO32", "_HI32", self.dummy_page_paddr >> 12) self.adev.reg(f"reg{ip}VM_L2_PROTECTION_FAULT_CNTL2").update(active_page_migration_pte_read_retry=1) @@ -106,22 +106,26 @@ class AM_SMU(AM_IP): self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetSoftMinByFreq, clck, poll=True) self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetSoftMaxByFreq, clck, poll=True) + def is_smu_alive(self): + with contextlib.suppress(RuntimeError): self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_GetSmuVersion, 0, timeout=100) + return self.adev.mmMP1_SMN_C2PMSG_90.read() != 0 + def mode1_reset(self): if DEBUG >= 2: print(f"am {self.adev.devfmt}: mode1 reset") self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_Mode1Reset, 0, poll=True) time.sleep(0.5) # 500ms - def _smu_cmn_poll_stat(self): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1) + def _smu_cmn_poll_stat(self, timeout=10000): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1, timeout=timeout) def _smu_cmn_send_msg(self, msg, param=0): self.adev.mmMP1_SMN_C2PMSG_90.write(0) # resp reg self.adev.mmMP1_SMN_C2PMSG_82.write(param) self.adev.mmMP1_SMN_C2PMSG_66.write(msg) - def _smu_cmn_send_smc_msg_with_param(self, msg, param, poll=True, read_back_arg=False): - if poll: self._smu_cmn_poll_stat() + def _smu_cmn_send_smc_msg_with_param(self, msg, param, poll=True, read_back_arg=False, timeout=10000): # 10s + if poll: self._smu_cmn_poll_stat(timeout=timeout) self._smu_cmn_send_msg(msg, param) - self._smu_cmn_poll_stat() + self._smu_cmn_poll_stat(timeout=timeout) return self.adev.rreg(self.adev.mmMP1_SMN_C2PMSG_82) if read_back_arg else None class AM_GFX(AM_IP): @@ -232,27 +236,28 @@ class AM_GFX(AM_IP): class AM_IH(AM_IP): def __init__(self, adev): super().__init__(adev) - self.rings = [(self.adev.mm.palloc(512 << 10, boot=True), self.adev.mm.palloc(0x1000, boot=True), "", 0), - (self.adev.mm.palloc(512 << 10, boot=True), self.adev.mm.palloc(0x1000, boot=True), "_RING1", 1)] + self.ring_size = 512 << 10 + self.rings = [(self.adev.mm.palloc(self.ring_size, boot=True), self.adev.mm.palloc(0x1000, boot=True), "", 0), + (self.adev.mm.palloc(self.ring_size, boot=True), self.adev.mm.palloc(0x1000, boot=True), "_RING1", 1)] def interrupt_handler(self): - ring_vm, rwptr_vm, suf, _ = self.rings[0] - wptr = to_mv(rwptr_vm.cpu_addr(), 8).cast('Q')[0] + _, rwptr_vm, suf, _ = self.rings[0] + wptr = to_mv(self.adev.paddr2cpu(rwptr_vm), 8).cast('Q')[0] if self.adev.reg(f"regIH_RB_WPTR{suf}").read(rb_overflow=1): self.adev.reg(f"regIH_RB_WPTR{suf}").update(rb_overflow=0) self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=1) self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=0) - self.adev.regIH_RB_RPTR.write(wptr % ring_vm.size) + self.adev.regIH_RB_RPTR.write(wptr % self.ring_size) def init(self): for ring_vm, rwptr_vm, suf, ring_id in self.rings: - self.adev.wreg_pair("regIH_RB_BASE", suf, f"_HI{suf}", ring_vm.mc_addr() >> 8) + self.adev.wreg_pair("regIH_RB_BASE", suf, f"_HI{suf}", self.adev.paddr2mc(ring_vm) >> 8) - self.adev.reg(f"regIH_RB_CNTL{suf}").write(mc_space=4, wptr_overflow_clear=1, rb_size=(ring_vm.size//4).bit_length(), + self.adev.reg(f"regIH_RB_CNTL{suf}").write(mc_space=4, wptr_overflow_clear=1, rb_size=(self.ring_size//4).bit_length(), mc_snoop=1, mc_ro=0, mc_vmid=0, **({'wptr_overflow_enable': 1, 'rptr_rearm': 1} if ring_id == 0 else {'rb_full_drain_enable': 1})) - if ring_id == 0: self.adev.wreg_pair("regIH_RB_WPTR_ADDR", "_LO", "_HI", rwptr_vm.mc_addr()) + if ring_id == 0: self.adev.wreg_pair("regIH_RB_WPTR_ADDR", "_LO", "_HI", self.adev.paddr2mc(rwptr_vm)) self.adev.reg(f"regIH_RB_WPTR{suf}").write(0) self.adev.reg(f"regIH_RB_RPTR{suf}").write(0) @@ -303,10 +308,12 @@ class AM_PSP(AM_IP): def __init__(self, adev): super().__init__(adev) - self.msg1_pm = self.adev.mm.palloc(am.PSP_1_MEG, align=am.PSP_1_MEG, zero=not self.adev.partial_boot, boot=True) - self.cmd_pm = self.adev.mm.palloc(am.PSP_CMD_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True) - self.fence_pm = self.adev.mm.palloc(am.PSP_FENCE_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True) - self.ring_pm = self.adev.mm.palloc(0x10000, zero=not self.adev.partial_boot, boot=True) + self.msg1_paddr = self.adev.mm.palloc(am.PSP_1_MEG, align=am.PSP_1_MEG, zero=not self.adev.partial_boot, boot=True) + self.cmd_paddr = self.adev.mm.palloc(am.PSP_CMD_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True) + self.fence_paddr = self.adev.mm.palloc(am.PSP_FENCE_BUFFER_SIZE, zero=not self.adev.partial_boot, boot=True) + + self.ring_size = 0x10000 + self.ring_paddr = self.adev.mm.palloc(self.ring_size, zero=not self.adev.partial_boot, boot=True) def is_sos_alive(self): return self.adev.regMP0_SMN_C2PMSG_81.read() != 0x0 def init(self): @@ -316,8 +323,9 @@ class AM_PSP(AM_IP): (am.PSP_FW_TYPE_PSP_INTF_DRV, am.PSP_BL__LOAD_INTFDRV), (am.PSP_FW_TYPE_PSP_DBG_DRV, am.PSP_BL__LOAD_DBGDRV), (am.PSP_FW_TYPE_PSP_RAS_DRV, am.PSP_BL__LOAD_RASDRV), (am.PSP_FW_TYPE_PSP_SOS, am.PSP_BL__LOAD_SOSDRV)] - for fw, compid in sos_components_load_order: self._bootloader_load_component(fw, compid) - while not self.is_sos_alive(): time.sleep(0.01) + if not self.is_sos_alive(): + for fw, compid in sos_components_load_order: self._bootloader_load_component(fw, compid) + while not self.is_sos_alive(): time.sleep(0.01) self._ring_create() self._tmr_init() @@ -332,8 +340,8 @@ class AM_PSP(AM_IP): def _wait_for_bootloader(self): self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_35, mask=0xFFFFFFFF, value=0x80000000) def _prep_msg1(self, data): - ctypes.memset(self.msg1_pm.cpu_addr(), 0, self.msg1_pm.size) - self.msg1_pm.cpu_view()[:len(data)] = data + ctypes.memset(cpu_addr:=self.adev.paddr2cpu(self.msg1_paddr), 0, am.PSP_1_MEG) + to_mv(cpu_addr, len(data))[:] = data self.adev.gmc.flush_hdp() def _bootloader_load_component(self, fw, compid): @@ -342,7 +350,7 @@ class AM_PSP(AM_IP): self._wait_for_bootloader() self._prep_msg1(self.adev.fw.sos_fw[fw]) - self.adev.regMP0_SMN_C2PMSG_36.write(self.msg1_pm.mc_addr() >> 20) + self.adev.regMP0_SMN_C2PMSG_36.write(self.adev.paddr2mc(self.msg1_paddr) >> 20) self.adev.regMP0_SMN_C2PMSG_35.write(compid) return self._wait_for_bootloader() @@ -350,16 +358,22 @@ class AM_PSP(AM_IP): def _tmr_init(self): # Load TOC and calculate TMR size self._prep_msg1(fwm:=self.adev.fw.sos_fw[am.PSP_FW_TYPE_PSP_TOC]) - resp = self._load_toc_cmd(len(fwm)) - - self.tmr_pm = self.adev.mm.palloc(resp.resp.tmr_size, align=am.PSP_TMR_ALIGNMENT, boot=True) + self.tmr_size = self._load_toc_cmd(len(fwm)).resp.tmr_size + self.tmr_paddr = self.adev.mm.palloc(self.tmr_size, align=am.PSP_TMR_ALIGNMENT, boot=True) def _ring_create(self): + # If the ring is already created, destroy it + if self.adev.regMP0_SMN_C2PMSG_71.read() != 0: + self.adev.regMP0_SMN_C2PMSG_64.write(am.GFX_CTRL_CMD_ID_DESTROY_RINGS) + + # There might be handshake issue with hardware which needs delay + time.sleep(0.02) + # Wait until the sOS is ready self.adev.wait_reg(self.adev.regMP0_SMN_C2PMSG_64, mask=0x80000000, value=0x80000000) - self.adev.wreg_pair("regMP0_SMN_C2PMSG", "_69", "_70", self.ring_pm.mc_addr()) - self.adev.regMP0_SMN_C2PMSG_71.write(self.ring_pm.size) + self.adev.wreg_pair("regMP0_SMN_C2PMSG", "_69", "_70", self.adev.paddr2mc(self.ring_paddr)) + self.adev.regMP0_SMN_C2PMSG_71.write(self.ring_size) self.adev.regMP0_SMN_C2PMSG_64.write(am.PSP_RING_TYPE__KM << 16) # There might be handshake issue with hardware which needs delay @@ -369,28 +383,28 @@ class AM_PSP(AM_IP): def _ring_submit(self): prev_wptr = self.adev.regMP0_SMN_C2PMSG_67.read() - ring_entry_addr = self.ring_pm.cpu_addr() + prev_wptr * 4 + ring_entry_addr = self.adev.paddr2cpu(self.ring_paddr) + prev_wptr * 4 ctypes.memset(ring_entry_addr, 0, ctypes.sizeof(am.struct_psp_gfx_rb_frame)) write_loc = am.struct_psp_gfx_rb_frame.from_address(ring_entry_addr) - write_loc.cmd_buf_addr_hi, write_loc.cmd_buf_addr_lo = data64(self.cmd_pm.mc_addr()) - write_loc.fence_addr_hi, write_loc.fence_addr_lo = data64(self.fence_pm.mc_addr()) + write_loc.cmd_buf_addr_hi, write_loc.cmd_buf_addr_lo = data64(self.adev.paddr2mc(self.cmd_paddr)) + write_loc.fence_addr_hi, write_loc.fence_addr_lo = data64(self.adev.paddr2mc(self.fence_paddr)) write_loc.fence_value = prev_wptr # Move the wptr self.adev.regMP0_SMN_C2PMSG_67.write(prev_wptr + ctypes.sizeof(am.struct_psp_gfx_rb_frame) // 4) - while self.fence_pm.cpu_view().cast('I')[0] != prev_wptr: pass + while to_mv(self.adev.paddr2cpu(self.fence_paddr), 4).cast('I')[0] != prev_wptr: pass time.sleep(0.005) - resp = am.struct_psp_gfx_cmd_resp.from_address(self.cmd_pm.cpu_addr()) + resp = am.struct_psp_gfx_cmd_resp.from_address(self.adev.paddr2cpu(self.cmd_paddr)) if resp.resp.status != 0: raise RuntimeError(f"PSP command failed {resp.cmd_id} {resp.resp.status}") return resp def _prep_ring_cmd(self, hdr): - ctypes.memset(self.cmd_pm.cpu_addr(), 0, 0x1000) - cmd = am.struct_psp_gfx_cmd_resp.from_address(self.cmd_pm.cpu_addr()) + ctypes.memset(self.adev.paddr2cpu(self.cmd_paddr), 0, 0x1000) + cmd = am.struct_psp_gfx_cmd_resp.from_address(self.adev.paddr2cpu(self.cmd_paddr)) cmd.cmd_id = hdr return cmd @@ -400,22 +414,22 @@ class AM_PSP(AM_IP): self._prep_msg1(fw_bytes) cmd = self._prep_ring_cmd(am.GFX_CMD_ID_LOAD_IP_FW) - cmd.cmd.cmd_load_ip_fw.fw_phy_addr_hi, cmd.cmd.cmd_load_ip_fw.fw_phy_addr_lo = data64(self.msg1_pm.mc_addr()) + cmd.cmd.cmd_load_ip_fw.fw_phy_addr_hi, cmd.cmd.cmd_load_ip_fw.fw_phy_addr_lo = data64(self.adev.paddr2mc(self.msg1_paddr)) cmd.cmd.cmd_load_ip_fw.fw_size = len(fw_bytes) cmd.cmd.cmd_load_ip_fw.fw_type = fw_type return self._ring_submit() def _tmr_load_cmd(self): cmd = self._prep_ring_cmd(am.GFX_CMD_ID_SETUP_TMR) - cmd.cmd.cmd_setup_tmr.buf_phy_addr_hi, cmd.cmd.cmd_setup_tmr.buf_phy_addr_lo = data64(self.tmr_pm.mc_addr()) - cmd.cmd.cmd_setup_tmr.system_phy_addr_hi, cmd.cmd.cmd_setup_tmr.system_phy_addr_lo = data64(self.tmr_pm.paddr) + cmd.cmd.cmd_setup_tmr.buf_phy_addr_hi, cmd.cmd.cmd_setup_tmr.buf_phy_addr_lo = data64(self.adev.paddr2mc(self.tmr_paddr)) + cmd.cmd.cmd_setup_tmr.system_phy_addr_hi, cmd.cmd.cmd_setup_tmr.system_phy_addr_lo = data64(self.tmr_paddr) cmd.cmd.cmd_setup_tmr.bitfield.virt_phy_addr = 1 - cmd.cmd.cmd_setup_tmr.buf_size = self.tmr_pm.size + cmd.cmd.cmd_setup_tmr.buf_size = self.tmr_size return self._ring_submit() def _load_toc_cmd(self, toc_size): cmd = self._prep_ring_cmd(am.GFX_CMD_ID_LOAD_TOC) - cmd.cmd.cmd_load_toc.toc_phy_addr_hi, cmd.cmd.cmd_load_toc.toc_phy_addr_lo = data64(self.msg1_pm.mc_addr()) + cmd.cmd.cmd_load_toc.toc_phy_addr_hi, cmd.cmd.cmd_load_toc.toc_phy_addr_lo = data64(self.adev.paddr2mc(self.msg1_paddr)) cmd.cmd.cmd_load_toc.toc_size = toc_size return self._ring_submit() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 5fbdbfd018..45a9b23749 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1856,8 +1856,8 @@ class Tensor(SimpleMathTrait): return self.std(axis, keepdim, correction), self.mean(axis, keepdim) def _softmax(self, axis, dtype:Optional[DTypeLike]=None): - x = self.cast(dtype) if dtype is not None else self - m = x - x.max(axis=axis, keepdim=True).detach() + m = self - self.max(axis=axis, keepdim=True).detach() + if dtype is not None: m = m.cast(dtype) e = m.exp() return m, e, e.sum(axis=axis, keepdim=True)