mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
least controversial (#7863)
This commit is contained in:
@@ -107,8 +107,12 @@ class MultiLazyBuffer(MathTrait):
|
||||
assert any(new_real), "output contains no real lb"
|
||||
for mlb in msrcs:
|
||||
if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(mlb.lbs)
|
||||
elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis, bounds))
|
||||
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
|
||||
elif mlb.axis is None and axis is not None:
|
||||
assert bounds is not None
|
||||
srcs.append(to_sharded(mlb.lbs, axis, bounds))
|
||||
else:
|
||||
assert axis is not None and bounds is not None
|
||||
srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
|
||||
new_real_lbs:Dict[int,LazyBuffer] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r}
|
||||
# NOTE: const dtype should match real
|
||||
new_dtype = next(iter(new_real_lbs.values())).dtype
|
||||
|
||||
@@ -259,7 +259,8 @@ class LayerNorm:
|
||||
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps=1e-5, elementwise_affine=True):
|
||||
self.normalized_shape: Tuple[int, ...] = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
|
||||
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
|
||||
self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None)
|
||||
self.weight: Optional[Tensor] = Tensor.ones(*self.normalized_shape) if elementwise_affine else None
|
||||
self.bias: Optional[Tensor] = Tensor.zeros(*self.normalized_shape) if elementwise_affine else None
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
|
||||
@@ -341,7 +342,8 @@ class LSTMCell:
|
||||
stdv = 1.0 / math.sqrt(hidden_size)
|
||||
self.weight_ih = Tensor.uniform(hidden_size*4, input_size, low=-stdv, high=stdv)
|
||||
self.weight_hh = Tensor.uniform(hidden_size*4, hidden_size, low=-stdv, high=stdv)
|
||||
self.bias_ih, self.bias_hh = (Tensor.zeros(hidden_size*4), Tensor.zeros(hidden_size*4)) if bias else (None, None)
|
||||
self.bias_ih: Optional[Tensor] = Tensor.zeros(hidden_size*4) if bias else None
|
||||
self.bias_hh: Optional[Tensor] = Tensor.zeros(hidden_size*4) if bias else None
|
||||
|
||||
def __call__(self, x:Tensor, hc:Optional[Tuple[Tensor, Tensor]]=None) -> Tuple[Tensor, Tensor]:
|
||||
if hc is None: hc = (Tensor.zeros(x.size(0), self.weight_hh.size(1), dtype=x.dtype, device=x.device),)*2
|
||||
|
||||
@@ -16,6 +16,7 @@ def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
|
||||
"""
|
||||
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
json_len = t[0:8].bitcast(dtypes.int64).item()
|
||||
assert isinstance(json_len, int)
|
||||
return t, json_len, json.loads(t[8:8+json_len].data().tobytes())
|
||||
|
||||
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user