style: make __init__ first in Tensor class

This commit is contained in:
George Hotz
2024-06-05 12:51:41 +02:00
parent 273945df67
commit 3954f102aa

View File

@@ -83,21 +83,14 @@ class Tensor:
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
__deletable__ = ('_ctx',)
training: ClassVar[bool] = False
class train(ContextDecorator):
def __init__(self, mode:bool = True): self.mode = mode
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
no_grad: ClassVar[bool] = False
class inference_mode(ContextDecorator):
def __init__(self, mode:bool = True): self.mode = mode
def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable],
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
# tensors have gradients, buffers do not
# tensors can have gradients if you have called .backward
self.grad: Optional[Tensor] = None
# NOTE: this can be in three states. False and None: no gradient, True: gradient
@@ -138,6 +131,16 @@ class Tensor:
else:
self.lazydata = data if data.device == device else data.copy_to_device(device)
class train(ContextDecorator):
def __init__(self, mode:bool = True): self.mode = mode
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
class inference_mode(ContextDecorator):
def __init__(self, mode:bool = True): self.mode = mode
def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
def __repr__(self):
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"