randn_like in minrf (#10298)

tested that it trains to similar loss
This commit is contained in:
chenyu
2025-05-14 07:59:50 -04:00
committed by GitHub
parent 0788659d08
commit fbaa26247a

View File

@@ -1,5 +1,4 @@
# much taken from https://github.com/cloneofsimo/minRF # much taken from https://github.com/cloneofsimo/minRF
import math
from tinygrad import Tensor, nn, GlobalCounters, TinyJit from tinygrad import Tensor, nn, GlobalCounters, TinyJit
from tinygrad.helpers import getenv, trange from tinygrad.helpers import getenv, trange
from extra.models.llama import Attention, FeedForward, precompute_freqs_cis from extra.models.llama import Attention, FeedForward, precompute_freqs_cis
@@ -95,7 +94,7 @@ class DiT_Llama:
cond = (Tensor.rand(cond.shape[0]) < dropout_prob).where(cond.full_like(self.num_classes), cond) cond = (Tensor.rand(cond.shape[0]) < dropout_prob).where(cond.full_like(self.num_classes), cond)
# this is rectified flow # this is rectified flow
z1 = Tensor.randn(x.shape) # TODO: add Tensor.randn_like (and friends) to tinygrad z1 = x.randn_like()
zt = (1 - texp) * x + texp * z1 zt = (1 - texp) * x + texp * z1
vtheta = self(zt, t, cond) vtheta = self(zt, t, cond)