mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
remove match_type in ops_torch and ops_cpu (#2817)
* remove match_type in ops_torch and ops_cpu input dtypes are aligned and casted in mlops * dict union only after python3.9 * fix that * fix Sigmoid forward cast
This commit is contained in:
@@ -78,8 +78,8 @@ class Sqrt(Function):
|
||||
# TODO: have the backend automatically find this
|
||||
class Sigmoid(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.cast(ftype:=least_upper_float(x.dtype)).const(1).e(
|
||||
BinaryOps.DIV, x.cast(ftype).const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.cast(ftype).const(-1/math.log(2))).e(UnaryOps.EXP2)))
|
||||
x = x.cast(least_upper_float(x.dtype))
|
||||
self.ret = x.const(1).e(BinaryOps.DIV, x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)))
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
@@ -89,11 +89,13 @@ class Sigmoid(Function):
|
||||
|
||||
class Less(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.e(BinaryOps.CMPLT, y)
|
||||
output_dtype = least_upper_dtype(x.dtype, y.dtype)
|
||||
return x.cast(output_dtype).e(BinaryOps.CMPLT, y.cast(output_dtype))
|
||||
|
||||
class Xor(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.e(BinaryOps.XOR, y)
|
||||
output_dtype = least_upper_dtype(x.dtype, y.dtype)
|
||||
return x.cast(output_dtype).e(BinaryOps.XOR, y.cast(output_dtype))
|
||||
|
||||
class Add(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
|
||||
@@ -10,9 +10,6 @@ def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple
|
||||
|
||||
# TODO: this should be global infrastructure
|
||||
def output_type(x, y): return x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype
|
||||
def match_types(x, y):
|
||||
up = output_type(x, y)
|
||||
return x.astype(up, copy=False), y.astype(up, copy=False)
|
||||
|
||||
def einsum_mulacc(einsum, get_strides, expand):
|
||||
def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
|
||||
@@ -24,20 +21,23 @@ def einsum_mulacc(einsum, get_strides, expand):
|
||||
return expand(ret.reshape([(1 if i not in a_axes and i not in b_axes else s) for i,s in enumerate(new_shape)]), new_shape)
|
||||
return mulacc
|
||||
|
||||
def as_strided(x, arg):
|
||||
return np.ndarray(shape=arg[0], dtype=x.dtype, buffer=np.require(x, requirements='C'),
|
||||
offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1]))
|
||||
|
||||
numpy_fxn_for_op: Dict[Op, Callable] = {
|
||||
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
|
||||
UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
|
||||
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
|
||||
UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(output_type(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
|
||||
BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
|
||||
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)).astype(output_type(x, y), copy=False),
|
||||
BinaryOps.XOR: lambda x, y: np.bitwise_xor(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(output_type(x,y)),
|
||||
BinaryOps.ADD: np.add, BinaryOps.SUB: np.subtract, BinaryOps.MUL: np.multiply,
|
||||
BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(output_type(x, y), copy=False),
|
||||
BinaryOps.XOR: np.bitwise_xor, UnaryOps.SQRT: np.sqrt,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
|
||||
ReduceOps.MAX: lambda x, new_shape: x.max(shape_to_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
|
||||
MovementOps.AS_STRIDED: lambda x, arg: np.ndarray(arg[0], buffer=np.require(x, requirements='C'), dtype=x.dtype, offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1])), # noqa: E501
|
||||
MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, *match_types(a.copy(), b.copy()), optimize=True), lambda x: x.strides, np.broadcast_to),
|
||||
MovementOps.AS_STRIDED: as_strided, MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy(), optimize=True), lambda x: x.strides, np.broadcast_to),
|
||||
TernaryOps.WHERE: np.where,
|
||||
}
|
||||
|
||||
|
||||
@@ -7,17 +7,14 @@ from tinygrad.helpers import getenv, dtypes
|
||||
from tinygrad.runtime.ops_cpu import einsum_mulacc, shape_to_axis
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32,
|
||||
torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64,
|
||||
torch.uint8: dtypes.uint8, torch.bool: dtypes.bool, torch.int16: dtypes.int16, torch.bfloat16: dtypes.bfloat16}
|
||||
inverse_type_map = dict([(v, k) for k,v in type_map.items()] + [(dtypes.ushort, torch.int16), (dtypes.uint, torch.int32)])
|
||||
type_map = {torch.bool: dtypes.bool, torch.int8: dtypes.int8, torch.uint8: dtypes.uint8, torch.int16: dtypes.int16, torch.int32: dtypes.int32,
|
||||
torch.int64: dtypes.int64, torch.float16: dtypes.float16, torch.bfloat16: dtypes.bfloat16, torch.float32: dtypes.float32,
|
||||
torch.float64: dtypes.float64}
|
||||
inverse_type_map = {v: k for k,v in type_map.items()}
|
||||
inverse_type_map.update({dtypes.uint16: torch.int16, dtypes.uint32: torch.int32, dtypes.uint64: torch.int64})
|
||||
def np_type_cvt(t): return {np.uint32: np.int32}.get(t, t)
|
||||
|
||||
def output_type(x, y): return x.dtype if type_map[x.dtype].priority > type_map[y.dtype].priority else y.dtype
|
||||
def match_types(x, y, disallow_bool=False):
|
||||
up = output_type(x, y)
|
||||
if disallow_bool and up == torch.bool: up = torch.float
|
||||
return x.type(up), y.type(up)
|
||||
|
||||
def as_strided(x, arg):
|
||||
if any(i < 0 for i in arg[1]):
|
||||
@@ -33,11 +30,11 @@ torch_fxn_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(inverse_type_map[y[0]]),
|
||||
UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),
|
||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).type(torch.promote_types(x.dtype, y.dtype)),
|
||||
BinaryOps.ADD: lambda x,y: torch.add(*match_types(x, y)).type(output_type(x,y)),
|
||||
BinaryOps.SUB: lambda x,y: torch.sub(*match_types(x, y, disallow_bool=True)).type(output_type(x,y)),
|
||||
BinaryOps.MUL: lambda x,y: torch.mul(*match_types(x, y)).type(output_type(x,y)),
|
||||
BinaryOps.DIV: lambda x,y: torch.div(*match_types(x, y)).type(torch.promote_types(x.dtype, y.dtype)),
|
||||
BinaryOps.XOR: lambda x,y: torch.bitwise_xor(*match_types(x, y)),
|
||||
BinaryOps.ADD: lambda x,y: torch.add(x, y).type(output_type(x,y)),
|
||||
BinaryOps.SUB: lambda x,y: torch.sub(x, y).type(output_type(x,y)),
|
||||
BinaryOps.MUL: lambda x,y: torch.mul(x, y).type(output_type(x,y)),
|
||||
BinaryOps.DIV: lambda x,y: torch.div(x, y).type(torch.promote_types(x.dtype, y.dtype)),
|
||||
BinaryOps.XOR: torch.bitwise_xor,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
|
||||
ReduceOps.MAX: lambda x, new_shape: x.amax(shape_to_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
|
||||
MovementOps.AS_STRIDED: as_strided, MovementOps.EXPAND: lambda x, arg: x.expand(arg),
|
||||
|
||||
Reference in New Issue
Block a user