fix(server): stratified sampling

This commit is contained in:
Nicholas Tindle
2024-08-14 20:42:46 -05:00
parent 3c662af1ba
commit 51aaaf6ddc

View File

@@ -143,7 +143,21 @@ class DataSamplingBlock(Block):
)
strata = defaultdict(list)
for i, item in enumerate(data_to_sample):
strata[str(item[int(input_data.stratify_key)])].append(i)
if isinstance(item, dict):
strata_value = item.get(input_data.stratify_key)
elif hasattr(item, input_data.stratify_key):
strata_value = getattr(item, input_data.stratify_key)
else:
raise ValueError(
f"Stratify key '{input_data.stratify_key}' not found in item {item}"
)
if strata_value is None:
raise ValueError(
f"Stratify value for key '{input_data.stratify_key}' is None"
)
strata[str(strata_value)].append(i)
# Calculate the number of samples to take from each stratum
stratum_sizes = {