mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
better alphas
This commit is contained in:
@@ -515,6 +515,7 @@ class CLIPTextTransformer:
|
||||
|
||||
class StableDiffusion:
|
||||
def __init__(self):
|
||||
self.alphas_cumprod = Tensor.empty(1000)
|
||||
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel())
|
||||
self.first_stage_model = AutoencoderKL()
|
||||
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer()))
|
||||
@@ -591,14 +592,20 @@ if __name__ == "__main__":
|
||||
#alphas = [0.9983, 0.6722, 0.2750, 0.0557]
|
||||
#alphas_prev = [0.9991499781608582, 0.9982960224151611, 0.6721514463424683, 0.27499905228614807]
|
||||
|
||||
alphas = [0.9983, 0.8930, 0.7521, 0.5888, 0.4229, 0.2750, 0.1598, 0.0819, 0.0365, 0.0140]
|
||||
alphas_prev = [1.0, 0.9983, 0.8930, 0.7521, 0.5888, 0.4229, 0.2750, 0.1598, 0.0819, 0.0365]
|
||||
#alphas = [0.9983, 0.8930, 0.7521, 0.5888, 0.4229, 0.2750, 0.1598, 0.0819, 0.0365, 0.0140]
|
||||
#alphas_prev = [1.0, 0.9983, 0.8930, 0.7521, 0.5888, 0.4229, 0.2750, 0.1598, 0.0819, 0.0365]
|
||||
#timesteps = [1, 101, 201, 301, 401, 501, 601, 701, 801, 901]
|
||||
timesteps = list(np.arange(1, 1000, 1000//20))
|
||||
print(timesteps)
|
||||
alphas = [model.alphas_cumprod.numpy()[t] for t in timesteps]
|
||||
alphas_prev = [1.0] + alphas[:-1]
|
||||
|
||||
def get_x_prev_and_pred_x0(x, e_t, index):
|
||||
temperature = 1
|
||||
a_t, a_prev = alphas[index], alphas_prev[index]
|
||||
sigma_t = 0
|
||||
sqrt_one_minus_at = math.sqrt(1-a_t)
|
||||
print(a_t, a_prev, sigma_t, sqrt_one_minus_at)
|
||||
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / math.sqrt(a_t)
|
||||
|
||||
@@ -614,7 +621,7 @@ if __name__ == "__main__":
|
||||
|
||||
# is this the diffusion?
|
||||
#for index, timestep in tqdm(list(enumerate([1, 251, 501, 751]))[::-1]):
|
||||
for index, timestep in tqdm(list(enumerate([1, 101, 201, 301, 401, 501, 601, 701, 801, 901]))[::-1]):
|
||||
for index, timestep in tqdm(list(enumerate(timesteps))[::-1]):
|
||||
print(index, timestep)
|
||||
e_t = get_model_output(latent, timestep)
|
||||
#print(e_t.numpy())
|
||||
|
||||
Reference in New Issue
Block a user