Remove Zeroview (#748)

* no zeroview start

* closer

* stride mask

* st tests pass, delete ZeroView

* byebye zv

* close to working

* not contiguous with mask

* subtract, don't add

* mask on view

* ugh, that shouldn't have been in there

* shape merge

* bugfixes

* fuzzer + 4 fuzzer failures

* fuzzer for symbolic

* more fuzzing and nothing

* that fuzzer doesn't hit either

* fixes padding...ugh

* no more offsets

* working

* rewrite load and store

* all checks

* fix idxs

* progress

* bugfix

* float4_axis

* works

* cleanups

* complex valids_okay
This commit is contained in:
George Hotz
2023-04-17 08:21:46 -07:00
committed by GitHub
parent 4e17d27d09
commit 8b7ecd63bb
9 changed files with 391 additions and 181 deletions

60
test/external/fuzz_shapetracker.py vendored Normal file
View File

@@ -0,0 +1,60 @@
import random
from test.unit.test_shapetracker import CheckingShapeTracker
def do_permute(st):
perm = list(range(0, len(st.shape)))
random.shuffle(perm)
perm = tuple(perm)
print("st.permute(", perm, ")")
st.permute(perm)
def do_pad(st):
c = random.randint(0, len(st.shape)-1)
pad = tuple((random.randint(0,2), random.randint(0,2)) if i==c else (0,0) for i in range(len(st.shape)))
print("st.pad(", pad, ")")
st.pad(pad)
def do_reshape_split_one(st):
c = random.randint(0, len(st.shape)-1)
poss = [n for n in [1,2,3,4,5] if st.shape[c]%n == 0]
spl = random.choice(poss)
shp = st.shape[0:c] + (st.shape[c]//spl, spl) + st.shape[c+1:]
print("st.reshape(", shp, ")")
st.reshape(shp)
def do_reshape_combine_two(st):
if len(st.shape) < 2: return
c = random.randint(0, len(st.shape)-2)
shp = st.shape[:c] + (st.shape[c] * st.shape[c+1], ) + st.shape[c+2:]
print("st.reshape(", shp, ")")
st.reshape(shp)
def do_shrink(st):
c = random.randint(0, len(st.shape)-1)
while 1:
shrink = tuple((random.randint(0,s), random.randint(0,s)) if i == c else (0,s) for i,s in enumerate(st.shape))
if all(x<y for (x,y) in shrink): break
print("st.shrink(", shrink, ")")
st.shrink(shrink)
def do_stride(st):
c = random.randint(0, len(st.shape)-1)
stride = tuple(random.choice([-2,-1,2]) if i==c else 1 for i in range(len(st.shape)))
print("st.stride(", stride, ")")
st.stride(stride)
def do_expand(st):
c = [i for i,s in enumerate(st.shape) if s==1]
if len(c) == 0: return
c = random.choice(c)
expand = tuple(random.choice([2,3,4]) if i==c else s for i,s in enumerate(st.shape))
print("st.expand(", expand, ")")
st.expand(expand)
if __name__ == "__main__":
ops = [do_permute, do_pad, do_shrink, do_reshape_split_one, do_reshape_combine_two, do_stride, do_expand]
while 1:
st = CheckingShapeTracker((3, 3, 3))
for i in range(8): random.choice(ops)(st)
#st.simplify()
st.assert_same()

50
test/external/fuzz_symbolic.py vendored Normal file
View File

@@ -0,0 +1,50 @@
import random
from tinygrad.shape.symbolic import Variable
def add_v(expr, rng=None):
if rng is None: rng = random.randint(0,2)
return expr + v[rng], rng
def div(expr, rng=None):
if rng is None: rng = random.randint(1,9)
return expr // rng, rng
def mul(expr, rng=None):
if rng is None: rng = random.randint(-4,4)
return expr * rng, rng
def mod(expr, rng=None):
if rng is None: rng = random.randint(1,9)
return expr % rng, rng
def add_num(expr, rng=None):
if rng is None: rng = random.randint(-4,4)
return expr + rng, rng
if __name__ == "__main__":
ops = [add_v, div, mul, add_num]
while 1:
u1 = Variable("v1", 0, 2)
u2 = Variable("v2", 0, 3)
u3 = Variable("v3", 0, 4)
v = [u1,u2,u3]
tape = [random.choice(ops) for _ in range(20)]
expr = Variable.num(0)
rngs = []
for t in tape:
expr, rng = t(expr)
print(t.__name__, rng)
rngs.append(rng)
print(expr)
for v1 in range(u1.min, u1.max+1):
for v2 in range(u2.min, u2.max+1):
for v3 in range(u3.min, u3.max+1):
v = [v1,v2,v3]
rn = 0
for t,r in zip(tape, rngs):
rn, _ = t(rn, r)
num = eval(expr.render())
assert num == rn, f"mismatch at {v1} {v2} {v3}, {num} != {rn}"
#print(v1, v2, v3, num, rn)