reflect changes to shapetracker in doc printouts (#3349)

This commit is contained in:
Mason Mahaffey
2024-02-08 10:20:30 -05:00
committed by GitHub
parent 2266152b28
commit 3ebf7a3e38

View File

@@ -306,23 +306,23 @@ from tinygrad.shape.shapetracker import ShapeTracker
# create a virtual (10, 10) Tensor. this is just a shape, there's no actual tensor
a = ShapeTracker.from_shape((10, 10))
# you'll see it has one view. the (10, 1 are the strides)
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)])
# you'll see it has one view
print(a) # ShapeTracker(views=(View(shape=(10, 10), strides=(10, 1))))
# we can permute it, and the strides change
a = a.permute((1,0))
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)])
print(a) # ShapeTracker(views=(View(shape=(10, 10), strides=(1, 10))))
# we can then reshape it, and the strides change again
# note how the permute stays applied
a = a.reshape((5,2,5,2))
print(a) # ShapeTracker(shape=(5, 2, 5, 2), views=[View((5, 2, 5, 2), (2, 1, 20, 10), 0)])
print(a) # ShapeTracker(views=(View(shape=(5, 2, 5, 2), strides=(2, 1, 20, 10))))
# now, if we were to reshape it to a (100,) shape tensor, we have to create a second view
a = a.reshape((100,))
print(a) # ShapeTracker(shape=(100,), views=[
# View((5, 2, 5, 2), (2, 1, 20, 10), 0),
# View((100,), (1,), 0)])
print(a) # ShapeTracker(views=(
# View(shape=(5, 2, 5, 2), strides=(2, 1, 20, 10)),
# View(shape=(100,), strides=(1,))))
# Views stack on top of each other, to allow zero copy for any number of MovementOps
# we can render a Python expression for the index at any time
@@ -335,17 +335,17 @@ idx, _ = a.expr_idxs()
print(idx.render()) # ((idx1*10)+idx0)
# the ShapeTracker still has two views though...
print(a) # ShapeTracker(shape=(10, 10), views=[
# View((5, 2, 5, 2), (2, 1, 20, 10), 0),
# View((10, 10), (10, 1), 0)])
print(a) # ShapeTracker(views=(
# View(shape=(5, 2, 5, 2), strides=(2, 1, 20, 10),
# View(shape=(10, 10), strides=(10, 1))))
# ...until we simplify it!
a = a.simplify()
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)])
print(a) # ShapeTracker(views=(View(shape=(10, 10), strides=(1, 10), offset=0)))
# and now we permute it back
a = a.permute((1,0))
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)])
print(a) # ShapeTracker(views=(View(shape=(10, 10), strides=(10, 1), offset=0)))
# and it's even contiguous
assert a.contiguous == True