mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
remove memory peak for quantized llama (#1720)
This commit is contained in:
@@ -262,8 +262,8 @@ class AbsmaxQuantizedLinear:
|
||||
if 'feed_forward' in name or ('attention.w') in name or name == 'output.weight':
|
||||
scale = v.abs().max(axis=1) / 127.0
|
||||
int8_weight = (v.T/scale).T.cast(dtype=dtypes.int8)
|
||||
new_tensors[name] = int8_weight.realize()
|
||||
new_tensors[name.replace('weight', 'scale')] = scale.realize()
|
||||
new_tensors[name] = int8_weight
|
||||
new_tensors[name.replace('weight', 'scale')] = scale
|
||||
else:
|
||||
new_tensors[name] = v
|
||||
return new_tensors
|
||||
@@ -287,6 +287,7 @@ class LLaMa:
|
||||
|
||||
if quantize:
|
||||
weights = AbsmaxQuantizedLinear.quantize(weights)
|
||||
for _,v in weights.items(): v.realize()
|
||||
load_state_dict(model, weights, strict=False)
|
||||
|
||||
return LLaMa(model, sp_model)
|
||||
|
||||
Reference in New Issue
Block a user