mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
St real size (#3046)
* track the size in the lazybuffer * shapetracker real size * lint
This commit is contained in:
@@ -370,7 +370,7 @@ After you are done speaking, output [EOS]. You are not Chad.
|
||||
TOKENIZER_PATH = (MODEL_PATH if MODEL_PATH.is_dir() else MODEL_PATH.parent) / "tokenizer.model"
|
||||
print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model")
|
||||
llama = LLaMa.build(MODEL_PATH, TOKENIZER_PATH, model_gen=args.gen, model_size=args.size, quantize=args.quantize)
|
||||
param_count = sum(x.lazydata.st.size() for x in get_parameters(llama.model))
|
||||
param_count = sum(x.lazydata.size for x in get_parameters(llama.model))
|
||||
|
||||
if chatbot:
|
||||
# encode pre prompt
|
||||
|
||||
@@ -12,7 +12,7 @@ class TestTimeLinearizer(unittest.TestCase):
|
||||
|
||||
def test_reasonable_time(self):
|
||||
si = [si for si in Tensor([1,2,3,4]).add(1).lazydata.schedule() if si.ast.op not in LoadOps][0]
|
||||
rawbufs = [Buffer(Device.DEFAULT, si.out.st.size(), si.out.dtype)] + [Buffer(Device.DEFAULT, x.st.size(), x.dtype) for x in si.inputs]
|
||||
rawbufs = [Buffer(Device.DEFAULT, si.out.st.real_size(), si.out.dtype)] + [Buffer(Device.DEFAULT, x.st.real_size(), x.dtype) for x in si.inputs]
|
||||
tm = time_linearizer(Linearizer(si.ast), rawbufs, allow_test_size=False, cnt=10)
|
||||
assert tm > 0 and tm != float('inf')
|
||||
|
||||
|
||||
@@ -797,35 +797,35 @@ class TestGetContraction(unittest.TestCase):
|
||||
class TestShapeTrackerSize(unittest.TestCase):
|
||||
def test_simple_size(self):
|
||||
st = ShapeTracker.from_shape((100, 100))
|
||||
self.assertEqual(st.size(), 100*100)
|
||||
self.assertEqual(st.real_size(), 100*100)
|
||||
|
||||
def test_expand_size(self):
|
||||
st = ShapeTracker.from_shape((100, 100))
|
||||
st = st.reshape((100, 100, 1))
|
||||
st = st.expand((100, 100, 100))
|
||||
self.assertEqual(st.size(), 100*100)
|
||||
self.assertEqual(st.real_size(), 100*100)
|
||||
|
||||
def test_expand_size_flatten(self):
|
||||
st = ShapeTracker.from_shape((100, 100))
|
||||
st = st.reshape((100, 100, 1))
|
||||
st = st.expand((100, 100, 100))
|
||||
st = st.reshape((100*100*100,))
|
||||
self.assertEqual(st.size(), 100*100)
|
||||
self.assertEqual(st.real_size(), 100*100)
|
||||
|
||||
def test_shrink_size_axis_0(self):
|
||||
st = ShapeTracker.from_shape((100, 100))
|
||||
st = st.shrink(((0, 50), (0, 100)))
|
||||
self.assertEqual(st.size(), 50*100)
|
||||
self.assertEqual(st.real_size(), 50*100)
|
||||
|
||||
def test_shrink_size_axis_0_variable(self):
|
||||
st = ShapeTracker.from_shape((100, 100))
|
||||
st = st.shrink(((0, Variable("a", 0, 50)), (0, 100)))
|
||||
self.assertEqual(st.size(), 50*100)
|
||||
self.assertEqual(st.real_size(), 50*100)
|
||||
|
||||
def test_shrink_size_axis_1(self):
|
||||
st = ShapeTracker.from_shape((100, 100))
|
||||
st = st.shrink(((0, 100), (0, 50)))
|
||||
self.assertEqual(st.size(), 9950) # careful here
|
||||
self.assertEqual(st.real_size(), 9950) # careful here
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -326,7 +326,7 @@ class Kernel:
|
||||
bst *= shp[j]
|
||||
|
||||
self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
|
||||
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size()))
|
||||
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size))
|
||||
if DEBUG >= 4: print("aliasing buffer", self.sts[i])
|
||||
self.local_alias[i] = cast(LocalBuffer, self.bufs[-1])
|
||||
|
||||
|
||||
@@ -206,14 +206,14 @@ class Linearizer(Kernel):
|
||||
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), var.expr)
|
||||
# define local buffers
|
||||
for lb in self.local_alias.values():
|
||||
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size()))
|
||||
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size))
|
||||
# add a local buffer for multistage reduce. # TODO: use local alias
|
||||
if self.group_for_reduce:
|
||||
# TODO: the strides of this can be controlled
|
||||
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
|
||||
temp_dtype = self.get_base_dtype(get_lazyop_info(self.reduceop).dtype)
|
||||
self.bufs.append(LocalBuffer("temp", self.sts[-1].size(), temp_dtype))
|
||||
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size())))
|
||||
self.bufs.append(LocalBuffer("temp", self.sts[-1].size, temp_dtype))
|
||||
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size)))
|
||||
|
||||
# kernel name (before late upcast)
|
||||
self.name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
||||
|
||||
@@ -64,7 +64,8 @@ def bufs_from_lin(lin:Linearizer) -> List[Buffer]:
|
||||
for x in lin.membufs: bufsts[x.idx].append(x)
|
||||
rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
|
||||
for k,lx in bufsts.items():
|
||||
rawbufs[k] = Buffer(Device.DEFAULT, prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.size() for y in lx), lx[0].dtype)
|
||||
buf_size = prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.real_size() for y in lx)
|
||||
rawbufs[k] = Buffer(Device.DEFAULT, buf_size, lx[0].dtype)
|
||||
assert all(r is not None for r in rawbufs)
|
||||
return cast(List[Buffer], rawbufs)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Union, Optional, Any, Tuple, List, Set, Dict
|
||||
from tinygrad.dtype import dtypes, DType, ImageDType
|
||||
from tinygrad.helpers import prod, merge_dicts, flatten, getenv, dedup, DEBUG, all_int, all_same
|
||||
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem
|
||||
from tinygrad.shape.symbolic import sint, Variable, Node
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.graph import log_lazybuffer
|
||||
@@ -34,8 +34,7 @@ class LazyBuffer:
|
||||
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
base:Optional[LazyBuffer]=None):
|
||||
assert isinstance(device, str) and device == Device.canonicalize(device)
|
||||
self.device, self.st, self.dtype, self.shape = device, st, dtype, st.shape
|
||||
self.size = prod([x.max if isinstance(x, Node) else x for x in self.shape])
|
||||
self.device, self.st, self.dtype, self.shape, self.size = device, st, dtype, st.shape, st.size
|
||||
if base is None:
|
||||
# properties on base
|
||||
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
|
||||
@@ -278,7 +277,7 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
|
||||
# can only have one output buffer
|
||||
# can only reduce contiguous
|
||||
# max one reduceop per kernel
|
||||
if len(realized_children) > 1 or not st.contiguous or st.size() != r.st.size() or (tr in reduce_for_op and reduce_for_op[tr] != r):
|
||||
if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
|
||||
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
|
||||
forced_realize = True
|
||||
break
|
||||
@@ -304,7 +303,7 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
|
||||
tr_next = next(iter(tr.children))
|
||||
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
|
||||
if len(st_childs) > 1: break
|
||||
if st.size() != st_childs[0].st.size(): break
|
||||
if st.size != st_childs[0].st.size: break
|
||||
st = st + st_childs[0].st
|
||||
if not st.contiguous or tr_next.op in ReduceOps: break
|
||||
tr = tr_next
|
||||
|
||||
@@ -83,9 +83,9 @@ class FlopCounter:
|
||||
return ret
|
||||
|
||||
InterpretedFlopCounter: Dict[Op, Callable] = {
|
||||
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.size()}),
|
||||
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.real_size()}),
|
||||
BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}),
|
||||
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.size()}), # noqa: E501
|
||||
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.real_size()}), # noqa: E501
|
||||
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops
|
||||
**{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST}, # noqa: E501
|
||||
**{op:lambda self,y,op=op: FlopCounter(self.shape, dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else self.dtype, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
|
||||
|
||||
@@ -46,5 +46,5 @@ def run_schedule(schedule:List[ScheduleItem]):
|
||||
# run the function (put it in JIT)
|
||||
assert all(x.realized is not None for x in si.inputs), f"can't run, some inputs aren't realized {[x for x in si.inputs if x.realized is None]}"
|
||||
if prg: prg.exec([si.out.realized] + [cast(Buffer, x.realized) for x in si.inputs], si.var_vals)
|
||||
else: update_stats(colored(f"empty {si.out.st.size():10d} {si.out.dtype}", "yellow"), 0, 0, {}, None, 1, device=si.out.device)
|
||||
else: update_stats(colored(f"empty {si.out.st.size:10d} {si.out.dtype}", "yellow"), 0, 0, {}, None, 1, device=si.out.device)
|
||||
realized_lazybuffer(si.out, GlobalCounters.kernel_count)
|
||||
|
||||
@@ -75,7 +75,10 @@ class ShapeTracker:
|
||||
@property
|
||||
def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
|
||||
|
||||
def size(self) -> int:
|
||||
@property
|
||||
def size(self) -> int: return prod([x.max if isinstance(x, Node) else x for x in self.views[-1].shape])
|
||||
|
||||
def real_size(self) -> int:
|
||||
if 0 in self.shape: return 0
|
||||
ret = self.expr_idxs()[0].max
|
||||
while not isinstance(ret, int): ret = ret.max # TODO: this is a while loop?!? it should be more clear what max does
|
||||
|
||||
Reference in New Issue
Block a user