|
|
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- from __future__ import unicode_literals
-
- import argparse
- import warnings
-
- from rasa_nlu.training_data import load_data
- from rasa_nlu import config
- from rasa_nlu.model import Trainer
-
- import rasa_core
- from rasa_core.agent import Agent
- from rasa_core.policies.keras_policy import KerasPolicy
- from rasa_core.policies.memoization import MemoizationPolicy
-
-
- def train_nlu():
- training_data = load_data('data/nlu-data.md')
- trainer = Trainer(config.load("nlu-config.yml"))
- trainer.train(training_data)
- model_directory = trainer.persist('models/nlu/', fixed_model_name="current")
- return model_directory
-
-
- def train_dialogue(
- domain_file="domain.yml",
- model_path="models/dialogue",
- training_data_file="data/stories.md"
- ):
- agent = Agent(
- domain_file,
- policies=[MemoizationPolicy(max_history=3), KerasPolicy()]
- )
- training_data = agent.load_data(training_data_file)
- agent.train(
- training_data,
- epochs=400,
- batch_size=100,
- validation_split=0.2
- )
- agent.persist(model_path)
- return agent
-
-
- def train_all():
- model_directory = train_nlu()
- agent = train_dialogue()
- return [model_directory, agent]
-
-
- if __name__ == '__main__':
- warnings.filterwarnings(action='ignore', category=DeprecationWarning)
- rasa_core.utils.configure_colored_logging(loglevel="INFO")
-
- parser = argparse.ArgumentParser(
- description='starts the bot training')
-
- parser.add_argument(
- 'task',
- choices=["train-nlu", "train-dialogue", "train-all"],
- help="what the bot should do?")
- task = parser.parse_args().task
-
- # decide what to do based on first parameter of the script
- if task == "train-nlu":
- train_nlu()
- elif task == "train-dialogue":
- train_dialogue()
- elif task == "train-all":
- train_all()
|