mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
simplify nn.Embedding, support AFTER in CUSTOM_KERNEL (#14419)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export DEV=${DEV:-AMD}
|
||||
export EMULATE="AMD_CDNA4"
|
||||
export CHECK_OOB=0
|
||||
export REWRITE_STACK_LIMIT=5000000 HCQDEV_WAIT_TIMEOUT_MS=240000
|
||||
|
||||
@@ -10,7 +11,7 @@ export FLASH_ATTENTION=${FLASH_ATTENTION:-1}
|
||||
export ALL2ALL=${ALL2ALL:-1}
|
||||
|
||||
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
|
||||
export DP=8 BS=8 EVAL_BS=8 GRADIENT_ACC_STEPS=2
|
||||
export DP=${DP:-8} BS=8 EVAL_BS=8 GRADIENT_ACC_STEPS=2
|
||||
export GBS=$((BS * GRADIENT_ACC_STEPS))
|
||||
|
||||
export MODEL="llama3"
|
||||
|
||||
@@ -319,11 +319,9 @@ class Embedding:
|
||||
self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size)
|
||||
|
||||
def __call__(self, idx:Tensor) -> Tensor:
|
||||
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1)
|
||||
if not dtypes.is_int(idx.dtype): raise TypeError(f"Expected integer dtype for index in embedding, got {idx.dtype}")
|
||||
big_shp = idx.shape+(self.vocab_sz, self.embed_sz)
|
||||
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), self.weight.expand(big_shp)
|
||||
return (arange == idx).where(vals, 0).sum(-2, dtype=vals.dtype)
|
||||
arange = Tensor.arange(self.weight.shape[0], requires_grad=False, device=self.weight.device)
|
||||
return (arange == idx.unsqueeze(-1)).unsqueeze(-1).where(self.weight, 0).sum(-2, dtype=self.weight.dtype)
|
||||
|
||||
class LSTMCell:
|
||||
"""
|
||||
|
||||
@@ -821,7 +821,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return self.src[0].after(self.store(val).end(*argfix(end)))
|
||||
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
contig_srcs = tuple(x.contiguous() for x in srcs)
|
||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))
|
||||
return [s.after(kernel) for s in contig_srcs]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user