Make ShapeTracker Immutable (#1909)

* ugh

* ops test pass

* fix shapetracker tests

* sym shapetracker

* shapetracker is a tuple of views now

* from_shape

* fix has variable shape

* key isn't needed

* post init assert
This commit is contained in:
George Hotz
2023-09-24 21:09:03 +08:00
committed by GitHub
parent 45f02393f0
commit 20059dc55b
11 changed files with 143 additions and 128 deletions

View File

@@ -326,22 +326,22 @@ void E_1(float* data0) {
from tinygrad.shape.shapetracker import ShapeTracker
# create a virtual (10, 10) Tensor. this is just a shape, there's no actual tensor
a = ShapeTracker((10, 10))
a = ShapeTracker.from_shape((10, 10))
# you'll see it has one view. the (10, 1 are the strides)
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)])
# we can permute it, and the strides change
a.permute((1,0))
a = a.permute((1,0))
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)])
# we can then reshape it, and the strides change again
# note how the permute stays applied
a.reshape((5,2,5,2))
a = a.reshape((5,2,5,2))
print(a) # ShapeTracker(shape=(5, 2, 5, 2), views=[View((5, 2, 5, 2), (2, 1, 20, 10), 0)])
# now, if we were to reshape it to a (100,) shape tensor, we have to create a second view
a.reshape((100,))
a = a.reshape((100,))
print(a) # ShapeTracker(shape=(100,), views=[
# View((5, 2, 5, 2), (2, 1, 20, 10), 0),
# View((100,), (1,), 0)])
@@ -352,7 +352,7 @@ idx, _ = a.expr_idxs()
print(idx.render()) # (((idx0%10)*10)+(idx0//10))
# of course, if we reshape it back, the indexes get simple again
a.reshape((10,10))
a = a.reshape((10,10))
idx, _ = a.expr_idxs()
print(idx.render()) # ((idx1*10)+idx0)
@@ -362,11 +362,11 @@ print(a) # ShapeTracker(shape=(10, 10), views=[
# View((10, 10), (10, 1), 0)])
# ...until we simplify it!
a.simplify()
a = a.simplify()
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)])
# and now we permute it back
a.permute((1,0))
a = a.permute((1,0))
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)])
# and it's even contiguous

View File

@@ -43,7 +43,7 @@ class ATan2(Function):
assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch"
self.a, self.b = a, b
ast = LazyOp(LoadOps.CUSTOM, (a.contiguous(), b.contiguous()), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device])
return create_lazybuffer(a.device, ShapeTracker(a.shape), LoadOps, ast, max(a.dtype, b.dtype), {})
return create_lazybuffer(a.device, ShapeTracker.from_shape(a.shape), LoadOps, ast, max(a.dtype, b.dtype), {})
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
denom = (self.a.e(BinaryOps.MUL, self.a)).e(BinaryOps.ADD, self.b.e(BinaryOps.MUL, self.b))
return grad_output.e(BinaryOps.MUL, self.b.e(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \

View File

@@ -6,18 +6,18 @@ from tinygrad.tensor import Tensor
class TestSymbolic(unittest.TestCase):
def test_symbolic_st(self):
x = Variable("x", 1, 100)
st = ShapeTracker((x, 3))
st = ShapeTracker.from_shape((x, 3))
assert st.shape == (x, 3)
assert st.real_strides() == (3, 1)
def test_expr_idxs(self):
x = Variable("x", 1, 100)
st = ShapeTracker((x, 3))
st = ShapeTracker.from_shape((x, 3))
idxs = [Variable("x", 0, 100), Variable("y", 0, 100)]
e1, e2 = st.expr_idxs(idxs)
assert e1.render() == "((x*3)+y)"
assert e2.render() == "1"
st.permute((1, 0))
st = st.permute((1, 0))
e1, e2 = st.expr_idxs(idxs)
assert e1.render() == "((y*3)+x)"
assert e2.render() == "1"
@@ -142,8 +142,7 @@ class TestSymbolicShapeExpr(unittest.TestCase):
idx = (gidx0, lidx1, Variable.num(1))
shape = (i+1, 8, 4)
strides = (1, (i*4)+4, i+1)
view = View.create(shape, strides)
st = ShapeTracker(shape, [view])
st = ShapeTracker((View.create(shape, strides), ))
idx, valid = st.expr_idxs(idx)
assert idx.render() == "((lidx1*((i*4)+4))+1+gidx0+i)"

View File

@@ -1,13 +1,13 @@
#!/usr/bin/env python
import unittest
from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info, LoadOps, MemBuffer
from tinygrad.shape.view import View
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.helpers import dtypes
class TestFlopCounter(unittest.TestCase):
def setUp(self):
self.buf0 = LazyOp(LoadOps.BUFFER, (), MemBuffer(1, dtypes.float32, (View.create((4,)),)))
self.buf1 = LazyOp(LoadOps.BUFFER, (), MemBuffer(2, dtypes.float32, (View.create((4,)),)))
self.buf0 = LazyOp(LoadOps.BUFFER, (), MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,))))
self.buf1 = LazyOp(LoadOps.BUFFER, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,))))
def test_flops_add(self):
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)

View File

@@ -14,42 +14,51 @@ def shapetracker_getitem(st, val):
class CheckingShapeTracker:
def __init__(self, shape):
self.st = ShapeTracker(shape)
self.st = ShapeTracker.from_shape(shape)
self.t = np.arange(prod(shape), dtype=np.int32).reshape(shape)
@property
def shape(self):
return self.t.shape
def simplify(self): self.st.simplify()
def simplify(self):
self.st = self.st.simplify()
return self
def reshape(self, new_shape):
self.st.reshape(new_shape)
self.st = self.st.reshape(new_shape)
self.t = self.t.reshape(new_shape)
return self
def permute(self, axis):
self.st.permute(axis)
self.st = self.st.permute(axis)
self.t = np.transpose(self.t, axis)
return self
def expand(self, new_shape):
self.st.expand(new_shape)
self.st = self.st.expand(new_shape)
self.t = np.broadcast_to(self.t, new_shape)
return self
def flip(self, axis):
self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
self.st = self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
self.t = np.flip(self.t, axis)
return self
def shrink(self, arg):
self.st.shrink(arg)
self.st = self.st.shrink(arg)
self.t = self.t[tuple([slice(x[0], x[1]) for x in arg])]
return self
def pad(self, arg):
self.st.pad(arg)
self.st = self.st.pad(arg)
self.t = np.pad(self.t, arg, constant_values=-1)
return self
def stride(self, arg):
self.st.stride(arg)
self.st = self.st.stride(arg)
self.t = self.t[tuple([slice(None, None, x) for x in arg])]
return self
def __getitem__(self, val):
return self.t.flatten()[val]
@@ -70,7 +79,7 @@ class CheckingShapeTracker:
class TestRealIssues(unittest.TestCase):
def test_reshape_doesnt_multiview(self):
self.st = ShapeTracker((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), views=[View.create((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None)])
self.st = ShapeTracker((View.create((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None),))
self.st.reshape((128, 2, 256, 2, 2, 2, 2, 2, 256, 8, 2))
assert len(self.st.views) == 1
@@ -78,27 +87,27 @@ class TestRealDoesntSimplify(unittest.TestCase):
def tearDown(self):
st = self.st.real_strides()
print(st)
self.st.simplify()
self.st = self.st.simplify()
assert len(self.st.views) != 1
assert None in st
def test_1(self):
self.st = ShapeTracker((8, 6, 11), views=[
self.st = ShapeTracker((
View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None),
View.create((8, 6, 11), (66, 11, 1), 0, None)])
View.create((8, 6, 11), (66, 11, 1), 0, None)))
assert self.st.real_strides() == (33, None, 1)
def test_2(self):
self.st = ShapeTracker((4, 4, 3, 3), views=[
self.st = ShapeTracker((
View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None),
View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None)])
View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None)))
assert self.st.real_strides() == (None, 18, -3, -1)
class TestRealStrides(unittest.TestCase):
def test_1(self):
self.st = ShapeTracker((16, 32, 4), views=[
self.st = ShapeTracker((
View.create((2048,), (1,), 0, ((0, 512),)),
View.create((16, 32, 4), (128, 4, 1), 0, None)])
View.create((16, 32, 4), (128, 4, 1), 0, None)))
st = self.st.real_strides()
print(self.st, st)
assert st == (None, 4, 1)
@@ -106,27 +115,27 @@ class TestRealStrides(unittest.TestCase):
class TestRealSimplifies(unittest.TestCase):
def tearDown(self):
st = self.st.real_strides()
self.st.simplify()
self.st = self.st.simplify()
assert len(self.st.views) == 1
print(self.st.views[-1].strides, st)
assert self.st.views[-1].strides == st
def test_1(self):
self.st = ShapeTracker((1, 3, 2, 11, 26, 1, 1, 3), views=[
self.st = ShapeTracker((
View.create((1, 3, 2, 11, 4, 28), (0, 308, 0, 28, 0, 1), 0, None),
View.create((1, 3, 2, 11, 26, 1, 1, 3), (0, 2464, 0, 112, 1, 0, 0, 29), 0, None)])
View.create((1, 3, 2, 11, 26, 1, 1, 3), (0, 2464, 0, 112, 1, 0, 0, 29), 0, None)))
def test_2(self):
self.st = ShapeTracker((8, 1, 6, 10, 28, 3, 2, 1), views=[
self.st = ShapeTracker((
View.create((8, 3, 3, 11, 2, 28), (924, 308, 0, 28, 0, 1), 0, None),
View.create((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)])
View.create((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)))
class TestIndexExpressions2d(unittest.TestCase):
def setUp(self):
shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5
offsets = [0, 1, 15, 28, 10000]
self.sts = [ShapeTracker(base_shape, [View.create(base_shape, offset=offset)]) for base_shape in shapes for offset in offsets]
self.sts = [ShapeTracker((View.create(base_shape, offset=offset),)) for base_shape in shapes for offset in offsets]
self.offset = [Variable.num(offset) for base_shape in shapes for offset in offsets]
self.shapes = [shape for shape in shapes for offset in offsets]
self.node_exprs = []
@@ -171,36 +180,52 @@ class TestIndexExpressions2d(unittest.TestCase):
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[1] + offset)
def test_permute(self):
new_st = []
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
st.permute((1, 0))
st = st.permute((1, 0))
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0] + idxs[1]*base_shape[1] + offset)
new_st.append(st)
self.sts = new_st
def test_reshape(self):
new_st = []
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
st.reshape((base_shape[0], 1, base_shape[1]))
st = st.reshape((base_shape[0], 1, base_shape[1]))
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape) + offset)
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
new_st.append(st)
self.sts = new_st
def test_reshape_expand(self):
new_st = []
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
st.reshape((base_shape[0], 1, base_shape[1]))
st.expand((base_shape[0], base_shape[1], base_shape[1]))
st = st.reshape((base_shape[0], 1, base_shape[1]))
st = st.expand((base_shape[0], base_shape[1], base_shape[1]))
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx//(base_shape[1]*base_shape[1])%base_shape[0]*base_shape[1] + idx%base_shape[1] + offset)
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
new_st.append(st)
self.sts = new_st
def test_permute_reshape_1(self): # This tests multiple views
new_st = []
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
st.permute((1, 0))
st.reshape((base_shape[0]//5, 1, base_shape[1]*5))
st = st.permute((1, 0))
st = st.reshape((base_shape[0]//5, 1, base_shape[1]*5))
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape)%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[0]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + (idxs[0]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset)
new_st.append(st)
self.sts = new_st
def test_permute_reshape_2(self):
new_st = []
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
st.permute((1, 0))
st.reshape((1, base_shape[0]//5, base_shape[1]*5))
st = st.permute((1, 0))
st = st.reshape((1, base_shape[0]//5, base_shape[1]*5))
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape)%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[1]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + (idxs[1]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset)
new_st.append(st)
self.sts = new_st
class TestSimplifyingShapeTracker(unittest.TestCase):
def setUp(self):
@@ -211,14 +236,14 @@ class TestSimplifyingShapeTracker(unittest.TestCase):
# multiview simplify
def test_expand_contract_simple(self):
self.st.expand((10, 10))
self.st.reshape((100,))
self.st = self.st.expand((10, 10))
self.st = self.st.reshape((100,))
print(self.st.views)
assert(len(self.st.views) == 2)
self.st.reshape((10, 10))
self.st = self.st.reshape((10, 10))
print(self.st.views)
self.st.simplify()
self.st = self.st.simplify()
print(self.st.views)
assert(len(self.st.views) == 1)
@@ -231,7 +256,7 @@ class TestSimplifyingShapeTracker(unittest.TestCase):
self.st.reshape((2, 5, 2, 5))
print(self.st.views)
self.st.simplify()
self.st = self.st.simplify()
print(self.st.views)
assert(len(self.st.views) == 1)
@@ -243,7 +268,7 @@ class TestSimplifyingShapeTracker(unittest.TestCase):
assert(len(self.st.views) == 2)
self.st.reshape((5, 20))
self.st.simplify()
self.st = self.st.simplify()
print(self.st.views)
assert(len(self.st.views) == 2)
@@ -387,7 +412,7 @@ class TestShapeTrackerFuzzFailures(unittest.TestCase):
self.st.reshape((1, 4))
self.st.shrink(((0, 1), (1, 3)))
print(self.st.st)
self.st.simplify()
self.st = self.st.simplify()
print(self.st.st)
def test_case_2(self):
self.st.stride( (1, 1, -2) )

View File

@@ -4,7 +4,7 @@ from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, ReduceOps, LoadOp
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import sint
from tinygrad.shape.view import strides_for_shape, View
from tinygrad.shape.view import strides_for_shape
class LocalBuffer(NamedTuple):
name: str
@@ -43,11 +43,10 @@ class Kernel:
self.reduceop = reduceops[0] if reduceops else None
# create new shapetrackers inside this kernel, we will permute them
self.bufs = [MemBuffer(0, self.info.dtype, (View.create(self.info.shape),))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in LoadOps])
self.sts: List[ShapeTracker] = [ShapeTracker(x.views[-1].shape, views=list(x.views)) for x in self.bufs]
for st in self.sts: st.simplify()
self.bufs = [MemBuffer(0, self.info.dtype, ShapeTracker.from_shape(self.info.shape))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in LoadOps])
self.sts: List[ShapeTracker] = [x.st for x in self.bufs]
self.mem_estimate: int = sum(x.dtype.itemsize*x.views[-1].size() for x in self.bufs)
self.mem_estimate: int = sum(x.dtype.itemsize*x.st.size() for x in self.bufs)
# get earlybufs, before the one reduce op
self.earlybufs = [x.arg for x in self.reduceop.get_lazyops() if x.op in LoadOps] if self.reduceop else []
@@ -67,7 +66,7 @@ class Kernel:
def has_variable_shape(self) -> bool:
for b in self.bufs:
if not all_int(b.views[-1].shape): return True
if not all_int(b.st.views[-1].shape): return True
return False
def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]

View File

@@ -260,7 +260,7 @@ class Linearizer(OptimizedKernel):
# add a local buffer for multistage reduce. # TODO: use local alias
if self.group_for_reduce:
# TODO: the strides of this can be controlled
self.sts.append(ShapeTracker(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ("temp", self.sts[-1].size())))

View File

@@ -24,9 +24,12 @@ class OptimizedKernel(Kernel):
# apply reshape and permute to all shapetrackers
def reshape_and_permute(self, new_shape_fxn, axis):
new_sts = []
for st in self.sts:
if new_shape_fxn is not None: st.reshape(tuple(new_shape_fxn(st.shape)))
if axis is not None: st.permute(tuple(axis))
if new_shape_fxn is not None: st = st.reshape(tuple(new_shape_fxn(st.shape)))
if axis is not None: st = st.permute(tuple(axis))
new_sts.append(st)
self.sts = new_sts
# drops the final dimension
def upcast(self):
@@ -76,7 +79,7 @@ class OptimizedKernel(Kernel):
else: rets[j].append((shapes[j][i], strides[j][i]))
# do the reshapes
for i,x in enumerate(rets): self.sts[i].reshape(tuple([y[0] for y in x]))
for i,x in enumerate(rets): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
# ******************** GPU simplifiers ********************
def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
@@ -120,7 +123,7 @@ class OptimizedKernel(Kernel):
stride[j] = bst
bst *= shp[j]
self.sts.append(ShapeTracker(tuple(shp), [View.create(tuple(shp), tuple(stride))]))
self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size()))
if DEBUG >= 4: print("aliasing buffer", self.sts[i])
self.local_alias[i] = self.bufs[-1]

View File

@@ -7,7 +7,7 @@ import numpy as np
from tinygrad.graph import log_op
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition, all_int, dedup, merge_dicts
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer
from tinygrad.shape.shapetracker import ShapeTracker, View, get_contraction
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped, RawBufferTransfer, buf_is_kernel_arg
@@ -95,11 +95,10 @@ def _replace_loadops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
realized_bufs = dedup([x.realized for x in op.buffers if buf_is_kernel_arg(x)])
for x in op.buffers:
assert x.realized, "buffer isn't realized"
x.st.simplify()
if isinstance(x.realized, RawConst):
replacements[x] = LazyOp(LoadOps.CONST, (), ConstBuffer(x.realized._buf, x.realized.dtype, tuple(x.st.views)))
replacements[x] = LazyOp(LoadOps.CONST, (), ConstBuffer(x.realized._buf, x.realized.dtype, x.st.simplify()))
elif x.realized in realized_bufs:
replacements[x] = LazyOp(LoadOps.BUFFER, (), MemBuffer(realized_bufs.index(x.realized)+1, x.realized.dtype, tuple(x.st.views)))
replacements[x] = LazyOp(LoadOps.BUFFER, (), MemBuffer(realized_bufs.index(x.realized)+1, x.realized.dtype, x.st.simplify()))
else:
raise NotImplementedError(f"not handled {x}")
return (op.src[0] if op.op == MovementOps.RESHAPE else op).map_buffers(replacements), realized_bufs
@@ -149,8 +148,8 @@ class LazyBuffer:
def __repr__(self): return f"<LB {self.shape} {self.dtype} op={self.op.op if not self.realized else self.realized} st={self.st}>"
@property
def key(self):
if self.realized: return (self.dtype, self.realized.key, self.st.key, self.var_vals_key)
return (self.dtype, self.op.op, self.st.key, self.var_vals_key)
if self.realized: return (self.dtype, self.realized.key, self.st, self.var_vals_key)
return (self.dtype, self.op.op, self.st, self.var_vals_key)
def _device_extra_args(self) -> Dict[str, str]: return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {}
@@ -195,7 +194,7 @@ class LazyBuffer:
@staticmethod
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
return create_lazybuffer(device, ShapeTracker(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, {})
return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, {})
# create a constant with the shape and dtype of self
def const(self, val:Union[float, int]) -> LazyBuffer:
@@ -204,11 +203,11 @@ class LazyBuffer:
def contiguous(self:LazyBuffer) -> LazyBuffer:
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, self.var_vals)
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, self.var_vals)
@staticmethod
def fromCPU(x: np.ndarray) -> LazyBuffer:
return LazyBuffer("CPU", ShapeTracker(x.shape, [View.create(x.shape)]), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x))
return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x))
def toCPU(self) -> np.ndarray:
assert self.dtype.np, f"{self.dtype} is not supported in toCPU"
@@ -242,7 +241,7 @@ class LazyBuffer:
# remove the buffers from any (childless) BinaryOps that feed into this
srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore
return create_lazybuffer(out_device, ShapeTracker(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals)
return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals)
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children:
@@ -257,7 +256,7 @@ class LazyBuffer:
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
if self.shape == tuple(new_shape): return self
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype, self.var_vals)
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype, self.var_vals)
def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
if any(not isinstance(s, int) for s in self.shape) or prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
@@ -282,18 +281,18 @@ class LazyBuffer:
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
self.op.src[0].var_vals = self.var_vals
return self.op.src[0].reshape(arg)
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).reshape(arg), MovementOps.RESHAPE, arg)
return self.shuffle_and_prune_movement_ops(self.st.reshape(arg), MovementOps.RESHAPE, arg)
def pad(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
if all(b == 0 and e == 0 for b,e in arg): return self
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).pad(arg), MovementOps.PAD, arg)
return self.shuffle_and_prune_movement_ops(self.st.pad(arg), MovementOps.PAD, arg)
def expand(self: LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
if self.shape == arg: return self
if not self.realized and self.op.op == MovementOps.EXPAND:
return self.op.src[0].expand(arg)
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).expand(arg), MovementOps.EXPAND, arg)
return self.shuffle_and_prune_movement_ops(self.st.expand(arg), MovementOps.EXPAND, arg)
def permute(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
if arg == tuple(range(len(self.shape))): return self
@@ -315,19 +314,19 @@ class LazyBuffer:
if PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and isinstance(self.op.src[0], LazyBuffer):
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(ShapeTracker(self.st).permute(arg).shape)
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).permute(arg), MovementOps.PERMUTE, arg)
return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(self.st.permute(arg).shape)
return self.shuffle_and_prune_movement_ops(self.st.permute(arg), MovementOps.PERMUTE, arg)
def shrink(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
if all(b - a == s for s, (a, b) in zip(self.shape, arg)): return self
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).shrink(arg), MovementOps.SHRINK, arg)
return self.shuffle_and_prune_movement_ops(self.st.shrink(arg), MovementOps.SHRINK, arg)
def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
local_st = ShapeTracker(self.shape).stride(arg)
local_st = ShapeTracker.from_shape(self.shape).stride(arg)
if self.shape == local_st.shape and local_st.contiguous: return self
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(map(operator.mul, arg, self.op.arg)))
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).stride(arg), MovementOps.STRIDE, arg)
return self.shuffle_and_prune_movement_ops(self.st.stride(arg), MovementOps.STRIDE, arg)
@property
def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,)

View File

@@ -3,7 +3,7 @@ import time, importlib, inspect, functools, pathlib
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast, Mapping
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored
from tinygrad.shape.view import View
from tinygrad.shape.shapetracker import ShapeTracker
from dataclasses import dataclass
if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer
@@ -25,13 +25,13 @@ OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOp
class MemBuffer:
idx: int
dtype: DType
views: Tuple[View, ...]
st: ShapeTracker
@dataclass(frozen=True)
class ConstBuffer:
val: Any
dtype: DType
views: Tuple[View, ...]
st: ShapeTracker
class LazyOp:
__slots__ = "op", "src", "arg", "buffers", "__weakref__"
@@ -101,8 +101,8 @@ Device = _Device()
# **************** for Interpreted Buffers ****************
def apply_shapetracker(fxn_for_op, ret, views):
for v in views:
def apply_shapetracker(fxn_for_op, ret, st):
for v in st.views:
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
# first, we apply the offset
@@ -133,7 +133,7 @@ class Interpreted:
def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, context=None, **kwargs):
if ast.op == LoadOps.BUFFER and LoadOps.BUFFER not in self.fxn_for_op:
assert inputs[ast.arg.idx-1].dtype == ast.arg.dtype, "dtype mismatch"
return self.from_underlying(apply_shapetracker(self.fxn_for_op, self.to_underlying(inputs[ast.arg.idx-1]), ast.arg.views))
return self.from_underlying(apply_shapetracker(self.fxn_for_op, self.to_underlying(inputs[ast.arg.idx-1]), ast.arg.st))
if TernaryOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
created_context = context is None
@@ -162,7 +162,7 @@ class FlopCounter:
self.flops, ret = 0, self.flops
return ret
shape_fxn_for_op: Dict[Op, Callable] = {
LoadOps.BUFFER: lambda arg: (arg.views[-1].shape, arg.dtype, 0), LoadOps.CONST: lambda arg: (arg.views[-1].shape, arg.dtype, 0),
LoadOps.BUFFER: lambda arg: (arg.st.shape, arg.dtype, 0), LoadOps.CONST: lambda arg: (arg.st.shape, arg.dtype, 0),
UnaryOps.CAST: lambda self,arg: (self.shape, arg[0], self.consume_flops()), # cast uses no flops
**{op:lambda self: (self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps if op != UnaryOps.CAST},
**{op:lambda self,y: (self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},
@@ -241,8 +241,7 @@ class Compiled:
for i,a in enumerate(inputs):
# TODO: if this is contiguous it's fine
if a == output.realized:
views = [x.arg.views for x in ast.get_lazyops() if x.op == LoadOps.BUFFER and x.arg.idx == i+1]
if any(len(v) > 1 or not v[0].contiguous for v in views):
if any(not x.arg.st.contiguous for x in ast.get_lazyops() if x.op == LoadOps.BUFFER and x.arg.idx == i+1):
output.realized = None
break

View File

@@ -1,7 +1,8 @@
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
import functools
from typing import Tuple, Union, List, Optional, cast
from dataclasses import dataclass
from typing import Tuple, List, Optional, cast
from tinygrad.helpers import prod, DEBUG
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, sint
from tinygrad.shape.view import View
@@ -48,7 +49,7 @@ def expr_idxs(view:View, idxs) -> Node:
@functools.lru_cache(maxsize=None)
def merge_views(vm2:View, vm1:View) -> Optional[View]:
if vm2.mask: return None # this isn't supported yet
mst = ShapeTracker(vm1.shape, [vm2, vm1])
mst = ShapeTracker((vm2, vm1))
strides = mst.real_strides()
if None in strides: return None
return View.create(vm1.shape, cast(Tuple[sint, ...], strides), mst.real_offset(), vm1.mask)
@@ -63,12 +64,13 @@ def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
acc *= d
return Variable.sum(ret)
@dataclass(frozen=True)
class ShapeTracker:
__slots__ = "views"
def __init__(self, shape:Union[ShapeTracker, Tuple[sint, ...]], views:Optional[List[View]]=None):
self.views: List[View] = views if views is not None else [*shape.views] if isinstance(shape, ShapeTracker) else [View.create(shape)]
def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})"
def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views])
views: Tuple[View, ...]
def __post_init__(self): assert isinstance(self.views, tuple) and all(isinstance(v, View) for v in self.views), "ShapeTracker must be created with a tuple of Views"
@staticmethod
def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
@property
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
@@ -76,9 +78,6 @@ class ShapeTracker:
@property
def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
@property
def key(self) -> Tuple[View, ...]: return tuple(self.views)
# this is the real size (ish)
def size(self): return self.views[-1].size()
@@ -113,13 +112,13 @@ class ShapeTracker:
idx = expr_node(v, idx)
return idx, valid
def simplify(self):
def simplify(self) -> ShapeTracker:
if len(self.views) >= 2:
new_view = merge_views(self.views[-2], self.views[-1])
if new_view:
if DEBUG >= 4: print(f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}")
self.views = self.views[:-2] + [new_view]
self.simplify()
return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
return self
def expr_idxs(self, idxs=None):
if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
@@ -137,38 +136,30 @@ class ShapeTracker:
# *** under this line are the movement ops ***
def pad(self, arg: Tuple[Tuple[int, int], ...]):
self.views[-1] = self.views[-1].pad(arg)
return self
def pad(self, arg: Tuple[Tuple[int, int], ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]):
self.views[-1] = self.views[-1].shrink(arg)
return self
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
def expand(self, new_shape: Tuple[sint, ...]):
self.views[-1] = self.views[-1].expand(new_shape)
return self
def expand(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
def permute(self, axis: Tuple[int, ...]):
self.views[-1] = self.views[-1].permute(axis)
return self
def permute(self, axis: Tuple[int, ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
def stride(self, mul: Tuple[int, ...]):
self.views[-1] = self.views[-1].stride(mul)
return self
def stride(self, mul: Tuple[int, ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), ))
def reshape(self, new_shape: Tuple[sint, ...]):
def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
new_view = self.views[-1].reshape(new_shape)
if new_view is None:
extra_view = View.create(new_shape)
# last chance to merge. TODO: move into View
if (merged_view := merge_views(self.views[-1], extra_view)) is not None:
self.views[-1] = merged_view
else:
self.views.append(extra_view)
else:
self.views[-1] = new_view
return self
return ShapeTracker(self.views[0:-1] + (merged_view,))
return ShapeTracker(self.views + (extra_view, ))
return ShapeTracker(self.views[0:-1] + (new_view,))
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
# TODO: if we remove movementops from lazy.py we can delete this