You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

71 lines
1.8 KiB

6 years ago
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. from __future__ import unicode_literals
  5. import argparse
  6. import warnings
  7. from rasa_nlu.training_data import load_data
  8. from rasa_nlu import config
  9. from rasa_nlu.model import Trainer
  10. import rasa_core
  11. from rasa_core.agent import Agent
  12. from rasa_core.policies.keras_policy import KerasPolicy
  13. from rasa_core.policies.memoization import MemoizationPolicy
  14. def train_nlu():
  15. training_data = load_data('data/nlu-data.md')
  16. trainer = Trainer(config.load("nlu-config.yml"))
  17. trainer.train(training_data)
  18. model_directory = trainer.persist('models/nlu/', fixed_model_name="current")
  19. return model_directory
  20. def train_dialogue(
  21. domain_file="domain.yml",
  22. model_path="models/dialogue",
  23. training_data_file="data/stories.md"
  24. ):
  25. agent = Agent(
  26. domain_file,
  27. policies=[MemoizationPolicy(max_history=3), KerasPolicy()]
  28. )
  29. training_data = agent.load_data(training_data_file)
  30. agent.train(
  31. training_data,
  32. epochs=400,
  33. batch_size=100,
  34. validation_split=0.2
  35. )
  36. agent.persist(model_path)
  37. return agent
  38. def train_all():
  39. model_directory = train_nlu()
  40. agent = train_dialogue()
  41. return [model_directory, agent]
  42. if __name__ == '__main__':
  43. warnings.filterwarnings(action='ignore', category=DeprecationWarning)
  44. rasa_core.utils.configure_colored_logging(loglevel="INFO")
  45. parser = argparse.ArgumentParser(
  46. description='starts the bot training')
  47. parser.add_argument(
  48. 'task',
  49. choices=["train-nlu", "train-dialogue", "train-all"],
  50. help="what the bot should do?")
  51. task = parser.parse_args().task
  52. # decide what to do based on first parameter of the script
  53. if task == "train-nlu":
  54. train_nlu()
  55. elif task == "train-dialogue":
  56. train_dialogue()
  57. elif task == "train-all":
  58. train_all()