diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 32486368c4..7d16f26a75 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -83,21 +83,14 @@ class Tensor: __slots__ = "lazydata", "requires_grad", "grad", "_ctx" __deletable__ = ('_ctx',) training: ClassVar[bool] = False - class train(ContextDecorator): - def __init__(self, mode:bool = True): self.mode = mode - def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode - def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev - no_grad: ClassVar[bool] = False - class inference_mode(ContextDecorator): - def __init__(self, mode:bool = True): self.mode = mode - def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode - def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev + def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable], device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}" device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device) - # tensors have gradients, buffers do not + + # tensors can have gradients if you have called .backward self.grad: Optional[Tensor] = None # NOTE: this can be in three states. False and None: no gradient, True: gradient @@ -138,6 +131,16 @@ class Tensor: else: self.lazydata = data if data.device == device else data.copy_to_device(device) + class train(ContextDecorator): + def __init__(self, mode:bool = True): self.mode = mode + def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode + def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev + + class inference_mode(ContextDecorator): + def __init__(self, mode:bool = True): self.mode = mode + def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode + def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev + def __repr__(self): return f""