mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
env var to change default float (#3902)
* env var to change default float to fp16 or bf16
looking for standard names for these. we have FLOAT16 that does something to IMAGE and HALF to convert weights.
working on default bf16 too.
```
RuntimeError: compile failed: <null>(6): error: identifier "__bf16" is undefined
__bf16 cast0 = (nv_bfloat16)(val0);
```
remove that in cifar
* DEFAULT_FLOAT
* default of default
* unit test
* don't check default
* tests work on linux
This commit is contained in:
@@ -20,13 +20,6 @@ GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))]
|
||||
assert BS % len(GPUS) == 0, f"{BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow"
|
||||
assert EVAL_BS % len(GPUS) == 0, f"{EVAL_BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow"
|
||||
|
||||
if getenv("HALF"):
|
||||
dtypes.default_float = dtypes.float16
|
||||
elif getenv("BF16"):
|
||||
dtypes.default_float = dtypes.bfloat16
|
||||
else:
|
||||
dtypes.default_float = dtypes.float32
|
||||
|
||||
class UnsyncedBatchNorm:
|
||||
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1, num_devices=len(GPUS)):
|
||||
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
|
||||
|
||||
Reference in New Issue
Block a user