inf, -inf support for pad (#1436)

This commit is contained in:
Umut Zengin
2023-08-04 22:05:25 +03:00
committed by GitHub
parent 7325bc914f
commit 52db7d7435
2 changed files with 4 additions and 1 deletions

View File

@@ -7,7 +7,7 @@ import operator
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
from math import ceil, pi, prod, sqrt, log, cos, copysign
from math import ceil, pi, prod, sqrt, log, cos, copysign, isinf
from tinygrad.lazy import Device, LazyBuffer
from tinygrad.ops import LoadOps
@@ -246,6 +246,7 @@ class Tensor:
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
def pad(self, arg: Tuple[Tuple[int, int], ...], value:float=0) -> Tensor:
ret = mlops.Pad.apply(self, arg=arg) if any(x != (0, 0) for x in arg) else self
if isinf(value): return ret + copysign(1,value)/mlops.Pad.apply(Tensor.full(self.shape, value), arg=arg)
return ret if 0 == value else ret + (value - mlops.Pad.apply(Tensor.full(self.shape, value), arg=arg))
# ***** movement hlops *****