mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
@@ -1,5 +1,4 @@
|
|||||||
# much taken from https://github.com/cloneofsimo/minRF
|
# much taken from https://github.com/cloneofsimo/minRF
|
||||||
import math
|
|
||||||
from tinygrad import Tensor, nn, GlobalCounters, TinyJit
|
from tinygrad import Tensor, nn, GlobalCounters, TinyJit
|
||||||
from tinygrad.helpers import getenv, trange
|
from tinygrad.helpers import getenv, trange
|
||||||
from extra.models.llama import Attention, FeedForward, precompute_freqs_cis
|
from extra.models.llama import Attention, FeedForward, precompute_freqs_cis
|
||||||
@@ -95,7 +94,7 @@ class DiT_Llama:
|
|||||||
cond = (Tensor.rand(cond.shape[0]) < dropout_prob).where(cond.full_like(self.num_classes), cond)
|
cond = (Tensor.rand(cond.shape[0]) < dropout_prob).where(cond.full_like(self.num_classes), cond)
|
||||||
|
|
||||||
# this is rectified flow
|
# this is rectified flow
|
||||||
z1 = Tensor.randn(x.shape) # TODO: add Tensor.randn_like (and friends) to tinygrad
|
z1 = x.randn_like()
|
||||||
zt = (1 - texp) * x + texp * z1
|
zt = (1 - texp) * x + texp * z1
|
||||||
vtheta = self(zt, t, cond)
|
vtheta = self(zt, t, cond)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user