mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
num_batches_tracked does not need is_dtype_supported (#15018)
This commit is contained in:
committed by
GitHub
parent
3244131f59
commit
e5c0db66d1
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import math
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.device import is_dtype_supported as is_dtype_supported
|
||||
from tinygrad.helpers import prod, make_tuple, flatten, USE_ATOMICS
|
||||
from tinygrad.nn import optim, state, datasets # noqa: F401
|
||||
|
||||
@@ -36,7 +36,7 @@ class BatchNorm:
|
||||
self.weight: Tensor|None = Tensor.ones(sz) if affine else None
|
||||
self.bias: Tensor|None = Tensor.zeros(sz) if affine else None
|
||||
|
||||
self.num_batches_tracked = Tensor.zeros(dtype='long' if is_dtype_supported(dtypes.long) else 'int', requires_grad=False)
|
||||
self.num_batches_tracked = Tensor.zeros(dtype='long', requires_grad=False)
|
||||
if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
|
||||
|
||||
def calc_stats(self, x:Tensor) -> tuple[Tensor, Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user