mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
clean up old tests (#12708)
This commit is contained in:
@@ -131,8 +131,7 @@ class TestIndexing(unittest.TestCase):
|
||||
# llama3 is 128256
|
||||
vocab_size, embed_size = (10, 3) if CI else (32000, 4096)
|
||||
emb = nn.Embedding(vocab_size, embed_size)
|
||||
# TODO: why is a new realize needed here
|
||||
emb_w = emb.weight.realize().numpy()
|
||||
emb_w = emb.weight.numpy()
|
||||
x = Tensor([1,2,3,4])
|
||||
with Context(NOOPT=noopt):
|
||||
GlobalCounters.reset()
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
|
||||
import unittest
|
||||
from tinygrad import Device, dtypes
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.uop.ops import UOp, Ops, AxisType, KernelInfo
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.codegen.opt.search import Opt, OptOps
|
||||
@@ -31,48 +30,6 @@ class TestLinearizerFailure(unittest.TestCase):
|
||||
_ = get_program(ast, Device["METAL"].renderer)
|
||||
|
||||
class TestLinearizerDumb(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
|
||||
@unittest.skip("Ops.VALID no longer exists")
|
||||
def test_max_simplify_and_cancel(self):
|
||||
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1000), arg=0, src=())
|
||||
c1 = c0.view(ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))
|
||||
c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1000), arg=1, src=())
|
||||
c3 = c2.view(ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))
|
||||
c4 = c3.load()
|
||||
c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=2, src=())
|
||||
c6 = c5.view(ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)))
|
||||
c7 = c6.load()
|
||||
c8 = UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=())
|
||||
c9 = UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=())
|
||||
c10 = UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1000), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=())
|
||||
c11 = c1.store((c4.alu(Ops.CMPNE, c7).alu(Ops.CMPNE, UOp.const(dtypes.bool, True, src=c8)).cast(dtypes.int)*(c9.f(Ops.VALID, dtype=dtypes.bool).where(UOp.const(dtypes.int, -1, src=c10), UOp.const(dtypes.int, 0, src=c10)).f(Ops.REDUCE_AXIS, arg=(Ops.ADD, (1,)))+UOp.const(dtypes.int, 1000, src=c8))))
|
||||
ast = c11.sink()
|
||||
#opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8)]
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=0, arg=8)]
|
||||
prg = get_program(ast, Device[Device.DEFAULT].renderer, opts)
|
||||
print(prg.src)
|
||||
assert prg.uops is not None and not any(uop.op is Ops.MAX for uop in prg.uops), "leftover MAX"
|
||||
|
||||
# this was a bug in embedding, someday we should fold this anyway
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), f"half dtype not supported on {Device.DEFAULT}")
|
||||
@unittest.skip("UOp.view is no longer supported")
|
||||
def test_llama_embedding(self):
|
||||
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(4096), arg=0, src=())
|
||||
c1 = c0.view(ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)))
|
||||
c2 = UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32001, 63999), strides=(0, 0), offset=0, mask=((0, 32001), (31999, 63999)), contiguous=False), View(shape=(4096, 32000, 32000), strides=(0, 1, 64000), offset=0, mask=None, contiguous=False))), src=())
|
||||
c3 = UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 32000), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=())
|
||||
c4 = UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=())
|
||||
c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), arg=1, src=())
|
||||
c6 = c5.view(ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)))
|
||||
c7 = c6.load()
|
||||
c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(131072000), arg=2, src=())
|
||||
c9 = c8.view(ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(1, 4096, 0), offset=0, mask=None, contiguous=False),)))
|
||||
c10 = c9.load()
|
||||
c11 = c1.store(((c2.f(Ops.VALID, dtype=dtypes.bool).where(UOp.const(dtypes.int, 1, src=c3), UOp.const(dtypes.int, 0, src=c3)).f(Ops.REDUCE_AXIS, arg=(Ops.ADD, (2,)))+UOp.const(dtypes.int, -1, src=c4)).alu(Ops.CMPNE, c7).alu(Ops.CMPNE, UOp.const(dtypes.bool, True, src=c4)).cast(dtypes.half)*c10).cast(dtypes.float).f(Ops.REDUCE_AXIS, arg=(Ops.ADD, (1,))).cast(dtypes.half))
|
||||
ast = c11.sink()
|
||||
prg = get_program(ast, Device[Device.DEFAULT].renderer)
|
||||
print(prg.src)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need float4")
|
||||
def test_unrolled_float4_align(self):
|
||||
|
||||
@@ -837,7 +837,6 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
self.assertEqual(len(si.metadata), 3)
|
||||
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
|
||||
|
||||
@unittest.skip("not accurate")
|
||||
def test_complex_backward(self):
|
||||
x = Tensor.rand(3, requires_grad=True).realize()
|
||||
y = Tensor.rand(3, requires_grad=True).realize()
|
||||
@@ -849,11 +848,12 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
|
||||
self.assertTrue(y.grad.uop.metadata[0].backward)
|
||||
si = Tensor.schedule(out, x.grad, y.grad)[-1]
|
||||
self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
|
||||
self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "relu"})
|
||||
self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}")
|
||||
self.assertSetEqual(set(m.name for m in si.metadata), {"__mul__", "sigmoid", "relu"})
|
||||
bw = [m for m in si.metadata if m.backward]
|
||||
self.assertEqual(len(bw), 1)
|
||||
self.assertEqual(bw[0].name, "sigmoid")
|
||||
self.assertEqual(len(bw), 2)
|
||||
self.assertEqual(bw[0].name, "__mul__")
|
||||
self.assertEqual(bw[1].name, "sigmoid")
|
||||
|
||||
class TestIdxUpcast(unittest.TestCase):
|
||||
def _find_op(self, ast: UOp, op: Ops):
|
||||
|
||||
@@ -26,14 +26,6 @@ pm_quant = symbolic+PatternMatcher([
|
||||
# x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats)
|
||||
(UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats),
|
||||
lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None),
|
||||
# mul 0 * c1 is 0
|
||||
#(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
|
||||
# UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
|
||||
# mul (with plus) 0 * c1 is 0
|
||||
#(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
|
||||
# (UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int) + \
|
||||
# UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
|
||||
# lambda ld,v,c1: ld*c1),
|
||||
|
||||
# const push through add
|
||||
((UPat.var("x")*UPat.cvar("c1") + UPat.var("y")*UPat.cvar("c2")) * UPat.cvar("c3"), lambda x,y,c1,c2,c3: (x*c1*c3) + (y*c2*c3)),
|
||||
|
||||
Reference in New Issue
Block a user