mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
Test min (#932)
* fix __neg__ defaulting to float32 due to 0.0 * fixed __neg__ always defaulting to float32 * fixed openpilot (OpenCL) Test
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import math, functools, itertools, operator
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
|
||||
from tinygrad.helpers import prod, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
|
||||
from tinygrad.helpers import prod, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, ImageDType
|
||||
from tinygrad.lazy import Device, LazyBuffer
|
||||
from tinygrad.ops import LoadOps
|
||||
|
||||
@@ -503,9 +503,9 @@ class Tensor:
|
||||
def softsign(self): return self / (1 + self.abs())
|
||||
|
||||
# ***** broadcasted binary mlops *****
|
||||
|
||||
def _broadcasted(self, fxn:Type[Function], other:Union[Tensor, float], reverse:bool=False) -> Tensor:
|
||||
x,y = [Tensor(t, device=self.device, requires_grad=False) if not isinstance(t, Tensor) else t for t in ([other,self] if reverse else [self,other])]
|
||||
dtype = self.dtype if self.dtype != dtypes.bool and not isinstance(self.dtype,ImageDType) else dtypes.float32
|
||||
x,y = [Tensor(t, device=self.device, requires_grad=False, dtype=dtype) if not isinstance(t, Tensor) else t for t in ([other,self] if reverse else [self,other])]
|
||||
x,y = [t.reshape([1]*(max(len(x.shape), len(y.shape))-len(t.shape)) + list(t.shape)) for t in [x,y]]
|
||||
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x.shape, y.shape))
|
||||
return fxn.apply(x.expand(shape_ret), y.expand(shape_ret))
|
||||
|
||||
Reference in New Issue
Block a user