Merge remote-tracking branch 'upstream/master' into map-local-alias

This commit is contained in:
qazal
2024-05-28 23:44:41 +03:00
2 changed files with 56 additions and 35 deletions

View File

@@ -5,8 +5,9 @@ import tiktoken
from tiktoken.load import load_tiktoken_bpe
from tqdm import tqdm
from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16
from tinygrad.nn.state import safe_load, torch_load, load_state_dict
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
from tinygrad.helpers import Profiling, Timing, DEBUG
class Tokenizer:
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
@@ -195,7 +196,9 @@ if __name__ == "__main__":
parser.add_argument("--shard", type=int, default=1)
parser.add_argument("--quantize", choices=["int8", "nf4"])
parser.add_argument("--api", action="store_true")
parser.add_argument('--seed', type=int)
parser.add_argument("--seed", type=int)
parser.add_argument("--timing", action="store_true", help="Print timing per token")
parser.add_argument("--profile", action="store_true", help="Output profile data")
args = parser.parse_args()
if args.seed is not None: Tensor.manual_seed(args.seed)
@@ -209,6 +212,7 @@ if __name__ == "__main__":
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(args.shard)) if args.shard > 1 else Device.DEFAULT
model = build_transformer(args.model, model_size=args.size, quantize=args.quantize, device=device)
param_bytes = sum(x.lazydata.size * x.dtype.itemsize for x in get_parameters(model))
if args.api:
from bottle import Bottle, request, response, HTTPResponse, abort
@@ -317,7 +321,16 @@ if __name__ == "__main__":
last_tok = toks[-1]
while True:
GlobalCounters.reset()
tok = model(Tensor([[last_tok]], device=device), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).item()
if args.timing or args.profile: print("")
st = GlobalCounters.time_sum_s
with Profiling(enabled=args.profile):
with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/s, {GlobalCounters.global_mem/x:.2f} GB/s, param {param_bytes/x:.2f} GB/s"):
with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_bytes*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
tok = model(Tensor([[last_tok]], device=device), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P)
tok = tok.item()
start_pos += 1
last_tok = tok
if tok in tokenizer.stop_tokens: break

View File

@@ -38,6 +38,7 @@ class Opt:
class TensorCoreOptions(NamedTuple):
axes: List[int] # the location of the original N and M axes if still in the shape
axes_exist: List[bool] # true if the original N and M axes are still in the shape
axis_pads: List[Tuple[int, int]]
def fix_axes(self, removed_axis:int): # adjust the TC axes if necesssary when an dimension is removed
for tc_dim in [i for i in range(2) if self.axes_exist[i]]:
if removed_axis < self.axes[tc_dim]: self.axes[tc_dim] -= 1
@@ -325,52 +326,59 @@ class Kernel:
# ******************** high level optimizers ********************
def _create_tc_opts(self, reduceop:LazyOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
has_cast = tc.dtype_in != tc.dtype_out
if has_cast and not(reduceop.src[0].op is UnaryOps.CAST and reduceop.src[0].arg == tc.dtype_out): return None
mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
if mul_op.op is not BinaryOps.MUL: return None
def buf_index(src: LazyOp) -> Optional[int]:
# TODO: apply tc even if the sources are not from LOAD
if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
try:
if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
except ValueError: return None
return None
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None
axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
if not(axis < len(axis_choices)): return None
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0]
if axis_pads and (opt_level < 2): return None
self.bufs_for_tensor_core[reduceop] = (buf0, buf1)
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
return TensorCoreOptions(axes=[s0, s1, s2], axes_exist=[True, True], axis_pads=axis_pads)
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
if use_tensor_cores and self.opts.has_local and self.reduceop is not None and self.reduceop.op is ReduceOps.SUM:
for tc in self.opts.tensor_cores:
has_cast = tc.dtype_in != tc.dtype_out
if has_cast and not(self.reduceop.src[0].op is UnaryOps.CAST and self.reduceop.src[0].arg == tc.dtype_out): continue
mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
if mul_op.op is not BinaryOps.MUL: continue
def buf_index(src: LazyOp) -> Optional[int]:
# TODO: apply tc even if the sources are not from LOAD
if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
try:
if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
except ValueError: return None
return None
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): continue
axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
if not(axis < len(axis_choices)): continue
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0]
if axis_pads and (opt_level < 2): continue
# tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
self.tensor_core_opts = (tc_opts:=TensorCoreOptions(axes=[s0, s1], axes_exist=[True, True]))
self.bufs_for_tensor_core[self.reduceop] = (buf0, buf1)
tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
if tensor_core_opts[0] is None: continue
# verify all reduces are exactly the same shape and strides
assert all(x == tensor_core_opts[0] for x in tensor_core_opts)
self.tensor_core_opts = tc_opts = tensor_core_opts[0]
# attempt to pad the tensor axes that require it
try:
for axis, dim in axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
except KernelOptError: continue
self.apply_opt(Opt(OptOps.UNROLL, s2-self.first_reduce, tc.dims[2]), append_opt=False)
self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, tc.dims[2]), append_opt=False)
for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
for (tc_dim, tc_amt) in tc.threads:
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
# assert tensor core
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
return True
return False