move llama getenv("JIT") from models to examples (#2671)

Transformer class has a jit param so we should use that in the caller
This commit is contained in:
chenyu
2023-12-07 12:43:22 -05:00
committed by GitHub
parent fd21eced74
commit 539b00a645
2 changed files with 3 additions and 2 deletions

View File

@@ -155,7 +155,8 @@ class LLaMa:
sp_model = SentencePieceProcessor(model_file=str(tokenizer_path))
assert sp_model.vocab_size() == params["args"]["vocab_size"], f"{sp_model.vocab_size()=} not equal to {params['args']['vocab_size']}"
model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT) if quantize else Transformer(**params["args"], max_context=MAX_CONTEXT)
jit = bool(getenv("JIT", 1))
model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT, jit=jit) if quantize else Transformer(**params["args"], max_context=MAX_CONTEXT, jit=jit)
if model_path.is_dir():
weights = concat_weights([load(filename) for filename in [f"{model_path}/consolidated.{i:02d}.pth" for i in range(params["files"])]])

View File

@@ -119,7 +119,7 @@ class Transformer:
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
# TODO: better way to handle the first call v.s. the rest?
if tokens.shape[0:2] == (1,1) and self.forward_jit and getenv("JIT", 1):
if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
assert start_pos > 0
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context).bind(start_pos), temperature)
return self.forward(tokens, start_pos, temperature)