@@ -136,13 +136,15 @@ void istft_impl(const float* in_data,
136
136
OPENVINO_ASSERT (fft_results_dim == static_cast <size_t >((frame_size / 2 ) + 1 ));
137
137
138
138
const auto frame_size_dim = static_cast <size_t >(frame_size);
139
- const auto frame_size_dim_shape = ov::Shape{frame_size_dim};
140
- const auto frame_size_dim_shape_out = ov::Shape{frame_size_dim, 2 };
141
139
const auto fft_out_shape = ov::Shape{fft_results_dim, 2 };
142
140
143
141
const auto window_length = window_shape[0 ] < frame_size_dim ? window_shape[0 ] : frame_size_dim;
144
142
std::vector<float > pad_window (frame_size, 0 );
145
143
std::copy (window, window + window_shape[0 ], pad_window.begin () + (frame_size_dim - window_length) / 2 );
144
+ std::vector<float > pow_window (frame_size, 0 );
145
+ std::transform (pad_window.begin (), pad_window.end (), pow_window.begin (), [](float win_val) {
146
+ return win_val * win_val;
147
+ });
146
148
147
149
std::vector<float > data_t (shape_size (data_shape));
148
150
const auto stft_transp_out_shape = ov::Shape{batch_size, num_frames, fft_out_shape[0 ], fft_out_shape[1 ]};
@@ -187,9 +189,7 @@ void istft_impl(const float* in_data,
187
189
size_t batch_out_start = batch * signal_length;
188
190
189
191
const auto in_frame_start = batch_in_start + frame_idx * fft_out_shape_size;
190
-
191
192
const auto out_frame_start = batch_out_start + frame_idx * frame_step;
192
- const auto out_frame_end = out_frame_start + frame_size;
193
193
194
194
std::vector<float > frame_signal (frame_size);
195
195
rdft_executor->execute (data_t .data () + in_frame_start,
@@ -203,25 +203,20 @@ void istft_impl(const float* in_data,
203
203
{1 },
204
204
{1 });
205
205
206
- std::transform (frame_signal.begin (),
207
- frame_signal.end (),
208
- mid_result.begin () + out_frame_start,
209
- mid_result.begin () + out_frame_start,
210
- std::plus<>());
211
-
212
- std::transform (window_sum.begin () + out_frame_start,
213
- window_sum.begin () + out_frame_end,
214
- pad_window.begin (),
215
- window_sum.begin () + out_frame_start,
216
- std::plus<>());
206
+ // Overlap Add
207
+ float * mid_result_sum = mid_result.data () + out_frame_start;
208
+ float * window_frame_sum = window_sum.data () + out_frame_start;
209
+ for (size_t i = 0 ; i < frame_signal.size (); ++i) {
210
+ mid_result_sum[i] += frame_signal[i] * pad_window[i];
211
+ window_frame_sum[i] += pow_window[i];
212
+ }
217
213
}
218
214
float * result = mid_result.data () + (batch * signal_length);
219
215
std::transform (result,
220
216
result + signal_length,
221
217
window_sum.begin () + batch * signal_length,
222
218
result,
223
219
postprocess_func);
224
-
225
220
const auto result_start = result + margin;
226
221
std::copy (result_start, result_start + copy_end, final_result + batch * final_signal_length);
227
222
});
0 commit comments