remove memory peak for quantized llama (#1720)

This commit is contained in:
nimlgen
2023-08-30 23:32:30 +03:00
committed by GitHub
parent e4eb5d55c7
commit b5cf274da3

View File

@@ -262,8 +262,8 @@ class AbsmaxQuantizedLinear:
if 'feed_forward' in name or ('attention.w') in name or name == 'output.weight':
scale = v.abs().max(axis=1) / 127.0
int8_weight = (v.T/scale).T.cast(dtype=dtypes.int8)
new_tensors[name] = int8_weight.realize()
new_tensors[name.replace('weight', 'scale')] = scale.realize()
new_tensors[name] = int8_weight
new_tensors[name.replace('weight', 'scale')] = scale
else:
new_tensors[name] = v
return new_tensors
@@ -287,6 +287,7 @@ class LLaMa:
if quantize:
weights = AbsmaxQuantizedLinear.quantize(weights)
for _,v in weights.items(): v.realize()
load_state_dict(model, weights, strict=False)
return LLaMa(model, sp_model)