randn_like in minrf (#10298)

tested that it trains to similar loss
This commit is contained in:
chenyu
2025-05-14 07:59:50 -04:00
committed by GitHub
parent 0788659d08
commit fbaa26247a

View File

@@ -1,5 +1,4 @@
# much taken from https://github.com/cloneofsimo/minRF
import math
from tinygrad import Tensor, nn, GlobalCounters, TinyJit
from tinygrad.helpers import getenv, trange
from extra.models.llama import Attention, FeedForward, precompute_freqs_cis
@@ -95,7 +94,7 @@ class DiT_Llama:
cond = (Tensor.rand(cond.shape[0]) < dropout_prob).where(cond.full_like(self.num_classes), cond)
# this is rectified flow
z1 = Tensor.randn(x.shape) # TODO: add Tensor.randn_like (and friends) to tinygrad
z1 = x.randn_like()
zt = (1 - texp) * x + texp * z1
vtheta = self(zt, t, cond)