hotfix: add BS to beautiful_mnist

This commit is contained in:
George Hotz
2024-07-11 10:55:05 -07:00
parent 3e40211e45
commit 5232e405ce
2 changed files with 2 additions and 2 deletions

View File

@@ -27,7 +27,7 @@ if __name__ == "__main__":
def train_step() -> Tensor:
with Tensor.train():
opt.zero_grad()
samples = Tensor.randint(512, high=X_train.shape[0])
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
opt.step()

View File

@@ -23,7 +23,7 @@ render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.int, self.b
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
if getenv("UOP_IS_SYMBOLIC"):
# TODO: change this once UOps is ready to replace symbolic. note: this doesn't work for variable shapetrackers now
# TODO: change this once UOps is ready to replace symbolic
def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]:
# TODO: dtypes.realint
iexpr = variable_to_uop(view.offset)