2
2
import os
3
3
import tempfile
4
4
import typing
5
- from typing import Text , Optional , List , Union
5
+ from typing import Text , Optional , List , Union , Dict
6
6
7
7
from rasa import model , data
8
8
from rasa .cli .utils import create_output_path , print_success
@@ -18,10 +18,11 @@ def train(
18
18
training_files : Union [Text , List [Text ]],
19
19
output : Text = DEFAULT_MODELS_PATH ,
20
20
force_training : bool = False ,
21
+ kwargs : Optional [Dict ] = None ,
21
22
) -> Optional [Text ]:
22
23
loop = asyncio .get_event_loop ()
23
24
return loop .run_until_complete (
24
- train_async (domain , config , training_files , output , force_training )
25
+ train_async (domain , config , training_files , output , force_training , kwargs )
25
26
)
26
27
27
28
@@ -31,6 +32,7 @@ async def train_async(
31
32
training_files : Union [Text , List [Text ]],
32
33
output : Text = DEFAULT_MODELS_PATH ,
33
34
force_training : bool = False ,
35
+ kwargs : Optional [Dict ] = None ,
34
36
) -> Optional [Text ]:
35
37
"""Trains a Rasa model (Core and NLU).
36
38
@@ -40,6 +42,7 @@ async def train_async(
40
42
training_files: Paths to the training data for Core and NLU.
41
43
output: Output path.
42
44
force_training: If `True` retrain model even if data has not changed.
45
+ kwargs: Additional training parameters.
43
46
44
47
Returns:
45
48
Path of the trained model archive.
@@ -69,7 +72,9 @@ async def train_async(
69
72
retrain_nlu = not model .merge_model (old_nlu , target_path )
70
73
71
74
if force_training or retrain_core :
72
- await train_core_async (domain , config , story_directory , output , train_path )
75
+ await train_core_async (
76
+ domain , config , story_directory , output , train_path , kwargs
77
+ )
73
78
else :
74
79
print (
75
80
"Dialogue data / configuration did not change. "
@@ -100,16 +105,26 @@ async def train_async(
100
105
101
106
102
107
def train_core (
103
- domain : Text , config : Text , stories : Text , output : Text , train_path : Optional [Text ]
108
+ domain : Text ,
109
+ config : Text ,
110
+ stories : Text ,
111
+ output : Text ,
112
+ train_path : Optional [Text ],
113
+ kwargs : Optional [Dict ],
104
114
) -> Optional [Text ]:
105
115
loop = asyncio .get_event_loop ()
106
116
return loop .run_until_complete (
107
- train_core_async (domain , config , stories , output , train_path )
117
+ train_core_async (domain , config , stories , output , train_path , kwargs )
108
118
)
109
119
110
120
111
121
async def train_core_async (
112
- domain : Text , config : Text , stories : Text , output : Text , train_path : Optional [Text ]
122
+ domain : Text ,
123
+ config : Text ,
124
+ stories : Text ,
125
+ output : Text ,
126
+ train_path : Optional [Text ] = None ,
127
+ kwargs : Optional [Dict ] = None ,
113
128
) -> Optional [Text ]:
114
129
"""Trains a Core model.
115
130
@@ -120,6 +135,7 @@ async def train_core_async(
120
135
output: Output path.
121
136
train_path: If `None` the model will be trained in a temporary
122
137
directory, otherwise in the provided directory.
138
+ kwargs: Additional training parameters.
123
139
124
140
Returns:
125
141
If `train_path` is given it returns the path to the model archive,
@@ -136,6 +152,7 @@ async def train_core_async(
136
152
stories_file = stories ,
137
153
output_path = os .path .join (_train_path , "core" ),
138
154
policy_config = config ,
155
+ kwargs = kwargs ,
139
156
)
140
157
141
158
if not train_path :
0 commit comments