mirror of
https://github.com/royshil/obs-localvocal.git
synced 2026-01-10 04:48:02 -05:00
Add hash validation for downloading transcription models
This commit is contained in:
@@ -3,7 +3,18 @@
|
||||
|
||||
#include <obs-module.h>
|
||||
|
||||
#include <cstdlib>
|
||||
#include <cerrno>
|
||||
#include <cstring>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <openssl/sha.h>
|
||||
|
||||
size_t write_data(void *ptr, size_t size, size_t nmemb, FILE *stream)
|
||||
{
|
||||
@@ -211,6 +222,10 @@ void ModelDownloadWorker::download_model()
|
||||
emit download_error("Failed to download model file.");
|
||||
}
|
||||
fclose(fp);
|
||||
|
||||
if (!valid_hash(model_file_save_path, model_download_file.sha256)) {
|
||||
emit download_error("Downloaded model has invalid hash");
|
||||
}
|
||||
}
|
||||
curl_easy_cleanup(curl);
|
||||
emit download_finished(model_local_config_path);
|
||||
@@ -236,6 +251,60 @@ int ModelDownloadWorker::progress_callback(void *clientp, curl_off_t dltotal, cu
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool ModelDownloadWorker::valid_hash(std::string path, std::string hash)
|
||||
{
|
||||
auto calculated_hash = sha256_sum(path.c_str());
|
||||
|
||||
if (hash == "") {
|
||||
obs_log(LOG_WARNING, "No hash for model in config. Calculated hash: %s",
|
||||
calculated_hash.c_str());
|
||||
return true;
|
||||
} else if (hash == calculated_hash) {
|
||||
obs_log(LOG_INFO, "Model hash is valid");
|
||||
return true;
|
||||
} else {
|
||||
obs_log(LOG_ERROR, "Model hash mismatch. Model hash: %s, calculated hash: %s",
|
||||
hash.c_str(), calculated_hash.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
std::string ModelDownloadWorker::sha256_sum(const char *const path)
|
||||
{
|
||||
std::ifstream fp(path, std::ios::in | std::ios::binary);
|
||||
|
||||
if (not fp.good()) {
|
||||
std::ostringstream os;
|
||||
os << "Cannot open \"" << path << "\": " << std::strerror(errno) << ".";
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
|
||||
constexpr const std::size_t buffer_size{1 << 12};
|
||||
char buffer[buffer_size];
|
||||
|
||||
unsigned char hash[SHA256_DIGEST_LENGTH] = {0};
|
||||
|
||||
SHA256_CTX ctx;
|
||||
SHA256_Init(&ctx);
|
||||
|
||||
while (fp.good()) {
|
||||
fp.read(buffer, buffer_size);
|
||||
SHA256_Update(&ctx, buffer, fp.gcount());
|
||||
}
|
||||
|
||||
SHA256_Final(hash, &ctx);
|
||||
fp.close();
|
||||
|
||||
std::ostringstream os;
|
||||
os << std::hex << std::setfill('0');
|
||||
|
||||
for (int i = 0; i < SHA256_DIGEST_LENGTH; ++i) {
|
||||
os << std::setw(2) << static_cast<unsigned int>(hash[i]);
|
||||
}
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
ModelDownloader::~ModelDownloader()
|
||||
{
|
||||
if (this->download_thread != nullptr) {
|
||||
|
||||
@@ -28,6 +28,8 @@ signals:
|
||||
private:
|
||||
static int progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow,
|
||||
curl_off_t ultotal, curl_off_t ulnow);
|
||||
bool valid_hash(std::string path, std::string hash);
|
||||
std::string sha256_sum(const char *const path);
|
||||
ModelInfo model_info;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user