St real size (#3046)

* track the size in the lazybuffer

* shapetracker real size

* lint
This commit is contained in:
George Hotz
2024-01-08 14:44:53 -08:00
committed by GitHub
parent 1d730b8853
commit 655c6f61d3
10 changed files with 25 additions and 22 deletions

View File

@@ -370,7 +370,7 @@ After you are done speaking, output [EOS]. You are not Chad.
TOKENIZER_PATH = (MODEL_PATH if MODEL_PATH.is_dir() else MODEL_PATH.parent) / "tokenizer.model"
print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model")
llama = LLaMa.build(MODEL_PATH, TOKENIZER_PATH, model_gen=args.gen, model_size=args.size, quantize=args.quantize)
param_count = sum(x.lazydata.st.size() for x in get_parameters(llama.model))
param_count = sum(x.lazydata.size for x in get_parameters(llama.model))
if chatbot:
# encode pre prompt

View File

@@ -12,7 +12,7 @@ class TestTimeLinearizer(unittest.TestCase):
def test_reasonable_time(self):
si = [si for si in Tensor([1,2,3,4]).add(1).lazydata.schedule() if si.ast.op not in LoadOps][0]
rawbufs = [Buffer(Device.DEFAULT, si.out.st.size(), si.out.dtype)] + [Buffer(Device.DEFAULT, x.st.size(), x.dtype) for x in si.inputs]
rawbufs = [Buffer(Device.DEFAULT, si.out.st.real_size(), si.out.dtype)] + [Buffer(Device.DEFAULT, x.st.real_size(), x.dtype) for x in si.inputs]
tm = time_linearizer(Linearizer(si.ast), rawbufs, allow_test_size=False, cnt=10)
assert tm > 0 and tm != float('inf')

View File

@@ -797,35 +797,35 @@ class TestGetContraction(unittest.TestCase):
class TestShapeTrackerSize(unittest.TestCase):
def test_simple_size(self):
st = ShapeTracker.from_shape((100, 100))
self.assertEqual(st.size(), 100*100)
self.assertEqual(st.real_size(), 100*100)
def test_expand_size(self):
st = ShapeTracker.from_shape((100, 100))
st = st.reshape((100, 100, 1))
st = st.expand((100, 100, 100))
self.assertEqual(st.size(), 100*100)
self.assertEqual(st.real_size(), 100*100)
def test_expand_size_flatten(self):
st = ShapeTracker.from_shape((100, 100))
st = st.reshape((100, 100, 1))
st = st.expand((100, 100, 100))
st = st.reshape((100*100*100,))
self.assertEqual(st.size(), 100*100)
self.assertEqual(st.real_size(), 100*100)
def test_shrink_size_axis_0(self):
st = ShapeTracker.from_shape((100, 100))
st = st.shrink(((0, 50), (0, 100)))
self.assertEqual(st.size(), 50*100)
self.assertEqual(st.real_size(), 50*100)
def test_shrink_size_axis_0_variable(self):
st = ShapeTracker.from_shape((100, 100))
st = st.shrink(((0, Variable("a", 0, 50)), (0, 100)))
self.assertEqual(st.size(), 50*100)
self.assertEqual(st.real_size(), 50*100)
def test_shrink_size_axis_1(self):
st = ShapeTracker.from_shape((100, 100))
st = st.shrink(((0, 100), (0, 50)))
self.assertEqual(st.size(), 9950) # careful here
self.assertEqual(st.real_size(), 9950) # careful here
if __name__ == '__main__':
unittest.main()

View File

@@ -326,7 +326,7 @@ class Kernel:
bst *= shp[j]
self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size()))
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size))
if DEBUG >= 4: print("aliasing buffer", self.sts[i])
self.local_alias[i] = cast(LocalBuffer, self.bufs[-1])

View File

@@ -206,14 +206,14 @@ class Linearizer(Kernel):
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), var.expr)
# define local buffers
for lb in self.local_alias.values():
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size()))
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size))
# add a local buffer for multistage reduce. # TODO: use local alias
if self.group_for_reduce:
# TODO: the strides of this can be controlled
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
temp_dtype = self.get_base_dtype(get_lazyop_info(self.reduceop).dtype)
self.bufs.append(LocalBuffer("temp", self.sts[-1].size(), temp_dtype))
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size())))
self.bufs.append(LocalBuffer("temp", self.sts[-1].size, temp_dtype))
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size)))
# kernel name (before late upcast)
self.name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])

View File

@@ -64,7 +64,8 @@ def bufs_from_lin(lin:Linearizer) -> List[Buffer]:
for x in lin.membufs: bufsts[x.idx].append(x)
rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
for k,lx in bufsts.items():
rawbufs[k] = Buffer(Device.DEFAULT, prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.size() for y in lx), lx[0].dtype)
buf_size = prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.real_size() for y in lx)
rawbufs[k] = Buffer(Device.DEFAULT, buf_size, lx[0].dtype)
assert all(r is not None for r in rawbufs)
return cast(List[Buffer], rawbufs)

View File

@@ -5,7 +5,7 @@ from typing import Union, Optional, Any, Tuple, List, Set, Dict
from tinygrad.dtype import dtypes, DType, ImageDType
from tinygrad.helpers import prod, merge_dicts, flatten, getenv, dedup, DEBUG, all_int, all_same
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem
from tinygrad.shape.symbolic import sint, Variable, Node
from tinygrad.shape.symbolic import sint, Variable
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer, Device
from tinygrad.graph import log_lazybuffer
@@ -34,8 +34,7 @@ class LazyBuffer:
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
base:Optional[LazyBuffer]=None):
assert isinstance(device, str) and device == Device.canonicalize(device)
self.device, self.st, self.dtype, self.shape = device, st, dtype, st.shape
self.size = prod([x.max if isinstance(x, Node) else x for x in self.shape])
self.device, self.st, self.dtype, self.shape, self.size = device, st, dtype, st.shape, st.size
if base is None:
# properties on base
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
@@ -278,7 +277,7 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
# can only have one output buffer
# can only reduce contiguous
# max one reduceop per kernel
if len(realized_children) > 1 or not st.contiguous or st.size() != r.st.size() or (tr in reduce_for_op and reduce_for_op[tr] != r):
if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
forced_realize = True
break
@@ -304,7 +303,7 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
tr_next = next(iter(tr.children))
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
if len(st_childs) > 1: break
if st.size() != st_childs[0].st.size(): break
if st.size != st_childs[0].st.size: break
st = st + st_childs[0].st
if not st.contiguous or tr_next.op in ReduceOps: break
tr = tr_next

View File

@@ -83,9 +83,9 @@ class FlopCounter:
return ret
InterpretedFlopCounter: Dict[Op, Callable] = {
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.size()}),
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.real_size()}),
BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}),
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.size()}), # noqa: E501
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.real_size()}), # noqa: E501
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops
**{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST}, # noqa: E501
**{op:lambda self,y,op=op: FlopCounter(self.shape, dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else self.dtype, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501

View File

@@ -46,5 +46,5 @@ def run_schedule(schedule:List[ScheduleItem]):
# run the function (put it in JIT)
assert all(x.realized is not None for x in si.inputs), f"can't run, some inputs aren't realized {[x for x in si.inputs if x.realized is None]}"
if prg: prg.exec([si.out.realized] + [cast(Buffer, x.realized) for x in si.inputs], si.var_vals)
else: update_stats(colored(f"empty {si.out.st.size():10d} {si.out.dtype}", "yellow"), 0, 0, {}, None, 1, device=si.out.device)
else: update_stats(colored(f"empty {si.out.st.size:10d} {si.out.dtype}", "yellow"), 0, 0, {}, None, 1, device=si.out.device)
realized_lazybuffer(si.out, GlobalCounters.kernel_count)

View File

@@ -75,7 +75,10 @@ class ShapeTracker:
@property
def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
def size(self) -> int:
@property
def size(self) -> int: return prod([x.max if isinstance(x, Node) else x for x in self.views[-1].shape])
def real_size(self) -> int:
if 0 in self.shape: return 0
ret = self.expr_idxs()[0].max
while not isinstance(ret, int): ret = ret.max # TODO: this is a while loop?!? it should be more clear what max does