Correct `find_migrate_repo` usage

The `find_migrate_repo` function was never used with an argument although the
function interface presupposes an optional argument named
`package`.

$ grep -r "find_migrate_repo" keystone

    keystone/common/sql/migration.py: repo_path = find_migrate_repo()
    keystone/common/sql/migration.py: repo_path = find_migrate_repo()
    keystone/common/sql/migration.py: repo_path = find_migrate_repo()

The wrong function usage makes the function redundant (only one
string is needed there, all others are never executed) and forces redundant
code lines in other modules.
Database migration functions must take an extension module object and handle
a migration repository path with `find_migrate_repo` function.

::db_sync: [version, [package]]
::db_version: [package]
::db_version_control: [version, [package]]

`find_migrate_repo` raises an error when the migrate repository path doesn't
exist. The error should be handled to prevent further unexpected behaviour.

Closes-Bug: #1273336
Change-Id: Id62a1859a7e72139d25dfe91d29805660b4b7394
This commit is contained in:
Ilya Pekelny 2014-01-27 19:40:54 +02:00
parent 42a4b42f67
commit 7dcd6f03b8
5 changed files with 38 additions and 44 deletions

View File

@ -27,6 +27,7 @@ from keystone.common.sql import migration
from keystone.common import utils
from keystone import config
from keystone import contrib
from keystone import exception
from keystone.openstack.common import importutils
from keystone import token
@ -71,21 +72,25 @@ class DbSync(BaseApp):
if not extension:
migration.db_sync(version=version)
else:
package_name = "%s.%s.migrate_repo" % (contrib.__name__, extension)
try:
package_name = '.'.join((contrib.__name__, extension))
package = importutils.import_module(package_name)
repo_path = os.path.abspath(os.path.dirname(package.__file__))
except ImportError:
print(_("This extension does not provide migrations."))
exit(0)
raise ImportError(_("%s extension does not exist.")
% package_name)
try:
try:
migration.db_version_control(package=package)
# Register the repo with the version control API
# If it already knows about the repo, it will throw
# an exception that we can safely ignore
migration.db_version_control(version=None, repo_path=repo_path)
except exceptions.DatabaseAlreadyControlledError:
pass
migration.db_sync(version=version, repo_path=repo_path)
except exceptions.DatabaseAlreadyControlledError:
pass
migration.db_sync(version=version, package=package)
except exception.MigrationNotProvided as e:
print(e)
exit(0)
class DbVersion(BaseApp):
@ -106,14 +111,16 @@ class DbVersion(BaseApp):
extension = CONF.command.extension
if extension:
try:
package_name = ("%s.%s.migrate_repo" %
(contrib.__name__, extension))
package_name = '.'.join((contrib.__name__, extension))
package = importutils.import_module(package_name)
repo_path = os.path.abspath(os.path.dirname(package.__file__))
print(migration.db_version(repo_path))
except ImportError:
print(_("This extension does not provide migrations."))
exit(1)
raise ImportError(_("%s extension does not exist.")
% package_name)
try:
print(migration.db_version(package))
except exception.MigrationNotProvided as e:
print(e)
exit(0)
else:
print(migration.db_version())

View File

@ -48,30 +48,27 @@ def migrate_repository(version, current_version, repo_path):
return result
def db_sync(version=None, repo_path=None):
def db_sync(version=None, package=None):
if version is not None:
try:
version = int(version)
except ValueError:
raise Exception(_('version should be an integer'))
if repo_path is None:
repo_path = find_migrate_repo()
current_version = db_version(repo_path=repo_path)
repo_path = find_migrate_repo(package=package)
current_version = db_version(package=package)
return migrate_repository(version, current_version, repo_path)
def db_version(repo_path=None):
if repo_path is None:
repo_path = find_migrate_repo()
def db_version(package=None):
repo_path = find_migrate_repo(package=package)
try:
return versioning_api.db_version(CONF.database.connection, repo_path)
except versioning_exceptions.DatabaseNotControlledError:
return db_version_control(0)
return db_version_control(version=0, package=package)
def db_version_control(version=None, repo_path=None):
if repo_path is None:
repo_path = find_migrate_repo()
def db_version_control(version=None, package=None):
repo_path = find_migrate_repo(package=package)
versioning_api.version_control(CONF.database.connection, repo_path,
version)
return version

View File

@ -12,7 +12,6 @@
# License for the specific language governing permissions and limitations
# under the License.
import os
import uuid
from keystone.common.sql import migration
@ -32,13 +31,10 @@ class TestExtensionCase(test_v3.RestfulTestCase):
self.conf_files.append(
tests.dirs.tests('test_associate_project_endpoint_extension.conf'))
super(TestExtensionCase, self).setup_database()
package_name = "%s.%s.migrate_repo" % (contrib.__name__,
self.EXTENSION_NAME)
package_name = '.'.join((contrib.__name__, self.EXTENSION_NAME))
package = importutils.import_module(package_name)
self.repo_path = os.path.abspath(
os.path.dirname(package.__file__))
migration.db_version_control(version=None, repo_path=self.repo_path)
migration.db_sync(version=None, repo_path=self.repo_path)
migration.db_version_control(package=package)
migration.db_sync(package=package)
def setUp(self):
super(TestExtensionCase, self).setUp()

View File

@ -10,7 +10,6 @@
# License for the specific language governing permissions and limitations
# under the License.
import os
import random
import uuid
@ -40,12 +39,10 @@ class FederationTests(test_v3.RestfulTestCase):
def setup_database(self):
super(FederationTests, self).setup_database()
package_name = "%s.%s.migrate_repo" % (contrib.__name__,
self.EXTENSION_NAME)
package_name = '.'.join((contrib.__name__, self.EXTENSION_NAME))
package = importutils.import_module(package_name)
self.repo_path = os.path.abspath(os.path.dirname(package.__file__))
migration.db_version_control(version=None, repo_path=self.repo_path)
migration.db_sync(version=None, repo_path=self.repo_path)
migration.db_version_control(package=package)
migration.db_sync(package=package)
class FederatedIdentityProviderTests(FederationTests):

View File

@ -13,7 +13,6 @@
# under the License.
import copy
import os
import uuid
from six.moves import urllib
@ -38,12 +37,10 @@ class OAuth1Tests(test_v3.RestfulTestCase):
def setup_database(self):
super(OAuth1Tests, self).setup_database()
package_name = "%s.%s.migrate_repo" % (contrib.__name__,
self.EXTENSION_NAME)
package_name = '.'.join((contrib.__name__, self.EXTENSION_NAME))
package = importutils.import_module(package_name)
self.repo_path = os.path.abspath(os.path.dirname(package.__file__))
migration.db_version_control(version=None, repo_path=self.repo_path)
migration.db_sync(version=None, repo_path=self.repo_path)
migration.db_version_control(package=package)
migration.db_sync(package=package)
def setUp(self):
super(OAuth1Tests, self).setUp()