mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
profiling llama + cache is_contiguous
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
#!/bin/bash
|
||||
mypyc --check-untyped-defs --explicit-package-bases --warn-unreachable tinygrad/shape/__init__.py tinygrad/shape/symbolic.py \
|
||||
mypyc --check-untyped-defs --explicit-package-bases --warn-unreachable tinygrad/shape/shapetracker.py tinygrad/shape/symbolic.py \
|
||||
tinygrad/nn/__init__.py tinygrad/helpers.py tinygrad/mlops.py tinygrad/tensor.py tinygrad/graph.py \
|
||||
#tinygrad/codegen/ast.py tinygrad/codegen/gpu.py tinygrad/ops.py tinygrad/runtime/ops_metal.py
|
||||
#tinygrad/runtime/ops_metal.py tinygrad/shape/__init__.py tinygrad/ops.py tinygrad/codegen/ast.py \
|
||||
|
||||
@@ -191,6 +191,7 @@ if __name__ == "__main__":
|
||||
|
||||
parser.add_argument('--temperature', type=float, default=0.7, help="Temperature in the softmax")
|
||||
parser.add_argument('--timing', action='store_true', help="Print timing per token")
|
||||
parser.add_argument('--profile', action='store_true', help="Output profile data to out.prof")
|
||||
parser.add_argument('--large', action='store_true', help="Use the 13B model instead of the 7B one")
|
||||
args = parser.parse_args()
|
||||
chatbot = args.prompt == None
|
||||
@@ -342,6 +343,10 @@ After you are done speaking, output [EOS]. You are not the User.
|
||||
sys.stdout.write(outputted)
|
||||
sys.stdout.flush()
|
||||
|
||||
if args.profile:
|
||||
import cProfile, pstats
|
||||
profiler = cProfile.Profile()
|
||||
|
||||
# chatbot loop
|
||||
while 1:
|
||||
# add tokens from user in chatbot mode
|
||||
@@ -356,6 +361,8 @@ After you are done speaking, output [EOS]. You are not the User.
|
||||
|
||||
last_break = len(outputted)
|
||||
for i in range(args.count):
|
||||
if args.profile and i == 2: profiler.enable()
|
||||
|
||||
if args.timing: print("")
|
||||
st = GlobalCounters.time_sum_s
|
||||
with Timing("ran model in ", on_exit=(lambda et: f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU") if DEBUG else None, enabled=args.timing):
|
||||
@@ -379,3 +386,7 @@ After you are done speaking, output [EOS]. You are not the User.
|
||||
if chatbot and outputted.endswith(end_delim): break
|
||||
if not chatbot: break
|
||||
|
||||
if args.profile:
|
||||
profiler.disable()
|
||||
stats = pstats.Stats(profiler)
|
||||
stats.dump_stats('out.prof')
|
||||
|
||||
@@ -20,11 +20,14 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup
|
||||
ret.append((shape[i], strides[i]))
|
||||
return ret
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: return all(s1 == s2 or s == 1 for s,s1,s2 in zip(shape, strides, strides_for_shape(shape)))
|
||||
|
||||
class View:
|
||||
def __init__(self, shape:Tuple[int, ...], strides:Tuple[int, ...], offset:int=0):
|
||||
self.shape, self.strides, self.offset = shape, tuple(stride if shp != 1 else 0 for stride,shp in zip(strides, shape)), offset
|
||||
self.shape_strides = to_shape_strides(self.shape, self.strides)
|
||||
self.contiguous : bool = self.offset == 0 and all(s1 == s2 or s == 1 for s,s1,s2 in zip(self.shape, self.strides, strides_for_shape(self.shape)))
|
||||
self.contiguous : bool = self.offset == 0 and is_contiguous(self.shape, self.strides)
|
||||
|
||||
def __repr__(self): return f"View({self.shape}, {self.strides}, {self.offset})"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user