diff --git a/dragonflow/db/model_framework.py b/dragonflow/db/model_framework.py index 6d9a6bb47..061f22bb6 100644 --- a/dragonflow/db/model_framework.py +++ b/dragonflow/db/model_framework.py @@ -194,7 +194,7 @@ class _CommonBase(models.Base): @classmethod def dependencies(cls): - deps = [] + deps = set() for key, field in cls.iterate_over_fields(): if isinstance(field, fields.ListField): types = field.items_types @@ -203,11 +203,16 @@ class _CommonBase(models.Base): for field_type in types: try: - deps.append(field_type.get_proxied_model()) + deps.add(field_type.get_proxied_model()) except AttributeError: - pass + if issubclass(field_type, ModelBase): + # If the field is not a reference, and it is a df + # model(derived from ModelBase), it is considered as + # non-first class model. And its dependency + # will be treated as current model's dependency. + deps |= field_type.dependencies() - return set(deps) + return deps @classmethod def is_first_class(cls): diff --git a/dragonflow/tests/unit/test_model_framework.py b/dragonflow/tests/unit/test_model_framework.py index d5dcf374f..34b6743dc 100644 --- a/dragonflow/tests/unit/test_model_framework.py +++ b/dragonflow/tests/unit/test_model_framework.py @@ -135,6 +135,18 @@ class EmbeddingModel2(mf.ModelBase): emb_required = fields.EmbeddedField(EmbeddedModel, required=True) +@mf.construct_nb_db_model +class ReffingNonFirstClassModel(mf.ModelBase): + ref1 = df_fields.ReferenceField(ReffedModel) + + +@mf.register_model +@mf.construct_nb_db_model +class ReffingModel3(mf.ModelBase): + table_name = 'ReffingModel3' + ref = fields.ListField(ReffingNonFirstClassModel) + + class TestModelFramework(tests_base.BaseTestCase): def test_lookup(self): self.assertEqual(ModelTest, mf.get_model('ModelTest')) @@ -418,3 +430,11 @@ class TestModelFramework(tests_base.BaseTestCase): (emb1, emb2, emb3), embedding1.iterate_embedded_model_instances(), ) + + def test_hierarchical_dependency(self): + sorted_models = mf.iter_models_by_dependency_order() + self.assertLess( + sorted_models.index(ReffedModel), + sorted_models.index(ReffingModel3) + ) + self.assertIn(ReffedModel, ReffingModel3.dependencies())