-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b9d849b
commit b3d67a5
Showing
13 changed files
with
23,113 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import copy | ||
|
||
import six | ||
|
||
from chainer import configuration | ||
from chainer.dataset import convert | ||
from chainer.dataset import iterator as iterator_module | ||
from chainer import functions as F | ||
from chainer import function | ||
from chainer import link | ||
from chainer import reporter as reporter_module | ||
from chainer.training import extensions | ||
from chainer.training import extension | ||
from chainer import cuda | ||
import numpy as np | ||
import scipy.sparse as sp | ||
import pdb | ||
|
||
class MyEvaluator(extensions.Evaluator): | ||
|
||
trigger = 1, 'epoch' | ||
default_name = 'validation' | ||
priority = extension.PRIORITY_WRITER | ||
|
||
name = None | ||
|
||
def __init__(self, iterator, target, class_dim, converter=convert.concat_examples, | ||
device=None, eval_hook=None, eval_func=None): | ||
if isinstance(iterator, iterator_module.Iterator): | ||
iterator = {'main': iterator} | ||
self._iterators = iterator | ||
|
||
if isinstance(target, link.Link): | ||
target = {'main': target} | ||
self._targets = target | ||
|
||
self.converter = converter | ||
self.device = device | ||
self.eval_hook = eval_hook | ||
self.eval_func = eval_func | ||
self.class_dim = class_dim | ||
|
||
def evaluate(self): | ||
|
||
iterator = self._iterators['main'] | ||
eval_func = self.eval_func or self._targets['main'] | ||
|
||
if self.eval_hook: | ||
self.eval_hook(self) | ||
|
||
if hasattr(iterator, 'reset'): | ||
iterator.reset() | ||
it = iterator | ||
else: | ||
it = copy.copy(iterator) | ||
|
||
summary = reporter_module.DictSummary() | ||
|
||
for batch in it: | ||
observation = {} | ||
with reporter_module.report_scope(observation): | ||
row_idx, col_idx, val_idx = [], [], [] | ||
x = cuda.to_gpu(np.array([i[0] for i in batch])) | ||
labels = [l[1] for l in batch] | ||
for i in range(len(labels)): | ||
l_list = list(set(labels[i])) # remove duplicate cateories to avoid double count | ||
for y in l_list: | ||
row_idx.append(i) | ||
col_idx.append(y) | ||
val_idx.append(1) | ||
m = len(labels) | ||
n = self.class_dim | ||
t = sp.csr_matrix((val_idx, (row_idx, col_idx)), shape=(m, n), dtype=np.int8).todense() | ||
t = cuda.to_gpu(t) | ||
|
||
with function.no_backprop_mode(): | ||
#pdb.set_trace() | ||
loss = F.sigmoid_cross_entropy(eval_func(x), t) | ||
summary.add({MyEvaluator.default_name + '/main/loss':loss}) | ||
summary.add(observation) | ||
|
||
return summary.compute_mean() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import six | ||
import numpy as np | ||
import chainer | ||
import chainer.functions as F | ||
import chainer.links as L | ||
from chainer import cuda, training, reporter | ||
from chainer.datasets import get_mnist | ||
from chainer.training import trainer, extensions | ||
from chainer.dataset import convert | ||
from chainer.dataset import iterator as iterator_module | ||
from chainer.datasets import get_mnist | ||
from chainer import optimizer as optimizer_module | ||
import scipy.sparse as sp | ||
import pdb | ||
|
||
class MyUpdater(training.StandardUpdater): | ||
def __init__(self, iterator, optimizer, class_dim, converter=convert.concat_examples, | ||
device=None, loss_func=None): | ||
if isinstance(iterator, iterator_module.Iterator): | ||
iterator = {'main': iterator} | ||
self._iterators = iterator | ||
|
||
if not isinstance(optimizer, dict): | ||
optimizer = {'main': optimizer} | ||
self._optimizers = optimizer | ||
|
||
if device is not None and device >= 0: | ||
for optimizer in six.itervalues(self._optimizers): | ||
optimizer.target.to_gpu(device) | ||
|
||
self.converter = converter | ||
self.loss_func = loss_func | ||
self.device = device | ||
self.iteration = 0 | ||
self.class_dim = class_dim | ||
|
||
def update_core(self): | ||
batch = self._iterators['main'].next() | ||
|
||
x = chainer.cuda.to_gpu(np.array([i[0] for i in batch])) | ||
labels = [l[1] for l in batch] | ||
row_idx, col_idx, val_idx = [], [], [] | ||
for i in range(len(labels)): | ||
l_list = list(set(labels[i])) # remove duplicate cateories to avoid double count | ||
for y in l_list: | ||
row_idx.append(i) | ||
col_idx.append(y) | ||
val_idx.append(1) | ||
m = len(labels) | ||
n = self.class_dim | ||
t = sp.csr_matrix((val_idx, (row_idx, col_idx)), shape=(m, n), dtype=np.int8).todense() | ||
|
||
|
||
t = chainer.cuda.to_gpu(t) | ||
|
||
optimizer = self._optimizers['main'] | ||
optimizer.target.cleargrads() | ||
loss = F.sigmoid_cross_entropy(optimizer.target(x), t) | ||
chainer.reporter.report({'main/loss':loss}) | ||
loss.backward() | ||
optimizer.update() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,112 @@ | ||
# HFT-CNN | ||
to be prepared. | ||
HFT-CNN | ||
== | ||
このコードでは次の4種類のCNNモデルを用いた文書分類ができます: | ||
* Flat モデル : 階層構造を利用せずに学習 | ||
* Without Fine-tuning (WoFt) モデル : 階層構造を利用するがFine-tuningは利用せずに学習 | ||
* Hierarchical Fine-Tuning (HFT) モデル : 階層構造とFine-tuningを利用して学習 | ||
* XML-CNN モデル ([Liu+ '17](http://nyc.lti.cs.cmu.edu/yiming/Publications/jliu-sigir17.pdf)) : Liuら'17 の提案したモデル | ||
|
||
このコードを用いる際には次の論文をご参照ください: | ||
|
||
**HFT-CNN: Learning Hierarchical Category Structure for Multi-label Short Text Categorization** Kazuya Shimura, Jiyi Li and Fumiyo Fukumoto. EMNLP, 2018. | ||
|
||
|
||
### 各モデルの特徴 | ||
|
||
| 特徴\手法 | Flatモデル | WoFtモデル | HFTモデル | XML-CNNモデル | | ||
|-----------------------:|:-------------:|:-------------:|:-------------:|:-------------------:| | ||
| Hierarchycal Structure | | ✔ | ✔ | | | ||
| Fine-tuning | | ✔ | ✔ | | | ||
| Pooling Type | 1-max pooling | 1-max pooling | 1-max pooling | dynamic max pooling | | ||
| Compact Representation | | | | ✔ | | ||
|
||
## Requirements | ||
このコードを実行するために必要なライブラリのうち、代表的なものを次に示します。 | ||
* Python 3.5.4 以降 | ||
* Chainer 4.0.0 以降 ([chainer](http://chainer.org/)) | ||
* CuPy 4.0.0 以降 ([cupy](https://cupy.chainer.org/)) | ||
|
||
注意: | ||
* 現在のコードのバージョンでは**GPU**を利用することが前提となっています。 | ||
* コードを実行するために必要なライブラリの詳細はrequirements.txtをご参照ください。 | ||
|
||
## Installation | ||
* このページの **clone or download** からコードをダウンロード | ||
* requirements.txtに書かれたライブラリをインストールし、実行環境を構築 | ||
* もし必要であれば、次の手順でAnaconda([anaconda](https://www.anaconda.com/enterprise/))による仮想環境を構築 | ||
1. [Anacondaのダウンロードページ](https://www.anaconda.com/download/)から自分の環境にあったものをインストール | ||
* 例: Linux(x86アーキテクチャ, 64bit)にインストールする場合: | ||
1. wget https://repo.continuum.io/archive/Anaconda3-5.1.0-Linux-x86_64.sh | ||
1. bash Anaconda3-5.1.0-Linux-x86_64.sh | ||
|
||
でインストールできます。 | ||
3. Anacondaをインストール後、仮想環境を構築 | ||
```conda env create -f=hft_cnn_env.yml``` | ||
4. ```source activate hft_cnn_env``` で仮想環境に切り替え | ||
5. この環境内でHFT-CNNのコードを実行することが可能 | ||
|
||
|
||
## Quick-start | ||
exmaple.shを実行することでFlatモデルを用いたサンプル文書(Amazon商品レビュー)の自動分類を試すことができます: | ||
``` | ||
bash example.sh | ||
-------------------------------------------------- | ||
Loading data... | ||
Loading train data: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 465927/465927 [00:18<00:00, 24959.42it/s] | ||
Loading valid data: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24522/24522 [00:00<00:00, 27551.44it/s] | ||
Loading test data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 153025/153025 [00:05<00:00, 27051.62it/s] | ||
-------------------------------------------------- | ||
Loading Word embedings... | ||
``` | ||
学習後の結果はCNNディレクトリに保存されます. | ||
* RESULT : テストデータを分類した結果 | ||
* PARAMS : 学習後のCNNのパラメータ | ||
* LOG : 学習のログファイル | ||
|
||
### 学習モデルの変更 | ||
```example.sh```内の ```ModelType``` を変更することで学習するモデルを変更することができます | ||
``` | ||
## Network Type (XML-CNN, CNN-Flat, CNN-Hierarchy, CNN-fine-tuning or Pre-process) | ||
ModelType=XML-CNN | ||
``` | ||
注意: | ||
* CNN-Hierarchy, CNN-fine-tuningを選択する場合には**Pre-process**で学習をしてから学習を行ってください | ||
* Pre-processでは階層構造の第1階層目のみを学習し、CNNのパラメータを保存します | ||
* このときに保存されたパラメータはCNN-Hierarchy, CNN-fine-tuningの両タイプで共有されます | ||
|
||
### 単語の分散表現について | ||
このコードでは単語の分散表現に[fastText](https://github.com/facebookresearch/fastText)の学習結果を利用しています. | ||
|
||
```example.sh```内の```EmbeddingWeightsPath```に単語埋め込み層の初期値として利用したいfastTextの```bin```ファイルを指定することができます。 | ||
|
||
fastTextの```bin```ファイルを用意していない場合、英語Wikipediaコーパスを用いた単語の分散表現が[chakin](https://github.com/chakki-works/chakin)を用いて自動的にダウンロードされます。 | ||
|
||
コードに手を加えず```example.sh```を実行した場合にはWord_embeddingディレクトリに```wiki.en.vec```がダウンロードされ、これが利用されます。 | ||
``` | ||
## Embedding Weights Type (fastText .bin and .vec) | ||
EmbeddingWeightsPath=./Word_embedding/ | ||
``` | ||
|
||
|
||
|
||
|
||
## 新しいデータでモデルを学習 | ||
### データについて | ||
#### 種類 | ||
必要な文書データは3種類です: | ||
* 訓練データ : CNNを学習させるために必要なデータ | ||
* 評価データ: CNNの汎化性能を検証するために必要なデータ | ||
* テストデータ : CNNを用いて分類したいデータ | ||
|
||
評価データは各エポックごとにCNNの汎化誤差を評価する際に用いられ、学習の継続によって過学習が起きた場合にEarly Stoppingを行います. また保存されるCNNのパラメータは汎化誤差が最も小さい時のエポックのものが保存されます. | ||
|
||
#### 形式 | ||
|
||
### 文書データが階層構造を有する場合 | ||
|
||
## License | ||
|
||
|
||
|
||
|
||
|
Oops, something went wrong.