mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-31 01:38:20 -05:00
style: make __init__ first in Tensor class
This commit is contained in:
@@ -83,21 +83,14 @@ class Tensor:
|
||||
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
||||
__deletable__ = ('_ctx',)
|
||||
training: ClassVar[bool] = False
|
||||
class train(ContextDecorator):
|
||||
def __init__(self, mode:bool = True): self.mode = mode
|
||||
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
|
||||
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
|
||||
|
||||
no_grad: ClassVar[bool] = False
|
||||
class inference_mode(ContextDecorator):
|
||||
def __init__(self, mode:bool = True): self.mode = mode
|
||||
def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
|
||||
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
|
||||
|
||||
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable],
|
||||
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
||||
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
||||
# tensors have gradients, buffers do not
|
||||
|
||||
# tensors can have gradients if you have called .backward
|
||||
self.grad: Optional[Tensor] = None
|
||||
|
||||
# NOTE: this can be in three states. False and None: no gradient, True: gradient
|
||||
@@ -138,6 +131,16 @@ class Tensor:
|
||||
else:
|
||||
self.lazydata = data if data.device == device else data.copy_to_device(device)
|
||||
|
||||
class train(ContextDecorator):
|
||||
def __init__(self, mode:bool = True): self.mode = mode
|
||||
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
|
||||
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
|
||||
|
||||
class inference_mode(ContextDecorator):
|
||||
def __init__(self, mode:bool = True): self.mode = mode
|
||||
def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
|
||||
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user