mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
fix: revive torch backend (#13280)
* fix: revive torch backend
* as_strided view vs copy
* Revert "as_strided view vs copy"
This reverts commit 82a61223f2.
* add extra tests (move inplace, add fusion tests)
* better fusion with inplace_op
* no optimizer hooks (break mnist training fusion)
* split off fusion tests in separate file, assert on resnet fusion
fix: remove comments
* cleanup, reduce diff
* reduce diff
* better fusion and identity checks
---------
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
118
.github/workflows/test.yml
vendored
118
.github/workflows/test.yml
vendored
@@ -86,65 +86,67 @@ jobs:
|
||||
clang -O2 recognize.c -lm -o recognize
|
||||
cat test/models/efficientnet/Chicken.jpg | ./recognize | grep cock
|
||||
|
||||
# TODO: fix the torch backend and reenable
|
||||
# torchbackend:
|
||||
# name: Torch Backend Tests
|
||||
# runs-on: ubuntu-latest
|
||||
# timeout-minutes: 15
|
||||
# steps:
|
||||
# - name: Checkout Code
|
||||
# uses: actions/checkout@v4
|
||||
# - name: Setup Environment
|
||||
# uses: ./.github/actions/setup-tinygrad
|
||||
# with:
|
||||
# key: torch-backend-pillow-torchvision-et-pt
|
||||
# deps: testing_minimal
|
||||
# pydeps: "pillow torchvision expecttest"
|
||||
# llvm: 'true'
|
||||
# - name: Install ninja
|
||||
# run: |
|
||||
# sudo apt update || true
|
||||
# sudo apt install -y --no-install-recommends ninja-build
|
||||
# - name: Lint with ruff
|
||||
# run: |
|
||||
# pip3 install --upgrade --force-reinstall ruff==0.11.0
|
||||
# python3 -m ruff check extra/torch_backend/backend.py
|
||||
# - name: Test one op
|
||||
# run: FORWARD_ONLY=1 TINY_BACKEND=1 python3 test/test_ops.py TestOps.test_add
|
||||
# - name: Test ResNet-18
|
||||
# run: DEBUG=2 python3 extra/torch_backend/example.py
|
||||
# - name: My (custom) tests
|
||||
# run: python3 extra/torch_backend/test.py
|
||||
# - name: Test one op in torch tests
|
||||
# run: DEBUG=2 python3 extra/torch_backend/torch_tests.py TestTinyBackendPRIVATEUSE1.test_unary_log_tiny_float32
|
||||
# - name: Test Ops with TINY_BACKEND
|
||||
# run: CPU=1 CPU_LLVM=1 LLVMOPT=0 TINY_BACKEND=1 python3 -m pytest -n auto test/test_ops.py --durations=20
|
||||
# - name: Test in-place operations on views
|
||||
# run: TORCH_DEBUG=1 python3 extra/torch_backend/test_inplace.py
|
||||
# - name: Test multi-gpu
|
||||
# run: CPU=1 CPU_LLVM=1 GPUS=4 TORCH_DEBUG=1 python3 extra/torch_backend/test_multigpu.py
|
||||
torchbackend:
|
||||
name: Torch Backend Tests
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: torch-backend-pillow-torchvision-et-pt
|
||||
deps: testing_minimal
|
||||
pydeps: "pillow torchvision expecttest"
|
||||
llvm: 'true'
|
||||
- name: Install ninja
|
||||
run: |
|
||||
sudo apt update || true
|
||||
sudo apt install -y --no-install-recommends ninja-build
|
||||
- name: Lint with ruff
|
||||
run: |
|
||||
pip3 install --upgrade --force-reinstall ruff==0.11.0
|
||||
python3 -m ruff check extra/torch_backend/backend.py
|
||||
- name: Test one op
|
||||
run: FORWARD_ONLY=1 TINY_BACKEND=1 python3 test/test_ops.py TestOps.test_add
|
||||
- name: Test ResNet-18
|
||||
run: DEBUG=2 python3 extra/torch_backend/example.py
|
||||
- name: My (custom) tests
|
||||
run: python3 extra/torch_backend/test.py
|
||||
- name: Test one op in torch tests
|
||||
run: DEBUG=2 python3 extra/torch_backend/torch_tests.py TestTinyBackendPRIVATEUSE1.test_unary_log_tiny_float32
|
||||
- name: Test Ops with TINY_BACKEND
|
||||
run: CPU=1 CPU_LLVM=1 LLVMOPT=0 TINY_BACKEND=1 python3 -m pytest -n auto test/test_ops.py --durations=20
|
||||
- name: Test in-place operations on views
|
||||
run: TORCH_DEBUG=1 python3 extra/torch_backend/test_inplace.py
|
||||
- name: Test multi-gpu
|
||||
run: CPU=1 CPU_LLVM=1 GPUS=4 TORCH_DEBUG=1 python3 extra/torch_backend/test_multigpu.py
|
||||
- name: Test kernel fusion
|
||||
run: python3 extra/torch_backend/test_kernel_fusion.py
|
||||
|
||||
# torchbackendmore:
|
||||
# name: Torch Backend Tests More
|
||||
# runs-on: ubuntu-latest
|
||||
# timeout-minutes: 15
|
||||
# steps:
|
||||
# - name: Checkout Code
|
||||
# uses: actions/checkout@v4
|
||||
# - name: Setup Environment
|
||||
# uses: ./.github/actions/setup-tinygrad
|
||||
# with:
|
||||
# key: torch-backend-pillow-torchvision-et-pt
|
||||
# deps: testing_minimal
|
||||
# llvm: 'true'
|
||||
# - name: Install ninja
|
||||
# run: |
|
||||
# sudo apt update || true
|
||||
# sudo apt install -y --no-install-recommends ninja-build
|
||||
# - name: Test beautiful_mnist in torch with TINY_BACKEND
|
||||
# run: STEPS=20 CPU=1 TARGET_EVAL_ACC_PCT=90.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py
|
||||
# - name: Test some torch tests (expect failure)
|
||||
# run: python3 -m pytest extra/torch_backend/torch_tests.py -v --tb=no || true
|
||||
|
||||
torchbackendmore:
|
||||
name: Torch Backend Tests More
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: torch-backend-pillow-torchvision-et-pt
|
||||
deps: testing_minimal
|
||||
llvm: 'true'
|
||||
- name: Install ninja
|
||||
run: |
|
||||
sudo apt update || true
|
||||
sudo apt install -y --no-install-recommends ninja-build
|
||||
- name: Test beautiful_mnist in torch with TINY_BACKEND
|
||||
run: STEPS=20 CPU=1 TARGET_EVAL_ACC_PCT=90.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py
|
||||
- name: Test some torch tests (expect failure)
|
||||
run: python3 -m pytest extra/torch_backend/torch_tests.py -v --tb=no || true
|
||||
|
||||
bepython:
|
||||
name: Python Backend
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
# A006 Lambda argument `input` is shadowing a Python builtin
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.helpers import getenv, prod
|
||||
from tinygrad.helpers import getenv, prod, strides_for_shape, argfix
|
||||
import torch.lib
|
||||
TORCH_DEBUG = getenv("TORCH_DEBUG")
|
||||
import torch, pathlib, math, operator, functools, inspect
|
||||
import torch, pathlib, math, operator, functools, weakref
|
||||
torch.autograd.grad_mode.set_multithreading_enabled(False)
|
||||
from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype
|
||||
|
||||
@@ -18,7 +18,17 @@ def _to_torch_device(device: str): return torch.device("tiny", int(device.partit
|
||||
|
||||
import torch.utils.cpp_extension
|
||||
mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[str(pathlib.Path(__file__).parent / "wrapped_tensor.cpp")])
|
||||
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, _to_torch_dtype(x.dtype), _to_torch_device(x.device).index)
|
||||
def calculate_storage_offset(x: Tensor) -> int:
|
||||
offset = 0
|
||||
for u in x.uop.toposort():
|
||||
if u.op == Ops.SHRINK:
|
||||
u_strides = strides_for_shape(u.src[0].shape)
|
||||
for i, (start, _) in enumerate(u.marg): offset += start * u_strides[i]
|
||||
return offset
|
||||
def wrap(x: Tensor) -> torch.Tensor:
|
||||
x._strides = strides_for_shape(x.shape) # always recalculate
|
||||
if (not hasattr(x, '_storage_offset')) or (not x.uop.is_realized): x._storage_offset = calculate_storage_offset(x)
|
||||
return mod.wrap(x, _to_torch_dtype(x.dtype), _to_torch_device(x.device).index)
|
||||
def unwrap(x:torch.Tensor) -> Tensor:
|
||||
assert isinstance(x, torch.Tensor), f"x isn't {type(x)}"
|
||||
return mod.unwrap(x)
|
||||
@@ -35,17 +45,20 @@ torch.utils.generate_methods_for_privateuse1_backend()
|
||||
aten = torch.ops.aten
|
||||
|
||||
# track view relationships for in place operations
|
||||
def is_view(tensor: Tensor): return hasattr(tensor, "_view_base")
|
||||
def canonical_base(view: Tensor): return getattr(view, "_view_base", view)
|
||||
def derived_views(base: Tensor): return [t for tref in getattr(base, "_views", set()) if (t:=tref()) is not None]
|
||||
def unwrap_args(args, kwargs):
|
||||
return [unwrap(x) if isinstance(x, torch.Tensor) else x for x in args], {k:unwrap(v) if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()}
|
||||
def wrap_view_op(fn):
|
||||
def _wrap(*args,**kwargs):
|
||||
args = [unwrap(x) if isinstance(x, torch.Tensor) else x for x in args]
|
||||
kwargs = {k:unwrap(v) if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()}
|
||||
ret = fn(*args,**kwargs)
|
||||
ret._view_base = base = canonical_base(args[0])
|
||||
if not hasattr(base, "_views"): base._views = set()
|
||||
@functools.wraps(fn)
|
||||
def _wrap(*args, **kwargs):
|
||||
args, kwargs = unwrap_args(args, kwargs)
|
||||
ret = fn(*args, **kwargs)
|
||||
base = canonical_base(args[0])
|
||||
ret._view_base = base
|
||||
base._views = getattr(base, "_views", set())
|
||||
base._views.add(weakref.ref(ret))
|
||||
ret._view_ops = _get_view_ops(args[0]) + [(fn, args[1:], kwargs)]
|
||||
return wrap(ret)
|
||||
return _wrap
|
||||
|
||||
@@ -60,46 +73,79 @@ view_ops = {
|
||||
"aten.unsqueeze": Tensor.unsqueeze,
|
||||
"aten.detach": Tensor.detach,
|
||||
"aten.select.int": lambda self, dim, idx: self[(slice(None),) * (dim%self.ndim) + (idx,)],
|
||||
}
|
||||
"aten.permute": Tensor.permute,
|
||||
"aten.alias": lambda self: self,
|
||||
}
|
||||
|
||||
for k,v in view_ops.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_view_op(v))
|
||||
|
||||
# in place operations with views
|
||||
def realize_with_views(self: Tensor, views: Tensor):
|
||||
if not self.uop.st.contiguous: self.replace(self.contiguous())
|
||||
self.replace(self.clone().realize())
|
||||
for v in views:
|
||||
if v.uop.base.op is Ops.BUFFER_VIEW: continue # skip subbuffer, we just use the real buffer view
|
||||
ret = self
|
||||
st = ShapeTracker(self.uop.st.views + v.uop.st.views) # TODO: is this right?
|
||||
for mo in cached_to_movement_ops(self.shape, st): ret = apply_mop(ret, mo)
|
||||
v.replace(ret)
|
||||
def maybe_realize_storage(self: Tensor) -> bool:
|
||||
if realize:=is_view(self): realize_with_views((base:=canonical_base(self)), derived_views(base))
|
||||
return realize
|
||||
def inplace_fn(outvars: str|list[str]):
|
||||
if type(outvars) is str: outvars = [outvars]
|
||||
def decorator(fn):
|
||||
sig = inspect.signature(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
bound = sig.bind(*args, **kwargs)
|
||||
outs = [kwargs.get(v, bound.arguments.get(v)) for v in outvars]
|
||||
outs = [unwrap(o) if isinstance(o, torch.Tensor) else o for o in outs]
|
||||
realize = any(maybe_realize_storage(o) for o in outs)
|
||||
ret = fn(*args, **kwargs)
|
||||
if realize: Tensor.realize(*(o for o in outs))
|
||||
return ret
|
||||
return wrapper
|
||||
return decorator
|
||||
def _get_view_ops(view): return getattr(view, "_view_ops", [])
|
||||
|
||||
def _apply_view_ops(target, ops):
|
||||
for fn, args, kwargs in ops: target = fn(target, *args, **kwargs)
|
||||
return target
|
||||
|
||||
# similar to https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/InferSize.h
|
||||
def _reshape_target_shape(shape:tuple[int, ...], args) -> tuple[int, ...]|None:
|
||||
if not (req := argfix(*args)): return None
|
||||
new_shape, infer_idx = [], -1
|
||||
for i, s in enumerate(req):
|
||||
if s is None: s = shape[i] if i < len(shape) else None
|
||||
if not isinstance(s, int): return None
|
||||
if s == -1:
|
||||
if infer_idx != -1: return None
|
||||
infer_idx = len(new_shape)
|
||||
new_shape.append(s)
|
||||
total = prod(shape)
|
||||
if infer_idx != -1:
|
||||
known = prod(x for x in new_shape if x != -1)
|
||||
if known == 0:
|
||||
if total != 0: return None
|
||||
new_shape[infer_idx] = 0
|
||||
else: new_shape[infer_idx] = total // known
|
||||
return tuple(new_shape) if prod(new_shape) == total else None
|
||||
|
||||
# TODO: can we get rid of this? only for test_flatten_reshape_add
|
||||
def _try_simple_reshape_view_write(base: Tensor, view: Tensor, val: Tensor) -> bool:
|
||||
if not (ops := _get_view_ops(view)): return False
|
||||
shapes = [base.shape]
|
||||
for fn, args, _ in ops:
|
||||
if fn is Tensor.reshape:
|
||||
if not (next_shape := _reshape_target_shape(shapes[-1], args)): return False
|
||||
shapes.append(next_shape)
|
||||
if shapes[-1] != view.shape: return False
|
||||
for s in reversed(shapes[:-1]): val = val.reshape(s)
|
||||
base.assign(val)
|
||||
return True
|
||||
|
||||
def _view_write(base: Tensor, view: Tensor, value: Tensor) -> None:
|
||||
val = value if value.dtype == base.dtype else value.cast(base.dtype)
|
||||
if view.shape == base.shape: return base.assign(val)
|
||||
if _try_simple_reshape_view_write(base, view, val): return
|
||||
idx_base = Tensor.arange(base.numel(), device=base.device, dtype=dtypes.int32).reshape(base.shape)
|
||||
idx_view = _apply_view_ops(idx_base, _get_view_ops(view)).reshape(-1)
|
||||
flat_base = base.reshape(base.numel()).contiguous()
|
||||
flat_base[idx_view] = val.reshape(-1)
|
||||
base.assign(flat_base.reshape(base.shape))
|
||||
|
||||
def _apply_inplace(target: Tensor, value: Tensor) -> None:
|
||||
val = value if value.dtype == target.dtype else value.cast(target.dtype)
|
||||
base = canonical_base(target)
|
||||
views = derived_views(base)
|
||||
if not views: return target.assign(val)
|
||||
view_ops_map = {v: _get_view_ops(v) for v in views}
|
||||
if target is base or target.uop is base.uop: base.assign(val)
|
||||
else: _view_write(base, target, val)
|
||||
for v in views: v.replace(_apply_view_ops(base, view_ops_map[v]))
|
||||
|
||||
# *** bad functions on CPU ***
|
||||
|
||||
@torch.library.impl("aten::_index_put_impl_", "privateuseone")
|
||||
@inplace_fn("self")
|
||||
def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False):
|
||||
# TODO: move to tinygrad
|
||||
ret = aten._index_put_impl_(self.cpu(), [x.cpu() if isinstance(x, torch.Tensor) else None for x in indices], values.cpu(), accumulate, unsafe).to(self.device)
|
||||
return wrap(unwrap(self).assign(unwrap(ret)))
|
||||
unwrap(self).assign(unwrap(ret))
|
||||
return self
|
||||
|
||||
@torch.library.impl("aten::index_put", "privateuseone")
|
||||
def index_put(self, indices, values, accumulate=False):
|
||||
@@ -150,43 +196,23 @@ for i in [
|
||||
def index_tensor(x, y):
|
||||
return wrap(unwrap(x)[[unwrap(_y.to(x.device)) if _y is not None else slice(None) for _y in y]])
|
||||
|
||||
@torch.library.impl("aten::zero_", "privateuseone")
|
||||
@inplace_fn("x")
|
||||
def zero_(x):
|
||||
if TORCH_DEBUG: print(f"zero_ {x.shape}")
|
||||
tt = unwrap(x)
|
||||
tt.assign(tt.zeros_like())
|
||||
|
||||
@torch.library.impl("aten::fill_.Scalar", "privateuseone")
|
||||
@inplace_fn("x")
|
||||
def fill_scalar(x, y):
|
||||
if TORCH_DEBUG: print(f"fill_.Scalar {x.shape} {y}")
|
||||
tt = unwrap(x)
|
||||
tt.assign(tt.full_like(y))
|
||||
|
||||
@torch.library.impl("aten::_local_scalar_dense", "privateuseone")
|
||||
def _local_scalar_dense(tensor): return unwrap(tensor).item()
|
||||
|
||||
@functools.cache
|
||||
def cached_to_movement_ops(shape, st) -> list:
|
||||
mops = to_movement_ops(st)
|
||||
if mops[0] == (MovementOps.RESHAPE, shape): mops = mops[1:]
|
||||
return mops
|
||||
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from extra.to_movement_ops import to_movement_ops, apply_mop, MovementOps
|
||||
|
||||
@wrap_view_op
|
||||
def _as_strided(tensor:Tensor, size, stride, storage_offset=None):
|
||||
# multiple as_strided do not compound
|
||||
base = canonical_base(tensor)
|
||||
# TODO: this is heavyweight
|
||||
st = ShapeTracker(base.uop.st.views + (View.create(tuple(size), tuple(stride), storage_offset),))
|
||||
ret = base
|
||||
if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st)
|
||||
if prod(size) == 1: return ret.flatten()[storage_offset].reshape(size)
|
||||
for mo in cached_to_movement_ops(tuple(base.shape), st): ret = apply_mop(ret, mo)
|
||||
return ret
|
||||
def _as_strided(tensor:Tensor, size, stride, storage_offset=0):
|
||||
base = getattr(tensor, "_as_strided_base", canonical_base(tensor)).flatten()
|
||||
if prod(size) == 1: return base[storage_offset].reshape(size)
|
||||
indices = Tensor.zeros(size, dtype=dtypes.int32, device=base.device) + storage_offset
|
||||
for dim, (sz, st) in enumerate(zip(size, stride)):
|
||||
if st != 0:
|
||||
dim_range = Tensor.arange(sz, device=base.device, dtype=dtypes.int32) * st
|
||||
shape_for_broadcast = [1] * dim + [sz] + [1] * (len(size) - dim - 1)
|
||||
indices = indices + dim_range.reshape(shape_for_broadcast)
|
||||
result = base[indices.flatten()].reshape(size)
|
||||
result._as_strided_base = base
|
||||
return result
|
||||
|
||||
@torch.library.impl("aten::as_strided", "privateuseone")
|
||||
def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
|
||||
@@ -245,15 +271,14 @@ def convolution_overrideable(input, weight, bias, stride, padding, dilation, tra
|
||||
if TORCH_DEBUG >= 1:
|
||||
print(f"convolution {input.shape=} {weight.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
|
||||
input, weight, bias = unwrap(input), unwrap(weight), unwrap(bias) if bias is not None else None
|
||||
# TODO: fix test_biased_conv2d fails without realize()
|
||||
if not transposed: return wrap(input.conv2d(weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding).realize())
|
||||
return wrap(input.conv_transpose2d(weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding, output_padding=output_padding).realize())
|
||||
if not transposed: return wrap(input.conv2d(weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding))
|
||||
return wrap(input.conv_transpose2d(weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding, output_padding=output_padding))
|
||||
|
||||
@torch.library.impl("aten::convolution_backward_overrideable", "privateuseone")
|
||||
def convolution_backward_overrideable(grad_out, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask):
|
||||
if TORCH_DEBUG >= 1:
|
||||
print(f"convolution_backward {input.shape=} {weight.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
|
||||
grad_out, input, weight, bias = unwrap(grad_out), unwrap(input), unwrap(weight), Tensor.zeros(weight.shape[0], device=_from_torch_device(weight.device))
|
||||
grad_out, input, weight, bias = unwrap(grad_out).detach(), unwrap(input).detach(), unwrap(weight).detach(), Tensor.zeros(weight.shape[0], device=_from_torch_device(weight.device))
|
||||
if not transposed: out = Tensor.conv2d(input, weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding)
|
||||
else:
|
||||
bias = Tensor.zeros(weight.shape[1] * groups)
|
||||
@@ -315,55 +340,57 @@ for i,pre in enumerate(["", "bi", "tri"]):
|
||||
torch.library.impl(f"aten::_upsample_nearest_exact{i+1}d", "privateuseone")(functools.partial(upsample, mode="nearest-exact"))
|
||||
|
||||
@torch.library.impl("aten::scatter_add.out", "privateuseone")
|
||||
@inplace_fn("out")
|
||||
def scatter_add(self, dim, index, src, out):
|
||||
self, index, src, out = unwrap(self), unwrap(index), unwrap(src), unwrap(out)
|
||||
if self.shape == (): return wrap(out.assign(src))
|
||||
return wrap(out.assign(Tensor.scatter_reduce(self, dim, index, src, reduce='sum')))
|
||||
self, index, src, out_unwrapped = unwrap(self), unwrap(index), unwrap(src), unwrap(out)
|
||||
if self.shape == (): _apply_inplace(out_unwrapped, src)
|
||||
else: _apply_inplace(out_unwrapped, Tensor.scatter_reduce(self, dim, index, src, reduce='sum'))
|
||||
return out
|
||||
|
||||
@torch.library.impl("aten::_copy_from", "privateuseone")
|
||||
def _copy_from(src: torch.Tensor, dest, non_blocking=False):
|
||||
realize = dest.is_tiny and maybe_realize_storage(unwrap(dest))
|
||||
cast_dtype = _from_torch_dtype(dest.dtype)
|
||||
def _copy_between_devices(src, dest, cast_dtype, to_device, non_blocking=False):
|
||||
if src.is_tiny and dest.is_tiny:
|
||||
to_device = _from_torch_device(dest.device)
|
||||
src,dest = unwrap(src),unwrap(dest)
|
||||
# TODO we need to properly match dest shape and strides, not blindly assign
|
||||
if dest.uop.st.contiguous or dest.uop.is_realized: src = src.contiguous() # this only solves some cases
|
||||
dest.assign(src.cast(cast_dtype).to(to_device))
|
||||
if realize: Tensor.realize(dest)
|
||||
src_t, dest_t = unwrap(src), unwrap(dest)
|
||||
if dest_t.uop.is_contiguous() or dest_t.uop.is_realized: src_t = src_t.contiguous()
|
||||
_apply_inplace(dest_t, src_t.cast(cast_dtype).to(to_device))
|
||||
elif src.is_tiny and dest.is_cpu:
|
||||
# TODO: is there a better way?
|
||||
dest.resize_(src.numel()).resize_(src.shape)
|
||||
dest.copy_(torch.from_numpy(unwrap(src).cast(cast_dtype).numpy()))
|
||||
elif src.is_cpu and dest.is_tiny:
|
||||
to_device = _from_torch_device(dest.device)
|
||||
# TODO we need to properly match dest shape and strides, not blindly assign
|
||||
unwrap(dest).assign(Tensor(src.numpy()).cast(cast_dtype).to(to_device))
|
||||
if realize: Tensor.realize(unwrap(dest))
|
||||
else:
|
||||
raise NotImplementedError(f"can't copy from {src.device} -> {dest.device}")
|
||||
|
||||
@torch.library.impl("aten::_copy_from", "privateuseone")
|
||||
def _copy_from(src: torch.Tensor, dest, non_blocking=False):
|
||||
cast_dtype = _from_torch_dtype(dest.dtype)
|
||||
to_device = _from_torch_device(dest.device)
|
||||
_copy_between_devices(src, dest, cast_dtype, to_device, non_blocking)
|
||||
return dest
|
||||
|
||||
@torch.library.impl("aten::copy_", "privateuseone")
|
||||
def copy_(self, src, non_blocking=False):
|
||||
cast_dtype = _from_torch_dtype(self.dtype)
|
||||
to_device = _from_torch_device(self.device)
|
||||
_copy_between_devices(src, self, cast_dtype, to_device, non_blocking)
|
||||
return self
|
||||
|
||||
@torch.library.impl("aten::cat.out", "privateuseone")
|
||||
@inplace_fn("out")
|
||||
def cat_out(tensors, dim=0, out=None):
|
||||
unwrap(out).assign(Tensor.cat(*[unwrap(x) for x in tensors], dim=dim))
|
||||
_apply_inplace(unwrap(out), Tensor.cat(*[unwrap(x) for x in tensors], dim=dim))
|
||||
return out
|
||||
|
||||
@torch.library.impl("aten::topk.values", "privateuseone")
|
||||
@inplace_fn(["values", "indices"])
|
||||
def topk_values(input, k, dim=None, largest=True, sorted=True, values=None, indices=None):
|
||||
out_values, out_indices = unwrap(input).topk(k, dim if dim is not None else -1, largest, sorted)
|
||||
unwrap(values).assign(out_values)
|
||||
unwrap(indices).assign(out_indices.cast(dtypes.int64))
|
||||
return wrap(out_values), wrap(out_indices)
|
||||
_apply_inplace(unwrap(values), out_values)
|
||||
_apply_inplace(unwrap(indices), out_indices.cast(dtypes.int64))
|
||||
return values, indices
|
||||
|
||||
@torch.library.impl("aten::sort.values_stable", "privateuseone")
|
||||
@inplace_fn(["values", "indices"])
|
||||
def sort_values(input, dim=-1, descending=False, stable=True, values=None, indices=None):
|
||||
out_values, out_indices = unwrap(input).sort(dim, descending)
|
||||
unwrap(values).assign(out_values)
|
||||
unwrap(indices).assign(out_indices.cast(dtypes.int64))
|
||||
return wrap(out_values), wrap(out_indices)
|
||||
_apply_inplace(unwrap(values), out_values)
|
||||
_apply_inplace(unwrap(indices), out_indices.cast(dtypes.int64))
|
||||
return values, indices
|
||||
|
||||
@torch.library.impl("aten::_linalg_svd", "privateuseone")
|
||||
def _linalg_svd(self, full_matrices=False):
|
||||
@@ -373,7 +400,6 @@ def _linalg_svd(self, full_matrices=False):
|
||||
# register some decompositions
|
||||
from torch._decomp import get_decompositions
|
||||
decomps = [
|
||||
aten.native_batch_norm, aten.native_batch_norm_backward,
|
||||
aten.native_layer_norm_backward,
|
||||
aten.linalg_cross,
|
||||
aten.addmm,
|
||||
@@ -510,7 +536,6 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
|
||||
|
||||
# we add the "out" here
|
||||
def wrap_out(f):
|
||||
@inplace_fn("out")
|
||||
def _wrap_out(*args, **kwargs):
|
||||
out = kwargs.pop('out')
|
||||
assigned = f(*args, **kwargs)
|
||||
@@ -518,22 +543,33 @@ def wrap_out(f):
|
||||
assert out.shape == assigned.shape, f"shape mismatch: {assigned.shape} -> {out.shape}"
|
||||
assert out.device == assigned.device, f"device mismatch: {assigned.device} -> {out.device}"
|
||||
assert out.dtype == assigned.dtype, f"dtype mismatch: {assigned.dtype} -> {out.dtype}"
|
||||
if out.uop.is_realized: assigned = assigned.contiguous() # TODO: how does this map to torch's semantics
|
||||
return out.assign(assigned)
|
||||
return _wrap_out
|
||||
|
||||
def _inplace_op(t, new_value):
|
||||
if not hasattr(t, "_view_base") and not getattr(canonical_base(t), "_views", set()): t.replace(new_value)
|
||||
else: _apply_inplace(t, new_value)
|
||||
return t
|
||||
|
||||
tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
"aten.remainder.Scalar_Tensor": lambda x,y: x%y,
|
||||
"aten.floor_divide": lambda x,y: x//y,
|
||||
"aten.floor_divide_.Tensor": inplace_fn("x")(lambda x,y: x.assign(x//y)),
|
||||
"aten.floor_divide_.Tensor": lambda x,y: x//y,
|
||||
# TODO: use tinygrad methods, but they require x to be unsigned
|
||||
"aten.__lshift__.Scalar": lambda x,y: x*(2**y),
|
||||
"aten.__ilshift__.Scalar": inplace_fn("x")(lambda x,y: x.assign(x*(2**y))),
|
||||
"aten.__ilshift__.Scalar": lambda x,y: x*(2**y),
|
||||
"aten.__rshift__.Scalar": lambda x,y: x//(2**y),
|
||||
"aten.__irshift__.Scalar": inplace_fn("x")(lambda x,y: x.assign(x//(2**y))),
|
||||
"aten.__irshift__.Scalar": lambda x,y: x//(2**y),
|
||||
# inplace ops using replace for fusion
|
||||
"aten.zero_": lambda x: x.zeros_like(),
|
||||
"aten.fill_.Scalar": lambda x, y: x.full_like(y),
|
||||
"aten.add_.Tensor": lambda self, other, alpha=1.0: self + other * alpha,
|
||||
"aten.add_.Scalar": lambda self, other, alpha=1.0: self + other * alpha,
|
||||
"aten.mul_.Tensor": lambda self, other: self * other,
|
||||
"aten.mul_.Scalar": lambda self, other: self * other,
|
||||
# relu doesn't have an out form?
|
||||
"aten.relu": Tensor.relu,
|
||||
"aten.relu_": inplace_fn("x")(lambda x: x.assign(x.relu())),
|
||||
"aten.relu_": lambda x: x.relu(),
|
||||
"aten.mean": Tensor.mean,
|
||||
"aten.mean.dim": Tensor.mean,
|
||||
"aten.min": Tensor.min,
|
||||
@@ -554,19 +590,17 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
"aten.repeat": lambda x,*repeats: Tensor.repeat(x,*repeats).contiguous(), # not a view
|
||||
"aten._softmax": lambda self,dim,half_to_float: self.softmax(dim),
|
||||
"aten._log_softmax": lambda self,dim,half_to_float: self.log_softmax(dim),
|
||||
"aten.random_": inplace_fn("self")(lambda self:
|
||||
self.assign(Tensor.randint(*self.shape, low=dtypes.min(self.dtype), high=dtypes.max(self.dtype), device=self.device, dtype=self.dtype))),
|
||||
"aten.random_.from": inplace_fn("self")(lambda self, from_, to:
|
||||
self.assign(Tensor.randint(*self.shape, low=from_, high=to, device=self.device, dtype=self.dtype))),
|
||||
"aten.uniform_": inplace_fn("self")(lambda self, low=0, high=1: self.assign(Tensor.uniform(*self.shape, low=low, high=high, dtype=self.dtype))),
|
||||
"aten.normal_": inplace_fn("self")(lambda self, mean=0, std=1: self.assign(Tensor.normal(*self.shape, mean=mean, std=std, dtype=self.dtype))),
|
||||
"aten.random_": lambda self: Tensor.randint(*self.shape, low=dtypes.min(self.dtype), high=dtypes.max(self.dtype), device=self.device, dtype=self.dtype),
|
||||
"aten.random_.from": lambda self, from_, to: Tensor.randint(*self.shape, low=from_, high=to, device=self.device, dtype=self.dtype),
|
||||
"aten.uniform_": lambda self, low=0, high=1: Tensor.uniform(*self.shape, low=low, high=high, dtype=self.dtype),
|
||||
"aten.normal_": lambda self, mean=0, std=1: Tensor.normal(*self.shape, mean=mean, std=std, dtype=self.dtype),
|
||||
# these don't work in out form, they have size 0
|
||||
"aten.abs": Tensor.abs,
|
||||
"aten.logical_not": Tensor.logical_not,
|
||||
"aten.logical_or_": inplace_fn("x")(lambda x, y: x.assign(x | y)),
|
||||
"aten.logical_or_": lambda x, y: x | y,
|
||||
"aten.multinomial": Tensor.multinomial,
|
||||
"aten.masked_fill_.Scalar": inplace_fn("self")(lambda self, mask, value: self.assign(self.masked_fill(mask, value))),
|
||||
"aten.masked_fill_.Tensor": inplace_fn("self")(lambda self, mask, value: self.assign(self.masked_fill(mask, value))),
|
||||
"aten.masked_fill_.Scalar": lambda self, mask, value: self.masked_fill(mask, value),
|
||||
"aten.masked_fill_.Tensor": lambda self, mask, value: self.masked_fill(mask, value),
|
||||
"aten.masked_fill.Scalar": Tensor.masked_fill,
|
||||
"aten.masked_fill.Tensor": Tensor.masked_fill,
|
||||
"aten.masked_select": Tensor.masked_select,
|
||||
@@ -580,7 +614,7 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
"aten.asinh": Tensor.asinh,
|
||||
"aten.mul": Tensor.mul,
|
||||
"aten.atanh": Tensor.atanh,
|
||||
"aten.fill_.Tensor": Tensor.full, # TODO: looks wrong
|
||||
"aten.fill_.Tensor": lambda self, value: Tensor.full(self.shape, value.reshape(()).item(), device=self.device, dtype=self.dtype),
|
||||
"aten.flip": Tensor.flip,
|
||||
"aten.scatter_reduce.two": Tensor.scatter_reduce,
|
||||
"aten.squeeze_.dim": lambda self, dim: self.replace(self.squeeze(dim), allow_shape_mismatch=True), # TODO: inplace view op, here?
|
||||
@@ -601,20 +635,51 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
"aten.unfold": Tensor.unfold,
|
||||
}}
|
||||
|
||||
# operations that need inplace treatment (use _inplace_op instead of wrap_fxn) AKA return original tensor
|
||||
inplace_ops = {
|
||||
"aten.zero_",
|
||||
"aten.fill_.Scalar",
|
||||
"aten.fill_.Tensor",
|
||||
"aten.add_.Tensor",
|
||||
"aten.add_.Scalar",
|
||||
"aten.mul_.Tensor",
|
||||
"aten.mul_.Scalar",
|
||||
"aten.floor_divide_.Tensor",
|
||||
"aten.__ilshift__.Scalar",
|
||||
"aten.__irshift__.Scalar",
|
||||
"aten.relu_",
|
||||
"aten.random_",
|
||||
"aten.random_.from",
|
||||
"aten.uniform_",
|
||||
"aten.normal_",
|
||||
"aten.logical_or_",
|
||||
"aten.masked_fill_.Scalar",
|
||||
"aten.masked_fill_.Tensor",
|
||||
}
|
||||
|
||||
def wrap_fxn(k,f):
|
||||
def nf(*args, **kwargs):
|
||||
if TORCH_DEBUG:
|
||||
print(k, len(args), [x.shape if isinstance(x, torch.Tensor) else x for x in args],
|
||||
{k:v.shape if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()})
|
||||
args = [unwrap(x) if isinstance(x, torch.Tensor) else x for x in args]
|
||||
kwargs = {k:unwrap(v) if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()}
|
||||
args, kwargs = unwrap_args(args, kwargs)
|
||||
out = f(*args, **kwargs)
|
||||
if isinstance(out, Tensor): return wrap(out)
|
||||
elif isinstance(out, tuple): return tuple(wrap(x) for x in out)
|
||||
else: raise RuntimeError(f"unknown output type {type(out)}")
|
||||
return nf
|
||||
|
||||
for k,v in tiny_backend.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_fxn(k,v))
|
||||
def wrap_inplace(k,f):
|
||||
def nf(*args, **kwargs):
|
||||
orig = args[0]
|
||||
args, kwargs = unwrap_args(args, kwargs)
|
||||
_inplace_op(args[0], f(*args, **kwargs))
|
||||
return orig
|
||||
return nf
|
||||
|
||||
for k,v in tiny_backend.items():
|
||||
wrapper = wrap_inplace if k in inplace_ops else wrap_fxn
|
||||
torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrapper(k,v))
|
||||
|
||||
@torch.library.impl("aten::equal", "privateuseone")
|
||||
def equal(x: torch.Tensor, y: torch.Tensor): return (x==y).all().item()
|
||||
@@ -628,42 +693,72 @@ if TORCH_DEBUG:
|
||||
return func(*args, **(kwargs or {}))
|
||||
(_dispatch_log:=DispatchLog()).__enter__() # NOTE: must be kept alive
|
||||
|
||||
# NOTE: patch torch optimizer step to avoid continously growing the computation graph
|
||||
import weakref
|
||||
_torch_modules_with_buffers: weakref.WeakSet[torch.nn.Module] = weakref.WeakSet()
|
||||
def register_torch_buffer(mod, _name, _buffer): _torch_modules_with_buffers.add(mod)
|
||||
def get_real_tinygrad_buffers():
|
||||
res = set()
|
||||
for mod in _torch_modules_with_buffers:
|
||||
for _,b in mod.named_buffers(recurse=False):
|
||||
if b is not None and b.is_tiny:
|
||||
res.add(unwrap(b))
|
||||
return res
|
||||
torch.nn.modules.module.register_module_buffer_registration_hook(register_torch_buffer)
|
||||
# this implementation is needed to allow the batchnorm kernels to fuse in e.g. mnist training
|
||||
# aten::native_batch_norm does more than Tensor.batchnorm
|
||||
@torch.library.impl("aten::native_batch_norm", "privateuseone")
|
||||
def native_batch_norm(input, weight, bias, running_mean, running_var, training, momentum, eps):
|
||||
input_t, weight_t, bias_t = unwrap(input), unwrap(weight) if weight is not None else None, unwrap(bias) if bias is not None else None
|
||||
running_mean_t, running_var_t = unwrap(running_mean) if running_mean is not None else None, unwrap(running_var) if running_var is not None else None
|
||||
if training:
|
||||
batch_var, batch_mean = input_t.var_mean(axis=tuple(x for x in range(input_t.ndim) if x != 1), correction=0)
|
||||
batch_invstd = batch_var.add(eps).rsqrt()
|
||||
out = input_t.batchnorm(weight_t, bias_t, batch_mean, batch_invstd)
|
||||
if running_mean_t is not None and running_var_t is not None:
|
||||
numel_ratio = input_t.numel() / (input_t.numel() - input_t.shape[1])
|
||||
running_mean_t.assign((1 - momentum) * running_mean_t + momentum * batch_mean.detach())
|
||||
running_var_t.assign((1 - momentum) * running_var_t + momentum * numel_ratio * batch_var.detach())
|
||||
return wrap(out), wrap(batch_mean), wrap(batch_invstd)
|
||||
else:
|
||||
out = input_t.batchnorm(weight_t, bias_t, running_mean_t, running_var_t.add(eps).rsqrt())
|
||||
return wrap(out), wrap(running_mean_t), wrap(running_var_t.add(eps).rsqrt())
|
||||
|
||||
from torch.nn.modules import Module
|
||||
def param_hook(_grad):
|
||||
if _grad is not None and _grad.is_tiny: Tensor.realize(unwrap(_grad))
|
||||
def module_hook(module:Module, _name, _submodule):
|
||||
for param in _submodule.parameters(recurse=False):
|
||||
if param.requires_grad: param.register_hook(param_hook)
|
||||
torch.nn.modules.module.register_module_module_registration_hook(module_hook)
|
||||
@torch.library.impl("aten::native_batch_norm_backward", "privateuseone")
|
||||
def native_batch_norm_backward(grad_out, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps, output_mask):
|
||||
grad_out_t, input_t = unwrap(grad_out), unwrap(input)
|
||||
weight_t = unwrap(weight) if weight is not None else None
|
||||
save_mean_t = unwrap(save_mean)
|
||||
save_invstd_t = unwrap(save_invstd)
|
||||
out = input_t.batchnorm(weight_t, None, save_mean_t, save_invstd_t)
|
||||
targets = [t for t, m in zip([input_t, weight_t], output_mask[:2]) if t is not None and m]
|
||||
if targets:
|
||||
grads = out.gradient(*targets, gradient=grad_out_t)
|
||||
grad_input = grads.pop(0) if output_mask[0] else None
|
||||
grad_weight = grads.pop(0) if output_mask[1] and weight_t is not None else None
|
||||
else:
|
||||
grad_input, grad_weight = None, None
|
||||
grad_bias = grad_out_t.sum(axis=tuple(x for x in range(grad_out_t.ndim) if x != 1)) if output_mask[2] else None
|
||||
return (wrap(grad_input) if grad_input is not None else None,
|
||||
wrap(grad_weight) if grad_weight is not None else None,
|
||||
wrap(grad_bias) if grad_bias is not None else None)
|
||||
|
||||
def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs):
|
||||
tinygrad_tensors = []
|
||||
for param_group in optimizer.param_groups:
|
||||
for param in param_group["params"]:
|
||||
if param is None: continue
|
||||
tinygrad_tensors.append(param.data)
|
||||
for state_dict in optimizer.state.values():
|
||||
for _, value in state_dict.items():
|
||||
if torch.is_tensor(value): tinygrad_tensors.append(value)
|
||||
real_tinygrad_tensors = [unwrap(x) for x in tinygrad_tensors if x.is_tiny]
|
||||
real_tinygrad_tensors += get_real_tinygrad_buffers()
|
||||
if len(real_tinygrad_tensors): Tensor.realize(*real_tinygrad_tensors)
|
||||
# _pad_circular is not CompositeImplicitAutograd (unlike reflect/replicate pad)
|
||||
# we need torch.autograd.Function with explicit AutogradPrivateUse1 registration
|
||||
class _PadCircular(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, padding):
|
||||
ctx.save_for_backward(input)
|
||||
ctx.padding = padding
|
||||
return pad_forward(input, padding, mode="circular")
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return pad_backward(grad_output, input, ctx.padding, mode="circular"), None
|
||||
|
||||
_optimizer_init = torch.optim.Optimizer.__init__
|
||||
def _optimizer_patched_init(self, *args, **kwargs):
|
||||
_optimizer_init(self, *args, **kwargs)
|
||||
self.register_step_post_hook(realize_optimizer_step)
|
||||
torch.optim.Optimizer.__init__ = _optimizer_patched_init
|
||||
@torch.library.impl("aten::_pad_circular", "privateuseone")
|
||||
def _pad_circular(self, padding): return _PadCircular.apply(self, padding)
|
||||
|
||||
@torch.library.impl("aten::_pad_circular", "AutogradPrivateUse1")
|
||||
def _pad_circular_autograd(self, padding): return _PadCircular.apply(self, padding)
|
||||
|
||||
# only needed for test_diag_backward_gradient_values
|
||||
# was going through torch before, but now we are using tinygrad directly and tracking views
|
||||
# Tensor.diagonal does not support all cases tests in the tests
|
||||
@torch.library.impl("aten::diagonal", "privateuseone")
|
||||
@wrap_view_op
|
||||
def diagonal(self, offset=0, dim1=0, dim2=1):
|
||||
if offset != 0: raise NotImplementedError(f"diagonal with {offset=} not implemented")
|
||||
dim1, dim2 = dim1 % self.ndim, dim2 % self.ndim
|
||||
if dim1 != self.ndim - 2 or dim2 != self.ndim - 1: raise NotImplementedError(f"diagonal with {dim1=}, {dim2=} not implemented, only last two dims supported")
|
||||
batch_shape, m, n = self.shape[:-2], self.shape[-2], self.shape[-1]
|
||||
diag_len = min(m, n)
|
||||
return self.reshape(*batch_shape, m*n).pad(tuple((0,0) for _ in batch_shape) + ((0, diag_len),)).reshape(*batch_shape, diag_len, n+1)[..., :, 0]
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from PIL import Image
|
||||
from tinygrad.helpers import getenv
|
||||
import torch, torchvision, pathlib
|
||||
from tinygrad.helpers import getenv, GlobalCounters
|
||||
import torch, torchvision, pathlib, warnings
|
||||
import torchvision.transforms as transforms
|
||||
import extra.torch_backend.backend
|
||||
device = "tiny"
|
||||
torch.set_default_device(device)
|
||||
|
||||
if __name__ == "__main__":
|
||||
GlobalCounters.reset()
|
||||
img = Image.open(pathlib.Path(__file__).parent.parent.parent / "test/models/efficientnet/Chicken.jpg").convert('RGB')
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
|
||||
@@ -19,3 +20,10 @@ if __name__ == "__main__":
|
||||
out = model(img).detach().cpu().numpy()
|
||||
print("output:", out.shape, out.argmax())
|
||||
assert out.argmax() == 7 # cock
|
||||
|
||||
kernel_count = GlobalCounters.kernel_count
|
||||
assert kernel_count > 0, "No kernels, test failed"
|
||||
expected_kernels = 228
|
||||
expectation = f"ResNet18 kernels are {kernel_count} vs {expected_kernels} expected."
|
||||
if kernel_count < expected_kernels: warnings.warn(f"{expectation} Expectation can be lowered.", UserWarning)
|
||||
assert kernel_count <= expected_kernels, f"{expectation}"
|
||||
@@ -2,7 +2,7 @@
|
||||
import unittest
|
||||
import torch
|
||||
import numpy as np
|
||||
from tinygrad.helpers import getenv, Context, GlobalCounters
|
||||
from tinygrad.helpers import getenv, GlobalCounters
|
||||
if getenv("TINY_BACKEND2"):
|
||||
import extra.torch_backend.backend2
|
||||
device = "cpu"
|
||||
@@ -25,7 +25,7 @@ class TestTorchBackend(unittest.TestCase):
|
||||
a = torch.ones(4, device=device)
|
||||
np.testing.assert_equal(a.cpu().numpy(), [1,1,1,1])
|
||||
|
||||
def test_numpy_ones(self):
|
||||
def test_numpy_ones_int32(self):
|
||||
a = torch.ones(4, dtype=torch.int32, device=device)
|
||||
assert a.dtype == torch.int32
|
||||
np.testing.assert_equal(a.cpu().numpy(), [1,1,1,1])
|
||||
@@ -219,7 +219,6 @@ class TestTorchBackend(unittest.TestCase):
|
||||
a = torch.ones(4, device=device)
|
||||
print(str(a))
|
||||
|
||||
@unittest.skip("failed")
|
||||
def test_floor_div(self):
|
||||
a = torch.tensor([10., 7., 5.], device=device)
|
||||
b = torch.tensor([3., 2., 2.], device=device)
|
||||
@@ -248,5 +247,672 @@ class TestTorchBackend(unittest.TestCase):
|
||||
def test_diagonal_rectangular(self): self._test_diagonal(4, 5, 6)
|
||||
def test_diagonal_4d(self): self._test_diagonal(2, 3, 4, 5)
|
||||
|
||||
def test_pad_circular_simple(self):
|
||||
a = torch.arange(4, dtype=torch.float32, device=device).reshape(1,1,2,2)
|
||||
padded = torch.nn.functional.pad(a, (1,1,1,1), mode="circular")
|
||||
expected = np.array([[[[3.,2.,3.,2.], [1.,0.,1.,0.], [3.,2.,3.,2.], [1.,0.,1.,0.]]]], dtype=np.float32)
|
||||
np.testing.assert_allclose(padded.cpu().numpy(), expected)
|
||||
|
||||
def test_pad_circular_backward(self):
|
||||
a = torch.arange(4, dtype=torch.float32, device=device).reshape(1,1,2,2).requires_grad_(True)
|
||||
padded = torch.nn.functional.pad(a, (1,1,1,1), mode="circular")
|
||||
loss = padded.sum()
|
||||
loss.backward()
|
||||
expected_grad = np.array([[[[4., 4.], [4., 4.]]]], dtype=np.float32)
|
||||
np.testing.assert_allclose(a.grad.cpu().numpy(), expected_grad)
|
||||
|
||||
|
||||
def test_matmul_backward(self):
|
||||
x = torch.randn(3, 4, device=device, dtype=torch.float32, requires_grad=True)
|
||||
y = torch.randn(4, 5, device=device, dtype=torch.float32, requires_grad=True)
|
||||
z = (x @ y).sum()
|
||||
z.backward()
|
||||
assert x.grad is not None
|
||||
assert y.grad is not None
|
||||
assert x.grad.shape == x.shape
|
||||
assert y.grad.shape == y.shape
|
||||
|
||||
def test_matmul_broadcast_backward(self):
|
||||
x = torch.randn(2, 3, 4, device=device, dtype=torch.float32, requires_grad=True)
|
||||
y = torch.randn(4, 5, device=device, dtype=torch.float32, requires_grad=True)
|
||||
z = (x @ y).sum()
|
||||
z.backward()
|
||||
assert x.grad is not None
|
||||
assert y.grad is not None
|
||||
assert x.grad.shape == x.shape
|
||||
assert y.grad.shape == y.shape
|
||||
|
||||
def test_diag_vector_to_matrix(self):
|
||||
vec = torch.tensor([1., 2., 3., 4., 5.], dtype=torch.float32, device=device)
|
||||
mat = torch.diag(vec)
|
||||
expected = np.diag([1., 2., 3., 4., 5.])
|
||||
np.testing.assert_allclose(mat.cpu().numpy(), expected, rtol=1e-5)
|
||||
assert mat.shape == (5, 5)
|
||||
|
||||
def test_diagonal_matrix_to_vector(self):
|
||||
mat = torch.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], dtype=torch.float32, device=device)
|
||||
vec = torch.linalg.diagonal(mat)
|
||||
expected = np.array([1., 5., 9.])
|
||||
np.testing.assert_allclose(vec.cpu().numpy(), expected, rtol=1e-5)
|
||||
assert vec.shape == (3,)
|
||||
|
||||
def test_permute_2(self):
|
||||
a = torch.randn(2, 3, 4, dtype=torch.float32, device=device)
|
||||
b = a.permute(2, 0, 1)
|
||||
assert b.shape == (4, 2, 3)
|
||||
np.testing.assert_equal(b.cpu().numpy(), a.cpu().numpy().transpose(2, 0, 1))
|
||||
|
||||
def test_batchnorm_unsqueeze(self):
|
||||
bn = torch.nn.BatchNorm2d(4).to(device)
|
||||
x = torch.randn(8, 4, 3, 3, device=device)
|
||||
out = bn(x)
|
||||
self.assertEqual(out.shape, x.shape)
|
||||
|
||||
def test_slice_inplace_zero(self):
|
||||
a = torch.ones((3, 3), device=device)
|
||||
b = a[1:, 1:]
|
||||
b.zero_()
|
||||
expected = np.array([[1., 1., 1.],
|
||||
[1., 0., 0.],
|
||||
[1., 0., 0.]])
|
||||
np.testing.assert_equal(a.cpu().numpy(), expected)
|
||||
|
||||
def test_slice_inplace_fill(self):
|
||||
a = torch.ones((3, 3), device=device)
|
||||
b = a[1:, 1:]
|
||||
b.fill_(5.0)
|
||||
expected = np.array([[1., 1., 1.],
|
||||
[1., 5., 5.],
|
||||
[1., 5., 5.]])
|
||||
np.testing.assert_equal(a.cpu().numpy(), expected)
|
||||
|
||||
def test_fill_tensor_value(self):
|
||||
a = torch.zeros((2, 2), dtype=torch.float32, device=device)
|
||||
value = torch.tensor(3, dtype=torch.int64, device=device)
|
||||
a.fill_(value)
|
||||
expected = np.full((2, 2), 3, dtype=np.float32)
|
||||
np.testing.assert_equal(a.cpu().numpy(), expected)
|
||||
|
||||
def test_slice_inplace_mul(self):
|
||||
a = torch.ones((3, 3), device=device)
|
||||
b = a[1:, 1:]
|
||||
b *= 2
|
||||
expected = np.array([[1., 1., 1.],
|
||||
[1., 2., 2.],
|
||||
[1., 2., 2.]])
|
||||
np.testing.assert_equal(a.cpu().numpy(), expected)
|
||||
|
||||
def test_permute_slice_zero(self):
|
||||
a = torch.ones((3, 3), device=device)
|
||||
b = a[1:, 1:].permute(1, 0)
|
||||
b.zero_()
|
||||
expected = np.array([[1., 1., 1.],
|
||||
[1., 0., 0.],
|
||||
[1., 0., 0.]])
|
||||
np.testing.assert_equal(a.cpu().numpy(), expected)
|
||||
|
||||
def test_permute_slice_mul(self):
|
||||
a = torch.ones((3, 3), device=device)
|
||||
b = a[1:, 1:].permute(1, 0)
|
||||
b *= 2
|
||||
expected = np.array([[1., 1., 1.],
|
||||
[1., 2., 2.],
|
||||
[1., 2., 2.]])
|
||||
np.testing.assert_equal(a.cpu().numpy(), expected)
|
||||
|
||||
def test_simple_slice_setitem(self):
|
||||
a = torch.tensor([10, 20, 30], device=device)
|
||||
a[1] = 99
|
||||
np.testing.assert_equal(a.cpu().numpy(), [10, 99, 30])
|
||||
|
||||
def test_2d_slice_setitem(self):
|
||||
a = torch.zeros((3, 3), device=device)
|
||||
a[1, 2] = 99
|
||||
self.assertEqual(a[1, 2].item(), 99)
|
||||
self.assertEqual(a.sum().item(), 99)
|
||||
|
||||
def test_view_copy(self):
|
||||
a = torch.tensor([10, 20, 30], device=device)
|
||||
view = a[1]
|
||||
view.copy_(torch.tensor(88, device=device))
|
||||
np.testing.assert_equal(a.cpu().numpy(), [10, 88, 30])
|
||||
|
||||
def test_diag_2d_input(self):
|
||||
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
|
||||
d = torch.diag(a)
|
||||
np.testing.assert_equal(d.cpu().numpy(), [1, 5, 9])
|
||||
|
||||
def test_diag_1d_input(self):
|
||||
a = torch.tensor([1, 2, 3], device=device)
|
||||
d = torch.diag(a)
|
||||
expected = [[1, 0, 0], [0, 2, 0], [0, 0, 3]]
|
||||
np.testing.assert_equal(d.cpu().numpy(), expected)
|
||||
|
||||
def test_permute_view_tracking(self):
|
||||
a = torch.ones((2, 3, 4), device=device)
|
||||
b = a.permute(2, 0, 1)
|
||||
self.assertEqual(b.shape, (4, 2, 3))
|
||||
|
||||
def test_detach_view_creation(self):
|
||||
a = torch.tensor([1.0, 2.0, 3.0], device=device)
|
||||
b = a.detach()
|
||||
np.testing.assert_equal(b.cpu().numpy(), [1.0, 2.0, 3.0])
|
||||
|
||||
def test_view_zero_inplace(self):
|
||||
a = torch.ones((4, 4), device=device)
|
||||
view = a[1:3, 1:3]
|
||||
view.zero_()
|
||||
self.assertEqual(view.sum().item(), 0)
|
||||
|
||||
def test_view_fill_inplace(self):
|
||||
a = torch.zeros((4, 4), device=device)
|
||||
view = a[1:3, 1:3]
|
||||
view.fill_(5)
|
||||
self.assertEqual(view.sum().item(), 20)
|
||||
|
||||
def test_permute_contiguous(self):
|
||||
a = torch.tensor([[1, 2], [3, 4]], device=device)
|
||||
b = a.permute(1, 0)
|
||||
c = b.contiguous()
|
||||
expected = [[1, 3], [2, 4]]
|
||||
np.testing.assert_equal(c.cpu().numpy(), expected)
|
||||
|
||||
def test_diag_2d_extract_diagonal(self):
|
||||
a = torch.tensor([[1, 2], [3, 4]], device=device)
|
||||
result = torch.diag(a)
|
||||
np.testing.assert_equal(result.cpu().numpy(), [1, 4])
|
||||
|
||||
def test_slice_inplace_multiply_offset_preservation(self):
|
||||
a = torch.tensor([1, 2, 3], device=device)
|
||||
a[1:] *= 2
|
||||
np.testing.assert_equal(a.cpu().numpy(), [1, 4, 6])
|
||||
|
||||
def test_slice_inplace_mul_pattern(self):
|
||||
a = torch.tensor([1, 2, 3, 4], device=device)
|
||||
a[:2] *= 3
|
||||
a[2:] *= 2
|
||||
np.testing.assert_equal(a.cpu().numpy(), [3, 6, 6, 8])
|
||||
|
||||
def test_chained_slice_column(self):
|
||||
a = torch.arange(16, dtype=torch.float32, device=device).reshape(4, 4)
|
||||
torch_res = a[:, 1:2][:, 0:1].cpu().numpy()
|
||||
cpu_res = torch.arange(16, dtype=torch.float32).reshape(4, 4)[:, 1:2][:, 0:1].numpy()
|
||||
np.testing.assert_equal(torch_res, cpu_res)
|
||||
|
||||
def test_slice_with_step(self):
|
||||
a = torch.arange(20, dtype=torch.float32, device=device)
|
||||
torch_res = a[::2][1:4].cpu().numpy()
|
||||
cpu_res = torch.arange(20, dtype=torch.float32)[::2][1:4].numpy()
|
||||
np.testing.assert_equal(torch_res, cpu_res)
|
||||
|
||||
def test_slice_negative_dim(self):
|
||||
a = torch.arange(13, dtype=torch.int32, device=device).repeat(8, 1)
|
||||
torch_chunks = a.chunk(3, -1)
|
||||
cpu_chunks = torch.arange(13, dtype=torch.int32).repeat(8, 1).chunk(3, -1)
|
||||
assert len(torch_chunks) == len(cpu_chunks)
|
||||
for i in range(len(torch_chunks)):
|
||||
np.testing.assert_equal(torch_chunks[i].cpu().numpy(), cpu_chunks[i].numpy())
|
||||
|
||||
def test_dot_vector_matrix(self):
|
||||
a = torch.arange(65, dtype=torch.float32, device=device)
|
||||
b = torch.arange(65*45, dtype=torch.float32, device=device).reshape(65, 45)
|
||||
torch_res = a.matmul(b).reshape(-1).cpu().numpy()
|
||||
cpu_res = torch.arange(65, dtype=torch.float32).matmul(torch.arange(65*45, dtype=torch.float32).reshape(65, 45)).numpy()
|
||||
np.testing.assert_equal(torch_res, cpu_res)
|
||||
|
||||
def test_alias_passthrough(self):
|
||||
a = torch.randn(3, 3, device=device)
|
||||
alias_view = torch.ops.aten.alias(a)
|
||||
alias_view += 1
|
||||
np.testing.assert_equal(a.cpu().numpy(), alias_view.cpu().numpy())
|
||||
|
||||
def test_split_simple_vector(self):
|
||||
a = torch.arange(10, dtype=torch.float32, device=device)
|
||||
torch_chunks = a.split([1,4,5])
|
||||
cpu_chunks = torch.arange(10, dtype=torch.float32).split([1,4,5])
|
||||
for tc, cc in zip(torch_chunks, cpu_chunks):
|
||||
np.testing.assert_equal(tc.cpu().numpy(), cc.cpu().numpy())
|
||||
|
||||
def test_split_matches_torch(self):
|
||||
a = torch.arange(10, dtype=torch.float32, device=device)
|
||||
torch_chunks = a.split([1,4,5])
|
||||
tiny_chunks = [chunk.cpu().numpy() for chunk in torch_chunks]
|
||||
cpu_chunks = [torch.arange(10, dtype=torch.float32).split([1,4,5])[i].numpy() for i in range(3)]
|
||||
for tr, cr in zip(tiny_chunks, cpu_chunks): np.testing.assert_equal(tr, cr)
|
||||
|
||||
def test_sum_matches_torch(self):
|
||||
a = torch.arange(6, dtype=torch.float32, device=device).reshape(2,3)
|
||||
torch_res = a.sum().cpu().numpy()
|
||||
cpu_res = torch.arange(6, dtype=torch.float32).reshape(2,3).sum().numpy()
|
||||
np.testing.assert_equal(torch_res, cpu_res)
|
||||
|
||||
def test_view_matches_torch(self):
|
||||
a = torch.arange(6, dtype=torch.float32, device=device)
|
||||
torch_res = a.view(2, 3).cpu().numpy()
|
||||
cpu_res = torch.arange(6, dtype=torch.float32).view(2, 3).numpy()
|
||||
np.testing.assert_equal(torch_res, cpu_res)
|
||||
|
||||
def test_view_zero_with_indices(self):
|
||||
a = torch.tensor([1, 2, 3, 4], device=device)
|
||||
a[1:3].zero_()
|
||||
np.testing.assert_equal(a.cpu().numpy(), [1, 0, 0, 4])
|
||||
|
||||
def test_view_fill_with_indices(self):
|
||||
a = torch.tensor([1, 2, 3, 4], device=device)
|
||||
a[::2].fill_(9)
|
||||
np.testing.assert_equal(a.cpu().numpy(), [9, 2, 9, 4])
|
||||
|
||||
def test_nested_slice_inplace_ops(self):
|
||||
a = torch.tensor([1, 2, 3, 4, 5, 6], device=device)
|
||||
a[:3] += 10
|
||||
a[3:] *= 2
|
||||
np.testing.assert_equal(a.cpu().numpy(), [11, 12, 13, 8, 10, 12])
|
||||
|
||||
def test_diag_1d(self):
|
||||
a = torch.tensor([1, 2, 3], device=device)
|
||||
result = torch.diag(a)
|
||||
expected = [[1, 0, 0], [0, 2, 0], [0, 0, 3]]
|
||||
np.testing.assert_equal(result.cpu().numpy(), expected)
|
||||
|
||||
def test_diag_backward(self):
|
||||
a = torch.randn(5, dtype=torch.float32, device=device, requires_grad=True)
|
||||
b = torch.diag(a)
|
||||
b.sum().backward()
|
||||
assert a.grad is not None
|
||||
|
||||
def test_diagonal(self):
|
||||
a = torch.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], dtype=torch.float32, device=device, requires_grad=True)
|
||||
b = torch.diagonal(a)
|
||||
expected = torch.tensor([1., 5., 9.], dtype=torch.float32)
|
||||
self.assertEqual(b.shape, (3,))
|
||||
np.testing.assert_allclose(b.detach().cpu().numpy(), expected.numpy(), rtol=1e-5)
|
||||
|
||||
def test_diagonal_backward(self):
|
||||
a = torch.randn(5, 5, dtype=torch.float32, device=device, requires_grad=True)
|
||||
b = torch.diagonal(a)
|
||||
b.sum().backward()
|
||||
assert a.grad is not None
|
||||
|
||||
def test_expand_backward(self):
|
||||
a = torch.randn(4, 3, 1, 6, dtype=torch.float32, device=device, requires_grad=True)
|
||||
b = a.expand(4, 3, 2, 6)
|
||||
b.sum().backward()
|
||||
assert a.grad is not None
|
||||
|
||||
def test_einsum_backward(self):
|
||||
a = torch.randn(10, 10, dtype=torch.float32, device=device, requires_grad=True)
|
||||
b = torch.einsum('ij->ji', a)
|
||||
b.sum().backward()
|
||||
assert a.grad is not None
|
||||
|
||||
def test_diag_backward_gradient_values(self):
|
||||
a = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device=device, requires_grad=True)
|
||||
b = torch.diag(a)
|
||||
loss = b.sum()
|
||||
loss.backward()
|
||||
expected_grad = torch.ones(3, dtype=torch.float32)
|
||||
np.testing.assert_allclose(a.grad.cpu().numpy(), expected_grad.numpy(), rtol=1e-5)
|
||||
|
||||
def test_diag_backward_gradient_values_2d_to_1d(self):
|
||||
a = torch.tensor([[1.0, 2.0, 3.0],
|
||||
[4.0, 5.0, 6.0],
|
||||
[7.0, 8.0, 9.0]], dtype=torch.float32, device=device, requires_grad=True)
|
||||
b = torch.diagonal(a)
|
||||
loss = b.sum()
|
||||
loss.backward()
|
||||
expected_grad = torch.tensor([[1.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0]], dtype=torch.float32)
|
||||
np.testing.assert_allclose(a.grad.cpu().numpy(), expected_grad.numpy(), rtol=1e-5)
|
||||
|
||||
def test_expand_backward_gradient_values(self):
|
||||
a = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float32, device=device, requires_grad=True)
|
||||
b = a.expand(3, 4)
|
||||
loss = b.sum()
|
||||
loss.backward()
|
||||
expected_grad = torch.tensor([[4.0], [4.0], [4.0]], dtype=torch.float32)
|
||||
np.testing.assert_allclose(a.grad.cpu().numpy(), expected_grad.numpy(), rtol=1e-5)
|
||||
|
||||
def test_expand_backward_with_leading_dims(self):
|
||||
a = torch.tensor([[1.0, 2.0]], dtype=torch.float32, device=device, requires_grad=True)
|
||||
b = a.expand(3, 1, 2)
|
||||
loss = b.sum()
|
||||
loss.backward()
|
||||
expected_grad = torch.tensor([[3.0, 3.0]], dtype=torch.float32)
|
||||
np.testing.assert_allclose(a.grad.cpu().numpy(), expected_grad.numpy(), rtol=1e-5)
|
||||
|
||||
def test_diag_2d_to_1d_backward(self):
|
||||
a = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32, device=device, requires_grad=True)
|
||||
b = torch.diag(a)
|
||||
loss = b.sum()
|
||||
loss.backward()
|
||||
expected_grad = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)
|
||||
np.testing.assert_allclose(a.grad.cpu().numpy(), expected_grad.numpy(), rtol=1e-5)
|
||||
|
||||
def test_expand_complex_backward(self):
|
||||
a = torch.tensor([[[1.0, 2.0]]], dtype=torch.float32, device=device, requires_grad=True)
|
||||
b = a.expand(2, 3, 2)
|
||||
loss = b.sum()
|
||||
loss.backward()
|
||||
expected_grad = torch.tensor([[[6.0, 6.0]]], dtype=torch.float32)
|
||||
np.testing.assert_allclose(a.grad.cpu().numpy(), expected_grad.numpy(), rtol=1e-5)
|
||||
|
||||
def test_diag_backward_with_scaling(self):
|
||||
a = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device=device, requires_grad=True)
|
||||
b = torch.diag(a)
|
||||
loss = (b * torch.tensor([[2.0, 0.0, 0.0],
|
||||
[0.0, 3.0, 0.0],
|
||||
[0.0, 0.0, 4.0]], device=device)).sum()
|
||||
loss.backward()
|
||||
expected_grad = torch.tensor([2.0, 3.0, 4.0], dtype=torch.float32)
|
||||
np.testing.assert_allclose(a.grad.cpu().numpy(), expected_grad.numpy(), rtol=1e-5)
|
||||
|
||||
def test_repeat_basic(self):
|
||||
a = torch.tensor([1, 2, 3], dtype=torch.float32, device=device)
|
||||
b = a.repeat(2, 1)
|
||||
expected = torch.tensor([[1, 2, 3], [1, 2, 3]], dtype=torch.float32)
|
||||
np.testing.assert_equal(b.cpu().numpy(), expected.numpy())
|
||||
|
||||
def test_repeat_multidim(self):
|
||||
a = torch.arange(6, dtype=torch.float32, device=device).reshape(2, 3)
|
||||
b = a.repeat(2, 3)
|
||||
expected = torch.arange(6, dtype=torch.float32).reshape(2, 3).repeat(2, 3)
|
||||
np.testing.assert_equal(b.cpu().numpy(), expected.numpy())
|
||||
|
||||
def test_repeat_backward(self):
|
||||
a = torch.tensor([[1.0, 2.0]], dtype=torch.float32, device=device, requires_grad=True)
|
||||
b = a.repeat(3, 2)
|
||||
loss = b.sum()
|
||||
loss.backward()
|
||||
expected_grad = torch.tensor([[6.0, 6.0]], dtype=torch.float32)
|
||||
np.testing.assert_allclose(a.grad.cpu().numpy(), expected_grad.numpy(), rtol=1e-5)
|
||||
|
||||
def test_cumsum_1d(self):
|
||||
a = torch.tensor([1, 2, 3, 4], dtype=torch.float32, device=device)
|
||||
b = torch.cumsum(a, dim=0)
|
||||
expected = torch.tensor([1, 3, 6, 10], dtype=torch.float32)
|
||||
np.testing.assert_equal(b.cpu().numpy(), expected.numpy())
|
||||
|
||||
def test_cumsum_2d(self):
|
||||
a = torch.arange(12, dtype=torch.float32, device=device).reshape(3, 4)
|
||||
b = torch.cumsum(a, dim=0)
|
||||
expected = torch.arange(12, dtype=torch.float32).reshape(3, 4).cumsum(dim=0)
|
||||
np.testing.assert_equal(b.cpu().numpy(), expected.numpy())
|
||||
|
||||
c = torch.cumsum(a, dim=1)
|
||||
expected = torch.arange(12, dtype=torch.float32).reshape(3, 4).cumsum(dim=1)
|
||||
np.testing.assert_equal(c.cpu().numpy(), expected.numpy())
|
||||
|
||||
def test_cumsum_backward(self):
|
||||
a = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32, device=device, requires_grad=True)
|
||||
b = torch.cumsum(a, dim=0)
|
||||
loss = b.sum()
|
||||
loss.backward()
|
||||
expected_grad = torch.tensor([4.0, 3.0, 2.0, 1.0], dtype=torch.float32)
|
||||
np.testing.assert_allclose(a.grad.cpu().numpy(), expected_grad.numpy(), rtol=1e-5)
|
||||
|
||||
def test_constant_pad_nd_1d(self):
|
||||
a = torch.tensor([1, 2, 3], dtype=torch.float32, device=device)
|
||||
b = torch.nn.functional.pad(a, (1, 2), mode='constant', value=0)
|
||||
expected = torch.tensor([0, 1, 2, 3, 0, 0], dtype=torch.float32)
|
||||
np.testing.assert_equal(b.cpu().numpy(), expected.numpy())
|
||||
|
||||
def test_constant_pad_nd_2d(self):
|
||||
a = torch.arange(6, dtype=torch.float32, device=device).reshape(2, 3)
|
||||
b = torch.nn.functional.pad(a, (1, 1, 1, 1), mode='constant', value=0)
|
||||
expected = torch.nn.functional.pad(torch.arange(6, dtype=torch.float32).reshape(2, 3), (1, 1, 1, 1), mode='constant', value=0)
|
||||
np.testing.assert_equal(b.cpu().numpy(), expected.numpy())
|
||||
|
||||
def test_constant_pad_nd_2d_backward(self):
|
||||
a = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32, device=device, requires_grad=True)
|
||||
b = torch.nn.functional.pad(a, (1, 1, 1, 1), mode='constant', value=0)
|
||||
loss = b.sum()
|
||||
loss.backward()
|
||||
expected_grad = torch.ones((2, 2), dtype=torch.float32)
|
||||
np.testing.assert_allclose(a.grad.cpu().numpy(), expected_grad.numpy(), rtol=1e-5)
|
||||
|
||||
def test_negative_strides_cumsum_backward(self):
|
||||
a = torch.randn(5, device=device, requires_grad=True)
|
||||
b = torch.cumsum(a, dim=0)
|
||||
b.sum().backward()
|
||||
grad = a.grad.cpu().numpy()
|
||||
self.assertEqual(len(grad), 5)
|
||||
|
||||
def test_cumsum_fix_gradient_values(self):
|
||||
a = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32, device=device, requires_grad=True)
|
||||
b = torch.cumsum(a, dim=0)
|
||||
loss = b.sum()
|
||||
loss.backward()
|
||||
expected = np.array([4.0, 3.0, 2.0, 1.0])
|
||||
np.testing.assert_allclose(a.grad.cpu().numpy(), expected, rtol=1e-5)
|
||||
|
||||
def test_diag_1d_to_2d(self):
|
||||
a = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device=device, requires_grad=True)
|
||||
b = torch.diag(a)
|
||||
expected = [[1, 0, 0], [0, 2, 0], [0, 0, 3]]
|
||||
np.testing.assert_equal(b.detach().cpu().numpy(), expected)
|
||||
|
||||
def test_diag_2d_to_1d(self):
|
||||
c = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32, device=device)
|
||||
d = torch.diag(c)
|
||||
np.testing.assert_equal(d.cpu().numpy(), [1, 5, 9])
|
||||
|
||||
def test_biased_conv2d(self):
|
||||
# Test case for two sequential conv2d with same weights/bias and ReLU in between, this is as special case from test_ops.py
|
||||
torch.manual_seed(0)
|
||||
C = 8
|
||||
x_cpu = torch.randn(1, C, 5, 5, requires_grad=True)
|
||||
w_cpu = torch.randn(C, C, 1, 1, requires_grad=True)
|
||||
b_cpu = torch.randn(C, requires_grad=True)
|
||||
x_tiny = x_cpu.detach().to(device).requires_grad_(True)
|
||||
w_tiny = w_cpu.detach().to(device).requires_grad_(True)
|
||||
b_tiny = b_cpu.detach().to(device).requires_grad_(True)
|
||||
out_cpu = torch.nn.functional.conv2d(torch.nn.functional.conv2d(x_cpu, w_cpu, b_cpu).relu(), w_cpu, b_cpu)
|
||||
out_tiny = torch.nn.functional.conv2d(torch.nn.functional.conv2d(x_tiny, w_tiny, b_tiny).relu(), w_tiny, b_tiny)
|
||||
grad_out = torch.randn_like(out_cpu)
|
||||
out_cpu.backward(grad_out)
|
||||
out_tiny.backward(grad_out.to(device))
|
||||
np.testing.assert_allclose(x_tiny.grad.cpu().numpy(), x_cpu.grad.numpy(), atol=1e-4, rtol=1e-3)
|
||||
np.testing.assert_allclose(w_tiny.grad.cpu().numpy(), w_cpu.grad.numpy(), atol=1e-4, rtol=1e-3)
|
||||
np.testing.assert_allclose(b_tiny.grad.cpu().numpy(), b_cpu.grad.numpy(), atol=1e-4, rtol=1e-3)
|
||||
|
||||
|
||||
from tinygrad import Tensor
|
||||
class TestBackendHelpers(unittest.TestCase):
|
||||
|
||||
def test_calculate_storage_offset_no_shrink(self):
|
||||
t = Tensor.ones(3, 4)
|
||||
assert extra.torch_backend.backend.calculate_storage_offset(t) == 0
|
||||
|
||||
def test_calculate_storage_offset_with_shrink(self):
|
||||
t = Tensor.ones(10, 10)[2:5, 3:7]
|
||||
# strides for (10, 10) are [10, 1]
|
||||
# offset = 2*10 + 3*1 = 23
|
||||
assert extra.torch_backend.backend.calculate_storage_offset(t) == 23
|
||||
|
||||
def test_calculate_storage_offset_multiple_shrinks(self):
|
||||
t = Tensor.ones(5, 6, 7)[1:3, 2:4, 3:5]
|
||||
# strides for (5, 6, 7) are [42, 7, 1]
|
||||
# offset = 1*42 + 2*7 + 3*1 = 42 + 14 + 3 = 59
|
||||
assert extra.torch_backend.backend.calculate_storage_offset(t) == 59
|
||||
|
||||
def test_calculate_storage_offset_with_reshape(self):
|
||||
t = Tensor.ones(10, 10)
|
||||
orig_offset = extra.torch_backend.backend.calculate_storage_offset(t)
|
||||
assert orig_offset == 0
|
||||
t = t.reshape(100)
|
||||
assert extra.torch_backend.backend.calculate_storage_offset(t) == orig_offset
|
||||
|
||||
def test_slice_values_match_torch(self):
|
||||
torch_cpu = torch.arange(100, dtype=torch.float32).reshape(10, 10)
|
||||
torch_tiny = torch_cpu.to(device)
|
||||
sliced_cpu = torch_cpu[2:5, 3:7]
|
||||
sliced_tiny = torch_tiny[2:5, 3:7]
|
||||
np.testing.assert_equal(sliced_tiny.cpu().numpy(), sliced_cpu.numpy())
|
||||
|
||||
def test_slice_values_match_torch_3d(self):
|
||||
torch_cpu_3d = torch.arange(210, dtype=torch.float32).reshape(5, 6, 7)
|
||||
torch_tiny_3d = torch_cpu_3d.to(device)
|
||||
sliced_cpu_3d = torch_cpu_3d[1:3, 2:4, 3:5]
|
||||
sliced_tiny_3d = torch_tiny_3d[1:3, 2:4, 3:5]
|
||||
np.testing.assert_equal(sliced_tiny_3d.cpu().numpy(), sliced_cpu_3d.numpy())
|
||||
|
||||
def test_topk_out(self):
|
||||
a = torch.tensor([1, 3, 2, 4], device=device)
|
||||
values = torch.empty(2, device=device)
|
||||
indices = torch.empty(2, dtype=torch.int64, device=device)
|
||||
ret_values, ret_indices = torch.topk(a, k=2, out=(values, indices))
|
||||
np.testing.assert_equal(values.cpu().numpy(), [4, 3])
|
||||
np.testing.assert_equal(indices.cpu().numpy(), [3, 1])
|
||||
assert ret_values is values
|
||||
assert ret_indices is indices
|
||||
|
||||
def test_sort_out(self):
|
||||
a = torch.tensor([3, 1, 4, 2], device=device)
|
||||
values = torch.empty(4, device=device)
|
||||
indices = torch.empty(4, dtype=torch.int64, device=device)
|
||||
ret_values, ret_indices = torch.sort(a, out=(values, indices))
|
||||
np.testing.assert_equal(values.cpu().numpy(), [1, 2, 3, 4])
|
||||
np.testing.assert_equal(indices.cpu().numpy(), [1, 3, 0, 2])
|
||||
assert ret_values is values
|
||||
assert ret_indices is indices
|
||||
|
||||
def test_cat_out(self):
|
||||
a = torch.tensor([1, 2], device=device)
|
||||
b = torch.tensor([3, 4], device=device)
|
||||
out = torch.empty(4, device=device)
|
||||
ret = torch.cat([a, b], out=out)
|
||||
np.testing.assert_equal(out.cpu().numpy(), [1, 2, 3, 4])
|
||||
assert ret is out
|
||||
|
||||
def test_scatter_add_out(self):
|
||||
src = torch.tensor([[1, 2, 3], [4, 5, 6]], device=device, dtype=torch.float32)
|
||||
index = torch.tensor([[0, 1, 2], [0, 1, 2]], device=device)
|
||||
input = torch.zeros(3, 3, device=device, dtype=torch.float32)
|
||||
out = torch.zeros(3, 3, device=device, dtype=torch.float32)
|
||||
ret = torch.scatter_add(input, 0, index, src, out=out)
|
||||
expected = torch.tensor([[5, 0, 0], [0, 7, 0], [0, 0, 9]], dtype=torch.float32)
|
||||
np.testing.assert_allclose(out.cpu().numpy(), expected.cpu().numpy())
|
||||
assert ret is out
|
||||
|
||||
def test_floor_divide_inplace_identity(self):
|
||||
x = torch.tensor([10, 20, 30, 40], dtype=torch.int32, device=device)
|
||||
y = torch.tensor([2, 4, 5, 8], dtype=torch.int32, device=device)
|
||||
ret = x.floor_divide_(y)
|
||||
assert ret is x
|
||||
np.testing.assert_equal(x.cpu().numpy(), [5, 5, 6, 5])
|
||||
|
||||
def test_lshift_inplace_identity(self):
|
||||
x = torch.tensor([1, 2, 3, 4], dtype=torch.int32, device=device)
|
||||
ret = x.__ilshift__(2)
|
||||
assert ret is x
|
||||
np.testing.assert_equal(x.cpu().numpy(), [4, 8, 12, 16])
|
||||
|
||||
def test_rshift_inplace_identity(self):
|
||||
x = torch.tensor([16, 32, 48, 64], dtype=torch.int32, device=device)
|
||||
ret = x.__irshift__(2)
|
||||
assert ret is x
|
||||
np.testing.assert_equal(x.cpu().numpy(), [4, 8, 12, 16])
|
||||
|
||||
def test_relu_inplace_identity(self):
|
||||
x = torch.tensor([-1.0, 2.0, -3.0, 4.0], device=device)
|
||||
ret = x.relu_()
|
||||
assert ret is x
|
||||
np.testing.assert_equal(x.cpu().numpy(), [0.0, 2.0, 0.0, 4.0])
|
||||
|
||||
def test_random_inplace_identity(self):
|
||||
x = torch.zeros(10, dtype=torch.int32, device=device)
|
||||
ret = x.random_()
|
||||
assert ret is x
|
||||
assert x.shape == (10,)
|
||||
|
||||
def test_random_from_inplace_identity(self):
|
||||
x = torch.zeros(10, dtype=torch.int32, device=device)
|
||||
ret = x.random_(5, 10)
|
||||
assert ret is x
|
||||
# values should be in range [5, 10)
|
||||
assert torch.all(x >= 5).item() and torch.all(x < 10).item()
|
||||
|
||||
def test_uniform_inplace_identity(self):
|
||||
x = torch.zeros(10, device=device)
|
||||
ret = x.uniform_(0.0, 1.0)
|
||||
assert ret is x
|
||||
# values should be in range [0, 1)
|
||||
assert torch.all(x >= 0.0).item() and torch.all(x < 1.0).item()
|
||||
|
||||
def test_normal_inplace_identity(self):
|
||||
x = torch.zeros(100, device=device)
|
||||
ret = x.normal_(0.0, 1.0)
|
||||
assert ret is x
|
||||
# just check that values changed from zeros
|
||||
assert not torch.all(x == 0.0).item()
|
||||
|
||||
def test_logical_or_inplace_identity(self):
|
||||
x = torch.tensor([True, False, True, False], device=device)
|
||||
y = torch.tensor([False, False, True, True], device=device)
|
||||
ret = x.logical_or_(y)
|
||||
assert ret is x
|
||||
np.testing.assert_equal(x.cpu().numpy(), [True, False, True, True])
|
||||
|
||||
def test_masked_fill_scalar_inplace_identity(self):
|
||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0], device=device)
|
||||
mask = torch.tensor([True, False, True, False], device=device)
|
||||
ret = x.masked_fill_(mask, 0.0)
|
||||
assert ret is x
|
||||
np.testing.assert_equal(x.cpu().numpy(), [0.0, 2.0, 0.0, 4.0])
|
||||
|
||||
def test_masked_fill_tensor_inplace_identity(self):
|
||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0], device=device)
|
||||
mask = torch.tensor([True, False, True, False], device=device)
|
||||
value = torch.tensor(99.0, device=device)
|
||||
ret = x.masked_fill_(mask, value)
|
||||
assert ret is x
|
||||
np.testing.assert_equal(x.cpu().numpy(), [99.0, 2.0, 99.0, 4.0])
|
||||
|
||||
def test_zero_inplace_identity(self):
|
||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0], device=device)
|
||||
ret = x.zero_()
|
||||
assert ret is x
|
||||
np.testing.assert_equal(x.cpu().numpy(), [0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
def test_fill_scalar_inplace_identity(self):
|
||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0], device=device)
|
||||
ret = x.fill_(5.0)
|
||||
assert ret is x
|
||||
np.testing.assert_equal(x.cpu().numpy(), [5.0, 5.0, 5.0, 5.0])
|
||||
|
||||
def test_fill_tensor_inplace_identity(self):
|
||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0], device=device)
|
||||
value = torch.tensor(7.0, device=device)
|
||||
ret = x.fill_(value)
|
||||
assert ret is x
|
||||
np.testing.assert_equal(x.cpu().numpy(), [7.0, 7.0, 7.0, 7.0])
|
||||
|
||||
def test_add_tensor_inplace_identity(self):
|
||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0], device=device)
|
||||
y = torch.tensor([10.0, 20.0, 30.0, 40.0], device=device)
|
||||
ret = x.add_(y)
|
||||
assert ret is x
|
||||
np.testing.assert_equal(x.cpu().numpy(), [11.0, 22.0, 33.0, 44.0])
|
||||
|
||||
def test_add_scalar_inplace_identity(self):
|
||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0], device=device)
|
||||
ret = x.add_(10.0)
|
||||
assert ret is x
|
||||
np.testing.assert_equal(x.cpu().numpy(), [11.0, 12.0, 13.0, 14.0])
|
||||
|
||||
def test_mul_tensor_inplace_identity(self):
|
||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0], device=device)
|
||||
y = torch.tensor([2.0, 3.0, 4.0, 5.0], device=device)
|
||||
ret = x.mul_(y)
|
||||
assert ret is x
|
||||
np.testing.assert_equal(x.cpu().numpy(), [2.0, 6.0, 12.0, 20.0])
|
||||
|
||||
def test_mul_scalar_inplace_identity(self):
|
||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0], device=device)
|
||||
ret = x.mul_(2.0)
|
||||
assert ret is x
|
||||
np.testing.assert_equal(x.cpu().numpy(), [2.0, 4.0, 6.0, 8.0])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
144
extra/torch_backend/test_kernel_fusion.py
Normal file
144
extra/torch_backend/test_kernel_fusion.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# simple tests
|
||||
import unittest
|
||||
import torch
|
||||
import warnings
|
||||
from tinygrad.helpers import getenv, GlobalCounters
|
||||
if getenv("TINY_BACKEND2"):
|
||||
import extra.torch_backend.backend2
|
||||
device = "cpu"
|
||||
else:
|
||||
import extra.torch_backend.backend
|
||||
device = "tiny"
|
||||
|
||||
|
||||
class TestKernelFusionRegression(unittest.TestCase):
|
||||
def _realize(self, t): _ = t.detach().cpu().numpy()
|
||||
|
||||
def _check_kernel_count(self, fn, expected_kernels):
|
||||
torch.manual_seed(42)
|
||||
GlobalCounters.reset()
|
||||
fn().detach().cpu().numpy()
|
||||
expectation = f"{GlobalCounters.kernel_count} vs {expected_kernels} expected."
|
||||
if GlobalCounters.kernel_count < expected_kernels: warnings.warn(f"{expectation} Expectation can be lowered.", UserWarning)
|
||||
self.assertLessEqual(GlobalCounters.kernel_count, expected_kernels, f"{expectation}")
|
||||
|
||||
def test_elementwise_fusion(self):
|
||||
def fn():
|
||||
x = torch.randn(128, 128, device=device)
|
||||
return (x + 1.0) * 2.0 - 0.5
|
||||
self._check_kernel_count(fn, 6)
|
||||
|
||||
def test_relu_fusion(self):
|
||||
def fn():
|
||||
x = torch.randn(1, 3, 32, 32, device=device)
|
||||
conv = torch.nn.Conv2d(3, 16, 3, padding=1).to(device)
|
||||
with torch.no_grad():
|
||||
return torch.nn.functional.relu(conv(x))
|
||||
self._check_kernel_count(fn, 8)
|
||||
|
||||
def test_batchnorm_fusion(self):
|
||||
def fn():
|
||||
x = torch.randn(2, 3, 16, 16, device=device)
|
||||
conv = torch.nn.Conv2d(3, 8, 3, padding=1).to(device)
|
||||
bn = torch.nn.BatchNorm2d(8).to(device)
|
||||
bn.eval()
|
||||
with torch.no_grad():
|
||||
return torch.nn.functional.relu(bn(conv(x)))
|
||||
self._check_kernel_count(fn, 16)
|
||||
|
||||
def test_reduce_fusion(self):
|
||||
def fn():
|
||||
x = torch.randn(64, 64, device=device)
|
||||
return (x * 2.0).sum()
|
||||
self._check_kernel_count(fn, 7)
|
||||
|
||||
def test_matmul_elementwise_fusion(self):
|
||||
def fn():
|
||||
x = torch.randn(32, 32, device=device)
|
||||
w = torch.randn(32, 32, device=device)
|
||||
return torch.nn.functional.relu(x @ w + 1.0)
|
||||
self._check_kernel_count(fn, 6)
|
||||
|
||||
def test_pooling_fusion(self):
|
||||
def fn():
|
||||
x = torch.randn(1, 8, 16, 16, device=device)
|
||||
return torch.nn.functional.max_pool2d(x * 2.0, 2)
|
||||
self._check_kernel_count(fn, 5)
|
||||
|
||||
def test_residual_add_relu_fusion(self):
|
||||
def fn():
|
||||
x = torch.randn(1, 8, 16, 16, device=device)
|
||||
identity = torch.randn(1, 8, 16, 16, device=device)
|
||||
out = x + identity
|
||||
return torch.nn.functional.relu(out)
|
||||
self._check_kernel_count(fn, 6)
|
||||
|
||||
def test_inplace_add_relu_fusion(self):
|
||||
def fn():
|
||||
x = torch.randn(1, 16, 32, 32, device=device)
|
||||
y = torch.randn(1, 16, 32, 32, device=device)
|
||||
x += y
|
||||
return torch.nn.functional.relu(x)
|
||||
self._check_kernel_count(fn, 6)
|
||||
|
||||
def test_conv_bn_add_relu_fusion(self):
|
||||
def fn():
|
||||
x = torch.randn(1, 8, 16, 16, device=device)
|
||||
identity = torch.randn(1, 8, 16, 16, device=device)
|
||||
conv = torch.nn.Conv2d(8, 8, 3, padding=1, bias=False).to(device)
|
||||
bn = torch.nn.BatchNorm2d(8).to(device)
|
||||
bn.eval()
|
||||
with torch.no_grad():
|
||||
out = bn(conv(x))
|
||||
out += identity
|
||||
return torch.nn.functional.relu(out)
|
||||
self._check_kernel_count(fn, 16)
|
||||
|
||||
def test_multiple_inplace_ops_fusion(self):
|
||||
def fn():
|
||||
x = torch.randn(64, 64, device=device)
|
||||
x += 1.0
|
||||
x *= 2.0
|
||||
return torch.nn.functional.relu(x)
|
||||
self._check_kernel_count(fn, 4)
|
||||
|
||||
def test_view_inplace_no_fusion_break(self):
|
||||
def fn():
|
||||
x = torch.randn(4, 64, device=device)
|
||||
view = x[1:3]
|
||||
view += 1.0
|
||||
return x.sum()
|
||||
self._check_kernel_count(fn, 8)
|
||||
|
||||
def test_batchnorm_running_stats_update(self):
|
||||
def fn():
|
||||
x = torch.randn(2, 8, 8, 8, device=device)
|
||||
bn = torch.nn.BatchNorm2d(8).to(device)
|
||||
bn.train()
|
||||
with torch.no_grad():
|
||||
return bn(x)
|
||||
self._check_kernel_count(fn, 10)
|
||||
|
||||
# this is a minimal extra/other_mnist/beautiful_mnist_torch.py to cover fusion for training with optimizer
|
||||
def test_mnist_training_fusion(self):
|
||||
def fn():
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 8, 3, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.MaxPool2d(2),
|
||||
torch.nn.Flatten(),
|
||||
torch.nn.Linear(8*14*14, 10)
|
||||
).to(device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
|
||||
x = torch.randn(32, 1, 28, 28, device=device)
|
||||
labels = torch.randint(0, 10, (32,), device=device)
|
||||
out = model(x)
|
||||
loss = torch.nn.functional.cross_entropy(out, labels)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
return loss
|
||||
self._check_kernel_count(fn, 33)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -113,16 +113,9 @@ int register_hook() {
|
||||
int temp_register_hook = register_hook();
|
||||
|
||||
at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype, c10::DeviceIndex device_index) {
|
||||
// TODO: we have to get the dtype and the shape from the tinygrad Tensor
|
||||
std::vector<int64_t> sizes = py_obj.attr("shape").cast<std::vector<int64_t>>();
|
||||
|
||||
py::list views = py_obj.attr("uop").attr("st").attr("views");
|
||||
std::vector<int64_t> strides = views[views.size() - 1].attr("strides").cast<std::vector<int64_t>>();
|
||||
int64_t storage_offset = 0;
|
||||
for (auto& v: views) {
|
||||
storage_offset += v.attr("offset").cast<int64_t>(); // TODO: is this correct?
|
||||
}
|
||||
|
||||
std::vector<int64_t> strides = py_obj.attr("_strides").cast<std::vector<int64_t>>();
|
||||
int64_t storage_offset = py_obj.attr("_storage_offset").cast<int64_t>();
|
||||
return at::detail::make_tensor<at::TinyOpaqueTensorImpl<std::shared_ptr<c10::SafePyObject>>>(
|
||||
at::DispatchKeySet(at::DispatchKey::PrivateUse1),
|
||||
c10::scalarTypeToTypeMeta(dtype),
|
||||
|
||||
Reference in New Issue
Block a user