diff --git a/sqlalchemy_utils/types/timezone.py b/sqlalchemy_utils/types/timezone.py index 239712c..b474066 100644 --- a/sqlalchemy_utils/types/timezone.py +++ b/sqlalchemy_utils/types/timezone.py @@ -52,9 +52,10 @@ class TimezoneType(types.TypeDecorator, ScalarCoercible): elif backend == 'pytz': try: - from pytz import tzfile, timezone + from pytz import timezone + from pytz.tzinfo import BaseTzInfo - self.python_type = tzfile.DstTzInfo + self.python_type = BaseTzInfo self._to = timezone self._from = six.text_type @@ -71,13 +72,11 @@ class TimezoneType(types.TypeDecorator, ScalarCoercible): ) def _coerce(self, value): - if value and not isinstance(value, self.python_type): + if value is not None and not isinstance(value, self.python_type): obj = self._to(value) if obj is None: raise ValueError("unknown time zone '%s'" % value) - return obj - return value def process_bind_param(self, value, dialect): diff --git a/tests/types/test_timezone.py b/tests/types/test_timezone.py index 502288f..6124688 100644 --- a/tests/types/test_timezone.py +++ b/tests/types/test_timezone.py @@ -1,7 +1,9 @@ import pytest +import pytz import sqlalchemy as sa +from dateutil.zoneinfo import getzoneinfofile_stream, tzfile, ZoneInfoFile -from sqlalchemy_utils.types import timezone +from sqlalchemy_utils.types import timezone, TimezoneType @pytest.fixture @@ -46,3 +48,48 @@ class TestTimezoneType(object): assert visitor_dateutil is not None assert visitor_pytz is not None + + +TIMEZONE_BACKENDS = ['dateutil', 'pytz'] + + +def test_can_coerce_pytz_DstTzInfo(): + tzcol = TimezoneType(backend='pytz') + tz = pytz.timezone('America/New_York') + assert isinstance(tz, pytz.tzfile.DstTzInfo) + assert tzcol._coerce(tz) is tz + + +def test_can_coerce_pytz_StaticTzInfo(): + tzcol = TimezoneType(backend='pytz') + tz = pytz.timezone('Pacific/Truk') + assert isinstance(tz, pytz.tzfile.StaticTzInfo) + assert tzcol._coerce(tz) is tz + + +@pytest.mark.parametrize('zone', pytz.all_timezones) +def test_can_coerce_string_for_pytz_zone(zone): + tzcol = TimezoneType(backend='pytz') + assert tzcol._coerce(zone).zone == zone + + +@pytest.mark.parametrize( + 'zone', ZoneInfoFile(getzoneinfofile_stream()).zones.keys()) +def test_can_coerce_string_for_dateutil_zone(zone): + tzcol = TimezoneType(backend='dateutil') + assert isinstance(tzcol._coerce(zone), tzfile) + + +@pytest.mark.parametrize('backend', TIMEZONE_BACKENDS) +def test_can_coerce_and_raise_UnknownTimeZoneError_or_ValueError(backend): + tzcol = TimezoneType(backend=backend) + with pytest.raises((ValueError, pytz.exceptions.UnknownTimeZoneError)): + tzcol._coerce('SolarSystem/Mars') + with pytest.raises((ValueError, pytz.exceptions.UnknownTimeZoneError)): + tzcol._coerce('') + + +@pytest.mark.parametrize('backend', TIMEZONE_BACKENDS) +def test_can_coerce_None(backend): + tzcol = TimezoneType(backend=backend) + assert tzcol._coerce(None) is None