Merge "Run black & isort over the codebase."

This commit is contained in:
Zuul 2018-10-02 14:31:59 +00:00 committed by Gerrit Code Review
commit 37446042af
21 changed files with 346 additions and 455 deletions

View File

@ -4,9 +4,9 @@ from django.contrib.auth.models import Group
class RecordAdmin(admin.ModelAdmin):
list_display = ('id', 'key', 'value', 'type')
search_fields = ('key', 'value', 'type')
ordering = ('key',)
list_display = ("id", "key", "value", "type")
search_fields = ("key", "value", "type")
ordering = ("key",)
admin.site.register(models.Record, RecordAdmin)

View File

@ -2,4 +2,4 @@ from django.apps import AppConfig
class ApiConfig(AppConfig):
name = 'ara.api'
name = "ara.api"

View File

@ -52,13 +52,13 @@ class FileContent(Base):
"""
class Meta:
db_table = 'file_contents'
db_table = "file_contents"
sha1 = models.CharField(max_length=40, unique=True)
contents = models.BinaryField(max_length=(2 ** 32) - 1)
def __str__(self):
return '<FileContent %s:%s>' % (self.id, self.sha1)
return "<FileContent %s:%s>" % (self.id, self.sha1)
class File(Base):
@ -68,13 +68,13 @@ class File(Base):
"""
class Meta:
db_table = 'files'
db_table = "files"
path = models.CharField(max_length=255)
content = models.ForeignKey(FileContent, on_delete=models.CASCADE, related_name='files')
content = models.ForeignKey(FileContent, on_delete=models.CASCADE, related_name="files")
def __str__(self):
return '<File %s:%s>' % (self.id, self.path)
return "<File %s:%s>" % (self.id, self.path)
class Label(Base):
@ -90,13 +90,13 @@ class Label(Base):
"""
class Meta:
db_table = 'labels'
db_table = "labels"
name = models.CharField(max_length=255)
description = models.BinaryField(max_length=(2 ** 32) - 1)
def __str__(self):
return '<Label %s: %s>' % (self.id, self.name)
return "<Label %s: %s>" % (self.id, self.name)
class Playbook(Duration):
@ -107,18 +107,18 @@ class Playbook(Duration):
"""
class Meta:
db_table = 'playbooks'
db_table = "playbooks"
name = models.CharField(max_length=255, null=True)
ansible_version = models.CharField(max_length=255)
completed = models.BooleanField(default=False)
parameters = models.BinaryField(max_length=(2 ** 32) - 1)
file = models.ForeignKey(File, on_delete=models.CASCADE, related_name='playbooks')
file = models.ForeignKey(File, on_delete=models.CASCADE, related_name="playbooks")
files = models.ManyToManyField(File)
labels = models.ManyToManyField(Label)
def __str__(self):
return '<Playbook %s>' % self.id
return "<Playbook %s>" % self.id
class Record(Base):
@ -128,16 +128,16 @@ class Record(Base):
"""
class Meta:
db_table = 'records'
unique_together = ('key', 'playbook',)
db_table = "records"
unique_together = ("key", "playbook")
key = models.CharField(max_length=255)
value = models.TextField(null=True, blank=True)
type = models.CharField(max_length=255)
playbook = models.ForeignKey(Playbook, on_delete=models.CASCADE, related_name='records')
playbook = models.ForeignKey(Playbook, on_delete=models.CASCADE, related_name="records")
def __str__(self):
return '<Record %s:%s>' % (self.id, self.key)
return "<Record %s:%s>" % (self.id, self.key)
class Play(Duration):
@ -147,21 +147,21 @@ class Play(Duration):
"""
class Meta:
db_table = 'plays'
db_table = "plays"
name = models.CharField(max_length=255, blank=True, null=True)
completed = models.BooleanField(default=False)
playbook = models.ForeignKey(Playbook, on_delete=models.CASCADE, related_name='plays')
playbook = models.ForeignKey(Playbook, on_delete=models.CASCADE, related_name="plays")
def __str__(self):
return '<Play %s:%s>' % (self.id, self.name)
return "<Play %s:%s>" % (self.id, self.name)
class Task(Duration):
"""Data about Ansible tasks."""
class Meta:
db_table = 'tasks'
db_table = "tasks"
name = models.TextField(blank=True, null=True)
action = models.TextField()
@ -170,11 +170,11 @@ class Task(Duration):
handler = models.BooleanField()
completed = models.BooleanField(default=False)
play = models.ForeignKey(Play, on_delete=models.CASCADE, related_name='tasks')
file = models.ForeignKey(File, on_delete=models.CASCADE, related_name='tasks')
play = models.ForeignKey(Play, on_delete=models.CASCADE, related_name="tasks")
file = models.ForeignKey(File, on_delete=models.CASCADE, related_name="tasks")
def __str__(self):
return '<Task %s:%s>' % (self.name, self.id)
return "<Task %s:%s>" % (self.name, self.id)
class Host(Base):
@ -183,8 +183,8 @@ class Host(Base):
"""
class Meta:
db_table = 'hosts'
unique_together = ('name', 'playbook',)
db_table = "hosts"
unique_together = ("name", "playbook")
name = models.CharField(max_length=255)
facts = models.BinaryField(max_length=(2 ** 32) - 1)
@ -197,10 +197,10 @@ class Host(Base):
# The logic for supplying aliases does not live here, it's provided by the
# clients and consumers.
alias = models.CharField(max_length=255, null=True)
playbook = models.ForeignKey(Playbook, on_delete=models.CASCADE, related_name='hosts')
playbook = models.ForeignKey(Playbook, on_delete=models.CASCADE, related_name="hosts")
def __str__(self):
return '<Host %s:%s>' % (self.id, self.name)
return "<Host %s:%s>" % (self.id, self.name)
class Stats(Base):
@ -209,11 +209,11 @@ class Stats(Base):
"""
class Meta:
db_table = 'stats'
unique_together = ('host', 'playbook',)
db_table = "stats"
unique_together = ("host", "playbook")
playbook = models.ForeignKey(Playbook, on_delete=models.CASCADE, related_name='stats')
host = models.ForeignKey(Host, on_delete=models.CASCADE, related_name='stats')
playbook = models.ForeignKey(Playbook, on_delete=models.CASCADE, related_name="stats")
host = models.ForeignKey(Host, on_delete=models.CASCADE, related_name="stats")
changed = models.IntegerField(default=0)
failed = models.IntegerField(default=0)
ok = models.IntegerField(default=0)
@ -222,10 +222,8 @@ class Stats(Base):
def __str__(self):
# Verbose because it's otherwise kind of useless
return '<Stats for {host} ({id}) in playbook {playbook}>'.format(
host=self.host.name,
id=self.host.id,
playbook=self.playbook.id
return "<Stats for {host} ({id}) in playbook {playbook}>".format(
host=self.host.name, id=self.host.id, playbook=self.playbook.id
)
@ -236,33 +234,33 @@ class Result(Duration):
"""
class Meta:
db_table = 'results'
db_table = "results"
# Ansible statuses
OK = 'ok'
FAILED = 'failed'
SKIPPED = 'skipped'
UNREACHABLE = 'unreachable'
OK = "ok"
FAILED = "failed"
SKIPPED = "skipped"
UNREACHABLE = "unreachable"
# ARA specific statuses (derived or assumed)
CHANGED = 'changed'
IGNORED = 'ignored'
UNKNOWN = 'unknown'
CHANGED = "changed"
IGNORED = "ignored"
UNKNOWN = "unknown"
STATUS = (
(OK, 'ok'),
(FAILED, 'failed'),
(SKIPPED, 'skipped'),
(UNREACHABLE, 'unreachable'),
(CHANGED, 'changed'),
(IGNORED, 'ignored'),
(UNKNOWN, 'unknown')
(OK, "ok"),
(FAILED, "failed"),
(SKIPPED, "skipped"),
(UNREACHABLE, "unreachable"),
(CHANGED, "changed"),
(IGNORED, "ignored"),
(UNKNOWN, "unknown"),
)
status = models.CharField(max_length=25, choices=STATUS, default=UNKNOWN)
# todo use a single Content table
content = models.BinaryField(max_length=(2 ** 32) - 1)
host = models.ForeignKey(Host, on_delete=models.CASCADE, related_name='results')
task = models.ForeignKey(Task, on_delete=models.CASCADE, related_name='results')
host = models.ForeignKey(Host, on_delete=models.CASCADE, related_name="results")
task = models.ForeignKey(Task, on_delete=models.CASCADE, related_name="results")
def __str__(self):
return '<Result %s, %s>' % (self.id, self.status)
return "<Result %s, %s>" % (self.id, self.status)

View File

@ -25,7 +25,7 @@ from rest_framework import serializers
DATE_FORMAT = "(iso-8601: 2016-05-06T17:20:25.749489-04:00)"
DURATION_FORMAT = "([DD] [HH:[MM:]]ss[.uuuuuu])"
logger = logging.getLogger('ara.api.serializers')
logger = logging.getLogger("ara.api.serializers")
class CompressedTextField(serializers.CharField):
@ -35,10 +35,10 @@ class CompressedTextField(serializers.CharField):
"""
def to_representation(self, obj):
return zlib.decompress(obj).decode('utf8')
return zlib.decompress(obj).decode("utf8")
def to_internal_value(self, data):
return zlib.compress(data.encode('utf8'))
return zlib.compress(data.encode("utf8"))
class CompressedObjectField(serializers.JSONField):
@ -49,10 +49,10 @@ class CompressedObjectField(serializers.JSONField):
"""
def to_representation(self, obj):
return json.loads(zlib.decompress(obj).decode('utf8'))
return json.loads(zlib.decompress(obj).decode("utf8"))
def to_internal_value(self, data):
return zlib.compress(json.dumps(data).encode('utf8'))
return zlib.compress(json.dumps(data).encode("utf8"))
class DurationSerializer(serializers.ModelSerializer):
@ -75,7 +75,7 @@ class DurationSerializer(serializers.ModelSerializer):
class FileContentSerializer(serializers.ModelSerializer):
class Meta:
model = models.FileContent
fields = '__all__'
fields = "__all__"
class FileContentField(serializers.CharField):
@ -85,22 +85,21 @@ class FileContentField(serializers.CharField):
"""
def to_representation(self, obj):
return zlib.decompress(obj.contents).decode('utf8')
return zlib.decompress(obj.contents).decode("utf8")
def to_internal_value(self, data):
contents = zlib.compress(data.encode('utf8'))
contents = zlib.compress(data.encode("utf8"))
sha1 = hashlib.sha1(contents).hexdigest()
content_file, created = models.FileContent.objects.get_or_create(sha1=sha1, defaults={
'sha1': sha1,
'contents': contents
})
content_file, created = models.FileContent.objects.get_or_create(
sha1=sha1, defaults={"sha1": sha1, "contents": contents}
)
return content_file
class FileSerializer(serializers.ModelSerializer):
class Meta:
model = models.File
fields = '__all__'
fields = "__all__"
content = FileContentField()
@ -108,24 +107,22 @@ class FileSerializer(serializers.ModelSerializer):
class HostSerializer(serializers.ModelSerializer):
class Meta:
model = models.Host
fields = '__all__'
fields = "__all__"
facts = CompressedObjectField(default=zlib.compress(json.dumps({}).encode('utf8')))
facts = CompressedObjectField(default=zlib.compress(json.dumps({}).encode("utf8")))
def get_unique_together_validators(self):
'''
"""
Hosts have a "unique together" constraint for host.name and play.id.
We want to have a "get_or_create" facility and in order to do that, we
must manage the validation during the creation, not before.
Overriding this method effectively disables this validator.
'''
"""
return []
def create(self, validated_data):
host, created = models.Host.objects.get_or_create(
name=validated_data['name'],
playbook=validated_data['playbook'],
defaults=validated_data
name=validated_data["name"], playbook=validated_data["playbook"], defaults=validated_data
)
return host
@ -133,28 +130,27 @@ class HostSerializer(serializers.ModelSerializer):
class ResultSerializer(serializers.ModelSerializer):
class Meta:
model = models.Result
fields = '__all__'
fields = "__all__"
content = CompressedObjectField(default=zlib.compress(json.dumps({}).encode('utf8')))
content = CompressedObjectField(default=zlib.compress(json.dumps({}).encode("utf8")))
class LabelSerializer(serializers.ModelSerializer):
class Meta:
model = models.Label
fields = '__all__'
fields = "__all__"
description = CompressedTextField(
default=zlib.compress(json.dumps("").encode('utf8')),
help_text='A textual description of the label'
default=zlib.compress(json.dumps("").encode("utf8")), help_text="A textual description of the label"
)
class PlaybookSerializer(DurationSerializer):
class Meta:
model = models.Playbook
fields = '__all__'
fields = "__all__"
parameters = CompressedObjectField(default=zlib.compress(json.dumps({}).encode('utf8')))
parameters = CompressedObjectField(default=zlib.compress(json.dumps({}).encode("utf8")))
file = FileSerializer()
files = FileSerializer(many=True, default=[])
hosts = HostSerializer(many=True, default=[])
@ -162,13 +158,13 @@ class PlaybookSerializer(DurationSerializer):
def create(self, validated_data):
# Create the file for the playbook
file_dict = validated_data.pop('file')
validated_data['file'] = models.File.objects.create(**file_dict)
file_dict = validated_data.pop("file")
validated_data["file"] = models.File.objects.create(**file_dict)
# Create the playbook without the file and label references for now
files = validated_data.pop('files')
hosts = validated_data.pop('hosts')
labels = validated_data.pop('labels')
files = validated_data.pop("files")
hosts = validated_data.pop("hosts")
labels = validated_data.pop("labels")
playbook = models.Playbook.objects.create(**validated_data)
# Add the files, hosts and the labels in
@ -185,7 +181,7 @@ class PlaybookSerializer(DurationSerializer):
class PlaySerializer(DurationSerializer):
class Meta:
model = models.Play
fields = '__all__'
fields = "__all__"
hosts = HostSerializer(read_only=True, many=True)
results = ResultSerializer(read_only=True, many=True)
@ -194,15 +190,14 @@ class PlaySerializer(DurationSerializer):
class TaskSerializer(DurationSerializer):
class Meta:
model = models.Task
fields = '__all__'
fields = "__all__"
tags = CompressedObjectField(
default=zlib.compress(json.dumps([]).encode('utf8')),
help_text='A JSON list containing Ansible tags'
default=zlib.compress(json.dumps([]).encode("utf8")), help_text="A JSON list containing Ansible tags"
)
class StatsSerializer(serializers.ModelSerializer):
class Meta:
model = models.Stats
fields = '__all__'
fields = "__all__"

View File

@ -22,28 +22,18 @@ from ara.api.tests import utils
# constants for things like compressed byte strings or objects
FILE_CONTENTS = '---\n# Example file'
HOST_FACTS = {
'ansible_fqdn': 'hostname',
'ansible_distribution': 'CentOS'
}
PLAYBOOK_PARAMETERS = {
'ansible_version': '2.5.5',
'inventory': '/etc/ansible/hosts'
}
RESULT_CONTENTS = {
'results': [{
'msg': 'something happened'
}]
}
LABEL_DESCRIPTION = 'label description'
TASK_TAGS = ['always', 'never']
FILE_CONTENTS = "---\n# Example file"
HOST_FACTS = {"ansible_fqdn": "hostname", "ansible_distribution": "CentOS"}
PLAYBOOK_PARAMETERS = {"ansible_version": "2.5.5", "inventory": "/etc/ansible/hosts"}
RESULT_CONTENTS = {"results": [{"msg": "something happened"}]}
LABEL_DESCRIPTION = "label description"
TASK_TAGS = ["always", "never"]
class FileContentFactory(factory.DjangoModelFactory):
class Meta:
model = models.FileContent
django_get_or_create = ('sha1',)
django_get_or_create = ("sha1",)
sha1 = utils.sha1(FILE_CONTENTS)
contents = utils.compressed_str(FILE_CONTENTS)
@ -53,7 +43,7 @@ class FileFactory(factory.DjangoModelFactory):
class Meta:
model = models.File
path = '/path/playbook.yml'
path = "/path/playbook.yml"
content = factory.SubFactory(FileContentFactory)
@ -61,7 +51,7 @@ class LabelFactory(factory.DjangoModelFactory):
class Meta:
model = models.Label
name = 'test label'
name = "test label"
description = utils.compressed_str(LABEL_DESCRIPTION)
@ -69,7 +59,7 @@ class PlaybookFactory(factory.DjangoModelFactory):
class Meta:
model = models.Playbook
ansible_version = '2.4.0'
ansible_version = "2.4.0"
completed = True
parameters = utils.compressed_obj(PLAYBOOK_PARAMETERS)
file = factory.SubFactory(FileFactory)
@ -79,7 +69,7 @@ class PlayFactory(factory.DjangoModelFactory):
class Meta:
model = models.Play
name = 'test play'
name = "test play"
completed = True
playbook = factory.SubFactory(PlaybookFactory)
@ -88,9 +78,9 @@ class TaskFactory(factory.DjangoModelFactory):
class Meta:
model = models.Task
name = 'test task'
name = "test task"
completed = True
action = 'setup'
action = "setup"
lineno = 2
handler = False
tags = utils.compressed_obj(TASK_TAGS)
@ -103,7 +93,7 @@ class HostFactory(factory.DjangoModelFactory):
model = models.Host
facts = utils.compressed_obj(HOST_FACTS)
name = 'hostname'
name = "hostname"
alias = "9f5d3ba7-e43d-4f3b-ab17-f90c39e43d07"
playbook = factory.SubFactory(PlaybookFactory)
@ -113,7 +103,7 @@ class ResultFactory(factory.DjangoModelFactory):
model = models.Result
content = utils.compressed_obj(RESULT_CONTENTS)
status = 'ok'
status = "ok"
host = factory.SubFactory(HostFactory)
task = factory.SubFactory(TaskFactory)

View File

@ -25,33 +25,28 @@ from ara.api.tests import utils
class FileTestCase(APITestCase):
def test_file_factory(self):
file_content = factories.FileContentFactory()
file = factories.FileFactory(path='/path/playbook.yml', content=file_content)
self.assertEqual(file.path, '/path/playbook.yml')
file = factories.FileFactory(path="/path/playbook.yml", content=file_content)
self.assertEqual(file.path, "/path/playbook.yml")
self.assertEqual(file.content.sha1, file_content.sha1)
def test_file_serializer(self):
serializer = serializers.FileSerializer(data={
'path': '/path/playbook.yml',
'content': factories.FILE_CONTENTS
})
serializer = serializers.FileSerializer(data={"path": "/path/playbook.yml", "content": factories.FILE_CONTENTS})
serializer.is_valid()
file = serializer.save()
file.refresh_from_db()
self.assertEqual(file.content.sha1, utils.sha1(factories.FILE_CONTENTS))
def test_create_file_with_same_content_create_only_one_file_content(self):
serializer = serializers.FileSerializer(data={
'path': '/path/1/playbook.yml',
'content': factories.FILE_CONTENTS
})
serializer = serializers.FileSerializer(
data={"path": "/path/1/playbook.yml", "content": factories.FILE_CONTENTS}
)
serializer.is_valid()
file_content = serializer.save()
file_content.refresh_from_db()
serializer2 = serializers.FileSerializer(data={
'path': '/path/2/playbook.yml',
'content': factories.FILE_CONTENTS
})
serializer2 = serializers.FileSerializer(
data={"path": "/path/2/playbook.yml", "content": factories.FILE_CONTENTS}
)
serializer2.is_valid()
file_content = serializer2.save()
file_content.refresh_from_db()
@ -61,52 +56,46 @@ class FileTestCase(APITestCase):
def test_create_file(self):
self.assertEqual(0, models.File.objects.count())
request = self.client.post('/api/v1/files', {
'path': '/path/playbook.yml',
'content': factories.FILE_CONTENTS
})
request = self.client.post("/api/v1/files", {"path": "/path/playbook.yml", "content": factories.FILE_CONTENTS})
self.assertEqual(201, request.status_code)
self.assertEqual(1, models.File.objects.count())
def test_get_no_files(self):
request = self.client.get('/api/v1/files')
self.assertEqual(0, len(request.data['results']))
request = self.client.get("/api/v1/files")
self.assertEqual(0, len(request.data["results"]))
def test_get_files(self):
file = factories.FileFactory()
request = self.client.get('/api/v1/files')
self.assertEqual(1, len(request.data['results']))
self.assertEqual(file.path, request.data['results'][0]['path'])
request = self.client.get("/api/v1/files")
self.assertEqual(1, len(request.data["results"]))
self.assertEqual(file.path, request.data["results"][0]["path"])
def test_get_file(self):
file = factories.FileFactory()
request = self.client.get('/api/v1/files/%s' % file.id)
self.assertEqual(file.path, request.data['path'])
request = self.client.get("/api/v1/files/%s" % file.id)
self.assertEqual(file.path, request.data["path"])
def test_update_file(self):
file = factories.FileFactory()
self.assertNotEqual('/path/new_playbook.yml', file.path)
request = self.client.put('/api/v1/files/%s' % file.id, {
"path": "/path/new_playbook.yml",
'content': '# playbook'
})
self.assertNotEqual("/path/new_playbook.yml", file.path)
request = self.client.put(
"/api/v1/files/%s" % file.id, {"path": "/path/new_playbook.yml", "content": "# playbook"}
)
self.assertEqual(200, request.status_code)
file_updated = models.File.objects.get(id=file.id)
self.assertEqual('/path/new_playbook.yml', file_updated.path)
self.assertEqual("/path/new_playbook.yml", file_updated.path)
def test_partial_update_file(self):
file = factories.FileFactory()
self.assertNotEqual('/path/new_playbook.yml', file.path)
request = self.client.patch('/api/v1/files/%s' % file.id, {
"path": "/path/new_playbook.yml",
})
self.assertNotEqual("/path/new_playbook.yml", file.path)
request = self.client.patch("/api/v1/files/%s" % file.id, {"path": "/path/new_playbook.yml"})
self.assertEqual(200, request.status_code)
file_updated = models.File.objects.get(id=file.id)
self.assertEqual('/path/new_playbook.yml', file_updated.path)
self.assertEqual("/path/new_playbook.yml", file_updated.path)
def test_delete_file(self):
file = factories.FileFactory()
self.assertEqual(1, models.File.objects.all().count())
request = self.client.delete('/api/v1/files/%s' % file.id)
request = self.client.delete("/api/v1/files/%s" % file.id)
self.assertEqual(204, request.status_code)
self.assertEqual(0, models.File.objects.all().count())

View File

@ -22,5 +22,5 @@ from ara.api.tests import factories
class FileContentTestCase(APITestCase):
def test_file_content_factory(self):
file_content = factories.FileContentFactory(sha1='413a2f16b8689267b7d0c2e10cdd19bf3e54208d')
self.assertEqual(file_content.sha1, '413a2f16b8689267b7d0c2e10cdd19bf3e54208d')
file_content = factories.FileContentFactory(sha1="413a2f16b8689267b7d0c2e10cdd19bf3e54208d")
self.assertEqual(file_content.sha1, "413a2f16b8689267b7d0c2e10cdd19bf3e54208d")

View File

@ -24,95 +24,77 @@ from ara.api.tests import utils
class HostTestCase(APITestCase):
def test_host_factory(self):
host = factories.HostFactory(name='testhost')
self.assertEqual(host.name, 'testhost')
host = factories.HostFactory(name="testhost")
self.assertEqual(host.name, "testhost")
def test_host_serializer(self):
playbook = factories.PlaybookFactory()
serializer = serializers.HostSerializer(data={
'name': 'serializer',
'playbook': playbook.id
})
serializer = serializers.HostSerializer(data={"name": "serializer", "playbook": playbook.id})
serializer.is_valid()
host = serializer.save()
host.refresh_from_db()
self.assertEqual(host.name, 'serializer')
self.assertEqual(host.name, "serializer")
self.assertEqual(host.playbook.id, playbook.id)
def test_host_serializer_compress_facts(self):
playbook = factories.PlaybookFactory()
serializer = serializers.HostSerializer(data={
'name': 'compress',
'facts': factories.HOST_FACTS,
'playbook': playbook.id,
})
serializer = serializers.HostSerializer(
data={"name": "compress", "facts": factories.HOST_FACTS, "playbook": playbook.id}
)
serializer.is_valid()
host = serializer.save()
host.refresh_from_db()
self.assertEqual(host.facts, utils.compressed_obj(factories.HOST_FACTS))
def test_host_serializer_decompress_facts(self):
host = factories.HostFactory(
facts=utils.compressed_obj(factories.HOST_FACTS)
)
host = factories.HostFactory(facts=utils.compressed_obj(factories.HOST_FACTS))
serializer = serializers.HostSerializer(instance=host)
self.assertEqual(serializer.data['facts'], factories.HOST_FACTS)
self.assertEqual(serializer.data["facts"], factories.HOST_FACTS)
def test_get_no_hosts(self):
request = self.client.get('/api/v1/hosts')
self.assertEqual(0, len(request.data['results']))
request = self.client.get("/api/v1/hosts")
self.assertEqual(0, len(request.data["results"]))
def test_get_hosts(self):
host = factories.HostFactory()
request = self.client.get('/api/v1/hosts')
self.assertEqual(1, len(request.data['results']))
self.assertEqual(host.name, request.data['results'][0]['name'])
request = self.client.get("/api/v1/hosts")
self.assertEqual(1, len(request.data["results"]))
self.assertEqual(host.name, request.data["results"][0]["name"])
def test_delete_host(self):
host = factories.HostFactory()
self.assertEqual(1, models.Host.objects.all().count())
request = self.client.delete('/api/v1/hosts/%s' % host.id)
request = self.client.delete("/api/v1/hosts/%s" % host.id)
self.assertEqual(204, request.status_code)
self.assertEqual(0, models.Host.objects.all().count())
def test_create_host(self):
playbook = factories.PlaybookFactory()
self.assertEqual(0, models.Host.objects.count())
request = self.client.post('/api/v1/hosts', {
'name': 'create',
'playbook': playbook.id
})
request = self.client.post("/api/v1/hosts", {"name": "create", "playbook": playbook.id})
self.assertEqual(201, request.status_code)
self.assertEqual(1, models.Host.objects.count())
def test_post_same_host_for_a_playbook(self):
playbook = factories.PlaybookFactory()
self.assertEqual(0, models.Host.objects.count())
request = self.client.post('/api/v1/hosts', {
'name': 'create',
'playbook': playbook.id
})
request = self.client.post("/api/v1/hosts", {"name": "create", "playbook": playbook.id})
self.assertEqual(201, request.status_code)
self.assertEqual(1, models.Host.objects.count())
request = self.client.post('/api/v1/hosts', {
'name': 'create',
'playbook': playbook.id
})
request = self.client.post("/api/v1/hosts", {"name": "create", "playbook": playbook.id})
self.assertEqual(201, request.status_code)
self.assertEqual(1, models.Host.objects.count())
def test_partial_update_host(self):
host = factories.HostFactory()
self.assertNotEqual('foo', host.name)
request = self.client.patch('/api/v1/hosts/%s' % host.id, {
'name': 'foo'
})
self.assertNotEqual("foo", host.name)
request = self.client.patch("/api/v1/hosts/%s" % host.id, {"name": "foo"})
self.assertEqual(200, request.status_code)
host_updated = models.Host.objects.get(id=host.id)
self.assertEqual('foo', host_updated.name)
self.assertEqual("foo", host_updated.name)
def test_get_host(self):
host = factories.HostFactory()
request = self.client.get('/api/v1/hosts/%s' % host.id)
self.assertEqual(host.name, request.data['name'])
request = self.client.get("/api/v1/hosts/%s" % host.id)
self.assertEqual(host.name, request.data["name"])

View File

@ -24,72 +24,60 @@ from ara.api.tests import utils
class LabelTestCase(APITestCase):
def test_label_factory(self):
label = factories.LabelFactory(name='factory')
self.assertEqual(label.name, 'factory')
label = factories.LabelFactory(name="factory")
self.assertEqual(label.name, "factory")
def test_label_serializer(self):
serializer = serializers.LabelSerializer(data={
'name': 'serializer',
})
serializer = serializers.LabelSerializer(data={"name": "serializer"})
serializer.is_valid()
label = serializer.save()
label.refresh_from_db()
self.assertEqual(label.name, 'serializer')
self.assertEqual(label.name, "serializer")
def test_label_serializer_compress_description(self):
serializer = serializers.LabelSerializer(data={
'name': 'compress',
'description': factories.LABEL_DESCRIPTION
})
serializer = serializers.LabelSerializer(data={"name": "compress", "description": factories.LABEL_DESCRIPTION})
serializer.is_valid()
label = serializer.save()
label.refresh_from_db()
self.assertEqual(label.description, utils.compressed_str(factories.LABEL_DESCRIPTION))
def test_label_serializer_decompress_description(self):
label = factories.LabelFactory(
description=utils.compressed_str(factories.LABEL_DESCRIPTION)
)
label = factories.LabelFactory(description=utils.compressed_str(factories.LABEL_DESCRIPTION))
serializer = serializers.LabelSerializer(instance=label)
self.assertEqual(serializer.data['description'], factories.LABEL_DESCRIPTION)
self.assertEqual(serializer.data["description"], factories.LABEL_DESCRIPTION)
def test_create_label(self):
self.assertEqual(0, models.Label.objects.count())
request = self.client.post('/api/v1/labels', {
'name': 'compress',
'description': factories.LABEL_DESCRIPTION
})
request = self.client.post("/api/v1/labels", {"name": "compress", "description": factories.LABEL_DESCRIPTION})
self.assertEqual(201, request.status_code)
self.assertEqual(1, models.Label.objects.count())
def test_get_no_labels(self):
request = self.client.get('/api/v1/labels')
self.assertEqual(0, len(request.data['results']))
request = self.client.get("/api/v1/labels")
self.assertEqual(0, len(request.data["results"]))
def test_get_labels(self):
label = factories.LabelFactory()
request = self.client.get('/api/v1/labels')
self.assertEqual(1, len(request.data['results']))
self.assertEqual(label.name, request.data['results'][0]['name'])
request = self.client.get("/api/v1/labels")
self.assertEqual(1, len(request.data["results"]))
self.assertEqual(label.name, request.data["results"][0]["name"])
def test_get_label(self):
label = factories.LabelFactory()
request = self.client.get('/api/v1/labels/%s' % label.id)
self.assertEqual(label.name, request.data['name'])
request = self.client.get("/api/v1/labels/%s" % label.id)
self.assertEqual(label.name, request.data["name"])
def test_partial_update_label(self):
label = factories.LabelFactory()
self.assertNotEqual('updated', label.name)
request = self.client.patch('/api/v1/labels/%s' % label.id, {
'name': 'updated'
})
self.assertNotEqual("updated", label.name)
request = self.client.patch("/api/v1/labels/%s" % label.id, {"name": "updated"})
self.assertEqual(200, request.status_code)
label_updated = models.Label.objects.get(id=label.id)
self.assertEqual('updated', label_updated.name)
self.assertEqual("updated", label_updated.name)
def test_delete_label(self):
label = factories.LabelFactory()
self.assertEqual(1, models.Label.objects.all().count())
request = self.client.delete('/api/v1/labels/%s' % label.id)
request = self.client.delete("/api/v1/labels/%s" % label.id)
self.assertEqual(204, request.status_code)
self.assertEqual(0, models.Label.objects.all().count())

View File

@ -25,67 +25,57 @@ from ara.api.tests import factories
class PlayTestCase(APITestCase):
def test_play_factory(self):
play = factories.PlayFactory(name='play factory')
self.assertEqual(play.name, 'play factory')
play = factories.PlayFactory(name="play factory")
self.assertEqual(play.name, "play factory")
def test_play_serializer(self):
playbook = factories.PlaybookFactory()
serializer = serializers.PlaySerializer(data={
'name': 'serializer',
'completed': True,
'playbook': playbook.id
})
serializer = serializers.PlaySerializer(data={"name": "serializer", "completed": True, "playbook": playbook.id})
serializer.is_valid()
play = serializer.save()
play.refresh_from_db()
self.assertEqual(play.name, 'serializer')
self.assertEqual(play.name, "serializer")
def test_get_no_plays(self):
request = self.client.get('/api/v1/plays')
self.assertEqual(0, len(request.data['results']))
request = self.client.get("/api/v1/plays")
self.assertEqual(0, len(request.data["results"]))
def test_get_plays(self):
play = factories.PlayFactory()
request = self.client.get('/api/v1/plays')
self.assertEqual(1, len(request.data['results']))
self.assertEqual(play.name, request.data['results'][0]['name'])
request = self.client.get("/api/v1/plays")
self.assertEqual(1, len(request.data["results"]))
self.assertEqual(play.name, request.data["results"][0]["name"])
def test_delete_play(self):
play = factories.PlayFactory()
self.assertEqual(1, models.Play.objects.all().count())
request = self.client.delete('/api/v1/plays/%s' % play.id)
request = self.client.delete("/api/v1/plays/%s" % play.id)
self.assertEqual(204, request.status_code)
self.assertEqual(0, models.Play.objects.all().count())
def test_create_play(self):
playbook = factories.PlaybookFactory()
self.assertEqual(0, models.Play.objects.count())
request = self.client.post('/api/v1/plays', {
'name': 'create',
'completed': False,
'playbook': playbook.id
})
request = self.client.post("/api/v1/plays", {"name": "create", "completed": False, "playbook": playbook.id})
self.assertEqual(201, request.status_code)
self.assertEqual(1, models.Play.objects.count())
def test_partial_update_play(self):
play = factories.PlayFactory()
self.assertNotEqual('update', play.name)
request = self.client.patch('/api/v1/plays/%s' % play.id, {
'name': 'update',
})
self.assertNotEqual("update", play.name)
request = self.client.patch("/api/v1/plays/%s" % play.id, {"name": "update"})
self.assertEqual(200, request.status_code)
play_updated = models.Play.objects.get(id=play.id)
self.assertEqual('update', play_updated.name)
self.assertEqual("update", play_updated.name)
def test_get_play(self):
play = factories.PlayFactory()
request = self.client.get('/api/v1/plays/%s' % play.id)
self.assertEqual(play.name, request.data['name'])
request = self.client.get("/api/v1/plays/%s" % play.id)
self.assertEqual(play.name, request.data["name"])
def test_get_play_duration(self):
started = timezone.now()
ended = started + datetime.timedelta(hours=1)
play = factories.PlayFactory(started=started, ended=ended)
request = self.client.get('/api/v1/plays/%s' % play.id)
self.assertEqual(request.data['duration'], datetime.timedelta(0, 3600))
request = self.client.get("/api/v1/plays/%s" % play.id)
self.assertEqual(request.data["duration"], datetime.timedelta(0, 3600))

View File

@ -26,96 +26,85 @@ from ara.api.tests import utils
class PlaybookTestCase(APITestCase):
def test_playbook_factory(self):
playbook = factories.PlaybookFactory(ansible_version='2.4.0')
self.assertEqual(playbook.ansible_version, '2.4.0')
playbook = factories.PlaybookFactory(ansible_version="2.4.0")
self.assertEqual(playbook.ansible_version, "2.4.0")
def test_playbook_serializer(self):
serializer = serializers.PlaybookSerializer(data={
'name': 'serializer-playbook',
'ansible_version': '2.4.0',
'file': {
'path': '/path/playbook.yml',
'content': factories.FILE_CONTENTS
serializer = serializers.PlaybookSerializer(
data={
"name": "serializer-playbook",
"ansible_version": "2.4.0",
"file": {"path": "/path/playbook.yml", "content": factories.FILE_CONTENTS},
}
})
)
serializer.is_valid()
playbook = serializer.save()
playbook.refresh_from_db()
self.assertEqual(playbook.name, 'serializer-playbook')
self.assertEqual(playbook.ansible_version, '2.4.0')
self.assertEqual(playbook.name, "serializer-playbook")
self.assertEqual(playbook.ansible_version, "2.4.0")
def test_playbook_serializer_compress_parameters(self):
serializer = serializers.PlaybookSerializer(data={
'ansible_version': '2.4.0',
'file': {
'path': '/path/playbook.yml',
'content': factories.FILE_CONTENTS
},
'parameters': factories.PLAYBOOK_PARAMETERS
})
serializer = serializers.PlaybookSerializer(
data={
"ansible_version": "2.4.0",
"file": {"path": "/path/playbook.yml", "content": factories.FILE_CONTENTS},
"parameters": factories.PLAYBOOK_PARAMETERS,
}
)
serializer.is_valid()
playbook = serializer.save()
playbook.refresh_from_db()
self.assertEqual(
playbook.parameters, utils.compressed_obj(factories.PLAYBOOK_PARAMETERS)
)
self.assertEqual(playbook.parameters, utils.compressed_obj(factories.PLAYBOOK_PARAMETERS))
def test_playbook_serializer_decompress_parameters(self):
playbook = factories.PlaybookFactory(
parameters=utils.compressed_obj(factories.PLAYBOOK_PARAMETERS)
)
playbook = factories.PlaybookFactory(parameters=utils.compressed_obj(factories.PLAYBOOK_PARAMETERS))
serializer = serializers.PlaybookSerializer(instance=playbook)
self.assertEqual(serializer.data['parameters'], factories.PLAYBOOK_PARAMETERS)
self.assertEqual(serializer.data["parameters"], factories.PLAYBOOK_PARAMETERS)
def test_get_no_playbooks(self):
request = self.client.get('/api/v1/playbooks')
self.assertEqual(0, len(request.data['results']))
request = self.client.get("/api/v1/playbooks")
self.assertEqual(0, len(request.data["results"]))
def test_get_playbooks(self):
playbook = factories.PlaybookFactory()
request = self.client.get('/api/v1/playbooks')
self.assertEqual(1, len(request.data['results']))
self.assertEqual(playbook.ansible_version, request.data['results'][0]['ansible_version'])
request = self.client.get("/api/v1/playbooks")
self.assertEqual(1, len(request.data["results"]))
self.assertEqual(playbook.ansible_version, request.data["results"][0]["ansible_version"])
def test_delete_playbook(self):
playbook = factories.PlaybookFactory()
self.assertEqual(1, models.Playbook.objects.all().count())
request = self.client.delete('/api/v1/playbooks/%s' % playbook.id)
request = self.client.delete("/api/v1/playbooks/%s" % playbook.id)
self.assertEqual(204, request.status_code)
self.assertEqual(0, models.Playbook.objects.all().count())
def test_create_playbook(self):
self.assertEqual(0, models.Playbook.objects.count())
request = self.client.post('/api/v1/playbooks', {
"ansible_version": "2.4.0",
'file': {
'path': '/path/playbook.yml',
'content': factories.FILE_CONTENTS
}
})
request = self.client.post(
"/api/v1/playbooks",
{"ansible_version": "2.4.0", "file": {"path": "/path/playbook.yml", "content": factories.FILE_CONTENTS}},
)
self.assertEqual(201, request.status_code)
self.assertEqual(1, models.Playbook.objects.count())
def test_partial_update_playbook(self):
playbook = factories.PlaybookFactory()
self.assertNotEqual('2.3.0', playbook.ansible_version)
request = self.client.patch('/api/v1/playbooks/%s' % playbook.id, {
"ansible_version": "2.3.0",
})
self.assertNotEqual("2.3.0", playbook.ansible_version)
request = self.client.patch("/api/v1/playbooks/%s" % playbook.id, {"ansible_version": "2.3.0"})
self.assertEqual(200, request.status_code)
playbook_updated = models.Playbook.objects.get(id=playbook.id)
self.assertEqual('2.3.0', playbook_updated.ansible_version)
self.assertEqual("2.3.0", playbook_updated.ansible_version)
def test_get_playbook(self):
playbook = factories.PlaybookFactory()
request = self.client.get('/api/v1/playbooks/%s' % playbook.id)
self.assertEqual(playbook.ansible_version, request.data['ansible_version'])
request = self.client.get("/api/v1/playbooks/%s" % playbook.id)
self.assertEqual(playbook.ansible_version, request.data["ansible_version"])
def test_get_playbook_duration(self):
started = timezone.now()
ended = started + datetime.timedelta(hours=1)
playbook = factories.PlaybookFactory(started=started, ended=ended)
request = self.client.get('/api/v1/playbooks/%s' % playbook.id)
self.assertEqual(request.data['duration'], datetime.timedelta(0, 3600))
request = self.client.get("/api/v1/playbooks/%s" % playbook.id)
self.assertEqual(request.data["duration"], datetime.timedelta(0, 3600))
# TODO: Add tests for incrementally updating files

View File

@ -25,27 +25,24 @@ class PlaybookFileTestCase(APITestCase):
def test_create_a_file_and_a_playbook_directly(self):
self.assertEqual(0, models.Playbook.objects.all().count())
self.assertEqual(0, models.File.objects.all().count())
self.client.post('/api/v1/playbooks', {
'ansible_version': '2.4.0',
'file': {
'path': '/path/playbook.yml',
'content': factories.FILE_CONTENTS
self.client.post(
"/api/v1/playbooks",
{
"ansible_version": "2.4.0",
"file": {"path": "/path/playbook.yml", "content": factories.FILE_CONTENTS},
"files": [{"path": "/path/host", "content": "Another file"}],
},
'files': [{
'path': '/path/host',
'content': 'Another file'
}],
})
)
self.assertEqual(1, models.Playbook.objects.all().count())
self.assertEqual(2, models.File.objects.all().count())
def test_create_file_to_a_playbook(self):
playbook = factories.PlaybookFactory()
self.assertEqual(1, models.File.objects.all().count())
self.client.post('/api/v1/playbooks/%s/files' % playbook.id, {
'path': '/path/playbook.yml',
'content': factories.FILE_CONTENTS
})
self.client.post(
"/api/v1/playbooks/%s/files" % playbook.id,
{"path": "/path/playbook.yml", "content": factories.FILE_CONTENTS},
)
self.assertEqual(2, models.File.objects.all().count())
self.assertEqual(1, models.FileContent.objects.all().count())
@ -53,14 +50,12 @@ class PlaybookFileTestCase(APITestCase):
playbook = factories.PlaybookFactory()
number_playbooks = models.File.objects.all().count()
number_file_contents = models.FileContent.objects.all().count()
content = '# %s' % factories.FILE_CONTENTS
self.client.post('/api/v1/playbooks/%s/files' % playbook.id, {
'path': '/path/1/playbook.yml',
'content': content
})
self.client.post('/api/v1/playbooks/%s/files' % playbook.id, {
'path': '/path/2/playbook.yml',
'content': content
})
content = "# %s" % factories.FILE_CONTENTS
self.client.post(
"/api/v1/playbooks/%s/files" % playbook.id, {"path": "/path/1/playbook.yml", "content": content}
)
self.client.post(
"/api/v1/playbooks/%s/files" % playbook.id, {"path": "/path/2/playbook.yml", "content": content}
)
self.assertEqual(number_playbooks + 2, models.File.objects.all().count())
self.assertEqual(number_file_contents + 1, models.FileContent.objects.all().count())

View File

@ -24,32 +24,26 @@ from ara.api.tests import utils
class ResultTestCase(APITestCase):
def test_result_factory(self):
result = factories.ResultFactory(status='failed')
self.assertEqual(result.status, 'failed')
result = factories.ResultFactory(status="failed")
self.assertEqual(result.status, "failed")
def test_result_serializer(self):
host = factories.HostFactory()
task = factories.TaskFactory()
serializer = serializers.ResultSerializer(data={
'status': 'skipped',
'host': host.id,
'task': task.id
})
serializer = serializers.ResultSerializer(data={"status": "skipped", "host": host.id, "task": task.id})
serializer.is_valid()
result = serializer.save()
result.refresh_from_db()
self.assertEqual(result.status, 'skipped')
self.assertEqual(result.status, "skipped")
self.assertEqual(result.host.id, host.id)
self.assertEqual(result.task.id, task.id)
def test_result_serializer_compress_content(self):
host = factories.HostFactory()
task = factories.TaskFactory()
serializer = serializers.ResultSerializer(data={
'host': host.id,
'task': task.id,
'content': factories.RESULT_CONTENTS
})
serializer = serializers.ResultSerializer(
data={"host": host.id, "task": task.id, "content": factories.RESULT_CONTENTS}
)
serializer.is_valid()
result = serializer.save()
result.refresh_from_db()
@ -58,22 +52,22 @@ class ResultTestCase(APITestCase):
def test_result_serializer_decompress_content(self):
result = factories.ResultFactory(content=utils.compressed_obj(factories.RESULT_CONTENTS))
serializer = serializers.ResultSerializer(instance=result)
self.assertEqual(serializer.data['content'], factories.RESULT_CONTENTS)
self.assertEqual(serializer.data["content"], factories.RESULT_CONTENTS)
def test_get_no_results(self):
request = self.client.get('/api/v1/results')
self.assertEqual(0, len(request.data['results']))
request = self.client.get("/api/v1/results")
self.assertEqual(0, len(request.data["results"]))
def test_get_results(self):
result = factories.ResultFactory()
request = self.client.get('/api/v1/results')
self.assertEqual(1, len(request.data['results']))
self.assertEqual(result.status, request.data['results'][0]['status'])
request = self.client.get("/api/v1/results")
self.assertEqual(1, len(request.data["results"]))
self.assertEqual(result.status, request.data["results"][0]["status"])
def test_delete_result(self):
result = factories.ResultFactory()
self.assertEqual(1, models.Result.objects.all().count())
request = self.client.delete('/api/v1/results/%s' % result.id)
request = self.client.delete("/api/v1/results/%s" % result.id)
self.assertEqual(204, request.status_code)
self.assertEqual(0, models.Result.objects.all().count())
@ -81,26 +75,21 @@ class ResultTestCase(APITestCase):
host = factories.HostFactory()
task = factories.TaskFactory()
self.assertEqual(0, models.Result.objects.count())
request = self.client.post('/api/v1/results', {
'status': 'ok',
'host': host.id,
'task': task.id,
'content': factories.RESULT_CONTENTS
})
request = self.client.post(
"/api/v1/results", {"status": "ok", "host": host.id, "task": task.id, "content": factories.RESULT_CONTENTS}
)
self.assertEqual(201, request.status_code)
self.assertEqual(1, models.Result.objects.count())
def test_partial_update_result(self):
result = factories.ResultFactory()
self.assertNotEqual('unreachable', result.status)
request = self.client.patch('/api/v1/results/%s' % result.id, {
'status': 'unreachable'
})
self.assertNotEqual("unreachable", result.status)
request = self.client.patch("/api/v1/results/%s" % result.id, {"status": "unreachable"})
self.assertEqual(200, request.status_code)
result_updated = models.Result.objects.get(id=result.id)
self.assertEqual('unreachable', result_updated.status)
self.assertEqual("unreachable", result_updated.status)
def test_get_result(self):
result = factories.ResultFactory()
request = self.client.get('/api/v1/results/%s' % result.id)
self.assertEqual(result.status, request.data['status'])
request = self.client.get("/api/v1/results/%s" % result.id)
self.assertEqual(result.status, request.data["status"])

View File

@ -23,13 +23,7 @@ from ara.api.tests import factories
class StatsTestCase(APITestCase):
def test_stats_factory(self):
stats = factories.StatsFactory(
changed=2,
failed=1,
ok=3,
skipped=2,
unreachable=1
)
stats = factories.StatsFactory(changed=2, failed=1, ok=3, skipped=2, unreachable=1)
self.assertEqual(stats.changed, 2)
self.assertEqual(stats.failed, 1)
self.assertEqual(stats.ok, 3)
@ -39,11 +33,7 @@ class StatsTestCase(APITestCase):
def test_stats_serializer(self):
playbook = factories.PlaybookFactory()
host = factories.HostFactory()
serializer = serializers.StatsSerializer(data=dict(
playbook=playbook.id,
host=host.id,
ok=9001
))
serializer = serializers.StatsSerializer(data=dict(playbook=playbook.id, host=host.id, ok=9001))
serializer.is_valid()
stats = serializer.save()
stats.refresh_from_db()
@ -55,35 +45,29 @@ class StatsTestCase(APITestCase):
playbook = factories.PlaybookFactory()
host = factories.HostFactory()
self.assertEqual(0, models.Stats.objects.count())
request = self.client.post('/api/v1/stats', dict(
playbook=playbook.id,
host=host.id,
ok=9001
))
request = self.client.post("/api/v1/stats", dict(playbook=playbook.id, host=host.id, ok=9001))
self.assertEqual(201, request.status_code)
self.assertEqual(1, models.Stats.objects.count())
def test_get_no_stats(self):
request = self.client.get('/api/v1/stats')
self.assertEqual(0, len(request.data['results']))
request = self.client.get("/api/v1/stats")
self.assertEqual(0, len(request.data["results"]))
def test_get_stats(self):
stats = factories.StatsFactory()
request = self.client.get('/api/v1/stats')
self.assertEqual(1, len(request.data['results']))
self.assertEqual(stats.ok, request.data['results'][0]['ok'])
request = self.client.get("/api/v1/stats")
self.assertEqual(1, len(request.data["results"]))
self.assertEqual(stats.ok, request.data["results"][0]["ok"])
def test_get_stats_id(self):
stats = factories.StatsFactory()
request = self.client.get('/api/v1/stats/%s' % stats.id)
self.assertEqual(stats.ok, request.data['ok'])
request = self.client.get("/api/v1/stats/%s" % stats.id)
self.assertEqual(stats.ok, request.data["ok"])
def test_partial_update_stats(self):
stats = factories.StatsFactory()
self.assertNotEqual(9001, stats.ok)
request = self.client.patch('/api/v1/stats/%s' % stats.id, dict(
ok=9001
))
request = self.client.patch("/api/v1/stats/%s" % stats.id, dict(ok=9001))
self.assertEqual(200, request.status_code)
stats_updated = models.Stats.objects.get(id=stats.id)
self.assertEqual(9001, stats_updated.ok)
@ -91,6 +75,6 @@ class StatsTestCase(APITestCase):
def test_delete_stats(self):
stats = factories.StatsFactory()
self.assertEqual(1, models.Stats.objects.all().count())
request = self.client.delete('/api/v1/stats/%s' % stats.id)
request = self.client.delete("/api/v1/stats/%s" % stats.id)
self.assertEqual(204, request.status_code)
self.assertEqual(0, models.Stats.objects.all().count())

View File

@ -26,39 +26,43 @@ from ara.api.tests import utils
class TaskTestCase(APITestCase):
def test_task_factory(self):
task = factories.TaskFactory(name='factory')
self.assertEqual(task.name, 'factory')
task = factories.TaskFactory(name="factory")
self.assertEqual(task.name, "factory")
def test_task_serializer(self):
play = factories.PlayFactory()
file = factories.FileFactory()
serializer = serializers.TaskSerializer(data={
'name': 'serializer',
'action': 'test',
'lineno': 2,
'completed': True,
'handler': False,
'play': play.id,
'file': file.id
})
serializer = serializers.TaskSerializer(
data={
"name": "serializer",
"action": "test",
"lineno": 2,
"completed": True,
"handler": False,
"play": play.id,
"file": file.id,
}
)
serializer.is_valid()
task = serializer.save()
task.refresh_from_db()
self.assertEqual(task.name, 'serializer')
self.assertEqual(task.name, "serializer")
def test_task_serializer_compress_tags(self):
play = factories.PlayFactory()
file = factories.FileFactory()
serializer = serializers.TaskSerializer(data={
'name': 'compress',
'action': 'test',
'lineno': 2,
'completed': True,
'handler': False,
'play': play.id,
'file': file.id,
'tags': factories.TASK_TAGS
})
serializer = serializers.TaskSerializer(
data={
"name": "compress",
"action": "test",
"lineno": 2,
"completed": True,
"handler": False,
"play": play.id,
"file": file.id,
"tags": factories.TASK_TAGS,
}
)
serializer.is_valid()
task = serializer.save()
task.refresh_from_db()
@ -67,22 +71,22 @@ class TaskTestCase(APITestCase):
def test_task_serializer_decompress_tags(self):
task = factories.TaskFactory(tags=utils.compressed_obj(factories.TASK_TAGS))
serializer = serializers.TaskSerializer(instance=task)
self.assertEqual(serializer.data['tags'], factories.TASK_TAGS)
self.assertEqual(serializer.data["tags"], factories.TASK_TAGS)
def test_get_no_tasks(self):
request = self.client.get('/api/v1/tasks')
self.assertEqual(0, len(request.data['results']))
request = self.client.get("/api/v1/tasks")
self.assertEqual(0, len(request.data["results"]))
def test_get_tasks(self):
task = factories.TaskFactory()
request = self.client.get('/api/v1/tasks')
self.assertEqual(1, len(request.data['results']))
self.assertEqual(task.name, request.data['results'][0]['name'])
request = self.client.get("/api/v1/tasks")
self.assertEqual(1, len(request.data["results"]))
self.assertEqual(task.name, request.data["results"][0]["name"])
def test_delete_task(self):
task = factories.TaskFactory()
self.assertEqual(1, models.Task.objects.all().count())
request = self.client.delete('/api/v1/tasks/%s' % task.id)
request = self.client.delete("/api/v1/tasks/%s" % task.id)
self.assertEqual(204, request.status_code)
self.assertEqual(0, models.Task.objects.all().count())
@ -90,36 +94,37 @@ class TaskTestCase(APITestCase):
play = factories.PlayFactory()
file = factories.FileFactory()
self.assertEqual(0, models.Task.objects.count())
request = self.client.post('/api/v1/tasks', {
'name': 'create',
'action': 'test',
'lineno': 2,
'handler': False,
'completed': True,
'play': play.id,
'file': file.id
})
request = self.client.post(
"/api/v1/tasks",
{
"name": "create",
"action": "test",
"lineno": 2,
"handler": False,
"completed": True,
"play": play.id,
"file": file.id,
},
)
self.assertEqual(201, request.status_code)
self.assertEqual(1, models.Task.objects.count())
def test_partial_update_task(self):
task = factories.TaskFactory()
self.assertNotEqual('update', task.name)
request = self.client.patch('/api/v1/tasks/%s' % task.id, {
'name': 'update'
})
self.assertNotEqual("update", task.name)
request = self.client.patch("/api/v1/tasks/%s" % task.id, {"name": "update"})
self.assertEqual(200, request.status_code)
task_updated = models.Task.objects.get(id=task.id)
self.assertEqual('update', task_updated.name)
self.assertEqual("update", task_updated.name)
def test_get_task(self):
task = factories.TaskFactory()
request = self.client.get('/api/v1/tasks/%s' % task.id)
self.assertEqual(task.name, request.data['name'])
request = self.client.get("/api/v1/tasks/%s" % task.id)
self.assertEqual(task.name, request.data["name"])
def test_get_task_duration(self):
started = timezone.now()
ended = started + datetime.timedelta(hours=1)
task = factories.TaskFactory(started=started, ended=ended)
request = self.client.get('/api/v1/tasks/%s' % task.id)
self.assertEqual(request.data['duration'], datetime.timedelta(0, 3600))
request = self.client.get("/api/v1/tasks/%s" % task.id)
self.assertEqual(request.data["duration"], datetime.timedelta(0, 3600))

View File

@ -24,19 +24,19 @@ def compressed_obj(obj):
"""
Returns a zlib compressed representation of an object
"""
return zlib.compress(json.dumps(obj).encode('utf-8'))
return zlib.compress(json.dumps(obj).encode("utf-8"))
def compressed_str(obj):
"""
Returns a zlib compressed representation of a string
"""
return zlib.compress(obj.encode('utf-8'))
return zlib.compress(obj.encode("utf-8"))
def sha1(obj):
"""
Returns the sha1 of a compressed string or an object
"""
contents = zlib.compress(obj.encode('utf8'))
contents = zlib.compress(obj.encode("utf8"))
return hashlib.sha1(contents).hexdigest()

View File

@ -19,15 +19,15 @@ from rest_framework_extensions.routers import ExtendedDefaultRouter
from ara.api import views
router = ExtendedDefaultRouter(trailing_slash=False)
router.register('labels', views.LabelViewSet, base_name='label')
router.register('plays', views.PlayViewSet, base_name='play')
router.register('tasks', views.TaskViewSet, base_name='task')
router.register('hosts', views.HostViewSet, base_name='host')
router.register('results', views.ResultViewSet, base_name='result')
router.register('files', views.FileViewSet, base_name='file')
router.register('stats', views.StatsViewSet, base_name='stats')
router.register("labels", views.LabelViewSet, base_name="label")
router.register("plays", views.PlayViewSet, base_name="play")
router.register("tasks", views.TaskViewSet, base_name="task")
router.register("hosts", views.HostViewSet, base_name="host")
router.register("results", views.ResultViewSet, base_name="result")
router.register("files", views.FileViewSet, base_name="file")
router.register("stats", views.StatsViewSet, base_name="stats")
playbook_routes = router.register('playbooks', views.PlaybookViewSet, base_name='playbook')
playbook_routes.register('files', views.PlaybookFilesDetail, base_name='file', parents_query_lookups=['playbooks'])
playbook_routes = router.register("playbooks", views.PlaybookViewSet, base_name="playbook")
playbook_routes.register("files", views.PlaybookFilesDetail, base_name="file", parents_query_lookups=["playbooks"])
urlpatterns = router.urls

View File

@ -36,7 +36,7 @@ class PlaybookFilesDetail(NestedViewSetMixin, viewsets.ModelViewSet):
serializer_class = serializers.FileSerializer
def perform_create(self, serializer):
playbook = models.Playbook.objects.get(pk=self.get_parents_query_dict()['playbooks'])
playbook = models.Playbook.objects.get(pk=self.get_parents_query_dict()["playbooks"])
with transaction.atomic(savepoint=False):
instance = serializer.save()
playbook.files.add(instance)

View File

@ -2,5 +2,5 @@ from django.contrib import admin
class AraAdminSite(admin.AdminSite):
site_header = 'Administration'
index_title = 'Administration Ara'
site_header = "Administration"
index_title = "Administration Ara"

View File

@ -2,4 +2,4 @@ from django.contrib.admin.apps import AdminConfig
class AraAdminConfig(AdminConfig):
default_site = 'ara.server.admin.AraAdminSite'
default_site = "ara.server.admin.AraAdminSite"

View File

@ -2,7 +2,4 @@ from django.urls import include, path
from django.contrib import admin
urlpatterns = [
path('api/v1/', include('ara.api.urls')),
path('admin/', admin.site.urls),
]
urlpatterns = [path("api/v1/", include("ara.api.urls")), path("admin/", admin.site.urls)]