mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
Make ShapeTracker Immutable (#1909)
* ugh * ops test pass * fix shapetracker tests * sym shapetracker * shapetracker is a tuple of views now * from_shape * fix has variable shape * key isn't needed * post init assert
This commit is contained in:
@@ -326,22 +326,22 @@ void E_1(float* data0) {
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
# create a virtual (10, 10) Tensor. this is just a shape, there's no actual tensor
|
||||
a = ShapeTracker((10, 10))
|
||||
a = ShapeTracker.from_shape((10, 10))
|
||||
|
||||
# you'll see it has one view. the (10, 1 are the strides)
|
||||
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)])
|
||||
|
||||
# we can permute it, and the strides change
|
||||
a.permute((1,0))
|
||||
a = a.permute((1,0))
|
||||
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)])
|
||||
|
||||
# we can then reshape it, and the strides change again
|
||||
# note how the permute stays applied
|
||||
a.reshape((5,2,5,2))
|
||||
a = a.reshape((5,2,5,2))
|
||||
print(a) # ShapeTracker(shape=(5, 2, 5, 2), views=[View((5, 2, 5, 2), (2, 1, 20, 10), 0)])
|
||||
|
||||
# now, if we were to reshape it to a (100,) shape tensor, we have to create a second view
|
||||
a.reshape((100,))
|
||||
a = a.reshape((100,))
|
||||
print(a) # ShapeTracker(shape=(100,), views=[
|
||||
# View((5, 2, 5, 2), (2, 1, 20, 10), 0),
|
||||
# View((100,), (1,), 0)])
|
||||
@@ -352,7 +352,7 @@ idx, _ = a.expr_idxs()
|
||||
print(idx.render()) # (((idx0%10)*10)+(idx0//10))
|
||||
|
||||
# of course, if we reshape it back, the indexes get simple again
|
||||
a.reshape((10,10))
|
||||
a = a.reshape((10,10))
|
||||
idx, _ = a.expr_idxs()
|
||||
print(idx.render()) # ((idx1*10)+idx0)
|
||||
|
||||
@@ -362,11 +362,11 @@ print(a) # ShapeTracker(shape=(10, 10), views=[
|
||||
# View((10, 10), (10, 1), 0)])
|
||||
|
||||
# ...until we simplify it!
|
||||
a.simplify()
|
||||
a = a.simplify()
|
||||
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)])
|
||||
|
||||
# and now we permute it back
|
||||
a.permute((1,0))
|
||||
a = a.permute((1,0))
|
||||
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)])
|
||||
|
||||
# and it's even contiguous
|
||||
|
||||
Reference in New Issue
Block a user