From a92254322c803a75c37462bac78b4f4d766fd1fd Mon Sep 17 00:00:00 2001 From: Vlad Niculae Date: Thu, 22 Sep 2016 17:31:14 -0400 Subject: [PATCH] fix possible bug in axes normalization --- seqlearn/hmm.py | 4 ++-- seqlearn/tests/test_hmm.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/seqlearn/hmm.py b/seqlearn/hmm.py index 59df882..4de0c57 100644 --- a/seqlearn/hmm.py +++ b/seqlearn/hmm.py @@ -68,10 +68,10 @@ def fit(self, X, y, lengths): final_prob -= logsumexp(final_prob) feature_prob = np.log(safe_sparse_dot(Y.T, X) + alpha) - feature_prob -= logsumexp(feature_prob, axis=0) + feature_prob -= logsumexp(feature_prob, axis=1)[:, np.newaxis] trans_prob = np.log(count_trans(y, len(classes)) + alpha) - trans_prob -= logsumexp(trans_prob, axis=0) + trans_prob -= logsumexp(trans_prob, axis=1)[:, np.newaxis] self.coef_ = feature_prob self.intercept_init_ = init_prob diff --git a/seqlearn/tests/test_hmm.py b/seqlearn/tests/test_hmm.py index dfec148..825814e 100644 --- a/seqlearn/tests/test_hmm.py +++ b/seqlearn/tests/test_hmm.py @@ -34,10 +34,10 @@ def test_hmm(): assert_array_equal(clf.predict(X), y) n_classes = len(clf.classes_) - assert_array_almost_equal(np.ones(n_features), - np.exp(clf.coef_).sum(axis=0)) assert_array_almost_equal(np.ones(n_classes), - np.exp(clf.intercept_trans_).sum(axis=0)) + np.exp(clf.coef_).sum(axis=1)) + assert_array_almost_equal(np.ones(n_classes), + np.exp(clf.intercept_trans_).sum(axis=1)) assert_array_almost_equal(1., np.exp(clf.intercept_final_).sum()) assert_array_almost_equal(1., np.exp(clf.intercept_init_).sum())