Skip to content

Commit 8a60a2e

Browse files
authored
Merge pull request RasaHQ#3319 from RasaHQ/cli-arguments-train-core
rasa train core: use additional training arguments
2 parents 747683b + 0774091 commit 8a60a2e

File tree

4 files changed

+51
-9
lines changed

4 files changed

+51
-9
lines changed

.travis.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ install:
1919
- python -m pip install -U pip
2020
- pip install git+https://github.com/tmbo/MITIE.git
2121
- pip install -r requirements-dev.txt
22-
- pip install -e .
22+
- pip install -e . --no-use-pep517
2323
- pip install coveralls==1.3.0
2424
- pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_md-2.0.0/en_core_web_md-2.0.0.tar.gz --no-cache-dir > jnk
2525
- python -m spacy link en_core_web_md en

CHANGELOG.rst

+1
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,4 @@ Fixed
5555
Store
5656
- ``rasa nlu test`` doesn't error anymore when a test file is passed with ``-u``
5757
- in interactive learning: only updates entity values if user changes annotation
58+
- ``rasa train core`` actually uses additional arguments, such as `augmentation`

rasa/cli/train.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,14 @@ def train(args: argparse.Namespace) -> Optional[Text]:
118118
get_validated_path(f, "data", DEFAULT_DATA_PATH) for f in args.data
119119
]
120120

121-
return rasa.train(domain, config, training_files, args.out, args.force)
121+
return rasa.train(
122+
domain,
123+
config,
124+
training_files,
125+
args.out,
126+
args.force,
127+
extract_additional_arguments(args),
128+
)
122129

123130

124131
def train_core(
@@ -143,7 +150,14 @@ def train_core(
143150

144151
config = get_validated_path(args.config, "config", DEFAULT_CONFIG_PATH)
145152

146-
return train_core(args.domain, config, stories, output, train_path)
153+
return train_core(
154+
args.domain,
155+
config,
156+
stories,
157+
output,
158+
train_path,
159+
extract_additional_arguments(args),
160+
)
147161
else:
148162
from rasa.core.train import do_compare_training
149163

@@ -162,3 +176,13 @@ def train_nlu(
162176
nlu_data = get_validated_path(args.nlu, "nlu", DEFAULT_DATA_PATH)
163177

164178
return train_nlu(config, nlu_data, output, train_path)
179+
180+
181+
def extract_additional_arguments(args: argparse.Namespace) -> typing.Dict:
182+
return {
183+
"augmentation_factor": args.augmentation,
184+
"dump_stories": args.dump_stories,
185+
"debug_plots": args.debug_plots,
186+
"percentages": args.percentages,
187+
"runs": args.runs,
188+
}

rasa/train.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import tempfile
44
import typing
5-
from typing import Text, Optional, List, Union
5+
from typing import Text, Optional, List, Union, Dict
66

77
from rasa import model, data
88
from rasa.cli.utils import create_output_path, print_success
@@ -18,10 +18,11 @@ def train(
1818
training_files: Union[Text, List[Text]],
1919
output: Text = DEFAULT_MODELS_PATH,
2020
force_training: bool = False,
21+
kwargs: Optional[Dict] = None,
2122
) -> Optional[Text]:
2223
loop = asyncio.get_event_loop()
2324
return loop.run_until_complete(
24-
train_async(domain, config, training_files, output, force_training)
25+
train_async(domain, config, training_files, output, force_training, kwargs)
2526
)
2627

2728

@@ -31,6 +32,7 @@ async def train_async(
3132
training_files: Union[Text, List[Text]],
3233
output: Text = DEFAULT_MODELS_PATH,
3334
force_training: bool = False,
35+
kwargs: Optional[Dict] = None,
3436
) -> Optional[Text]:
3537
"""Trains a Rasa model (Core and NLU).
3638
@@ -40,6 +42,7 @@ async def train_async(
4042
training_files: Paths to the training data for Core and NLU.
4143
output: Output path.
4244
force_training: If `True` retrain model even if data has not changed.
45+
kwargs: Additional training parameters.
4346
4447
Returns:
4548
Path of the trained model archive.
@@ -69,7 +72,9 @@ async def train_async(
6972
retrain_nlu = not model.merge_model(old_nlu, target_path)
7073

7174
if force_training or retrain_core:
72-
await train_core_async(domain, config, story_directory, output, train_path)
75+
await train_core_async(
76+
domain, config, story_directory, output, train_path, kwargs
77+
)
7378
else:
7479
print (
7580
"Dialogue data / configuration did not change. "
@@ -100,16 +105,26 @@ async def train_async(
100105

101106

102107
def train_core(
103-
domain: Text, config: Text, stories: Text, output: Text, train_path: Optional[Text]
108+
domain: Text,
109+
config: Text,
110+
stories: Text,
111+
output: Text,
112+
train_path: Optional[Text],
113+
kwargs: Optional[Dict],
104114
) -> Optional[Text]:
105115
loop = asyncio.get_event_loop()
106116
return loop.run_until_complete(
107-
train_core_async(domain, config, stories, output, train_path)
117+
train_core_async(domain, config, stories, output, train_path, kwargs)
108118
)
109119

110120

111121
async def train_core_async(
112-
domain: Text, config: Text, stories: Text, output: Text, train_path: Optional[Text]
122+
domain: Text,
123+
config: Text,
124+
stories: Text,
125+
output: Text,
126+
train_path: Optional[Text] = None,
127+
kwargs: Optional[Dict] = None,
113128
) -> Optional[Text]:
114129
"""Trains a Core model.
115130
@@ -120,6 +135,7 @@ async def train_core_async(
120135
output: Output path.
121136
train_path: If `None` the model will be trained in a temporary
122137
directory, otherwise in the provided directory.
138+
kwargs: Additional training parameters.
123139
124140
Returns:
125141
If `train_path` is given it returns the path to the model archive,
@@ -136,6 +152,7 @@ async def train_core_async(
136152
stories_file=stories,
137153
output_path=os.path.join(_train_path, "core"),
138154
policy_config=config,
155+
kwargs=kwargs,
139156
)
140157

141158
if not train_path:

0 commit comments

Comments
 (0)