env var to change default float (#3902)

* env var to change default float to fp16 or bf16

looking for standard names for these. we have FLOAT16 that does something to IMAGE and HALF to convert weights.

working on default bf16 too.
```
RuntimeError: compile failed: <null>(6): error: identifier "__bf16" is undefined
    __bf16 cast0 = (nv_bfloat16)(val0);
```

remove that in cifar

* DEFAULT_FLOAT

* default of default

* unit test

* don't check default

* tests work on linux
This commit is contained in:
chenyu
2024-03-24 20:33:57 -04:00
committed by GitHub
parent 03899a74bb
commit 83f39a8ceb
5 changed files with 30 additions and 15 deletions

View File

@@ -20,13 +20,6 @@ GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))]
assert BS % len(GPUS) == 0, f"{BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow"
assert EVAL_BS % len(GPUS) == 0, f"{EVAL_BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow"
if getenv("HALF"):
dtypes.default_float = dtypes.float16
elif getenv("BF16"):
dtypes.default_float = dtypes.bfloat16
else:
dtypes.default_float = dtypes.float32
class UnsyncedBatchNorm:
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1, num_devices=len(GPUS)):
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum