mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(server): stratified sampling
This commit is contained in:
@@ -143,7 +143,21 @@ class DataSamplingBlock(Block):
|
||||
)
|
||||
strata = defaultdict(list)
|
||||
for i, item in enumerate(data_to_sample):
|
||||
strata[str(item[int(input_data.stratify_key)])].append(i)
|
||||
if isinstance(item, dict):
|
||||
strata_value = item.get(input_data.stratify_key)
|
||||
elif hasattr(item, input_data.stratify_key):
|
||||
strata_value = getattr(item, input_data.stratify_key)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Stratify key '{input_data.stratify_key}' not found in item {item}"
|
||||
)
|
||||
|
||||
if strata_value is None:
|
||||
raise ValueError(
|
||||
f"Stratify value for key '{input_data.stratify_key}' is None"
|
||||
)
|
||||
|
||||
strata[str(strata_value)].append(i)
|
||||
|
||||
# Calculate the number of samples to take from each stratum
|
||||
stratum_sizes = {
|
||||
|
||||
Reference in New Issue
Block a user