mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
@@ -1,5 +1,4 @@
|
||||
# much taken from https://github.com/cloneofsimo/minRF
|
||||
import math
|
||||
from tinygrad import Tensor, nn, GlobalCounters, TinyJit
|
||||
from tinygrad.helpers import getenv, trange
|
||||
from extra.models.llama import Attention, FeedForward, precompute_freqs_cis
|
||||
@@ -95,7 +94,7 @@ class DiT_Llama:
|
||||
cond = (Tensor.rand(cond.shape[0]) < dropout_prob).where(cond.full_like(self.num_classes), cond)
|
||||
|
||||
# this is rectified flow
|
||||
z1 = Tensor.randn(x.shape) # TODO: add Tensor.randn_like (and friends) to tinygrad
|
||||
z1 = x.randn_like()
|
||||
zt = (1 - texp) * x + texp * z1
|
||||
vtheta = self(zt, t, cond)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user