274 lines
9.6 KiB
Python
274 lines
9.6 KiB
Python
# Copyright (c) 2013-2014 Hewlett-Packard Development Company, L.P.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
# implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import threading
|
|
import time
|
|
import uuid
|
|
|
|
from OpenSSL import crypto
|
|
import fixtures
|
|
import testscenarios
|
|
import testtools
|
|
|
|
import gear
|
|
from gear import tests
|
|
|
|
|
|
def iterate_timeout(max_seconds, purpose):
|
|
start = time.time()
|
|
count = 0
|
|
while (time.time() < start + max_seconds):
|
|
count += 1
|
|
yield count
|
|
time.sleep(0)
|
|
raise Exception("Timeout waiting for %s" % purpose)
|
|
|
|
|
|
class TestFunctional(tests.BaseTestCase):
|
|
scenarios = [
|
|
('no_ssl', dict(ssl=False)),
|
|
('ssl', dict(ssl=True)),
|
|
]
|
|
|
|
def setUp(self):
|
|
super(TestFunctional, self).setUp()
|
|
if self.ssl:
|
|
self.tmp_root = self.useFixture(fixtures.TempDir()).path
|
|
root_subject, root_key = self.create_cert('root')
|
|
self.create_cert('server', root_subject, root_key)
|
|
self.create_cert('client', root_subject, root_key)
|
|
self.create_cert('worker', root_subject, root_key)
|
|
self.server = gear.Server(
|
|
0,
|
|
os.path.join(self.tmp_root, 'server.key'),
|
|
os.path.join(self.tmp_root, 'server.crt'),
|
|
os.path.join(self.tmp_root, 'root.crt'))
|
|
self.client = gear.Client('client')
|
|
self.worker = gear.Worker('worker')
|
|
self.client.addServer('127.0.0.1', self.server.port,
|
|
os.path.join(self.tmp_root, 'client.key'),
|
|
os.path.join(self.tmp_root, 'client.crt'),
|
|
os.path.join(self.tmp_root, 'root.crt'))
|
|
self.worker.addServer('127.0.0.1', self.server.port,
|
|
os.path.join(self.tmp_root, 'worker.key'),
|
|
os.path.join(self.tmp_root, 'worker.crt'),
|
|
os.path.join(self.tmp_root, 'root.crt'))
|
|
else:
|
|
self.server = gear.Server(0)
|
|
self.client = gear.Client('client')
|
|
self.worker = gear.Worker('worker')
|
|
self.client.addServer('127.0.0.1', self.server.port)
|
|
self.worker.addServer('127.0.0.1', self.server.port)
|
|
|
|
self.client.waitForServer()
|
|
self.worker.waitForServer()
|
|
|
|
def create_cert(self, cn, issuer=None, signing_key=None):
|
|
key = crypto.PKey()
|
|
key.generate_key(crypto.TYPE_RSA, 1024)
|
|
|
|
cert = crypto.X509()
|
|
subject = cert.get_subject()
|
|
subject.C = "US"
|
|
subject.ST = "State"
|
|
subject.L = "Locality"
|
|
subject.O = "Org"
|
|
subject.OU = "Org Unit"
|
|
subject.CN = cn
|
|
cert.set_serial_number(1)
|
|
cert.gmtime_adj_notBefore(0)
|
|
cert.gmtime_adj_notAfter(3600)
|
|
cert.set_pubkey(key)
|
|
if issuer:
|
|
cert.set_issuer(issuer)
|
|
else:
|
|
cert.set_issuer(subject)
|
|
if signing_key:
|
|
cert.sign(signing_key, 'sha1')
|
|
else:
|
|
cert.sign(key, 'sha1')
|
|
|
|
open(os.path.join(self.tmp_root, '%s.crt' % cn), 'w').write(
|
|
crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode('utf-8'))
|
|
open(os.path.join(self.tmp_root, '%s.key' % cn), 'w').write(
|
|
crypto.dump_privatekey(crypto.FILETYPE_PEM, key).decode('utf-8'))
|
|
|
|
return (subject, key)
|
|
|
|
def test_job(self):
|
|
self.worker.registerFunction('test')
|
|
|
|
for jobcount in range(2):
|
|
job = gear.Job(b'test', b'testdata')
|
|
self.client.submitJob(job)
|
|
self.assertNotEqual(job.handle, None)
|
|
|
|
workerjob = self.worker.getJob()
|
|
self.assertEqual(workerjob.handle, job.handle)
|
|
self.assertEqual(workerjob.arguments, b'testdata')
|
|
workerjob.sendWorkData(b'workdata')
|
|
workerjob.sendWorkComplete()
|
|
|
|
for count in iterate_timeout(30, "job completion"):
|
|
if job.complete:
|
|
break
|
|
self.assertTrue(job.complete)
|
|
self.assertEqual(job.data, [b'workdata'])
|
|
|
|
def test_bg_job(self):
|
|
self.worker.registerFunction('test')
|
|
|
|
job = gear.Job(b'test', b'testdata')
|
|
self.client.submitJob(job, background=True)
|
|
self.assertNotEqual(job.handle, None)
|
|
self.client.shutdown()
|
|
del self.client
|
|
|
|
workerjob = self.worker.getJob()
|
|
self.assertEqual(workerjob.handle, job.handle)
|
|
self.assertEqual(workerjob.arguments, b'testdata')
|
|
workerjob.sendWorkData(b'workdata')
|
|
workerjob.sendWorkComplete()
|
|
|
|
def test_worker_termination(self):
|
|
def getJob():
|
|
with testtools.ExpectedException(gear.InterruptedError):
|
|
self.worker.getJob()
|
|
self.worker.registerFunction('test')
|
|
jobthread = threading.Thread(target=getJob)
|
|
jobthread.daemon = True
|
|
jobthread.start()
|
|
self.worker.stopWaitingForJobs()
|
|
|
|
def test_text_job_name(self):
|
|
self.worker.registerFunction('test')
|
|
|
|
for jobcount in range(2):
|
|
job = gear.Job('test', b'testdata')
|
|
self.client.submitJob(job)
|
|
self.assertNotEqual(job.handle, None)
|
|
|
|
workerjob = self.worker.getJob()
|
|
self.assertEqual('test', workerjob.name)
|
|
|
|
|
|
class TestFunctionalText(tests.BaseTestCase):
|
|
def setUp(self):
|
|
super(TestFunctionalText, self).setUp()
|
|
self.server = gear.Server(0)
|
|
self.client = gear.Client('client')
|
|
self.worker = gear.TextWorker('worker')
|
|
self.client.addServer('127.0.0.1', self.server.port)
|
|
self.worker.addServer('127.0.0.1', self.server.port)
|
|
self.client.waitForServer()
|
|
self.worker.waitForServer()
|
|
|
|
def test_text_job(self):
|
|
self.worker.registerFunction('test')
|
|
|
|
for jobcount in range(2):
|
|
job = gear.TextJob('test', 'testdata')
|
|
self.client.submitJob(job)
|
|
self.assertNotEqual(job.handle, None)
|
|
|
|
workerjob = self.worker.getJob()
|
|
self.assertEqual(workerjob.handle, job.handle)
|
|
self.assertEqual(workerjob.arguments, 'testdata')
|
|
workerjob.sendWorkData('workdata')
|
|
workerjob.sendWorkComplete()
|
|
|
|
for count in iterate_timeout(30, "job completion"):
|
|
if job.complete:
|
|
break
|
|
self.assertTrue(job.complete)
|
|
self.assertEqual(job.data, ['workdata'])
|
|
|
|
def test_text_job_unique(self):
|
|
self.worker.registerFunction('test')
|
|
|
|
for jobcount in range(2):
|
|
jobunique = uuid.uuid4().hex
|
|
job = gear.TextJob('test', 'testdata', unique=jobunique)
|
|
self.client.submitJob(job)
|
|
self.assertNotEqual(job.handle, None)
|
|
|
|
workerjob = self.worker.getJob()
|
|
self.assertEqual(workerjob.handle, job.handle)
|
|
self.assertEqual(workerjob.arguments, 'testdata')
|
|
workerjob.sendWorkData('workdata')
|
|
workerjob.sendWorkComplete()
|
|
|
|
for count in iterate_timeout(30, "job completion"):
|
|
if job.complete:
|
|
break
|
|
self.assertTrue(job.complete)
|
|
self.assertEqual(job.data, ['workdata'])
|
|
self.assertEqual(job.unique, jobunique)
|
|
self.assertEqual(workerjob.unique, jobunique)
|
|
|
|
def test_text_job_exception(self):
|
|
self.worker.registerFunction('test')
|
|
|
|
for jobcount in range(2):
|
|
job = gear.TextJob('test', 'testdata')
|
|
self.client.submitJob(job)
|
|
self.assertNotEqual(job.handle, None)
|
|
|
|
workerjob = self.worker.getJob()
|
|
self.assertEqual(workerjob.handle, job.handle)
|
|
self.assertEqual(workerjob.arguments, 'testdata')
|
|
workerjob.sendWorkException('work failed')
|
|
|
|
for count in iterate_timeout(30, "job completion"):
|
|
if job.complete:
|
|
break
|
|
self.assertTrue(job.complete)
|
|
self.assertEqual(job.exception, 'work failed')
|
|
|
|
def test_grab_job_after_register(self):
|
|
jobunique = uuid.uuid4().hex
|
|
job = gear.TextJob('test', 'testdata', unique=jobunique)
|
|
self.client.submitJob(job)
|
|
self.assertNotEqual(job.handle, None)
|
|
|
|
def getJob():
|
|
workerjob = self.worker.getJob()
|
|
workerjob.sendWorkComplete()
|
|
|
|
jobthread = threading.Thread(target=getJob)
|
|
jobthread.daemon = True
|
|
jobthread.start()
|
|
|
|
for count in iterate_timeout(30, "worker sleeping"):
|
|
if self.worker.active_connections[0].state == 'SLEEP':
|
|
break
|
|
self.assertEqual(1, len(self.server.normal_queue))
|
|
self.assertFalse(job.complete)
|
|
|
|
# When we register the function, the worker should send a
|
|
# grab_job packet and pick up the job and it should complete.
|
|
self.worker.registerFunction('test')
|
|
|
|
for count in iterate_timeout(30, "job completion"):
|
|
if job.complete:
|
|
break
|
|
|
|
self.assertEqual(0, len(self.server.normal_queue))
|
|
|
|
|
|
def load_tests(loader, in_tests, pattern):
|
|
return testscenarios.load_tests_apply_scenarios(loader, in_tests, pattern)
|