mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-08 13:45:50 -05:00
Remove Zeroview (#748)
* no zeroview start * closer * stride mask * st tests pass, delete ZeroView * byebye zv * close to working * not contiguous with mask * subtract, don't add * mask on view * ugh, that shouldn't have been in there * shape merge * bugfixes * fuzzer + 4 fuzzer failures * fuzzer for symbolic * more fuzzing and nothing * that fuzzer doesn't hit either * fixes padding...ugh * no more offsets * working * rewrite load and store * all checks * fix idxs * progress * bugfix * float4_axis * works * cleanups * complex valids_okay
This commit is contained in:
60
test/external/fuzz_shapetracker.py
vendored
Normal file
60
test/external/fuzz_shapetracker.py
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
import random
|
||||
from test.unit.test_shapetracker import CheckingShapeTracker
|
||||
|
||||
def do_permute(st):
|
||||
perm = list(range(0, len(st.shape)))
|
||||
random.shuffle(perm)
|
||||
perm = tuple(perm)
|
||||
print("st.permute(", perm, ")")
|
||||
st.permute(perm)
|
||||
|
||||
def do_pad(st):
|
||||
c = random.randint(0, len(st.shape)-1)
|
||||
pad = tuple((random.randint(0,2), random.randint(0,2)) if i==c else (0,0) for i in range(len(st.shape)))
|
||||
print("st.pad(", pad, ")")
|
||||
st.pad(pad)
|
||||
|
||||
def do_reshape_split_one(st):
|
||||
c = random.randint(0, len(st.shape)-1)
|
||||
poss = [n for n in [1,2,3,4,5] if st.shape[c]%n == 0]
|
||||
spl = random.choice(poss)
|
||||
shp = st.shape[0:c] + (st.shape[c]//spl, spl) + st.shape[c+1:]
|
||||
print("st.reshape(", shp, ")")
|
||||
st.reshape(shp)
|
||||
|
||||
def do_reshape_combine_two(st):
|
||||
if len(st.shape) < 2: return
|
||||
c = random.randint(0, len(st.shape)-2)
|
||||
shp = st.shape[:c] + (st.shape[c] * st.shape[c+1], ) + st.shape[c+2:]
|
||||
print("st.reshape(", shp, ")")
|
||||
st.reshape(shp)
|
||||
|
||||
def do_shrink(st):
|
||||
c = random.randint(0, len(st.shape)-1)
|
||||
while 1:
|
||||
shrink = tuple((random.randint(0,s), random.randint(0,s)) if i == c else (0,s) for i,s in enumerate(st.shape))
|
||||
if all(x<y for (x,y) in shrink): break
|
||||
print("st.shrink(", shrink, ")")
|
||||
st.shrink(shrink)
|
||||
|
||||
def do_stride(st):
|
||||
c = random.randint(0, len(st.shape)-1)
|
||||
stride = tuple(random.choice([-2,-1,2]) if i==c else 1 for i in range(len(st.shape)))
|
||||
print("st.stride(", stride, ")")
|
||||
st.stride(stride)
|
||||
|
||||
def do_expand(st):
|
||||
c = [i for i,s in enumerate(st.shape) if s==1]
|
||||
if len(c) == 0: return
|
||||
c = random.choice(c)
|
||||
expand = tuple(random.choice([2,3,4]) if i==c else s for i,s in enumerate(st.shape))
|
||||
print("st.expand(", expand, ")")
|
||||
st.expand(expand)
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops = [do_permute, do_pad, do_shrink, do_reshape_split_one, do_reshape_combine_two, do_stride, do_expand]
|
||||
while 1:
|
||||
st = CheckingShapeTracker((3, 3, 3))
|
||||
for i in range(8): random.choice(ops)(st)
|
||||
#st.simplify()
|
||||
st.assert_same()
|
||||
50
test/external/fuzz_symbolic.py
vendored
Normal file
50
test/external/fuzz_symbolic.py
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
import random
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
def add_v(expr, rng=None):
|
||||
if rng is None: rng = random.randint(0,2)
|
||||
return expr + v[rng], rng
|
||||
|
||||
def div(expr, rng=None):
|
||||
if rng is None: rng = random.randint(1,9)
|
||||
return expr // rng, rng
|
||||
|
||||
def mul(expr, rng=None):
|
||||
if rng is None: rng = random.randint(-4,4)
|
||||
return expr * rng, rng
|
||||
|
||||
def mod(expr, rng=None):
|
||||
if rng is None: rng = random.randint(1,9)
|
||||
return expr % rng, rng
|
||||
|
||||
def add_num(expr, rng=None):
|
||||
if rng is None: rng = random.randint(-4,4)
|
||||
return expr + rng, rng
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops = [add_v, div, mul, add_num]
|
||||
while 1:
|
||||
u1 = Variable("v1", 0, 2)
|
||||
u2 = Variable("v2", 0, 3)
|
||||
u3 = Variable("v3", 0, 4)
|
||||
v = [u1,u2,u3]
|
||||
tape = [random.choice(ops) for _ in range(20)]
|
||||
expr = Variable.num(0)
|
||||
rngs = []
|
||||
for t in tape:
|
||||
expr, rng = t(expr)
|
||||
print(t.__name__, rng)
|
||||
rngs.append(rng)
|
||||
print(expr)
|
||||
for v1 in range(u1.min, u1.max+1):
|
||||
for v2 in range(u2.min, u2.max+1):
|
||||
for v3 in range(u3.min, u3.max+1):
|
||||
v = [v1,v2,v3]
|
||||
rn = 0
|
||||
for t,r in zip(tape, rngs):
|
||||
rn, _ = t(rn, r)
|
||||
num = eval(expr.render())
|
||||
assert num == rn, f"mismatch at {v1} {v2} {v3}, {num} != {rn}"
|
||||
#print(v1, v2, v3, num, rn)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user