mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
feat: generate blend index (#11566)
This commit is contained in:
@@ -564,10 +564,13 @@ class GPTDataset:
|
||||
# check for cache
|
||||
cache_hash = hashlib.sha256(f"{samples}:{seqlen}:{seed}:{shuffle}".encode()).hexdigest()
|
||||
cache_path = base_path.with_name(f"{base_path.name}.{cache_hash}.index_cache")
|
||||
print(f"try loading GPTDataset from {cache_path}...")
|
||||
if cache_path.exists():
|
||||
print("cache found, loading...")
|
||||
with open(cache_path, "rb") as f:
|
||||
self.doc_idx, self.sample_idx, self.shuffle_idx = pickle.load(f)
|
||||
else:
|
||||
print("cache not found, building index...")
|
||||
self.doc_idx = self._build_doc_idx()
|
||||
self.sample_idx = self._build_sample_idx()
|
||||
self.shuffle_idx = self._build_shuffle_idx()
|
||||
@@ -692,10 +695,47 @@ class BlendedGPTDataset:
|
||||
|
||||
self.datasets = [GPTDataset(path, samples_per_blend[i], seqlen, seed + i, shuffle) for i,path in enumerate(paths)]
|
||||
|
||||
# check for cache
|
||||
cache_hash = hashlib.sha256(f"{samples}:{seqlen}:{seed}:{shuffle}".encode()).hexdigest()
|
||||
cache_path = paths[0].with_name(f"{paths[0].name}.{cache_hash}.blend_cache")
|
||||
print(f"try loading BlendedGPTDataset from {cache_path}...")
|
||||
if cache_path.exists():
|
||||
print("cache found, loading...")
|
||||
with open(cache_path, "rb") as f:
|
||||
self.dataset_idx, self.dataset_sample_idx = pickle.load(f)
|
||||
else:
|
||||
print("cache not found, building index...")
|
||||
self.dataset_idx, self.dataset_sample_idx = self._build_blend_idx()
|
||||
# save cache
|
||||
with open(cache_path, "wb") as f:
|
||||
pickle.dump((self.dataset_idx, self.dataset_sample_idx), f)
|
||||
|
||||
def get(self, idx:int):
|
||||
tokens = self.datasets[0][idx]
|
||||
tokens = self.datasets[self.dataset_idx[idx]][self.dataset_sample_idx[idx]]
|
||||
return tokens
|
||||
|
||||
def _build_blend_idx(self):
|
||||
dataset_idx = np.zeros(self.samples, dtype=np.int16)
|
||||
dataset_sample_idx = np.zeros(self.samples, dtype=np.int64)
|
||||
|
||||
unspent_datasets = set(range(len(self.datasets)))
|
||||
dataset_sample_counts = [0] * len(self.datasets)
|
||||
|
||||
for i in tqdm(range(self.samples)):
|
||||
error_argmax, error_max = 0, 0.0
|
||||
for di in unspent_datasets:
|
||||
error = self.weights[di] * max(i, 1) - dataset_sample_counts[di]
|
||||
if error > error_max:
|
||||
error_max = error
|
||||
error_argmax = di
|
||||
|
||||
dataset_idx[i] = error_argmax
|
||||
dataset_sample_idx[i] = dataset_sample_counts[error_argmax]
|
||||
|
||||
dataset_sample_counts[error_argmax] += 1
|
||||
|
||||
return dataset_idx, dataset_sample_idx
|
||||
|
||||
def batch_load_llama3(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True):
|
||||
if val:
|
||||
dataset = BlendedGPTDataset([
|
||||
|
||||
Reference in New Issue
Block a user