mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
hotfix: add BS to beautiful_mnist
This commit is contained in:
@@ -27,7 +27,7 @@ if __name__ == "__main__":
|
||||
def train_step() -> Tensor:
|
||||
with Tensor.train():
|
||||
opt.zero_grad()
|
||||
samples = Tensor.randint(512, high=X_train.shape[0])
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
|
||||
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
|
||||
opt.step()
|
||||
|
||||
@@ -23,7 +23,7 @@ render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.int, self.b
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
if getenv("UOP_IS_SYMBOLIC"):
|
||||
# TODO: change this once UOps is ready to replace symbolic. note: this doesn't work for variable shapetrackers now
|
||||
# TODO: change this once UOps is ready to replace symbolic
|
||||
def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]:
|
||||
# TODO: dtypes.realint
|
||||
iexpr = variable_to_uop(view.offset)
|
||||
|
||||
Reference in New Issue
Block a user