profiling llama + cache is_contiguous

This commit is contained in:
George Hotz
2023-03-11 08:23:21 -08:00
parent 01f39b19dc
commit 5e1380df6a
3 changed files with 16 additions and 2 deletions

View File

@@ -1,5 +1,5 @@
#!/bin/bash
mypyc --check-untyped-defs --explicit-package-bases --warn-unreachable tinygrad/shape/__init__.py tinygrad/shape/symbolic.py \
mypyc --check-untyped-defs --explicit-package-bases --warn-unreachable tinygrad/shape/shapetracker.py tinygrad/shape/symbolic.py \
tinygrad/nn/__init__.py tinygrad/helpers.py tinygrad/mlops.py tinygrad/tensor.py tinygrad/graph.py \
#tinygrad/codegen/ast.py tinygrad/codegen/gpu.py tinygrad/ops.py tinygrad/runtime/ops_metal.py
#tinygrad/runtime/ops_metal.py tinygrad/shape/__init__.py tinygrad/ops.py tinygrad/codegen/ast.py \

View File

@@ -191,6 +191,7 @@ if __name__ == "__main__":
parser.add_argument('--temperature', type=float, default=0.7, help="Temperature in the softmax")
parser.add_argument('--timing', action='store_true', help="Print timing per token")
parser.add_argument('--profile', action='store_true', help="Output profile data to out.prof")
parser.add_argument('--large', action='store_true', help="Use the 13B model instead of the 7B one")
args = parser.parse_args()
chatbot = args.prompt == None
@@ -342,6 +343,10 @@ After you are done speaking, output [EOS]. You are not the User.
sys.stdout.write(outputted)
sys.stdout.flush()
if args.profile:
import cProfile, pstats
profiler = cProfile.Profile()
# chatbot loop
while 1:
# add tokens from user in chatbot mode
@@ -356,6 +361,8 @@ After you are done speaking, output [EOS]. You are not the User.
last_break = len(outputted)
for i in range(args.count):
if args.profile and i == 2: profiler.enable()
if args.timing: print("")
st = GlobalCounters.time_sum_s
with Timing("ran model in ", on_exit=(lambda et: f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU") if DEBUG else None, enabled=args.timing):
@@ -379,3 +386,7 @@ After you are done speaking, output [EOS]. You are not the User.
if chatbot and outputted.endswith(end_delim): break
if not chatbot: break
if args.profile:
profiler.disable()
stats = pstats.Stats(profiler)
stats.dump_stats('out.prof')

View File

@@ -20,11 +20,14 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup
ret.append((shape[i], strides[i]))
return ret
@functools.lru_cache(maxsize=None)
def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: return all(s1 == s2 or s == 1 for s,s1,s2 in zip(shape, strides, strides_for_shape(shape)))
class View:
def __init__(self, shape:Tuple[int, ...], strides:Tuple[int, ...], offset:int=0):
self.shape, self.strides, self.offset = shape, tuple(stride if shp != 1 else 0 for stride,shp in zip(strides, shape)), offset
self.shape_strides = to_shape_strides(self.shape, self.strides)
self.contiguous : bool = self.offset == 0 and all(s1 == s2 or s == 1 for s,s1,s2 in zip(self.shape, self.strides, strides_for_shape(self.shape)))
self.contiguous : bool = self.offset == 0 and is_contiguous(self.shape, self.strides)
def __repr__(self): return f"View({self.shape}, {self.strides}, {self.offset})"