mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
embedding doesn't cast (#5952)
* embedding doesn't cast * test the right thing * too much annoying with that test
This commit is contained in:
@@ -77,5 +77,28 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
conditions = if_uops[0].src[0].sparents
|
||||
self.assertLessEqual(len(conditions), 8)
|
||||
|
||||
# this was a bug in embedding, someday we should fold this anyway
|
||||
def test_llama_embedding(self):
|
||||
ast = LazyOp(MetaOps.KERNEL, arg=None, src=(
|
||||
LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),))), src=(
|
||||
LazyOp(UnaryOps.CAST, arg=dtypes.half, src=(
|
||||
LazyOp(ReduceOps.SUM, arg=(1,), src=(
|
||||
LazyOp(UnaryOps.CAST, arg=dtypes.float, src=(
|
||||
LazyOp(BinaryOps.MUL, arg=None, src=(
|
||||
LazyOp(UnaryOps.CAST, arg=dtypes.half, src=(
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(BinaryOps.ADD, arg=None, src=(
|
||||
LazyOp(ReduceOps.SUM, arg=(2,), src=(
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(32001, 63999), strides=(0, 0), offset=0, mask=((0, 32001), (31999, 63999)), contiguous=False), View(shape=(4096, 32000, 32000), strides=(0, 1, 64000), offset=0, mask=None, contiguous=False)))), src=()),)),
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous= False),))), src=()),)),
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(1, 4096, 0), offset=0, mask=None, contiguous=False),)
|
||||
)), src=()),)),)),)),)),)),))
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user