Make ShapeTracker Immutable (#1909)

* ugh

* ops test pass

* fix shapetracker tests

* sym shapetracker

* shapetracker is a tuple of views now

* from_shape

* fix has variable shape

* key isn't needed

* post init assert
This commit is contained in:
George Hotz
2023-09-24 21:09:03 +08:00
committed by GitHub
parent 45f02393f0
commit 20059dc55b
11 changed files with 143 additions and 128 deletions

View File

@@ -326,22 +326,22 @@ void E_1(float* data0) {
from tinygrad.shape.shapetracker import ShapeTracker
# create a virtual (10, 10) Tensor. this is just a shape, there's no actual tensor
a = ShapeTracker((10, 10))
a = ShapeTracker.from_shape((10, 10))
# you'll see it has one view. the (10, 1 are the strides)
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)])
# we can permute it, and the strides change
a.permute((1,0))
a = a.permute((1,0))
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)])
# we can then reshape it, and the strides change again
# note how the permute stays applied
a.reshape((5,2,5,2))
a = a.reshape((5,2,5,2))
print(a) # ShapeTracker(shape=(5, 2, 5, 2), views=[View((5, 2, 5, 2), (2, 1, 20, 10), 0)])
# now, if we were to reshape it to a (100,) shape tensor, we have to create a second view
a.reshape((100,))
a = a.reshape((100,))
print(a) # ShapeTracker(shape=(100,), views=[
# View((5, 2, 5, 2), (2, 1, 20, 10), 0),
# View((100,), (1,), 0)])
@@ -352,7 +352,7 @@ idx, _ = a.expr_idxs()
print(idx.render()) # (((idx0%10)*10)+(idx0//10))
# of course, if we reshape it back, the indexes get simple again
a.reshape((10,10))
a = a.reshape((10,10))
idx, _ = a.expr_idxs()
print(idx.render()) # ((idx1*10)+idx0)
@@ -362,11 +362,11 @@ print(a) # ShapeTracker(shape=(10, 10), views=[
# View((10, 10), (10, 1), 0)])
# ...until we simplify it!
a.simplify()
a = a.simplify()
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)])
# and now we permute it back
a.permute((1,0))
a = a.permute((1,0))
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)])
# and it's even contiguous