Skip to content

Commit

Permalink
add files
Browse files Browse the repository at this point in the history
  • Loading branch information
ShimShim46 committed Sep 21, 2018
1 parent b9d849b commit b3d67a5
Show file tree
Hide file tree
Showing 13 changed files with 23,113 additions and 2 deletions.
82 changes: 82 additions & 0 deletions MyEvaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import copy

import six

from chainer import configuration
from chainer.dataset import convert
from chainer.dataset import iterator as iterator_module
from chainer import functions as F
from chainer import function
from chainer import link
from chainer import reporter as reporter_module
from chainer.training import extensions
from chainer.training import extension
from chainer import cuda
import numpy as np
import scipy.sparse as sp
import pdb

class MyEvaluator(extensions.Evaluator):

trigger = 1, 'epoch'
default_name = 'validation'
priority = extension.PRIORITY_WRITER

name = None

def __init__(self, iterator, target, class_dim, converter=convert.concat_examples,
device=None, eval_hook=None, eval_func=None):
if isinstance(iterator, iterator_module.Iterator):
iterator = {'main': iterator}
self._iterators = iterator

if isinstance(target, link.Link):
target = {'main': target}
self._targets = target

self.converter = converter
self.device = device
self.eval_hook = eval_hook
self.eval_func = eval_func
self.class_dim = class_dim

def evaluate(self):

iterator = self._iterators['main']
eval_func = self.eval_func or self._targets['main']

if self.eval_hook:
self.eval_hook(self)

if hasattr(iterator, 'reset'):
iterator.reset()
it = iterator
else:
it = copy.copy(iterator)

summary = reporter_module.DictSummary()

for batch in it:
observation = {}
with reporter_module.report_scope(observation):
row_idx, col_idx, val_idx = [], [], []
x = cuda.to_gpu(np.array([i[0] for i in batch]))
labels = [l[1] for l in batch]
for i in range(len(labels)):
l_list = list(set(labels[i])) # remove duplicate cateories to avoid double count
for y in l_list:
row_idx.append(i)
col_idx.append(y)
val_idx.append(1)
m = len(labels)
n = self.class_dim
t = sp.csr_matrix((val_idx, (row_idx, col_idx)), shape=(m, n), dtype=np.int8).todense()
t = cuda.to_gpu(t)

with function.no_backprop_mode():
#pdb.set_trace()
loss = F.sigmoid_cross_entropy(eval_func(x), t)
summary.add({MyEvaluator.default_name + '/main/loss':loss})
summary.add(observation)

return summary.compute_mean()
62 changes: 62 additions & 0 deletions MyUpdater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import six
import numpy as np
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import cuda, training, reporter
from chainer.datasets import get_mnist
from chainer.training import trainer, extensions
from chainer.dataset import convert
from chainer.dataset import iterator as iterator_module
from chainer.datasets import get_mnist
from chainer import optimizer as optimizer_module
import scipy.sparse as sp
import pdb

class MyUpdater(training.StandardUpdater):
def __init__(self, iterator, optimizer, class_dim, converter=convert.concat_examples,
device=None, loss_func=None):
if isinstance(iterator, iterator_module.Iterator):
iterator = {'main': iterator}
self._iterators = iterator

if not isinstance(optimizer, dict):
optimizer = {'main': optimizer}
self._optimizers = optimizer

if device is not None and device >= 0:
for optimizer in six.itervalues(self._optimizers):
optimizer.target.to_gpu(device)

self.converter = converter
self.loss_func = loss_func
self.device = device
self.iteration = 0
self.class_dim = class_dim

def update_core(self):
batch = self._iterators['main'].next()

x = chainer.cuda.to_gpu(np.array([i[0] for i in batch]))
labels = [l[1] for l in batch]
row_idx, col_idx, val_idx = [], [], []
for i in range(len(labels)):
l_list = list(set(labels[i])) # remove duplicate cateories to avoid double count
for y in l_list:
row_idx.append(i)
col_idx.append(y)
val_idx.append(1)
m = len(labels)
n = self.class_dim
t = sp.csr_matrix((val_idx, (row_idx, col_idx)), shape=(m, n), dtype=np.int8).todense()


t = chainer.cuda.to_gpu(t)

optimizer = self._optimizers['main']
optimizer.target.cleargrads()
loss = F.sigmoid_cross_entropy(optimizer.target(x), t)
chainer.reporter.report({'main/loss':loss})
loss.backward()
optimizer.update()

114 changes: 112 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,112 @@
# HFT-CNN
to be prepared.
HFT-CNN
==
このコードでは次の4種類のCNNモデルを用いた文書分類ができます:
* Flat モデル : 階層構造を利用せずに学習
* Without Fine-tuning (WoFt) モデル : 階層構造を利用するがFine-tuningは利用せずに学習
* Hierarchical Fine-Tuning (HFT) モデル : 階層構造とFine-tuningを利用して学習
* XML-CNN モデル ([Liu+ '17](http://nyc.lti.cs.cmu.edu/yiming/Publications/jliu-sigir17.pdf)) : Liuら'17 の提案したモデル

このコードを用いる際には次の論文をご参照ください:

**HFT-CNN: Learning Hierarchical Category Structure for Multi-label Short Text Categorization** Kazuya Shimura, Jiyi Li and Fumiyo Fukumoto. EMNLP, 2018.


### 各モデルの特徴

| 特徴\手法 | Flatモデル | WoFtモデル | HFTモデル | XML-CNNモデル |
|-----------------------:|:-------------:|:-------------:|:-------------:|:-------------------:|
| Hierarchycal Structure | ||| |
| Fine-tuning | ||| |
| Pooling Type | 1-max pooling | 1-max pooling | 1-max pooling | dynamic max pooling |
| Compact Representation | | | ||

## Requirements
このコードを実行するために必要なライブラリのうち、代表的なものを次に示します。
* Python 3.5.4 以降
* Chainer 4.0.0 以降 ([chainer](http://chainer.org/))
* CuPy 4.0.0 以降 ([cupy](https://cupy.chainer.org/))

注意:
* 現在のコードのバージョンでは**GPU**を利用することが前提となっています。
* コードを実行するために必要なライブラリの詳細はrequirements.txtをご参照ください。

## Installation
* このページの **clone or download** からコードをダウンロード
* requirements.txtに書かれたライブラリをインストールし、実行環境を構築
* もし必要であれば、次の手順でAnaconda([anaconda](https://www.anaconda.com/enterprise/))による仮想環境を構築
1. [Anacondaのダウンロードページ](https://www.anaconda.com/download/)から自分の環境にあったものをインストール
* 例: Linux(x86アーキテクチャ, 64bit)にインストールする場合:
1. wget https://repo.continuum.io/archive/Anaconda3-5.1.0-Linux-x86_64.sh
1. bash Anaconda3-5.1.0-Linux-x86_64.sh

でインストールできます。
3. Anacondaをインストール後、仮想環境を構築
```conda env create -f=hft_cnn_env.yml```
4. ```source activate hft_cnn_env``` で仮想環境に切り替え
5. この環境内でHFT-CNNのコードを実行することが可能


## Quick-start
exmaple.shを実行することでFlatモデルを用いたサンプル文書(Amazon商品レビュー)の自動分類を試すことができます:
```
bash example.sh
--------------------------------------------------
Loading data...
Loading train data: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 465927/465927 [00:18<00:00, 24959.42it/s]
Loading valid data: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24522/24522 [00:00<00:00, 27551.44it/s]
Loading test data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 153025/153025 [00:05<00:00, 27051.62it/s]
--------------------------------------------------
Loading Word embedings...
```
学習後の結果はCNNディレクトリに保存されます.
* RESULT : テストデータを分類した結果
* PARAMS : 学習後のCNNのパラメータ
* LOG : 学習のログファイル

### 学習モデルの変更
```example.sh```内の ```ModelType``` を変更することで学習するモデルを変更することができます
```
## Network Type (XML-CNN, CNN-Flat, CNN-Hierarchy, CNN-fine-tuning or Pre-process)
ModelType=XML-CNN
```
注意:
* CNN-Hierarchy, CNN-fine-tuningを選択する場合には**Pre-process**で学習をしてから学習を行ってください
* Pre-processでは階層構造の第1階層目のみを学習し、CNNのパラメータを保存します
* このときに保存されたパラメータはCNN-Hierarchy, CNN-fine-tuningの両タイプで共有されます

### 単語の分散表現について
このコードでは単語の分散表現に[fastText](https://github.com/facebookresearch/fastText)の学習結果を利用しています.

```example.sh```内の```EmbeddingWeightsPath```に単語埋め込み層の初期値として利用したいfastTextの```bin```ファイルを指定することができます。

fastTextの```bin```ファイルを用意していない場合、英語Wikipediaコーパスを用いた単語の分散表現が[chakin](https://github.com/chakki-works/chakin)を用いて自動的にダウンロードされます。

コードに手を加えず```example.sh```を実行した場合にはWord_embeddingディレクトリに```wiki.en.vec```がダウンロードされ、これが利用されます。
```
## Embedding Weights Type (fastText .bin and .vec)
EmbeddingWeightsPath=./Word_embedding/
```




## 新しいデータでモデルを学習
### データについて
#### 種類
必要な文書データは3種類です:
* 訓練データ : CNNを学習させるために必要なデータ
* 評価データ: CNNの汎化性能を検証するために必要なデータ
* テストデータ : CNNを用いて分類したいデータ

評価データは各エポックごとにCNNの汎化誤差を評価する際に用いられ、学習の継続によって過学習が起きた場合にEarly Stoppingを行います. また保存されるCNNのパラメータは汎化誤差が最も小さい時のエポックのものが保存されます.

#### 形式

### 文書データが階層構造を有する場合

## License





Loading

0 comments on commit b3d67a5

Please sign in to comment.