diff --git a/oslo_config/generator.py b/oslo_config/generator.py index d3b47cf7..164bb29f 100644 --- a/oslo_config/generator.py +++ b/oslo_config/generator.py @@ -724,9 +724,14 @@ def generate(conf, output_file=None): """ conf.register_opts(_generator_opts) + own_file = False + if output_file is None: - output_file = (open(conf.output_file, 'w') - if conf.output_file else sys.stdout) + if conf.output_file: + output_file = open(conf.output_file, 'w') + own_file = True + else: + output_file = sys.stdout groups = _get_groups(_list_opts(conf.namespace)) @@ -750,6 +755,9 @@ def generate(conf, output_file=None): output_file=output_file, conf=conf) + if own_file: + output_file.close() + def main(args=None): """The main function of oslo-config-generator.""" diff --git a/oslo_config/tests/test_generator.py b/oslo_config/tests/test_generator.py index a9bc5593..2daa61c8 100644 --- a/oslo_config/tests/test_generator.py +++ b/oslo_config/tests/test_generator.py @@ -1014,6 +1014,41 @@ class GeneratorTestCase(base.BaseTestCase): self.assertFalse(mock_log.warning.called) +class GeneratorFileHandlingTestCase(base.BaseTestCase): + + def setUp(self): + super(GeneratorFileHandlingTestCase, self).setUp() + + self.conf = cfg.ConfigOpts() + self.config_fixture = config_fixture.Config(self.conf) + self.config = self.config_fixture.config + + @mock.patch.object(generator, '_get_groups') + @mock.patch.object(generator, '_list_opts') + def test_close_generated_file(self, a, b): + generator.register_cli_opts(self.conf) + self.config(output_file='somefile') + + m = mock.mock_open() + m.close = mock.Mock() + + with mock.patch.object(generator, 'open', m, create=True): + generator.generate(self.conf, output_file=None) + + m().close.assert_called_once() + + @mock.patch.object(generator, '_get_groups') + @mock.patch.object(generator, '_list_opts') + def test_not_close_external_file(self, a, b): + generator.register_cli_opts(self.conf) + self.config() + + m = mock.Mock() + generator.generate(self.conf, output_file=m) + + m().close.assert_not_called() + + class DriverOptionTestCase(base.BaseTestCase): def setUp(self):