better alphas

This commit is contained in:
George Hotz
2022-09-05 16:48:26 -07:00
parent 0fda854b3e
commit 3728ef6d02

View File

@@ -515,6 +515,7 @@ class CLIPTextTransformer:
class StableDiffusion:
def __init__(self):
self.alphas_cumprod = Tensor.empty(1000)
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel())
self.first_stage_model = AutoencoderKL()
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer()))
@@ -591,14 +592,20 @@ if __name__ == "__main__":
#alphas = [0.9983, 0.6722, 0.2750, 0.0557]
#alphas_prev = [0.9991499781608582, 0.9982960224151611, 0.6721514463424683, 0.27499905228614807]
alphas = [0.9983, 0.8930, 0.7521, 0.5888, 0.4229, 0.2750, 0.1598, 0.0819, 0.0365, 0.0140]
alphas_prev = [1.0, 0.9983, 0.8930, 0.7521, 0.5888, 0.4229, 0.2750, 0.1598, 0.0819, 0.0365]
#alphas = [0.9983, 0.8930, 0.7521, 0.5888, 0.4229, 0.2750, 0.1598, 0.0819, 0.0365, 0.0140]
#alphas_prev = [1.0, 0.9983, 0.8930, 0.7521, 0.5888, 0.4229, 0.2750, 0.1598, 0.0819, 0.0365]
#timesteps = [1, 101, 201, 301, 401, 501, 601, 701, 801, 901]
timesteps = list(np.arange(1, 1000, 1000//20))
print(timesteps)
alphas = [model.alphas_cumprod.numpy()[t] for t in timesteps]
alphas_prev = [1.0] + alphas[:-1]
def get_x_prev_and_pred_x0(x, e_t, index):
temperature = 1
a_t, a_prev = alphas[index], alphas_prev[index]
sigma_t = 0
sqrt_one_minus_at = math.sqrt(1-a_t)
print(a_t, a_prev, sigma_t, sqrt_one_minus_at)
pred_x0 = (x - sqrt_one_minus_at * e_t) / math.sqrt(a_t)
@@ -614,7 +621,7 @@ if __name__ == "__main__":
# is this the diffusion?
#for index, timestep in tqdm(list(enumerate([1, 251, 501, 751]))[::-1]):
for index, timestep in tqdm(list(enumerate([1, 101, 201, 301, 401, 501, 601, 701, 801, 901]))[::-1]):
for index, timestep in tqdm(list(enumerate(timesteps))[::-1]):
print(index, timestep)
e_t = get_model_output(latent, timestep)
#print(e_t.numpy())