mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 23:38:58 -05:00
Merge remote-tracking branch 'upstream/master' into map-local-alias
This commit is contained in:
@@ -5,8 +5,9 @@ import tiktoken
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
from tqdm import tqdm
|
||||
from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16
|
||||
from tinygrad.nn.state import safe_load, torch_load, load_state_dict
|
||||
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
|
||||
from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
|
||||
from tinygrad.helpers import Profiling, Timing, DEBUG
|
||||
|
||||
class Tokenizer:
|
||||
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
|
||||
@@ -195,7 +196,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--shard", type=int, default=1)
|
||||
parser.add_argument("--quantize", choices=["int8", "nf4"])
|
||||
parser.add_argument("--api", action="store_true")
|
||||
parser.add_argument('--seed', type=int)
|
||||
parser.add_argument("--seed", type=int)
|
||||
parser.add_argument("--timing", action="store_true", help="Print timing per token")
|
||||
parser.add_argument("--profile", action="store_true", help="Output profile data")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.seed is not None: Tensor.manual_seed(args.seed)
|
||||
@@ -209,6 +212,7 @@ if __name__ == "__main__":
|
||||
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(args.shard)) if args.shard > 1 else Device.DEFAULT
|
||||
model = build_transformer(args.model, model_size=args.size, quantize=args.quantize, device=device)
|
||||
param_bytes = sum(x.lazydata.size * x.dtype.itemsize for x in get_parameters(model))
|
||||
|
||||
if args.api:
|
||||
from bottle import Bottle, request, response, HTTPResponse, abort
|
||||
@@ -317,7 +321,16 @@ if __name__ == "__main__":
|
||||
last_tok = toks[-1]
|
||||
while True:
|
||||
GlobalCounters.reset()
|
||||
tok = model(Tensor([[last_tok]], device=device), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).item()
|
||||
if args.timing or args.profile: print("")
|
||||
st = GlobalCounters.time_sum_s
|
||||
with Profiling(enabled=args.profile):
|
||||
with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/s, {GlobalCounters.global_mem/x:.2f} GB/s, param {param_bytes/x:.2f} GB/s"):
|
||||
with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
|
||||
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
|
||||
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_bytes*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
|
||||
|
||||
tok = model(Tensor([[last_tok]], device=device), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P)
|
||||
tok = tok.item()
|
||||
start_pos += 1
|
||||
last_tok = tok
|
||||
if tok in tokenizer.stop_tokens: break
|
||||
|
||||
@@ -38,6 +38,7 @@ class Opt:
|
||||
class TensorCoreOptions(NamedTuple):
|
||||
axes: List[int] # the location of the original N and M axes if still in the shape
|
||||
axes_exist: List[bool] # true if the original N and M axes are still in the shape
|
||||
axis_pads: List[Tuple[int, int]]
|
||||
def fix_axes(self, removed_axis:int): # adjust the TC axes if necesssary when an dimension is removed
|
||||
for tc_dim in [i for i in range(2) if self.axes_exist[i]]:
|
||||
if removed_axis < self.axes[tc_dim]: self.axes[tc_dim] -= 1
|
||||
@@ -325,52 +326,59 @@ class Kernel:
|
||||
|
||||
# ******************** high level optimizers ********************
|
||||
|
||||
def _create_tc_opts(self, reduceop:LazyOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
|
||||
has_cast = tc.dtype_in != tc.dtype_out
|
||||
if has_cast and not(reduceop.src[0].op is UnaryOps.CAST and reduceop.src[0].arg == tc.dtype_out): return None
|
||||
|
||||
mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
|
||||
if mul_op.op is not BinaryOps.MUL: return None
|
||||
|
||||
def buf_index(src: LazyOp) -> Optional[int]:
|
||||
# TODO: apply tc even if the sources are not from LOAD
|
||||
if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
|
||||
try:
|
||||
if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
|
||||
except ValueError: return None
|
||||
return None
|
||||
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
|
||||
|
||||
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
|
||||
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
|
||||
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
|
||||
if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): return None
|
||||
|
||||
axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
|
||||
if not(axis < len(axis_choices)): return None
|
||||
|
||||
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
|
||||
axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0]
|
||||
if axis_pads and (opt_level < 2): return None
|
||||
self.bufs_for_tensor_core[reduceop] = (buf0, buf1)
|
||||
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
||||
return TensorCoreOptions(axes=[s0, s1, s2], axes_exist=[True, True], axis_pads=axis_pads)
|
||||
|
||||
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
|
||||
if use_tensor_cores and self.opts.has_local and self.reduceop is not None and self.reduceop.op is ReduceOps.SUM:
|
||||
for tc in self.opts.tensor_cores:
|
||||
has_cast = tc.dtype_in != tc.dtype_out
|
||||
if has_cast and not(self.reduceop.src[0].op is UnaryOps.CAST and self.reduceop.src[0].arg == tc.dtype_out): continue
|
||||
|
||||
mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
|
||||
if mul_op.op is not BinaryOps.MUL: continue
|
||||
|
||||
def buf_index(src: LazyOp) -> Optional[int]:
|
||||
# TODO: apply tc even if the sources are not from LOAD
|
||||
if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
|
||||
try:
|
||||
if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
|
||||
except ValueError: return None
|
||||
return None
|
||||
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue
|
||||
|
||||
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
|
||||
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
|
||||
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
|
||||
if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): continue
|
||||
|
||||
axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
|
||||
if not(axis < len(axis_choices)): continue
|
||||
|
||||
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
|
||||
axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0]
|
||||
if axis_pads and (opt_level < 2): continue
|
||||
|
||||
# tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
|
||||
self.tensor_core_opts = (tc_opts:=TensorCoreOptions(axes=[s0, s1], axes_exist=[True, True]))
|
||||
self.bufs_for_tensor_core[self.reduceop] = (buf0, buf1)
|
||||
tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
|
||||
if tensor_core_opts[0] is None: continue
|
||||
|
||||
# verify all reduces are exactly the same shape and strides
|
||||
assert all(x == tensor_core_opts[0] for x in tensor_core_opts)
|
||||
self.tensor_core_opts = tc_opts = tensor_core_opts[0]
|
||||
|
||||
# attempt to pad the tensor axes that require it
|
||||
try:
|
||||
for axis, dim in axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
|
||||
for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
|
||||
except KernelOptError: continue
|
||||
self.apply_opt(Opt(OptOps.UNROLL, s2-self.first_reduce, tc.dims[2]), append_opt=False)
|
||||
self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, tc.dims[2]), append_opt=False)
|
||||
for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
|
||||
if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
|
||||
for (tc_dim, tc_amt) in tc.threads:
|
||||
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
|
||||
|
||||
# assert tensor core
|
||||
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
||||
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
|
||||
return True
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user