Add hash validation for downloading transcription models

This commit is contained in:
Tabby Cromarty
2025-10-19 15:49:39 +01:00
parent 9ce2321f96
commit bc9b4f0855
2 changed files with 71 additions and 0 deletions

View File

@@ -3,7 +3,18 @@
#include <obs-module.h>
#include <cstdlib>
#include <cerrno>
#include <cstring>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <iomanip>
#include <sstream>
#include <string>
#include <stdexcept>
#include <openssl/sha.h>
size_t write_data(void *ptr, size_t size, size_t nmemb, FILE *stream)
{
@@ -211,6 +222,10 @@ void ModelDownloadWorker::download_model()
emit download_error("Failed to download model file.");
}
fclose(fp);
if (!valid_hash(model_file_save_path, model_download_file.sha256)) {
emit download_error("Downloaded model has invalid hash");
}
}
curl_easy_cleanup(curl);
emit download_finished(model_local_config_path);
@@ -236,6 +251,60 @@ int ModelDownloadWorker::progress_callback(void *clientp, curl_off_t dltotal, cu
return 0;
}
bool ModelDownloadWorker::valid_hash(std::string path, std::string hash)
{
auto calculated_hash = sha256_sum(path.c_str());
if (hash == "") {
obs_log(LOG_WARNING, "No hash for model in config. Calculated hash: %s",
calculated_hash.c_str());
return true;
} else if (hash == calculated_hash) {
obs_log(LOG_INFO, "Model hash is valid");
return true;
} else {
obs_log(LOG_ERROR, "Model hash mismatch. Model hash: %s, calculated hash: %s",
hash.c_str(), calculated_hash.c_str());
return false;
}
}
std::string ModelDownloadWorker::sha256_sum(const char *const path)
{
std::ifstream fp(path, std::ios::in | std::ios::binary);
if (not fp.good()) {
std::ostringstream os;
os << "Cannot open \"" << path << "\": " << std::strerror(errno) << ".";
throw std::runtime_error(os.str());
}
constexpr const std::size_t buffer_size{1 << 12};
char buffer[buffer_size];
unsigned char hash[SHA256_DIGEST_LENGTH] = {0};
SHA256_CTX ctx;
SHA256_Init(&ctx);
while (fp.good()) {
fp.read(buffer, buffer_size);
SHA256_Update(&ctx, buffer, fp.gcount());
}
SHA256_Final(hash, &ctx);
fp.close();
std::ostringstream os;
os << std::hex << std::setfill('0');
for (int i = 0; i < SHA256_DIGEST_LENGTH; ++i) {
os << std::setw(2) << static_cast<unsigned int>(hash[i]);
}
return os.str();
}
ModelDownloader::~ModelDownloader()
{
if (this->download_thread != nullptr) {

View File

@@ -28,6 +28,8 @@ signals:
private:
static int progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow,
curl_off_t ultotal, curl_off_t ulnow);
bool valid_hash(std::string path, std::string hash);
std::string sha256_sum(const char *const path);
ModelInfo model_info;
};