mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
move llama getenv("JIT") from models to examples (#2671)
Transformer class has a jit param so we should use that in the caller
This commit is contained in:
@@ -155,7 +155,8 @@ class LLaMa:
|
||||
sp_model = SentencePieceProcessor(model_file=str(tokenizer_path))
|
||||
assert sp_model.vocab_size() == params["args"]["vocab_size"], f"{sp_model.vocab_size()=} not equal to {params['args']['vocab_size']}"
|
||||
|
||||
model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT) if quantize else Transformer(**params["args"], max_context=MAX_CONTEXT)
|
||||
jit = bool(getenv("JIT", 1))
|
||||
model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT, jit=jit) if quantize else Transformer(**params["args"], max_context=MAX_CONTEXT, jit=jit)
|
||||
|
||||
if model_path.is_dir():
|
||||
weights = concat_weights([load(filename) for filename in [f"{model_path}/consolidated.{i:02d}.pth" for i in range(params["files"])]])
|
||||
|
||||
@@ -119,7 +119,7 @@ class Transformer:
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
|
||||
# TODO: better way to handle the first call v.s. the rest?
|
||||
if tokens.shape[0:2] == (1,1) and self.forward_jit and getenv("JIT", 1):
|
||||
if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
|
||||
assert start_pos > 0
|
||||
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context).bind(start_pos), temperature)
|
||||
return self.forward(tokens, start_pos, temperature)
|
||||
|
||||
Reference in New Issue
Block a user