You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

104 lines
4.1 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """A convenience wrapper around tf.test.TestCase to enable TPU tests."""
  16. import os
  17. import tensorflow as tf
  18. from tensorflow.contrib import tpu
  19. flags = tf.app.flags
  20. flags.DEFINE_bool('tpu_test', False, 'Whether to configure test for TPU.')
  21. FLAGS = flags.FLAGS
  22. class TestCase(tf.test.TestCase):
  23. """Extends tf.test.TestCase to optionally allow running tests on TPU."""
  24. def execute_tpu(self, graph_fn, inputs):
  25. """Constructs the graph, executes it on TPU and returns the result.
  26. Args:
  27. graph_fn: a callable that constructs the tensorflow graph to test. The
  28. arguments of this function should correspond to `inputs`.
  29. inputs: a list of numpy arrays to feed input to the computation graph.
  30. Returns:
  31. A list of numpy arrays or a scalar returned from executing the tensorflow
  32. graph.
  33. """
  34. with self.test_session(graph=tf.Graph()) as sess:
  35. placeholders = [tf.placeholder_with_default(v, v.shape) for v in inputs]
  36. tpu_computation = tpu.rewrite(graph_fn, placeholders)
  37. sess.run(tpu.initialize_system())
  38. sess.run([tf.global_variables_initializer(), tf.tables_initializer(),
  39. tf.local_variables_initializer()])
  40. materialized_results = sess.run(tpu_computation,
  41. feed_dict=dict(zip(placeholders, inputs)))
  42. sess.run(tpu.shutdown_system())
  43. if (hasattr(materialized_results, '__len__') and
  44. len(materialized_results) == 1 and
  45. (isinstance(materialized_results, list) or
  46. isinstance(materialized_results, tuple))):
  47. materialized_results = materialized_results[0]
  48. return materialized_results
  49. def execute_cpu(self, graph_fn, inputs):
  50. """Constructs the graph, executes it on CPU and returns the result.
  51. Args:
  52. graph_fn: a callable that constructs the tensorflow graph to test. The
  53. arguments of this function should correspond to `inputs`.
  54. inputs: a list of numpy arrays to feed input to the computation graph.
  55. Returns:
  56. A list of numpy arrays or a scalar returned from executing the tensorflow
  57. graph.
  58. """
  59. with self.test_session(graph=tf.Graph()) as sess:
  60. placeholders = [tf.placeholder_with_default(v, v.shape) for v in inputs]
  61. results = graph_fn(*placeholders)
  62. sess.run([tf.global_variables_initializer(), tf.tables_initializer(),
  63. tf.local_variables_initializer()])
  64. materialized_results = sess.run(results, feed_dict=dict(zip(placeholders,
  65. inputs)))
  66. if (hasattr(materialized_results, '__len__') and
  67. len(materialized_results) == 1 and
  68. (isinstance(materialized_results, list) or
  69. isinstance(materialized_results, tuple))):
  70. materialized_results = materialized_results[0]
  71. return materialized_results
  72. def execute(self, graph_fn, inputs):
  73. """Constructs the graph, creates a test session and returns the results.
  74. The graph is executed either on TPU or CPU based on the `tpu_test` flag.
  75. Args:
  76. graph_fn: a callable that constructs the tensorflow graph to test. The
  77. arguments of this function should correspond to `inputs`.
  78. inputs: a list of numpy arrays to feed input to the computation graph.
  79. Returns:
  80. A list of numpy arrays or a scalar returned from executing the tensorflow
  81. graph.
  82. """
  83. if FLAGS.tpu_test:
  84. return self.execute_tpu(graph_fn, inputs)
  85. else:
  86. return self.execute_cpu(graph_fn, inputs)