Correct `find_migrate_repo` usage
The `find_migrate_repo` function was never used with an argument although the function interface presupposes an optional argument named `package`. $ grep -r "find_migrate_repo" keystone keystone/common/sql/migration.py: repo_path = find_migrate_repo() keystone/common/sql/migration.py: repo_path = find_migrate_repo() keystone/common/sql/migration.py: repo_path = find_migrate_repo() The wrong function usage makes the function redundant (only one string is needed there, all others are never executed) and forces redundant code lines in other modules. Database migration functions must take an extension module object and handle a migration repository path with `find_migrate_repo` function. ::db_sync: [version, [package]] ::db_version: [package] ::db_version_control: [version, [package]] `find_migrate_repo` raises an error when the migrate repository path doesn't exist. The error should be handled to prevent further unexpected behaviour. Closes-Bug: #1273336 Change-Id: Id62a1859a7e72139d25dfe91d29805660b4b7394
This commit is contained in:
parent
42a4b42f67
commit
7dcd6f03b8
|
@ -27,6 +27,7 @@ from keystone.common.sql import migration
|
|||
from keystone.common import utils
|
||||
from keystone import config
|
||||
from keystone import contrib
|
||||
from keystone import exception
|
||||
from keystone.openstack.common import importutils
|
||||
from keystone import token
|
||||
|
||||
|
@ -71,21 +72,25 @@ class DbSync(BaseApp):
|
|||
if not extension:
|
||||
migration.db_sync(version=version)
|
||||
else:
|
||||
package_name = "%s.%s.migrate_repo" % (contrib.__name__, extension)
|
||||
try:
|
||||
package_name = '.'.join((contrib.__name__, extension))
|
||||
package = importutils.import_module(package_name)
|
||||
repo_path = os.path.abspath(os.path.dirname(package.__file__))
|
||||
except ImportError:
|
||||
print(_("This extension does not provide migrations."))
|
||||
exit(0)
|
||||
raise ImportError(_("%s extension does not exist.")
|
||||
% package_name)
|
||||
try:
|
||||
try:
|
||||
migration.db_version_control(package=package)
|
||||
# Register the repo with the version control API
|
||||
# If it already knows about the repo, it will throw
|
||||
# an exception that we can safely ignore
|
||||
migration.db_version_control(version=None, repo_path=repo_path)
|
||||
except exceptions.DatabaseAlreadyControlledError:
|
||||
pass
|
||||
migration.db_sync(version=version, repo_path=repo_path)
|
||||
except exceptions.DatabaseAlreadyControlledError:
|
||||
pass
|
||||
|
||||
migration.db_sync(version=version, package=package)
|
||||
except exception.MigrationNotProvided as e:
|
||||
print(e)
|
||||
exit(0)
|
||||
|
||||
|
||||
class DbVersion(BaseApp):
|
||||
|
@ -106,14 +111,16 @@ class DbVersion(BaseApp):
|
|||
extension = CONF.command.extension
|
||||
if extension:
|
||||
try:
|
||||
package_name = ("%s.%s.migrate_repo" %
|
||||
(contrib.__name__, extension))
|
||||
package_name = '.'.join((contrib.__name__, extension))
|
||||
package = importutils.import_module(package_name)
|
||||
repo_path = os.path.abspath(os.path.dirname(package.__file__))
|
||||
print(migration.db_version(repo_path))
|
||||
except ImportError:
|
||||
print(_("This extension does not provide migrations."))
|
||||
exit(1)
|
||||
raise ImportError(_("%s extension does not exist.")
|
||||
% package_name)
|
||||
try:
|
||||
print(migration.db_version(package))
|
||||
except exception.MigrationNotProvided as e:
|
||||
print(e)
|
||||
exit(0)
|
||||
else:
|
||||
print(migration.db_version())
|
||||
|
||||
|
|
|
@ -48,30 +48,27 @@ def migrate_repository(version, current_version, repo_path):
|
|||
return result
|
||||
|
||||
|
||||
def db_sync(version=None, repo_path=None):
|
||||
def db_sync(version=None, package=None):
|
||||
if version is not None:
|
||||
try:
|
||||
version = int(version)
|
||||
except ValueError:
|
||||
raise Exception(_('version should be an integer'))
|
||||
if repo_path is None:
|
||||
repo_path = find_migrate_repo()
|
||||
current_version = db_version(repo_path=repo_path)
|
||||
repo_path = find_migrate_repo(package=package)
|
||||
current_version = db_version(package=package)
|
||||
return migrate_repository(version, current_version, repo_path)
|
||||
|
||||
|
||||
def db_version(repo_path=None):
|
||||
if repo_path is None:
|
||||
repo_path = find_migrate_repo()
|
||||
def db_version(package=None):
|
||||
repo_path = find_migrate_repo(package=package)
|
||||
try:
|
||||
return versioning_api.db_version(CONF.database.connection, repo_path)
|
||||
except versioning_exceptions.DatabaseNotControlledError:
|
||||
return db_version_control(0)
|
||||
return db_version_control(version=0, package=package)
|
||||
|
||||
|
||||
def db_version_control(version=None, repo_path=None):
|
||||
if repo_path is None:
|
||||
repo_path = find_migrate_repo()
|
||||
def db_version_control(version=None, package=None):
|
||||
repo_path = find_migrate_repo(package=package)
|
||||
versioning_api.version_control(CONF.database.connection, repo_path,
|
||||
version)
|
||||
return version
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from keystone.common.sql import migration
|
||||
|
@ -32,13 +31,10 @@ class TestExtensionCase(test_v3.RestfulTestCase):
|
|||
self.conf_files.append(
|
||||
tests.dirs.tests('test_associate_project_endpoint_extension.conf'))
|
||||
super(TestExtensionCase, self).setup_database()
|
||||
package_name = "%s.%s.migrate_repo" % (contrib.__name__,
|
||||
self.EXTENSION_NAME)
|
||||
package_name = '.'.join((contrib.__name__, self.EXTENSION_NAME))
|
||||
package = importutils.import_module(package_name)
|
||||
self.repo_path = os.path.abspath(
|
||||
os.path.dirname(package.__file__))
|
||||
migration.db_version_control(version=None, repo_path=self.repo_path)
|
||||
migration.db_sync(version=None, repo_path=self.repo_path)
|
||||
migration.db_version_control(package=package)
|
||||
migration.db_sync(package=package)
|
||||
|
||||
def setUp(self):
|
||||
super(TestExtensionCase, self).setUp()
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
|
||||
|
@ -40,12 +39,10 @@ class FederationTests(test_v3.RestfulTestCase):
|
|||
|
||||
def setup_database(self):
|
||||
super(FederationTests, self).setup_database()
|
||||
package_name = "%s.%s.migrate_repo" % (contrib.__name__,
|
||||
self.EXTENSION_NAME)
|
||||
package_name = '.'.join((contrib.__name__, self.EXTENSION_NAME))
|
||||
package = importutils.import_module(package_name)
|
||||
self.repo_path = os.path.abspath(os.path.dirname(package.__file__))
|
||||
migration.db_version_control(version=None, repo_path=self.repo_path)
|
||||
migration.db_sync(version=None, repo_path=self.repo_path)
|
||||
migration.db_version_control(package=package)
|
||||
migration.db_sync(package=package)
|
||||
|
||||
|
||||
class FederatedIdentityProviderTests(FederationTests):
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# under the License.
|
||||
|
||||
import copy
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from six.moves import urllib
|
||||
|
@ -38,12 +37,10 @@ class OAuth1Tests(test_v3.RestfulTestCase):
|
|||
|
||||
def setup_database(self):
|
||||
super(OAuth1Tests, self).setup_database()
|
||||
package_name = "%s.%s.migrate_repo" % (contrib.__name__,
|
||||
self.EXTENSION_NAME)
|
||||
package_name = '.'.join((contrib.__name__, self.EXTENSION_NAME))
|
||||
package = importutils.import_module(package_name)
|
||||
self.repo_path = os.path.abspath(os.path.dirname(package.__file__))
|
||||
migration.db_version_control(version=None, repo_path=self.repo_path)
|
||||
migration.db_sync(version=None, repo_path=self.repo_path)
|
||||
migration.db_version_control(package=package)
|
||||
migration.db_sync(package=package)
|
||||
|
||||
def setUp(self):
|
||||
super(OAuth1Tests, self).setUp()
|
||||
|
|
Loading…
Reference in New Issue