feat: generate blend index (#11566)

This commit is contained in:
wozeparrot
2025-08-07 14:20:28 -04:00
committed by GitHub
parent 594cbdc66f
commit 7ae4335127

View File

@@ -564,10 +564,13 @@ class GPTDataset:
# check for cache
cache_hash = hashlib.sha256(f"{samples}:{seqlen}:{seed}:{shuffle}".encode()).hexdigest()
cache_path = base_path.with_name(f"{base_path.name}.{cache_hash}.index_cache")
print(f"try loading GPTDataset from {cache_path}...")
if cache_path.exists():
print("cache found, loading...")
with open(cache_path, "rb") as f:
self.doc_idx, self.sample_idx, self.shuffle_idx = pickle.load(f)
else:
print("cache not found, building index...")
self.doc_idx = self._build_doc_idx()
self.sample_idx = self._build_sample_idx()
self.shuffle_idx = self._build_shuffle_idx()
@@ -692,10 +695,47 @@ class BlendedGPTDataset:
self.datasets = [GPTDataset(path, samples_per_blend[i], seqlen, seed + i, shuffle) for i,path in enumerate(paths)]
# check for cache
cache_hash = hashlib.sha256(f"{samples}:{seqlen}:{seed}:{shuffle}".encode()).hexdigest()
cache_path = paths[0].with_name(f"{paths[0].name}.{cache_hash}.blend_cache")
print(f"try loading BlendedGPTDataset from {cache_path}...")
if cache_path.exists():
print("cache found, loading...")
with open(cache_path, "rb") as f:
self.dataset_idx, self.dataset_sample_idx = pickle.load(f)
else:
print("cache not found, building index...")
self.dataset_idx, self.dataset_sample_idx = self._build_blend_idx()
# save cache
with open(cache_path, "wb") as f:
pickle.dump((self.dataset_idx, self.dataset_sample_idx), f)
def get(self, idx:int):
tokens = self.datasets[0][idx]
tokens = self.datasets[self.dataset_idx[idx]][self.dataset_sample_idx[idx]]
return tokens
def _build_blend_idx(self):
dataset_idx = np.zeros(self.samples, dtype=np.int16)
dataset_sample_idx = np.zeros(self.samples, dtype=np.int64)
unspent_datasets = set(range(len(self.datasets)))
dataset_sample_counts = [0] * len(self.datasets)
for i in tqdm(range(self.samples)):
error_argmax, error_max = 0, 0.0
for di in unspent_datasets:
error = self.weights[di] * max(i, 1) - dataset_sample_counts[di]
if error > error_max:
error_max = error
error_argmax = di
dataset_idx[i] = error_argmax
dataset_sample_idx[i] = dataset_sample_counts[error_argmax]
dataset_sample_counts[error_argmax] += 1
return dataset_idx, dataset_sample_idx
def batch_load_llama3(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True):
if val:
dataset = BlendedGPTDataset([