mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 02:21:40 -05:00
inf, -inf support for pad (#1436)
This commit is contained in:
@@ -7,7 +7,7 @@ import operator
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast
|
||||
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
|
||||
from math import ceil, pi, prod, sqrt, log, cos, copysign
|
||||
from math import ceil, pi, prod, sqrt, log, cos, copysign, isinf
|
||||
from tinygrad.lazy import Device, LazyBuffer
|
||||
from tinygrad.ops import LoadOps
|
||||
|
||||
@@ -246,6 +246,7 @@ class Tensor:
|
||||
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
|
||||
def pad(self, arg: Tuple[Tuple[int, int], ...], value:float=0) -> Tensor:
|
||||
ret = mlops.Pad.apply(self, arg=arg) if any(x != (0, 0) for x in arg) else self
|
||||
if isinf(value): return ret + copysign(1,value)/mlops.Pad.apply(Tensor.full(self.shape, value), arg=arg)
|
||||
return ret if 0 == value else ret + (value - mlops.Pad.apply(Tensor.full(self.shape, value), arg=arg))
|
||||
|
||||
# ***** movement hlops *****
|
||||
|
||||
Reference in New Issue
Block a user