diff --git a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh index 5d3492fb58..797e7b1a40 100755 --- a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh +++ b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh @@ -2,6 +2,7 @@ export PYTHONPATH="." export DEV=${DEV:-AMD} +export EMULATE="AMD_CDNA4" export CHECK_OOB=0 export REWRITE_STACK_LIMIT=5000000 HCQDEV_WAIT_TIMEOUT_MS=240000 @@ -10,7 +11,7 @@ export FLASH_ATTENTION=${FLASH_ATTENTION:-1} export ALL2ALL=${ALL2ALL:-1} export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16" -export DP=8 BS=8 EVAL_BS=8 GRADIENT_ACC_STEPS=2 +export DP=${DP:-8} BS=8 EVAL_BS=8 GRADIENT_ACC_STEPS=2 export GBS=$((BS * GRADIENT_ACC_STEPS)) export MODEL="llama3" diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 5d5ced5c32..52f59a2a2f 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -319,11 +319,9 @@ class Embedding: self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size) def __call__(self, idx:Tensor) -> Tensor: - if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1) if not dtypes.is_int(idx.dtype): raise TypeError(f"Expected integer dtype for index in embedding, got {idx.dtype}") - big_shp = idx.shape+(self.vocab_sz, self.embed_sz) - arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), self.weight.expand(big_shp) - return (arange == idx).where(vals, 0).sum(-2, dtype=vals.dtype) + arange = Tensor.arange(self.weight.shape[0], requires_grad=False, device=self.weight.device) + return (arange == idx.unsqueeze(-1)).unsqueeze(-1).where(self.weight, 0).sum(-2, dtype=self.weight.dtype) class LSTMCell: """ diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index fca3ea2ad8..56d01b1ad6 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -821,7 +821,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return self.src[0].after(self.store(val).end(*argfix(end))) def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]: - contig_srcs = tuple(x.contiguous() for x in srcs) + contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs) kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn)) return [s.after(kernel) for s in contig_srcs]