Add naive encoding classifier

This commit is contained in:
Allan Odgaard
2013-10-03 18:20:54 +02:00
parent 344fac505b
commit bd978fcb7e
4 changed files with 271 additions and 0 deletions

View File

@@ -0,0 +1,100 @@
#include "encoding.h"
#include "frequencies.capnp.h"
#include <capnp/message.h>
#include <capnp/serialize-packed.h>
static uint32_t const kCapnpClassifierFormatVersion = 1;
namespace encoding
{
std::vector<std::string> classifier_t::charsets () const
{
std::vector<std::string> res;
for(auto const& pair : _charsets)
res.emplace_back(pair.first);
return res;
}
void classifier_t::load (std::string const& path)
{
int fd = open(path.c_str(), O_RDONLY|O_CLOEXEC);
if(fd != -1)
{
capnp::PackedFdMessageReader message(kj::AutoCloseFd{fd});
auto freq = message.getRoot<Frequencies>();
if(freq.getVersion() != kCapnpClassifierFormatVersion)
{
fprintf(stderr, "skip %s version %u (expected %u)\n", path.c_str(), freq.getVersion(), kCapnpClassifierFormatVersion);
return;
}
for(auto const& src : freq.getCharsets())
{
record_t r;
for(auto const& word : src.getWords())
r.words.emplace(word.getType().getWord(), word.getCount());
for(auto const& byte : src.getBytes())
r.bytes.emplace(byte.getType().getByte(), byte.getCount());
_charsets.emplace(src.getCharset(), r);
}
for(auto& pair : _charsets)
{
for(auto const& word : pair.second.words)
{
_combined.words[word.first] += word.second;
_combined.total_words += word.second;
pair.second.total_words += word.second;
}
for(auto const& byte : pair.second.bytes)
{
_combined.bytes[byte.first] += byte.second;
_combined.total_bytes += byte.second;
pair.second.total_bytes += byte.second;
}
}
}
}
void classifier_t::save (std::string const& path) const
{
capnp::MallocMessageBuilder message;
auto freq = message.initRoot<Frequencies>();
freq.setVersion(kCapnpClassifierFormatVersion);
auto charsets = freq.initCharsets(_charsets.size());
size_t i = 0;
for(auto const& pair : _charsets)
{
auto entry = charsets[i++];
entry.setCharset(pair.first);
auto words = entry.initWords(pair.second.words.size());
size_t j = 0;
for(auto const& word : pair.second.words)
{
auto tmp = words[j++];
tmp.getType().setWord(word.first);
tmp.setCount(word.second);
}
auto bytes = entry.initBytes(pair.second.bytes.size());
j = 0;
for(auto const& byte : pair.second.bytes)
{
auto tmp = bytes[j++];
tmp.getType().setByte(byte.first);
tmp.setCount(byte.second);
}
}
int fd = open(path.c_str(), O_CREAT|O_TRUNC|O_WRONLY|O_CLOEXEC, S_IRUSR|S_IWUSR|S_IRGRP|S_IWGRP|S_IROTH);
if(fd != -1)
{
writePackedMessageToFd(fd, message);
close(fd);
}
}
} /* encoding */

View File

@@ -0,0 +1,148 @@
#ifndef ENCODING_H_3OJVUZM1
#define ENCODING_H_3OJVUZM1
#include <oak/misc.h>
namespace encoding
{
struct PUBLIC classifier_t
{
void load (std::string const& path);
void save (std::string const& path) const;
template <typename _InputIter>
void learn (_InputIter const& first, _InputIter const& last, std::string const& charset)
{
auto& r = _charsets[charset];
each_word(first, last, [&](std::string const& word){
r.words[word] += 1;
r.total_words += 1;
_combined.words[word] += 1;
_combined.total_words += 1;
for(char ch : word)
{
if(ch > 0x7F)
{
r.bytes[ch] += 1;
r.total_bytes += 1;
_combined.bytes[ch] += 1;
_combined.total_bytes += 1;
}
}
});
}
template <typename _InputIter>
double probability (_InputIter const& first, _InputIter const& last, std::string const& charset) const
{
auto record = _charsets.find(charset);
if(record == _charsets.end())
return 0;
std::set<std::string> seen;
double a = 1, b = 1;
each_word(first, last, [&](std::string const& word){
auto global = _combined.words.find(word);
if(global != _combined.words.end() && seen.find(word) == seen.end())
{
auto local = record->second.words.find(word);
if(local != record->second.words.end())
{
double pWT = local->second / (double)record->second.total_words;
double pWF = (global->second - local->second) / (double)_combined.total_words;
double p = pWT / (pWT + pWF);
a *= p;
b *= 1-p;
}
else
{
a = 0;
}
seen.insert(word);
}
else
{
for(char ch : word)
{
if(ch > 0x7F)
{
auto global = _combined.bytes.find(ch);
if(global != _combined.bytes.end())
{
auto local = record->second.bytes.find(ch);
if(local != record->second.bytes.end())
{
double pWT = local->second / (double)record->second.total_bytes;
double pWF = (global->second - local->second) / (double)_combined.total_bytes;
double p = pWT / (pWT + pWF);
a *= p;
b *= 1-p;
}
else
{
a = 0;
}
}
}
}
}
});
return (a + b) == 0 ? 0 : a / (a + b);
}
std::vector<std::string> charsets () const;
bool operator== (classifier_t const& rhs) const
{
return _charsets == rhs._charsets && _combined == rhs._combined;
}
bool operator!= (classifier_t const& rhs) const
{
return !(*this == rhs);
}
private:
template <typename _InputIter, typename _F>
static void each_word (_InputIter const& first, _InputIter const& last, _F op)
{
for(auto eow = first; eow != last; )
{
auto bow = std::find_if(eow, last, [](char ch){ return isalpha(ch) || ch > 0x7F; });
eow = std::find_if(bow, last, [](char ch){ return !isalnum(ch) && ch < 0x80; });
if(std::find_if(bow, eow, [](char ch){ return ch > 0x7F; }) != eow)
op(std::string(bow, eow));
}
}
struct record_t
{
bool operator== (record_t const& rhs) const
{
return words == rhs.words && bytes == rhs.bytes && total_words == rhs.total_words && total_bytes == rhs.total_bytes;
}
bool operator!= (record_t const& rhs) const
{
return !(*this == rhs);
}
std::map<std::string, size_t> words;
std::map<char, size_t> bytes;
size_t total_words = 0;
size_t total_bytes = 0;
};
std::map<std::string, record_t> _charsets;
record_t _combined;
};
} /* encoding */
#endif /* end of include guard: ENCODING_H_3OJVUZM1 */

View File

@@ -0,0 +1,20 @@
@0xf07cdbe73cbefea0;
struct Charset {
charset @0 :Text;
words @1 :List(Pair);
bytes @2 :List(Pair);
struct Pair {
type :union {
word @0 :Text;
byte @1 :UInt8;
}
count @2 :UInt64;
}
}
struct Frequencies {
version @0 :UInt32 = 1;
charsets @1 :List(Charset);
}