sum of bool by default uses default_float for acc. So without float, it might overflow with a large BS and default_float=HALF. fixed clsf_accuracy to not be inf in mi300x bert