least controversial (#7863)

This commit is contained in:
JaSpa99
2024-11-23 14:23:30 +01:00
committed by GitHub
parent 8c3d3181dd
commit 28e83e662e
3 changed files with 11 additions and 4 deletions

View File

@@ -107,8 +107,12 @@ class MultiLazyBuffer(MathTrait):
assert any(new_real), "output contains no real lb"
for mlb in msrcs:
if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(mlb.lbs)
elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis, bounds))
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
elif mlb.axis is None and axis is not None:
assert bounds is not None
srcs.append(to_sharded(mlb.lbs, axis, bounds))
else:
assert axis is not None and bounds is not None
srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
new_real_lbs:Dict[int,LazyBuffer] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r}
# NOTE: const dtype should match real
new_dtype = next(iter(new_real_lbs.values())).dtype

View File

@@ -259,7 +259,8 @@ class LayerNorm:
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps=1e-5, elementwise_affine=True):
self.normalized_shape: Tuple[int, ...] = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None)
self.weight: Optional[Tensor] = Tensor.ones(*self.normalized_shape) if elementwise_affine else None
self.bias: Optional[Tensor] = Tensor.zeros(*self.normalized_shape) if elementwise_affine else None
def __call__(self, x:Tensor) -> Tensor:
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
@@ -341,7 +342,8 @@ class LSTMCell:
stdv = 1.0 / math.sqrt(hidden_size)
self.weight_ih = Tensor.uniform(hidden_size*4, input_size, low=-stdv, high=stdv)
self.weight_hh = Tensor.uniform(hidden_size*4, hidden_size, low=-stdv, high=stdv)
self.bias_ih, self.bias_hh = (Tensor.zeros(hidden_size*4), Tensor.zeros(hidden_size*4)) if bias else (None, None)
self.bias_ih: Optional[Tensor] = Tensor.zeros(hidden_size*4) if bias else None
self.bias_hh: Optional[Tensor] = Tensor.zeros(hidden_size*4) if bias else None
def __call__(self, x:Tensor, hc:Optional[Tuple[Tensor, Tensor]]=None) -> Tuple[Tensor, Tensor]:
if hc is None: hc = (Tensor.zeros(x.size(0), self.weight_hh.size(1), dtype=x.dtype, device=x.device),)*2

View File

@@ -16,6 +16,7 @@ def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
"""
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
json_len = t[0:8].bitcast(dtypes.int64).item()
assert isinstance(json_len, int)
return t, json_len, json.loads(t[8:8+json_len].data().tobytes())
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]: