diff --git a/monascastatsd/client.py b/monascastatsd/client.py index aeb03f3..dc5e611 100644 --- a/monascastatsd/client.py +++ b/monascastatsd/client.py @@ -23,7 +23,8 @@ from monascastatsd.timer import Timer class Client(object): - def __init__(self, name=None, connection=None, max_buffer_size=50, dimensions=None): + def __init__(self, name=None, host='localhost', port=8125, + connection=None, max_buffer_size=50, dimensions=None): """Initialize a Client object. >>> monascastatsd = MonascaStatsd() @@ -35,10 +36,11 @@ class Client(object): :param max_buffer_size: Maximum number of metric to buffer before sending to the server if sending metrics in batch """ + if connection is None: - self.connection = Connection(host='localhost', - port=8125, - max_buffer_size=50) + self.connection = Connection(host=host, + port=port, + max_buffer_size=max_buffer_size) else: self.connection = connection self._dimensions = dimensions diff --git a/test-requirements.txt b/test-requirements.txt index d170128..82a2134 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -4,3 +4,4 @@ hacking<0.12,>=0.11.0 # Apache-2.0 nose # LGPL nosexcover # BSD +mock>=2.0 # BSD diff --git a/tests/test_monascastatsd.py b/tests/test_monascastatsd.py index d5abb0b..cd28152 100644 --- a/tests/test_monascastatsd.py +++ b/tests/test_monascastatsd.py @@ -23,6 +23,8 @@ import unittest import monascastatsd as mstatsd +import mock + class FakeSocket(object): @@ -60,6 +62,20 @@ class TestMonascaStatsd(unittest.TestCase): def recv(self, metric_obj): return metric_obj._connection.socket.recv() + @mock.patch('monascastatsd.client.Connection') + def test_client_set_host_port(self, connection_mock): + mstatsd.Client(host='foo.bar', port=5213) + connection_mock.assert_called_once_with(host='foo.bar', + port=5213, + max_buffer_size=50) + + @mock.patch('monascastatsd.client.Connection') + def test_client_default_host_port(self, connection_mock): + mstatsd.Client() + connection_mock.assert_called_once_with(host='localhost', + port=8125, + max_buffer_size=50) + def test_counter(self): counter = self.client.get_counter(name='page.views')