Make the test change database serializable

This adds methods to save and restore the test change database.
This can be used in a future change to perform an upgrade test
where we save and restore the fake gerrit (and other drivers)
state across an upgrade.

Some minor changes are made to FakeGerritChange to avoid storing
non-pickleable values.  Similarly, the Gerrit reporter was using
ZuulConfigKey objects as review categories.  That's fine since
they behave like strings everywhere we use them, but they aren't
pickleable, and we don't need the extra line context information
after loading the config, so we discard it and convert them to
strings.

Change-Id: Ifa404203e414932f50811291ad56a661f2875af0
This commit is contained in:
James E. Blair 2024-03-27 15:45:17 -07:00
parent 629f48e291
commit ea933f6b3f
9 changed files with 85 additions and 63 deletions

View File

@ -25,11 +25,12 @@ import itertools
import json
import logging
import os
import pickle
import random
import re
from collections import defaultdict, namedtuple
from queue import Queue
from typing import Callable, Optional, Generator, List, Dict
from typing import Generator, List
from unittest.case import skipIf
import zlib
@ -68,7 +69,7 @@ from kazoo.exceptions import NoNodeError
from zuul import model
from zuul.model import (
BuildRequest, Change, MergeRequest, WebInfo, HoldRequest
BuildRequest, MergeRequest, WebInfo, HoldRequest
)
from zuul.driver.zuul import ZuulDriver
@ -215,15 +216,35 @@ def registerProjects(source_name, client, config):
client.addProjectByName(project)
class FakeChangeDB:
def __init__(self):
# A dictionary of server -> dict as below
self.servers = {}
def getServerChangeDB(self, server):
"""Returns a dictionary for the specified server; key -> Change. The
key is driver dependent, but typically the change/PR/MR id.
"""
return self.servers.setdefault(server, {})
def save(self, path):
with open(path, 'wb') as f:
pickle.dump(self.servers, f, pickle.HIGHEST_PROTOCOL)
def load(self, path):
with open(path, 'rb') as f:
self.servers = pickle.load(f)
class StatException(Exception):
# Used by assertReportedStat
pass
class GerritDriverMock(GerritDriver):
def __init__(self, registry, changes: Dict[str, Dict[str, Change]],
upstream_root: str, additional_event_queues, poller_events,
add_cleanup: Callable[[Callable[[], None]], None]):
def __init__(self, registry, changes, upstream_root,
additional_event_queues, poller_events, add_cleanup):
super(GerritDriverMock, self).__init__()
self.registry = registry
self.changes = changes
@ -233,7 +254,8 @@ class GerritDriverMock(GerritDriver):
self.add_cleanup = add_cleanup
def getConnection(self, name, config):
db = self.changes.setdefault(config['server'], {})
server = config['server']
db = self.changes.getServerChangeDB(server)
poll_event = self.poller_events.setdefault(name, threading.Event())
ref_event = self.poller_events.setdefault(name + '-ref',
threading.Event())
@ -251,10 +273,8 @@ class GerritDriverMock(GerritDriver):
class GithubDriverMock(GithubDriver):
def __init__(self, registry, changes: Dict[str, Dict[str, Change]],
config: ConfigParser, upstream_root: str,
additional_event_queues,
git_url_with_auth: bool):
def __init__(self, registry, changes, config, upstream_root,
additional_event_queues, git_url_with_auth):
super(GithubDriverMock, self).__init__()
self.registry = registry
self.changes = changes
@ -265,7 +285,7 @@ class GithubDriverMock(GithubDriver):
def getConnection(self, name, config):
server = config.get('server', 'github.com')
db = self.changes.setdefault(server, {})
db = self.changes.getServerChangeDB(server)
connection = tests.fakegithub.FakeGithubConnection(
self, name, config,
changes_db=db,
@ -278,8 +298,8 @@ class GithubDriverMock(GithubDriver):
class PagureDriverMock(PagureDriver):
def __init__(self, registry, changes: Dict[str, Dict[str, Change]],
upstream_root: str, additional_event_queues):
def __init__(self, registry, changes, upstream_root,
additional_event_queues):
super(PagureDriverMock, self).__init__()
self.registry = registry
self.changes = changes
@ -288,7 +308,7 @@ class PagureDriverMock(PagureDriver):
def getConnection(self, name, config):
server = config.get('server', 'pagure.io')
db = self.changes.setdefault(server, {})
db = self.changes.getServerChangeDB(server)
connection = tests.fakepagure.FakePagureConnection(
self, name, config,
changes_db=db,
@ -298,8 +318,7 @@ class PagureDriverMock(PagureDriver):
class GitlabDriverMock(GitlabDriver):
def __init__(self, registry, changes: Dict[str, Dict[str, Change]],
config: ConfigParser, upstream_root: str,
def __init__(self, registry, changes, config, upstream_root,
additional_event_queues):
super(GitlabDriverMock, self).__init__()
self.registry = registry
@ -310,7 +329,7 @@ class GitlabDriverMock(GitlabDriver):
def getConnection(self, name, config):
server = config.get('server', 'gitlab.com')
db = self.changes.setdefault(server, {})
db = self.changes.getServerChangeDB(server)
connection = tests.fakegitlab.FakeGitlabConnection(
self, name, config,
changes_db=db,
@ -565,7 +584,7 @@ class FakeBuild(object):
"""
for change in changes:
hostname = change.source.canonical_hostname
hostname = change.source_hostname
path = os.path.join(self.jobdir.src_root, hostname, change.project)
try:
repo = git.Repo(path)
@ -1431,12 +1450,9 @@ class WebProxyFixture(fixtures.Fixture):
class ZuulWebFixture(fixtures.Fixture):
def __init__(self,
changes: Dict[str, Dict[str, Change]], config: ConfigParser,
additional_event_queues, upstream_root: str,
poller_events, git_url_with_auth: bool,
add_cleanup: Callable[[Callable[[], None]], None],
test_root: str, info: Optional[WebInfo] = None):
def __init__(self, changes, config, additional_event_queues,
upstream_root, poller_events, git_url_with_auth,
add_cleanup, test_root, info=None):
super(ZuulWebFixture, self).__init__()
self.config = config
self.connections = TestConnectionRegistry(
@ -2123,7 +2139,7 @@ class ZuulTestCase(BaseTestCase):
gerritsource.GerritSource.replication_retry_interval = 0.5
gerritconnection.GerritEventConnector.delay = 0.0
self.changes: Dict[str, Dict[str, Change]] = {}
self.changes = FakeChangeDB()
self.additional_event_queues = []
self.zk_client = ZooKeeperClient.fromConfig(self.config)
@ -3461,6 +3477,14 @@ class ZuulTestCase(BaseTestCase):
request.node_expiration = node_hold_expiration
self.sched_zk_nodepool.storeHoldRequest(request)
def saveChangeDB(self):
path = os.path.join(self.test_root, "changes.data")
self.changes.save(path)
def loadChangeDB(self):
path = os.path.join(self.test_root, "changes.data")
self.changes.load(path)
class AnsibleZuulTestCase(ZuulTestCase):
"""ZuulTestCase but with an actual ansible executor running"""

View File

@ -50,8 +50,8 @@ class FakeGerritChange(object):
status='NEW', upstream_root=None, files={},
parent=None, merge_parents=None, merge_files=None,
topic=None, empty=False):
self.gerrit = gerrit
self.source = gerrit
self.source_hostname = gerrit.canonical_hostname
self.gerrit_baseurl = gerrit.baseurl
self.reported = 0
self.queried = 0
self.patchsets = []
@ -87,7 +87,7 @@ class FakeGerritChange(object):
'subject': subject,
'submitRecords': [],
'hashtags': [],
'url': '%s/%s' % (self.gerrit.baseurl.rstrip('/'), number)}
'url': '%s/%s' % (self.gerrit_baseurl.rstrip('/'), number)}
if topic:
self.data['topic'] = topic

View File

@ -59,8 +59,8 @@ class FakeGithubPullRequest(object):
If the `files` argument is provided it must be a dictionary of
file names OR FakeFile instances -> content.
"""
self.github = github
self.source = github
self.source_hostname = github.canonical_hostname
self.github_server = github.server
self.number = number
self.project = project
self.branch = branch
@ -84,7 +84,8 @@ class FakeGithubPullRequest(object):
self.is_merged = False
self.merge_message = None
self.state = 'open'
self.url = 'https://%s/%s/pull/%s' % (github.server, project, number)
self.url = 'https://%s/%s/pull/%s' % (self.github_server,
project, number)
self.base_sha = base_sha
self.pr_ref = self._createPRRef(base_sha=base_sha)
self._addCommitToRepo(files=files)
@ -143,7 +144,7 @@ class FakeGithubPullRequest(object):
# A PR comment has an additional 'pull_request' key in the issue data
data['issue']['pull_request'] = {
'url': 'http://%s/api/v3/repos/%s/pull/%s' % (
self.github.server, self.project, self.number)
self.github_server, self.project, self.number)
}
return (name, data)
@ -950,29 +951,17 @@ class FakePull(object):
self._fake_pull_request.reviews.append(review)
return review
@property
def head(self):
client = FakeGithubClient(
data=self._fake_pull_request.github.github_data)
repo = client.repo_from_project(self._fake_pull_request.project)
return repo.commit(self._fake_pull_request.head_sha)
def commits(self):
# since we don't know all commits of a pr we just return here a list
# with the head_sha as the only commit
return [self.head]
def as_dict(self):
pr = self._fake_pull_request
connection = pr.github
server = pr.github_server
data = {
'number': pr.number,
'title': pr.subject,
'url': 'https://%s/api/v3/%s/pulls/%s' % (
connection.server, pr.project, pr.number
server, pr.project, pr.number
),
'html_url': 'https://%s/%s/pull/%s' % (
connection.server, pr.project, pr.number
server, pr.project, pr.number
),
'updated_at': pr.updated_at,
'base': {
@ -1144,7 +1133,7 @@ class FakeGithubSession(object):
project, pr_number = match.groups()
project = urllib.parse.unquote(project)
pr = self.client._data.pull_requests[int(pr_number)]
conn = pr.github
conn = self.client._data.fake_github_connection
# record that this got reported
self.client._data.reports.append(
@ -1217,12 +1206,13 @@ class FakeBranchProtectionRule:
self.require_codeowners_review = False
class FakeGithubData(object):
def __init__(self, pull_requests):
class FakeGithubData:
def __init__(self, pull_requests, fake_github_connection):
self.pull_requests = pull_requests
self.repos = {}
self.reports = []
self.fail_check_run_creation = False
self.fake_github_connection = fake_github_connection
def __repr__(self):
return ("pull_requests:%s repos:%s reports:%s "
@ -1503,7 +1493,7 @@ class FakeGithubConnection(githubconnection.GithubConnection):
self.merge_failure = False
self.merge_not_allowed_count = 0
self.github_data = FakeGithubData(changes_db)
self.github_data = FakeGithubData(changes_db, self)
self._github_client_manager.github_data = self.github_data
self.git_url_with_auth = git_url_with_auth

View File

@ -454,8 +454,8 @@ class FakeGitlabMergeRequest(object):
def __init__(self, gitlab, number, project, branch,
subject, upstream_root, files=[], description='',
base_sha=None):
self.gitlab = gitlab
self.source = gitlab
self.source_hostname = gitlab.canonical_hostname
self.gitlab_server = gitlab.server
self.number = number
self.project = project
self.branch = branch
@ -474,7 +474,7 @@ class FakeGitlabMergeRequest(object):
self.labels = []
self.notes = []
self.url = "https://%s/%s/merge_requests/%s" % (
self.gitlab.server, self.project, self.number)
self.gitlab_server, self.project, self.number)
self.base_sha = base_sha
self.approved = False
self.blocking_discussions_resolved = True
@ -619,10 +619,10 @@ class FakeGitlabMergeRequest(object):
self.mergeMergeRequest()
return self.getMergeRequestEvent(action='merge')
def getMergeRequestMergedPushEvent(self, added_files=None,
def getMergeRequestMergedPushEvent(self, gitlab, added_files=None,
removed_files=None,
modified_files=None):
return self.gitlab.getPushEvent(
return gitlab.getPushEvent(
project=self.project,
branch='refs/heads/%s' % self.branch,
before=random_sha1(),

View File

@ -38,8 +38,8 @@ class FakePagurePullRequest(object):
def __init__(self, pagure, number, project, branch,
subject, upstream_root, files={}, number_of_commits=1,
initial_comment=None):
self.pagure = pagure
self.source = pagure
self.source_hostname = pagure.canonical_hostname
self.pagure_server = pagure.server
self.number = number
self.project = project
self.branch = branch
@ -61,7 +61,7 @@ class FakePagurePullRequest(object):
self.upstream_root = upstream_root
self.cached_merge_status = 'MERGE'
self.url = "https://%s/%s/pull-request/%s" % (
self.pagure.server, self.project, self.number)
self.pagure_server, self.project, self.number)
self.is_merged = False
self.pr_ref = self._createPRRef()
self._addCommitInPR(files=files)

View File

@ -60,7 +60,7 @@ class TestGerritCRD(ZuulTestCase):
A.setDependsOn(AM1, 1)
AM1.setDependsOn(AM2, 1)
url = url_fmt.format(baseurl=B.gerrit.baseurl.rstrip('/'),
url = url_fmt.format(baseurl=B.gerrit_baseurl.rstrip('/'),
project=B.project,
change_no=B.number,
change_id=B.data['id'])
@ -281,7 +281,7 @@ class TestGerritCRD(ZuulTestCase):
# A Depends-On: B
url = url_fmt.format(baseurl=B.gerrit.baseurl.rstrip('/'),
url = url_fmt.format(baseurl=B.gerrit_baseurl.rstrip('/'),
project=B.project,
change_no=B.number)
A.data['commitMessage'] = '%s\n\nDepends-On: %s\n' % (

View File

@ -263,7 +263,8 @@ class TestGitlabDriver(ZuulTestCase):
state1 = self.scheds.first.sched.local_layout_state.get("tenant-one")
self.fake_gitlab.emitEvent(A.getMergeRequestMergedEvent())
self.fake_gitlab.emitEvent(A.getMergeRequestMergedPushEvent())
self.fake_gitlab.emitEvent(A.getMergeRequestMergedPushEvent(
self.fake_gitlab))
self.waitUntilSettled()
self.assertEqual(2, len(self.history))
self.assertHistory([{'name': 'project-post-job'},
@ -282,6 +283,7 @@ class TestGitlabDriver(ZuulTestCase):
state1 = self.scheds.first.sched.local_layout_state.get("tenant-one")
self.fake_gitlab.emitEvent(A.getMergeRequestMergedEvent())
self.fake_gitlab.emitEvent(A.getMergeRequestMergedPushEvent(
self.fake_gitlab,
modified_files=['.zuul.yaml']))
self.waitUntilSettled()
self.assertEqual(2, len(self.history))

View File

@ -534,6 +534,12 @@ class TestScheduler(ZuulTestCase):
for build in self.history:
self.assertTrue(build.parameters['zuul']['voting'])
# TODO: remove after we have tests that really exercise this;
# for now this verifies we can save and load the changedb (if
# popuplated by gerrit changes).
self.saveChangeDB()
self.loadChangeDB()
def test_zk_profile(self):
command_socket = self.scheds.first.sched.config.get(
'scheduler', 'command_socket')

View File

@ -35,7 +35,7 @@ class GerritReporter(BaseReporter):
self._create_comment = action.pop('comment', True)
self._submit = action.pop('submit', False)
self._checks_api = action.pop('checks-api', None)
self._labels = action
self._labels = {str(k): v for k, v in action.items()}
def __repr__(self):
return f"<GerritReporter: {self._action}>"