diff --git a/sgdml/__init__.py b/sgdml/__init__.py index 4c2e3d1..9eb651f 100644 --- a/sgdml/__init__.py +++ b/sgdml/__init__.py @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -__version__ = '0.5.3' +__version__ = '0.5.4' MAX_PRINT_WIDTH = 100 LOG_LEVELNAME_WIDTH = 7 # do not modify diff --git a/sgdml/predict.py b/sgdml/predict.py index 4768226..ebf4220 100644 --- a/sgdml/predict.py +++ b/sgdml/predict.py @@ -1278,11 +1278,10 @@ def predict(self, R=None, return_E=True): _predict_wo_wkr_starts_stops, self.wkr_starts_stops ) ) - - if R is not None: # Not in train mode. TODO: better set y_std to zero - E_F *= self.std - F = E_F[:, 1:] - E = E_F[:, 0] + self.c + + E_F *= self.std + F = E_F[:, 1:] + E = E_F[:, 0] + self.c ret = (F,) if return_E: