fix wino conv output dtype for half inputs (#7829)

This commit is contained in:
chenyu
2024-11-21 12:13:54 -05:00
committed by GitHub
parent cf1ec90ad4
commit 69e382216d
4 changed files with 13 additions and 104 deletions

View File

@@ -290,5 +290,5 @@ if __name__ == "__main__":
if args.prompt == default_prompt and args.steps == 6 and args.seed == 0 and args.guidance == 7.5:
ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png")))
distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item()
assert distance < 70e-5, colored(f"validation failed with {distance=}", "red")
assert distance < 3e-3, colored(f"validation failed with {distance=}", "red") # higher distance with WINO
print(colored(f"output validated with {distance=}", "green"))