diff --git a/ara/api/tests/tests_stats.py b/ara/api/tests/tests_stats.py index 795e78d..f51d03b 100644 --- a/ara/api/tests/tests_stats.py +++ b/ara/api/tests/tests_stats.py @@ -59,6 +59,18 @@ class StatsTestCase(APITestCase): self.assertEqual(1, len(request.data["results"])) self.assertEqual(stats.ok, request.data["results"][0]["ok"]) + def test_get_stats_by_playbook(self): + playbook = factories.PlaybookFactory() + host_one = factories.HostFactory(name="one") + host_two = factories.HostFactory(name="two") + stats = factories.StatsFactory(host=host_one, playbook=playbook, ok=9001) + factories.StatsFactory(host=host_two, playbook=playbook) + request = self.client.get("/api/v1/stats?playbook=%s" % playbook.id) + self.assertEqual(2, len(request.data["results"])) + self.assertEqual(host_one.id, request.data["results"][0]["id"]) + self.assertEqual(stats.ok, request.data["results"][0]["ok"]) + self.assertEqual(host_two.id, request.data["results"][1]["id"]) + def test_get_stats_id(self): stats = factories.StatsFactory() request = self.client.get("/api/v1/stats/%s" % stats.id) diff --git a/ara/api/views.py b/ara/api/views.py index e2e6a8f..3a9fa92 100644 --- a/ara/api/views.py +++ b/ara/api/views.py @@ -84,3 +84,5 @@ class RecordViewSet(viewsets.ModelViewSet): class StatsViewSet(viewsets.ModelViewSet): queryset = models.Stats.objects.all() serializer_class = serializers.StatsSerializer + filter_backends = (DjangoFilterBackend,) + filter_fields = ("playbook",)