From c235223c07a83a09eac4c904ba9351a66e36d665 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 29 May 2024 04:10:27 +0800 Subject: [PATCH 1/2] refactor tc_opt creation (#4765) * move reduceop loop * this is more mergable code add assert * integrate s2 --- tinygrad/codegen/kernel.py | 72 +++++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 15920956f9..5eeda2742e 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -37,6 +37,7 @@ class Opt: class TensorCoreOptions(NamedTuple): axes: List[int] # the location of the original N and M axes if still in the shape axes_exist: List[bool] # true if the original N and M axes are still in the shape + axis_pads: List[Tuple[int, int]] def fix_axes(self, removed_axis:int): # adjust the TC axes if necesssary when an dimension is removed for tc_dim in [i for i in range(2) if self.axes_exist[i]]: if removed_axis < self.axes[tc_dim]: self.axes[tc_dim] -= 1 @@ -324,52 +325,59 @@ class Kernel: # ******************** high level optimizers ******************** + def _create_tc_opts(self, reduceop:LazyOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]: + has_cast = tc.dtype_in != tc.dtype_out + if has_cast and not(reduceop.src[0].op is UnaryOps.CAST and reduceop.src[0].arg == tc.dtype_out): return None + + mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0] + if mul_op.op is not BinaryOps.MUL: return None + + def buf_index(src: LazyOp) -> Optional[int]: + # TODO: apply tc even if the sources are not from LOAD + if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg)) + try: + if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg)) + except ValueError: return None + return None + if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None + + buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides() + axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0] + axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0] + if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None + + axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len))) + if not(axis < len(axis_choices)): return None + + s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k + axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0] + if axis_pads and (opt_level < 2): return None + self.bufs_for_tensor_core[reduceop] = (buf0, buf1) + if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc) + return TensorCoreOptions(axes=[s0, s1, s2], axes_exist=[True, True], axis_pads=axis_pads) + def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool: if use_tensor_cores and self.opts.has_local and self.reduceop is not None and self.reduceop.op is ReduceOps.SUM: for tc in self.opts.tensor_cores: - has_cast = tc.dtype_in != tc.dtype_out - if has_cast and not(self.reduceop.src[0].op is UnaryOps.CAST and self.reduceop.src[0].arg == tc.dtype_out): continue - - mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0] - if mul_op.op is not BinaryOps.MUL: continue - - def buf_index(src: LazyOp) -> Optional[int]: - # TODO: apply tc even if the sources are not from LOAD - if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg)) - try: - if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg)) - except ValueError: return None - return None - if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue - - buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides() - axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0] - axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0] - if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): continue - - axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len))) - if not(axis < len(axis_choices)): continue - - s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k - axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0] - if axis_pads and (opt_level < 2): continue - # tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern - self.tensor_core_opts = (tc_opts:=TensorCoreOptions(axes=[s0, s1], axes_exist=[True, True])) - self.bufs_for_tensor_core[self.reduceop] = (buf0, buf1) + tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops] + if tensor_core_opts[0] is None: continue + + # verify all reduces are exactly the same shape and strides + assert all(x == tensor_core_opts[0] for x in tensor_core_opts) + self.tensor_core_opts = tc_opts = tensor_core_opts[0] # attempt to pad the tensor axes that require it try: - for axis, dim in axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail + for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail except KernelOptError: continue - self.apply_opt(Opt(OptOps.UNROLL, s2-self.first_reduce, tc.dims[2]), append_opt=False) + self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, tc.dims[2]), append_opt=False) for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False) for (tc_dim, tc_amt) in tc.threads: self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False) # assert tensor core - if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc) if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA return True return False From 7624ad3ddd7fc6071431cdea5acf9166860bee28 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 28 May 2024 16:24:44 -0400 Subject: [PATCH 2/2] add --timing and --profile to llama3 example (#4767) --- examples/llama3.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/examples/llama3.py b/examples/llama3.py index 25e1b74e20..58cc23e3c0 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -5,8 +5,9 @@ import tiktoken from tiktoken.load import load_tiktoken_bpe from tqdm import tqdm from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16 -from tinygrad.nn.state import safe_load, torch_load, load_state_dict +from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters +from tinygrad.helpers import Profiling, Timing, DEBUG class Tokenizer: pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" @@ -195,7 +196,9 @@ if __name__ == "__main__": parser.add_argument("--shard", type=int, default=1) parser.add_argument("--quantize", choices=["int8", "nf4"]) parser.add_argument("--api", action="store_true") - parser.add_argument('--seed', type=int) + parser.add_argument("--seed", type=int) + parser.add_argument("--timing", action="store_true", help="Print timing per token") + parser.add_argument("--profile", action="store_true", help="Output profile data") args = parser.parse_args() if args.seed is not None: Tensor.manual_seed(args.seed) @@ -209,6 +212,7 @@ if __name__ == "__main__": device = tuple(f"{Device.DEFAULT}:{i}" for i in range(args.shard)) if args.shard > 1 else Device.DEFAULT model = build_transformer(args.model, model_size=args.size, quantize=args.quantize, device=device) + param_bytes = sum(x.lazydata.size * x.dtype.itemsize for x in get_parameters(model)) if args.api: from bottle import Bottle, request, response, HTTPResponse, abort @@ -317,7 +321,16 @@ if __name__ == "__main__": last_tok = toks[-1] while True: GlobalCounters.reset() - tok = model(Tensor([[last_tok]], device=device), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).item() + if args.timing or args.profile: print("") + st = GlobalCounters.time_sum_s + with Profiling(enabled=args.profile): + with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/s, {GlobalCounters.global_mem/x:.2f} GB/s, param {param_bytes/x:.2f} GB/s"): + with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+ + f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+ + (f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_bytes*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing): + + tok = model(Tensor([[last_tok]], device=device), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P) + tok = tok.item() start_pos += 1 last_tok = tok if tok in tokenizer.stop_tokens: break