diff --git a/dragonflow/cli/df_model.py b/dragonflow/cli/df_model.py index 9d7b65c47..5724d3f15 100644 --- a/dragonflow/cli/df_model.py +++ b/dragonflow/cli/df_model.py @@ -31,6 +31,9 @@ class ModelsPrinter(object): def __init__(self, fh): self._output = fh + def _print(self, *args, **kwargs): + print(*args, file=self._output, **kwargs) + def output_start(self): """ Called once on the beginning of the processing. @@ -125,16 +128,16 @@ class PlaintextPrinter(ModelsPrinter): super(PlaintextPrinter, self).__init__(fh) def model_start(self, model_name): - print('-------------', file=self._output) - print('{}'.format(model_name), file=self._output) - print('-------------', file=self._output) + self._print('-------------') + self._print('{}'.format(model_name)) + self._print('-------------') def model_end(self, model_name): - print('', file=self._output) + self._print('') def fields_start(self): - print('Fields', file=self._output) - print('------', file=self._output) + self._print('Fields') + self._print('------') def handle_field(self, field_name, field_type, is_required, is_embedded, is_single, restrictions): @@ -143,26 +146,26 @@ class PlaintextPrinter(ModelsPrinter): name=field_name, type=field_type, restriction=restriction_str, required=', Required' if is_required else '', - to_many=', One' if is_single else ', Many', - embedded=', Embedded' if is_embedded else ''), - file=self._output) + to_many=', Multi' if not is_single else '', + embedded=', Embedded' if is_embedded else '')) def indexes_start(self): - print('Indexes', file=self._output) - print('-------', file=self._output) + self._print('Indexes') + self._print('-------') def handle_index(self, index_name): - print('{}'.format(index_name), file=self._output) + self._print('{}'.format(index_name)) def events_start(self): - print('Events', file=self._output) - print('------', file=self._output) + self._print('Events') + self._print('------') def handle_event(self, event_name): - print('{}'.format(event_name), file=self._output) + self._print('{}'.format(event_name)) class UMLPrinter(ModelsPrinter): + """PlantUML format printer""" def __init__(self, fh): super(UMLPrinter, self).__init__(fh) self._model = '' @@ -170,8 +173,8 @@ class UMLPrinter(ModelsPrinter): self._dependencies = set() def output_start(self): - print('@startuml', file=self._output) - print('hide circle', file=self._output) + self._print('@startuml') + self._print('hide circle') def _output_relations(self): for (dst, src, name, is_single, is_embedded) in self._dependencies: @@ -180,21 +183,20 @@ class UMLPrinter(ModelsPrinter): connector_str = ' *-- ' if is_single else '"1" *-- "*"' else: connector_str = ' o-- ' if is_single else ' o-- "*"' - print('{dest} {connector} {src} : {field_name}'.format( + self._print('{dest} {connector} {src} : {field_name}'.format( dest=dst, connector=connector_str, src=src, - field_name=name), - file=self._output) + field_name=name)) def output_end(self): self._output_relations() - print('@enduml', file=self._output) + self._print('@enduml') def model_start(self, model_name): self._model = model_name - print('class {} {{'.format(model_name), file=self._output) + self._print('class {} {{'.format(model_name)) def model_end(self, model_name): - print('}', file=self._output) + self._print('}') self._processed.add(model_name) self._model = '' @@ -202,23 +204,22 @@ class UMLPrinter(ModelsPrinter): is_single, restrictions): restriction_str = ' {}'.format(restrictions) if restrictions else '' name = '{}'.format(field_name) if is_required else field_name - print(' +{name} : {type} {restriction}'.format( - name=name, type=field_type, restriction=restriction_str), - file=self._output) + self._print(' +{name} : {type} {restriction}'.format( + name=name, type=field_type, restriction=restriction_str)) self._dependencies.add((self._model, field_type, field_name, is_single, is_embedded)) def indexes_start(self): - print(' .. Indexes ..', file=self._output) + self._print(' .. Indexes ..') def handle_index(self, index_name): - print(' {}'.format(index_name), file=self._output) + self._print(' {}'.format(index_name)) def events_start(self): - print(' == Events ==', file=self._output) + self._print(' == Events ==') def handle_event(self, event_name): - print(' {}'.format(event_name), file=self._output) + self._print(' {}'.format(event_name)) class DfModelParser(object):