from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import argparse import warnings from rasa_nlu.training_data import load_data from rasa_nlu import config from rasa_nlu.model import Trainer import rasa_core from rasa_core.agent import Agent from rasa_core.policies.keras_policy import KerasPolicy from rasa_core.policies.memoization import MemoizationPolicy def train_nlu(): training_data = load_data('data/nlu-data.md') trainer = Trainer(config.load("nlu-config.yml")) trainer.train(training_data) model_directory = trainer.persist('models/nlu/', fixed_model_name="current") return model_directory def train_dialogue( domain_file="domain.yml", model_path="models/dialogue", training_data_file="data/stories.md" ): agent = Agent( domain_file, policies=[MemoizationPolicy(max_history=3), KerasPolicy()] ) training_data = agent.load_data(training_data_file) agent.train( training_data, epochs=400, batch_size=100, validation_split=0.2 ) agent.persist(model_path) return agent def train_all(): model_directory = train_nlu() agent = train_dialogue() return [model_directory, agent] if __name__ == '__main__': warnings.filterwarnings(action='ignore', category=DeprecationWarning) rasa_core.utils.configure_colored_logging(loglevel="INFO") parser = argparse.ArgumentParser( description='starts the bot training') parser.add_argument( 'task', choices=["train-nlu", "train-dialogue", "train-all"], help="what the bot should do?") task = parser.parse_args().task # decide what to do based on first parameter of the script if task == "train-nlu": train_nlu() elif task == "train-dialogue": train_dialogue() elif task == "train-all": train_all()