Skip to content

Commit

Permalink
fix potential bug reading model data into a small size optimized stri…
Browse files Browse the repository at this point in the history
…ng which could lead to memory corruption. In an SSO string, you can't write data to &str[0] and expect it to work well.

Also added a small wrapper function to more safely read model data without having to get the sizeof right. I tested this on tiny, base and large models, there was no change in behaviour.
  • Loading branch information
berthubert authored and ggerganov committed Dec 10, 2022
1 parent 603f97b commit d1da35d
Showing 1 changed file with 33 additions and 24 deletions.
57 changes: 33 additions & 24 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,12 @@ struct whisper_context {
int32_t exp_n_audio_ctx; // 0 - use default
};

template<typename T>
static void read_safe(std::ifstream& fin, T& dest)
{
fin.read((char*)& dest, sizeof(T));
}

// load the model from a ggml file
//
// file format:
Expand All @@ -455,7 +461,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// verify magic
{
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
read_safe(fin, magic);
if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
return false;
Expand All @@ -466,17 +472,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
{
auto & hparams = model.hparams;

fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fin.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
fin.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
fin.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
fin.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
fin.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
fin.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
fin.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
fin.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
fin.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
fin.read((char *) &hparams.f16, sizeof(hparams.f16));
read_safe(fin, hparams.n_vocab);
read_safe(fin, hparams.n_audio_ctx);
read_safe(fin, hparams.n_audio_state);
read_safe(fin, hparams.n_audio_head);
read_safe(fin, hparams.n_audio_layer);
read_safe(fin, hparams.n_text_ctx);
read_safe(fin, hparams.n_text_state);
read_safe(fin, hparams.n_text_head);
read_safe(fin, hparams.n_text_layer);
read_safe(fin, hparams.n_mels);
read_safe(fin, hparams.f16);

assert(hparams.n_text_state == hparams.n_audio_state);

Expand Down Expand Up @@ -524,8 +530,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
{
auto & filters = wctx.model.filters;

fin.read((char *) &filters.n_mel, sizeof(filters.n_mel));
fin.read((char *) &filters.n_fft, sizeof(filters.n_fft));
read_safe(fin, filters.n_mel);
read_safe(fin, filters.n_fft);

filters.data.resize(filters.n_mel * filters.n_fft);
fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
Expand All @@ -534,7 +540,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// load vocab
{
int32_t n_vocab = 0;
fin.read((char *) &n_vocab, sizeof(n_vocab));
read_safe(fin, n_vocab);

//if (n_vocab != model.hparams.n_vocab) {
// fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
Expand All @@ -545,10 +551,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
std::string word;
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
fin.read((char *) &len, sizeof(len));
read_safe(fin, len);

word.resize(len);
fin.read((char *) word.data(), len);
std::vector<char> tmp(len); // create a buffer
fin.read( &tmp[0], tmp.size() ); // read to buffer
word.assign(&tmp[0], tmp.size());

vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
Expand Down Expand Up @@ -998,9 +1005,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
int32_t length;
int32_t ftype;

fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
read_safe(fin, n_dims);
read_safe(fin, length);
read_safe(fin, ftype);

if (fin.eof()) {
break;
Expand All @@ -1009,12 +1016,14 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
int32_t nelements = 1;
int32_t ne[3] = { 1, 1, 1 };
for (int i = 0; i < n_dims; ++i) {
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
read_safe(fin, ne[i]);
nelements *= ne[i];
}

std::string name(length, 0);
fin.read(&name[0], length);
std::string name;
std::vector<char> tmp(length); // create a buffer
fin.read( &tmp[0], tmp.size() ); // read to buffer
name.assign(&tmp[0], tmp.size());

if (model.tensors.find(name.data()) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
Expand Down

0 comments on commit d1da35d

Please sign in to comment.