simplify nn.Embedding, support AFTER in CUSTOM_KERNEL (#14419)

This commit is contained in:
George Hotz
2026-01-29 17:22:13 +08:00
committed by GitHub
parent 0c855d6149
commit 793afbd473
3 changed files with 5 additions and 6 deletions

View File

@@ -2,6 +2,7 @@
export PYTHONPATH="."
export DEV=${DEV:-AMD}
export EMULATE="AMD_CDNA4"
export CHECK_OOB=0
export REWRITE_STACK_LIMIT=5000000 HCQDEV_WAIT_TIMEOUT_MS=240000
@@ -10,7 +11,7 @@ export FLASH_ATTENTION=${FLASH_ATTENTION:-1}
export ALL2ALL=${ALL2ALL:-1}
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
export DP=8 BS=8 EVAL_BS=8 GRADIENT_ACC_STEPS=2
export DP=${DP:-8} BS=8 EVAL_BS=8 GRADIENT_ACC_STEPS=2
export GBS=$((BS * GRADIENT_ACC_STEPS))
export MODEL="llama3"

View File

@@ -319,11 +319,9 @@ class Embedding:
self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size)
def __call__(self, idx:Tensor) -> Tensor:
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1)
if not dtypes.is_int(idx.dtype): raise TypeError(f"Expected integer dtype for index in embedding, got {idx.dtype}")
big_shp = idx.shape+(self.vocab_sz, self.embed_sz)
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), self.weight.expand(big_shp)
return (arange == idx).where(vals, 0).sum(-2, dtype=vals.dtype)
arange = Tensor.arange(self.weight.shape[0], requires_grad=False, device=self.weight.device)
return (arange == idx.unsqueeze(-1)).unsqueeze(-1).where(self.weight, 0).sum(-2, dtype=self.weight.dtype)
class LSTMCell:
"""

View File

@@ -821,7 +821,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return self.src[0].after(self.store(val).end(*argfix(end)))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() for x in srcs)
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))
return [s.after(kernel) for s in contig_srcs]