mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Make ShapeTracker Immutable (#1909)
* ugh * ops test pass * fix shapetracker tests * sym shapetracker * shapetracker is a tuple of views now * from_shape * fix has variable shape * key isn't needed * post init assert
This commit is contained in:
@@ -326,22 +326,22 @@ void E_1(float* data0) {
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
# create a virtual (10, 10) Tensor. this is just a shape, there's no actual tensor
|
||||
a = ShapeTracker((10, 10))
|
||||
a = ShapeTracker.from_shape((10, 10))
|
||||
|
||||
# you'll see it has one view. the (10, 1 are the strides)
|
||||
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)])
|
||||
|
||||
# we can permute it, and the strides change
|
||||
a.permute((1,0))
|
||||
a = a.permute((1,0))
|
||||
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)])
|
||||
|
||||
# we can then reshape it, and the strides change again
|
||||
# note how the permute stays applied
|
||||
a.reshape((5,2,5,2))
|
||||
a = a.reshape((5,2,5,2))
|
||||
print(a) # ShapeTracker(shape=(5, 2, 5, 2), views=[View((5, 2, 5, 2), (2, 1, 20, 10), 0)])
|
||||
|
||||
# now, if we were to reshape it to a (100,) shape tensor, we have to create a second view
|
||||
a.reshape((100,))
|
||||
a = a.reshape((100,))
|
||||
print(a) # ShapeTracker(shape=(100,), views=[
|
||||
# View((5, 2, 5, 2), (2, 1, 20, 10), 0),
|
||||
# View((100,), (1,), 0)])
|
||||
@@ -352,7 +352,7 @@ idx, _ = a.expr_idxs()
|
||||
print(idx.render()) # (((idx0%10)*10)+(idx0//10))
|
||||
|
||||
# of course, if we reshape it back, the indexes get simple again
|
||||
a.reshape((10,10))
|
||||
a = a.reshape((10,10))
|
||||
idx, _ = a.expr_idxs()
|
||||
print(idx.render()) # ((idx1*10)+idx0)
|
||||
|
||||
@@ -362,11 +362,11 @@ print(a) # ShapeTracker(shape=(10, 10), views=[
|
||||
# View((10, 10), (10, 1), 0)])
|
||||
|
||||
# ...until we simplify it!
|
||||
a.simplify()
|
||||
a = a.simplify()
|
||||
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)])
|
||||
|
||||
# and now we permute it back
|
||||
a.permute((1,0))
|
||||
a = a.permute((1,0))
|
||||
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)])
|
||||
|
||||
# and it's even contiguous
|
||||
|
||||
@@ -43,7 +43,7 @@ class ATan2(Function):
|
||||
assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch"
|
||||
self.a, self.b = a, b
|
||||
ast = LazyOp(LoadOps.CUSTOM, (a.contiguous(), b.contiguous()), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device])
|
||||
return create_lazybuffer(a.device, ShapeTracker(a.shape), LoadOps, ast, max(a.dtype, b.dtype), {})
|
||||
return create_lazybuffer(a.device, ShapeTracker.from_shape(a.shape), LoadOps, ast, max(a.dtype, b.dtype), {})
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
denom = (self.a.e(BinaryOps.MUL, self.a)).e(BinaryOps.ADD, self.b.e(BinaryOps.MUL, self.b))
|
||||
return grad_output.e(BinaryOps.MUL, self.b.e(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \
|
||||
|
||||
@@ -6,18 +6,18 @@ from tinygrad.tensor import Tensor
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
def test_symbolic_st(self):
|
||||
x = Variable("x", 1, 100)
|
||||
st = ShapeTracker((x, 3))
|
||||
st = ShapeTracker.from_shape((x, 3))
|
||||
assert st.shape == (x, 3)
|
||||
assert st.real_strides() == (3, 1)
|
||||
|
||||
def test_expr_idxs(self):
|
||||
x = Variable("x", 1, 100)
|
||||
st = ShapeTracker((x, 3))
|
||||
st = ShapeTracker.from_shape((x, 3))
|
||||
idxs = [Variable("x", 0, 100), Variable("y", 0, 100)]
|
||||
e1, e2 = st.expr_idxs(idxs)
|
||||
assert e1.render() == "((x*3)+y)"
|
||||
assert e2.render() == "1"
|
||||
st.permute((1, 0))
|
||||
st = st.permute((1, 0))
|
||||
e1, e2 = st.expr_idxs(idxs)
|
||||
assert e1.render() == "((y*3)+x)"
|
||||
assert e2.render() == "1"
|
||||
@@ -142,8 +142,7 @@ class TestSymbolicShapeExpr(unittest.TestCase):
|
||||
idx = (gidx0, lidx1, Variable.num(1))
|
||||
shape = (i+1, 8, 4)
|
||||
strides = (1, (i*4)+4, i+1)
|
||||
view = View.create(shape, strides)
|
||||
st = ShapeTracker(shape, [view])
|
||||
st = ShapeTracker((View.create(shape, strides), ))
|
||||
idx, valid = st.expr_idxs(idx)
|
||||
assert idx.render() == "((lidx1*((i*4)+4))+1+gidx0+i)"
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info, LoadOps, MemBuffer
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.helpers import dtypes
|
||||
|
||||
class TestFlopCounter(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.buf0 = LazyOp(LoadOps.BUFFER, (), MemBuffer(1, dtypes.float32, (View.create((4,)),)))
|
||||
self.buf1 = LazyOp(LoadOps.BUFFER, (), MemBuffer(2, dtypes.float32, (View.create((4,)),)))
|
||||
self.buf0 = LazyOp(LoadOps.BUFFER, (), MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,))))
|
||||
self.buf1 = LazyOp(LoadOps.BUFFER, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,))))
|
||||
|
||||
def test_flops_add(self):
|
||||
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||||
|
||||
@@ -14,42 +14,51 @@ def shapetracker_getitem(st, val):
|
||||
|
||||
class CheckingShapeTracker:
|
||||
def __init__(self, shape):
|
||||
self.st = ShapeTracker(shape)
|
||||
self.st = ShapeTracker.from_shape(shape)
|
||||
self.t = np.arange(prod(shape), dtype=np.int32).reshape(shape)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.t.shape
|
||||
|
||||
def simplify(self): self.st.simplify()
|
||||
def simplify(self):
|
||||
self.st = self.st.simplify()
|
||||
return self
|
||||
|
||||
def reshape(self, new_shape):
|
||||
self.st.reshape(new_shape)
|
||||
self.st = self.st.reshape(new_shape)
|
||||
self.t = self.t.reshape(new_shape)
|
||||
return self
|
||||
|
||||
def permute(self, axis):
|
||||
self.st.permute(axis)
|
||||
self.st = self.st.permute(axis)
|
||||
self.t = np.transpose(self.t, axis)
|
||||
return self
|
||||
|
||||
def expand(self, new_shape):
|
||||
self.st.expand(new_shape)
|
||||
self.st = self.st.expand(new_shape)
|
||||
self.t = np.broadcast_to(self.t, new_shape)
|
||||
return self
|
||||
|
||||
def flip(self, axis):
|
||||
self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
|
||||
self.st = self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
|
||||
self.t = np.flip(self.t, axis)
|
||||
return self
|
||||
|
||||
def shrink(self, arg):
|
||||
self.st.shrink(arg)
|
||||
self.st = self.st.shrink(arg)
|
||||
self.t = self.t[tuple([slice(x[0], x[1]) for x in arg])]
|
||||
return self
|
||||
|
||||
def pad(self, arg):
|
||||
self.st.pad(arg)
|
||||
self.st = self.st.pad(arg)
|
||||
self.t = np.pad(self.t, arg, constant_values=-1)
|
||||
return self
|
||||
|
||||
def stride(self, arg):
|
||||
self.st.stride(arg)
|
||||
self.st = self.st.stride(arg)
|
||||
self.t = self.t[tuple([slice(None, None, x) for x in arg])]
|
||||
return self
|
||||
|
||||
def __getitem__(self, val):
|
||||
return self.t.flatten()[val]
|
||||
@@ -70,7 +79,7 @@ class CheckingShapeTracker:
|
||||
|
||||
class TestRealIssues(unittest.TestCase):
|
||||
def test_reshape_doesnt_multiview(self):
|
||||
self.st = ShapeTracker((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), views=[View.create((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None)])
|
||||
self.st = ShapeTracker((View.create((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None),))
|
||||
self.st.reshape((128, 2, 256, 2, 2, 2, 2, 2, 256, 8, 2))
|
||||
assert len(self.st.views) == 1
|
||||
|
||||
@@ -78,27 +87,27 @@ class TestRealDoesntSimplify(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
st = self.st.real_strides()
|
||||
print(st)
|
||||
self.st.simplify()
|
||||
self.st = self.st.simplify()
|
||||
assert len(self.st.views) != 1
|
||||
assert None in st
|
||||
|
||||
def test_1(self):
|
||||
self.st = ShapeTracker((8, 6, 11), views=[
|
||||
self.st = ShapeTracker((
|
||||
View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None),
|
||||
View.create((8, 6, 11), (66, 11, 1), 0, None)])
|
||||
View.create((8, 6, 11), (66, 11, 1), 0, None)))
|
||||
assert self.st.real_strides() == (33, None, 1)
|
||||
|
||||
def test_2(self):
|
||||
self.st = ShapeTracker((4, 4, 3, 3), views=[
|
||||
self.st = ShapeTracker((
|
||||
View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None),
|
||||
View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None)])
|
||||
View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None)))
|
||||
assert self.st.real_strides() == (None, 18, -3, -1)
|
||||
|
||||
class TestRealStrides(unittest.TestCase):
|
||||
def test_1(self):
|
||||
self.st = ShapeTracker((16, 32, 4), views=[
|
||||
self.st = ShapeTracker((
|
||||
View.create((2048,), (1,), 0, ((0, 512),)),
|
||||
View.create((16, 32, 4), (128, 4, 1), 0, None)])
|
||||
View.create((16, 32, 4), (128, 4, 1), 0, None)))
|
||||
st = self.st.real_strides()
|
||||
print(self.st, st)
|
||||
assert st == (None, 4, 1)
|
||||
@@ -106,27 +115,27 @@ class TestRealStrides(unittest.TestCase):
|
||||
class TestRealSimplifies(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
st = self.st.real_strides()
|
||||
self.st.simplify()
|
||||
self.st = self.st.simplify()
|
||||
assert len(self.st.views) == 1
|
||||
print(self.st.views[-1].strides, st)
|
||||
assert self.st.views[-1].strides == st
|
||||
|
||||
def test_1(self):
|
||||
self.st = ShapeTracker((1, 3, 2, 11, 26, 1, 1, 3), views=[
|
||||
self.st = ShapeTracker((
|
||||
View.create((1, 3, 2, 11, 4, 28), (0, 308, 0, 28, 0, 1), 0, None),
|
||||
View.create((1, 3, 2, 11, 26, 1, 1, 3), (0, 2464, 0, 112, 1, 0, 0, 29), 0, None)])
|
||||
View.create((1, 3, 2, 11, 26, 1, 1, 3), (0, 2464, 0, 112, 1, 0, 0, 29), 0, None)))
|
||||
|
||||
def test_2(self):
|
||||
self.st = ShapeTracker((8, 1, 6, 10, 28, 3, 2, 1), views=[
|
||||
self.st = ShapeTracker((
|
||||
View.create((8, 3, 3, 11, 2, 28), (924, 308, 0, 28, 0, 1), 0, None),
|
||||
View.create((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)])
|
||||
View.create((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)))
|
||||
|
||||
class TestIndexExpressions2d(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5
|
||||
offsets = [0, 1, 15, 28, 10000]
|
||||
self.sts = [ShapeTracker(base_shape, [View.create(base_shape, offset=offset)]) for base_shape in shapes for offset in offsets]
|
||||
self.sts = [ShapeTracker((View.create(base_shape, offset=offset),)) for base_shape in shapes for offset in offsets]
|
||||
self.offset = [Variable.num(offset) for base_shape in shapes for offset in offsets]
|
||||
self.shapes = [shape for shape in shapes for offset in offsets]
|
||||
self.node_exprs = []
|
||||
@@ -171,36 +180,52 @@ class TestIndexExpressions2d(unittest.TestCase):
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[1] + offset)
|
||||
|
||||
def test_permute(self):
|
||||
new_st = []
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st.permute((1, 0))
|
||||
st = st.permute((1, 0))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0] + idxs[1]*base_shape[1] + offset)
|
||||
new_st.append(st)
|
||||
self.sts = new_st
|
||||
|
||||
def test_reshape(self):
|
||||
new_st = []
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st.reshape((base_shape[0], 1, base_shape[1]))
|
||||
st = st.reshape((base_shape[0], 1, base_shape[1]))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape) + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
|
||||
new_st.append(st)
|
||||
self.sts = new_st
|
||||
|
||||
def test_reshape_expand(self):
|
||||
new_st = []
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st.reshape((base_shape[0], 1, base_shape[1]))
|
||||
st.expand((base_shape[0], base_shape[1], base_shape[1]))
|
||||
st = st.reshape((base_shape[0], 1, base_shape[1]))
|
||||
st = st.expand((base_shape[0], base_shape[1], base_shape[1]))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx//(base_shape[1]*base_shape[1])%base_shape[0]*base_shape[1] + idx%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
|
||||
new_st.append(st)
|
||||
self.sts = new_st
|
||||
|
||||
def test_permute_reshape_1(self): # This tests multiple views
|
||||
new_st = []
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st.permute((1, 0))
|
||||
st.reshape((base_shape[0]//5, 1, base_shape[1]*5))
|
||||
st = st.permute((1, 0))
|
||||
st = st.reshape((base_shape[0]//5, 1, base_shape[1]*5))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape)%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[0]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + (idxs[0]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset)
|
||||
new_st.append(st)
|
||||
self.sts = new_st
|
||||
|
||||
def test_permute_reshape_2(self):
|
||||
new_st = []
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st.permute((1, 0))
|
||||
st.reshape((1, base_shape[0]//5, base_shape[1]*5))
|
||||
st = st.permute((1, 0))
|
||||
st = st.reshape((1, base_shape[0]//5, base_shape[1]*5))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape)%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[1]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + (idxs[1]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset)
|
||||
new_st.append(st)
|
||||
self.sts = new_st
|
||||
|
||||
class TestSimplifyingShapeTracker(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -211,14 +236,14 @@ class TestSimplifyingShapeTracker(unittest.TestCase):
|
||||
|
||||
# multiview simplify
|
||||
def test_expand_contract_simple(self):
|
||||
self.st.expand((10, 10))
|
||||
self.st.reshape((100,))
|
||||
self.st = self.st.expand((10, 10))
|
||||
self.st = self.st.reshape((100,))
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 2)
|
||||
self.st.reshape((10, 10))
|
||||
self.st = self.st.reshape((10, 10))
|
||||
print(self.st.views)
|
||||
|
||||
self.st.simplify()
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 1)
|
||||
|
||||
@@ -231,7 +256,7 @@ class TestSimplifyingShapeTracker(unittest.TestCase):
|
||||
self.st.reshape((2, 5, 2, 5))
|
||||
print(self.st.views)
|
||||
|
||||
self.st.simplify()
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 1)
|
||||
|
||||
@@ -243,7 +268,7 @@ class TestSimplifyingShapeTracker(unittest.TestCase):
|
||||
assert(len(self.st.views) == 2)
|
||||
self.st.reshape((5, 20))
|
||||
|
||||
self.st.simplify()
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.views)
|
||||
assert(len(self.st.views) == 2)
|
||||
|
||||
@@ -387,7 +412,7 @@ class TestShapeTrackerFuzzFailures(unittest.TestCase):
|
||||
self.st.reshape((1, 4))
|
||||
self.st.shrink(((0, 1), (1, 3)))
|
||||
print(self.st.st)
|
||||
self.st.simplify()
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.st)
|
||||
def test_case_2(self):
|
||||
self.st.stride( (1, 1, -2) )
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, ReduceOps, LoadOp
|
||||
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.view import strides_for_shape, View
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
|
||||
class LocalBuffer(NamedTuple):
|
||||
name: str
|
||||
@@ -43,11 +43,10 @@ class Kernel:
|
||||
self.reduceop = reduceops[0] if reduceops else None
|
||||
|
||||
# create new shapetrackers inside this kernel, we will permute them
|
||||
self.bufs = [MemBuffer(0, self.info.dtype, (View.create(self.info.shape),))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in LoadOps])
|
||||
self.sts: List[ShapeTracker] = [ShapeTracker(x.views[-1].shape, views=list(x.views)) for x in self.bufs]
|
||||
for st in self.sts: st.simplify()
|
||||
self.bufs = [MemBuffer(0, self.info.dtype, ShapeTracker.from_shape(self.info.shape))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in LoadOps])
|
||||
self.sts: List[ShapeTracker] = [x.st for x in self.bufs]
|
||||
|
||||
self.mem_estimate: int = sum(x.dtype.itemsize*x.views[-1].size() for x in self.bufs)
|
||||
self.mem_estimate: int = sum(x.dtype.itemsize*x.st.size() for x in self.bufs)
|
||||
|
||||
# get earlybufs, before the one reduce op
|
||||
self.earlybufs = [x.arg for x in self.reduceop.get_lazyops() if x.op in LoadOps] if self.reduceop else []
|
||||
@@ -67,7 +66,7 @@ class Kernel:
|
||||
|
||||
def has_variable_shape(self) -> bool:
|
||||
for b in self.bufs:
|
||||
if not all_int(b.views[-1].shape): return True
|
||||
if not all_int(b.st.views[-1].shape): return True
|
||||
return False
|
||||
|
||||
def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
|
||||
|
||||
@@ -260,7 +260,7 @@ class Linearizer(OptimizedKernel):
|
||||
# add a local buffer for multistage reduce. # TODO: use local alias
|
||||
if self.group_for_reduce:
|
||||
# TODO: the strides of this can be controlled
|
||||
self.sts.append(ShapeTracker(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
|
||||
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
|
||||
self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
|
||||
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ("temp", self.sts[-1].size())))
|
||||
|
||||
|
||||
@@ -24,9 +24,12 @@ class OptimizedKernel(Kernel):
|
||||
|
||||
# apply reshape and permute to all shapetrackers
|
||||
def reshape_and_permute(self, new_shape_fxn, axis):
|
||||
new_sts = []
|
||||
for st in self.sts:
|
||||
if new_shape_fxn is not None: st.reshape(tuple(new_shape_fxn(st.shape)))
|
||||
if axis is not None: st.permute(tuple(axis))
|
||||
if new_shape_fxn is not None: st = st.reshape(tuple(new_shape_fxn(st.shape)))
|
||||
if axis is not None: st = st.permute(tuple(axis))
|
||||
new_sts.append(st)
|
||||
self.sts = new_sts
|
||||
|
||||
# drops the final dimension
|
||||
def upcast(self):
|
||||
@@ -76,7 +79,7 @@ class OptimizedKernel(Kernel):
|
||||
else: rets[j].append((shapes[j][i], strides[j][i]))
|
||||
|
||||
# do the reshapes
|
||||
for i,x in enumerate(rets): self.sts[i].reshape(tuple([y[0] for y in x]))
|
||||
for i,x in enumerate(rets): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
|
||||
|
||||
# ******************** GPU simplifiers ********************
|
||||
def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
|
||||
@@ -120,7 +123,7 @@ class OptimizedKernel(Kernel):
|
||||
stride[j] = bst
|
||||
bst *= shp[j]
|
||||
|
||||
self.sts.append(ShapeTracker(tuple(shp), [View.create(tuple(shp), tuple(stride))]))
|
||||
self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
|
||||
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size()))
|
||||
if DEBUG >= 4: print("aliasing buffer", self.sts[i])
|
||||
self.local_alias[i] = self.bufs[-1]
|
||||
|
||||
@@ -7,7 +7,7 @@ import numpy as np
|
||||
from tinygrad.graph import log_op
|
||||
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition, all_int, dedup, merge_dicts
|
||||
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View, get_contraction
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
|
||||
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped, RawBufferTransfer, buf_is_kernel_arg
|
||||
@@ -95,11 +95,10 @@ def _replace_loadops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
|
||||
realized_bufs = dedup([x.realized for x in op.buffers if buf_is_kernel_arg(x)])
|
||||
for x in op.buffers:
|
||||
assert x.realized, "buffer isn't realized"
|
||||
x.st.simplify()
|
||||
if isinstance(x.realized, RawConst):
|
||||
replacements[x] = LazyOp(LoadOps.CONST, (), ConstBuffer(x.realized._buf, x.realized.dtype, tuple(x.st.views)))
|
||||
replacements[x] = LazyOp(LoadOps.CONST, (), ConstBuffer(x.realized._buf, x.realized.dtype, x.st.simplify()))
|
||||
elif x.realized in realized_bufs:
|
||||
replacements[x] = LazyOp(LoadOps.BUFFER, (), MemBuffer(realized_bufs.index(x.realized)+1, x.realized.dtype, tuple(x.st.views)))
|
||||
replacements[x] = LazyOp(LoadOps.BUFFER, (), MemBuffer(realized_bufs.index(x.realized)+1, x.realized.dtype, x.st.simplify()))
|
||||
else:
|
||||
raise NotImplementedError(f"not handled {x}")
|
||||
return (op.src[0] if op.op == MovementOps.RESHAPE else op).map_buffers(replacements), realized_bufs
|
||||
@@ -149,8 +148,8 @@ class LazyBuffer:
|
||||
def __repr__(self): return f"<LB {self.shape} {self.dtype} op={self.op.op if not self.realized else self.realized} st={self.st}>"
|
||||
@property
|
||||
def key(self):
|
||||
if self.realized: return (self.dtype, self.realized.key, self.st.key, self.var_vals_key)
|
||||
return (self.dtype, self.op.op, self.st.key, self.var_vals_key)
|
||||
if self.realized: return (self.dtype, self.realized.key, self.st, self.var_vals_key)
|
||||
return (self.dtype, self.op.op, self.st, self.var_vals_key)
|
||||
|
||||
def _device_extra_args(self) -> Dict[str, str]: return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {}
|
||||
|
||||
@@ -195,7 +194,7 @@ class LazyBuffer:
|
||||
|
||||
@staticmethod
|
||||
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
|
||||
return create_lazybuffer(device, ShapeTracker(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, {})
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, {})
|
||||
|
||||
# create a constant with the shape and dtype of self
|
||||
def const(self, val:Union[float, int]) -> LazyBuffer:
|
||||
@@ -204,11 +203,11 @@ class LazyBuffer:
|
||||
|
||||
def contiguous(self:LazyBuffer) -> LazyBuffer:
|
||||
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
|
||||
return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, self.var_vals)
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, self.var_vals)
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x: np.ndarray) -> LazyBuffer:
|
||||
return LazyBuffer("CPU", ShapeTracker(x.shape, [View.create(x.shape)]), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x))
|
||||
return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x))
|
||||
|
||||
def toCPU(self) -> np.ndarray:
|
||||
assert self.dtype.np, f"{self.dtype} is not supported in toCPU"
|
||||
@@ -242,7 +241,7 @@ class LazyBuffer:
|
||||
# remove the buffers from any (childless) BinaryOps that feed into this
|
||||
srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore
|
||||
|
||||
return create_lazybuffer(out_device, ShapeTracker(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals)
|
||||
return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals)
|
||||
|
||||
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
|
||||
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children:
|
||||
@@ -257,7 +256,7 @@ class LazyBuffer:
|
||||
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if self.shape == tuple(new_shape): return self
|
||||
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
|
||||
return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype, self.var_vals)
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype, self.var_vals)
|
||||
|
||||
def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if any(not isinstance(s, int) for s in self.shape) or prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
|
||||
@@ -282,18 +281,18 @@ class LazyBuffer:
|
||||
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
|
||||
self.op.src[0].var_vals = self.var_vals
|
||||
return self.op.src[0].reshape(arg)
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).reshape(arg), MovementOps.RESHAPE, arg)
|
||||
return self.shuffle_and_prune_movement_ops(self.st.reshape(arg), MovementOps.RESHAPE, arg)
|
||||
|
||||
def pad(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
if all(b == 0 and e == 0 for b,e in arg): return self
|
||||
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).pad(arg), MovementOps.PAD, arg)
|
||||
return self.shuffle_and_prune_movement_ops(self.st.pad(arg), MovementOps.PAD, arg)
|
||||
|
||||
def expand(self: LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if self.shape == arg: return self
|
||||
if not self.realized and self.op.op == MovementOps.EXPAND:
|
||||
return self.op.src[0].expand(arg)
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).expand(arg), MovementOps.EXPAND, arg)
|
||||
return self.shuffle_and_prune_movement_ops(self.st.expand(arg), MovementOps.EXPAND, arg)
|
||||
|
||||
def permute(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
if arg == tuple(range(len(self.shape))): return self
|
||||
@@ -315,19 +314,19 @@ class LazyBuffer:
|
||||
if PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and isinstance(self.op.src[0], LazyBuffer):
|
||||
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
|
||||
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
|
||||
return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(ShapeTracker(self.st).permute(arg).shape)
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).permute(arg), MovementOps.PERMUTE, arg)
|
||||
return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(self.st.permute(arg).shape)
|
||||
return self.shuffle_and_prune_movement_ops(self.st.permute(arg), MovementOps.PERMUTE, arg)
|
||||
|
||||
def shrink(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
|
||||
if all(b - a == s for s, (a, b) in zip(self.shape, arg)): return self
|
||||
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).shrink(arg), MovementOps.SHRINK, arg)
|
||||
return self.shuffle_and_prune_movement_ops(self.st.shrink(arg), MovementOps.SHRINK, arg)
|
||||
|
||||
def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
local_st = ShapeTracker(self.shape).stride(arg)
|
||||
local_st = ShapeTracker.from_shape(self.shape).stride(arg)
|
||||
if self.shape == local_st.shape and local_st.contiguous: return self
|
||||
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(map(operator.mul, arg, self.op.arg)))
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).stride(arg), MovementOps.STRIDE, arg)
|
||||
return self.shuffle_and_prune_movement_ops(self.st.stride(arg), MovementOps.STRIDE, arg)
|
||||
|
||||
@property
|
||||
def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,)
|
||||
|
||||
@@ -3,7 +3,7 @@ import time, importlib, inspect, functools, pathlib
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast, Mapping
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from dataclasses import dataclass
|
||||
if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer
|
||||
|
||||
@@ -25,13 +25,13 @@ OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOp
|
||||
class MemBuffer:
|
||||
idx: int
|
||||
dtype: DType
|
||||
views: Tuple[View, ...]
|
||||
st: ShapeTracker
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstBuffer:
|
||||
val: Any
|
||||
dtype: DType
|
||||
views: Tuple[View, ...]
|
||||
st: ShapeTracker
|
||||
|
||||
class LazyOp:
|
||||
__slots__ = "op", "src", "arg", "buffers", "__weakref__"
|
||||
@@ -101,8 +101,8 @@ Device = _Device()
|
||||
|
||||
# **************** for Interpreted Buffers ****************
|
||||
|
||||
def apply_shapetracker(fxn_for_op, ret, views):
|
||||
for v in views:
|
||||
def apply_shapetracker(fxn_for_op, ret, st):
|
||||
for v in st.views:
|
||||
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
|
||||
real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
|
||||
# first, we apply the offset
|
||||
@@ -133,7 +133,7 @@ class Interpreted:
|
||||
def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, context=None, **kwargs):
|
||||
if ast.op == LoadOps.BUFFER and LoadOps.BUFFER not in self.fxn_for_op:
|
||||
assert inputs[ast.arg.idx-1].dtype == ast.arg.dtype, "dtype mismatch"
|
||||
return self.from_underlying(apply_shapetracker(self.fxn_for_op, self.to_underlying(inputs[ast.arg.idx-1]), ast.arg.views))
|
||||
return self.from_underlying(apply_shapetracker(self.fxn_for_op, self.to_underlying(inputs[ast.arg.idx-1]), ast.arg.st))
|
||||
if TernaryOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
||||
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
|
||||
created_context = context is None
|
||||
@@ -162,7 +162,7 @@ class FlopCounter:
|
||||
self.flops, ret = 0, self.flops
|
||||
return ret
|
||||
shape_fxn_for_op: Dict[Op, Callable] = {
|
||||
LoadOps.BUFFER: lambda arg: (arg.views[-1].shape, arg.dtype, 0), LoadOps.CONST: lambda arg: (arg.views[-1].shape, arg.dtype, 0),
|
||||
LoadOps.BUFFER: lambda arg: (arg.st.shape, arg.dtype, 0), LoadOps.CONST: lambda arg: (arg.st.shape, arg.dtype, 0),
|
||||
UnaryOps.CAST: lambda self,arg: (self.shape, arg[0], self.consume_flops()), # cast uses no flops
|
||||
**{op:lambda self: (self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps if op != UnaryOps.CAST},
|
||||
**{op:lambda self,y: (self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},
|
||||
@@ -241,8 +241,7 @@ class Compiled:
|
||||
for i,a in enumerate(inputs):
|
||||
# TODO: if this is contiguous it's fine
|
||||
if a == output.realized:
|
||||
views = [x.arg.views for x in ast.get_lazyops() if x.op == LoadOps.BUFFER and x.arg.idx == i+1]
|
||||
if any(len(v) > 1 or not v[0].contiguous for v in views):
|
||||
if any(not x.arg.st.contiguous for x in ast.get_lazyops() if x.op == LoadOps.BUFFER and x.arg.idx == i+1):
|
||||
output.realized = None
|
||||
break
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
from typing import Tuple, Union, List, Optional, cast
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, cast
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, sint
|
||||
from tinygrad.shape.view import View
|
||||
@@ -48,7 +49,7 @@ def expr_idxs(view:View, idxs) -> Node:
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
||||
if vm2.mask: return None # this isn't supported yet
|
||||
mst = ShapeTracker(vm1.shape, [vm2, vm1])
|
||||
mst = ShapeTracker((vm2, vm1))
|
||||
strides = mst.real_strides()
|
||||
if None in strides: return None
|
||||
return View.create(vm1.shape, cast(Tuple[sint, ...], strides), mst.real_offset(), vm1.mask)
|
||||
@@ -63,12 +64,13 @@ def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
|
||||
acc *= d
|
||||
return Variable.sum(ret)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShapeTracker:
|
||||
__slots__ = "views"
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[sint, ...]], views:Optional[List[View]]=None):
|
||||
self.views: List[View] = views if views is not None else [*shape.views] if isinstance(shape, ShapeTracker) else [View.create(shape)]
|
||||
def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})"
|
||||
def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views])
|
||||
views: Tuple[View, ...]
|
||||
def __post_init__(self): assert isinstance(self.views, tuple) and all(isinstance(v, View) for v in self.views), "ShapeTracker must be created with a tuple of Views"
|
||||
|
||||
@staticmethod
|
||||
def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
|
||||
|
||||
@property
|
||||
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
|
||||
@@ -76,9 +78,6 @@ class ShapeTracker:
|
||||
@property
|
||||
def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
|
||||
|
||||
@property
|
||||
def key(self) -> Tuple[View, ...]: return tuple(self.views)
|
||||
|
||||
# this is the real size (ish)
|
||||
def size(self): return self.views[-1].size()
|
||||
|
||||
@@ -113,13 +112,13 @@ class ShapeTracker:
|
||||
idx = expr_node(v, idx)
|
||||
return idx, valid
|
||||
|
||||
def simplify(self):
|
||||
def simplify(self) -> ShapeTracker:
|
||||
if len(self.views) >= 2:
|
||||
new_view = merge_views(self.views[-2], self.views[-1])
|
||||
if new_view:
|
||||
if DEBUG >= 4: print(f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}")
|
||||
self.views = self.views[:-2] + [new_view]
|
||||
self.simplify()
|
||||
return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
|
||||
return self
|
||||
|
||||
def expr_idxs(self, idxs=None):
|
||||
if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
|
||||
@@ -137,38 +136,30 @@ class ShapeTracker:
|
||||
|
||||
# *** under this line are the movement ops ***
|
||||
|
||||
def pad(self, arg: Tuple[Tuple[int, int], ...]):
|
||||
self.views[-1] = self.views[-1].pad(arg)
|
||||
return self
|
||||
def pad(self, arg: Tuple[Tuple[int, int], ...]) -> ShapeTracker:
|
||||
return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
|
||||
|
||||
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]):
|
||||
self.views[-1] = self.views[-1].shrink(arg)
|
||||
return self
|
||||
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker:
|
||||
return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
|
||||
|
||||
def expand(self, new_shape: Tuple[sint, ...]):
|
||||
self.views[-1] = self.views[-1].expand(new_shape)
|
||||
return self
|
||||
def expand(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
|
||||
return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
|
||||
|
||||
def permute(self, axis: Tuple[int, ...]):
|
||||
self.views[-1] = self.views[-1].permute(axis)
|
||||
return self
|
||||
def permute(self, axis: Tuple[int, ...]) -> ShapeTracker:
|
||||
return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
|
||||
|
||||
def stride(self, mul: Tuple[int, ...]):
|
||||
self.views[-1] = self.views[-1].stride(mul)
|
||||
return self
|
||||
def stride(self, mul: Tuple[int, ...]) -> ShapeTracker:
|
||||
return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), ))
|
||||
|
||||
def reshape(self, new_shape: Tuple[sint, ...]):
|
||||
def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
|
||||
new_view = self.views[-1].reshape(new_shape)
|
||||
if new_view is None:
|
||||
extra_view = View.create(new_shape)
|
||||
# last chance to merge. TODO: move into View
|
||||
if (merged_view := merge_views(self.views[-1], extra_view)) is not None:
|
||||
self.views[-1] = merged_view
|
||||
else:
|
||||
self.views.append(extra_view)
|
||||
else:
|
||||
self.views[-1] = new_view
|
||||
return self
|
||||
return ShapeTracker(self.views[0:-1] + (merged_view,))
|
||||
return ShapeTracker(self.views + (extra_view, ))
|
||||
return ShapeTracker(self.views[0:-1] + (new_view,))
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
# TODO: if we remove movementops from lazy.py we can delete this
|
||||
|
||||
Reference in New Issue
Block a user