diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index 121698c97b..c82c42410d 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -564,10 +564,13 @@ class GPTDataset: # check for cache cache_hash = hashlib.sha256(f"{samples}:{seqlen}:{seed}:{shuffle}".encode()).hexdigest() cache_path = base_path.with_name(f"{base_path.name}.{cache_hash}.index_cache") + print(f"try loading GPTDataset from {cache_path}...") if cache_path.exists(): + print("cache found, loading...") with open(cache_path, "rb") as f: self.doc_idx, self.sample_idx, self.shuffle_idx = pickle.load(f) else: + print("cache not found, building index...") self.doc_idx = self._build_doc_idx() self.sample_idx = self._build_sample_idx() self.shuffle_idx = self._build_shuffle_idx() @@ -692,10 +695,47 @@ class BlendedGPTDataset: self.datasets = [GPTDataset(path, samples_per_blend[i], seqlen, seed + i, shuffle) for i,path in enumerate(paths)] + # check for cache + cache_hash = hashlib.sha256(f"{samples}:{seqlen}:{seed}:{shuffle}".encode()).hexdigest() + cache_path = paths[0].with_name(f"{paths[0].name}.{cache_hash}.blend_cache") + print(f"try loading BlendedGPTDataset from {cache_path}...") + if cache_path.exists(): + print("cache found, loading...") + with open(cache_path, "rb") as f: + self.dataset_idx, self.dataset_sample_idx = pickle.load(f) + else: + print("cache not found, building index...") + self.dataset_idx, self.dataset_sample_idx = self._build_blend_idx() + # save cache + with open(cache_path, "wb") as f: + pickle.dump((self.dataset_idx, self.dataset_sample_idx), f) + def get(self, idx:int): - tokens = self.datasets[0][idx] + tokens = self.datasets[self.dataset_idx[idx]][self.dataset_sample_idx[idx]] return tokens + def _build_blend_idx(self): + dataset_idx = np.zeros(self.samples, dtype=np.int16) + dataset_sample_idx = np.zeros(self.samples, dtype=np.int64) + + unspent_datasets = set(range(len(self.datasets))) + dataset_sample_counts = [0] * len(self.datasets) + + for i in tqdm(range(self.samples)): + error_argmax, error_max = 0, 0.0 + for di in unspent_datasets: + error = self.weights[di] * max(i, 1) - dataset_sample_counts[di] + if error > error_max: + error_max = error + error_argmax = di + + dataset_idx[i] = error_argmax + dataset_sample_idx[i] = dataset_sample_counts[error_argmax] + + dataset_sample_counts[error_argmax] += 1 + + return dataset_idx, dataset_sample_idx + def batch_load_llama3(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True): if val: dataset = BlendedGPTDataset([