diff --git a/mistral_lib/actions/context.py b/mistral_lib/actions/context.py index c859c87..974919a 100644 --- a/mistral_lib/actions/context.py +++ b/mistral_lib/actions/context.py @@ -12,8 +12,10 @@ import warnings +from mistral_lib import serialization -class ActionContext(object): + +class ActionContext(serialization.MistralSerializable): def __init__(self, security_ctx=None, execution_ctx=None): self.security = security_ctx @@ -86,3 +88,19 @@ class ExecutionContext(object): def task_id(self): self._deprecate_task_id_warning() return self.task_execution_id + + +class ActionContextSerializer(serialization.DictBasedSerializer): + def serialize_to_dict(self, entity): + return { + 'security': vars(entity.security), + 'execution': vars(entity.execution), + } + + def deserialize_from_dict(self, entity_dict): + return ActionContext( + security_ctx=SecurityContext(**entity_dict['security']), + execution_ctx=ExecutionContext(**entity_dict['execution']) + ) + +serialization.register_serializer(ActionContext, ActionContextSerializer()) diff --git a/mistral_lib/tests/actions/test_context.py b/mistral_lib/tests/actions/test_context.py index 8eb97f6..95548c3 100644 --- a/mistral_lib/tests/actions/test_context.py +++ b/mistral_lib/tests/actions/test_context.py @@ -69,3 +69,27 @@ class TestActionsBase(tests_base.TestCase): old = getattr(ctx, deprecated) new = getattr(ctx.security, deprecated) self.assertEqual(old, new) + + +class TestActionContextSerializer(tests_base.TestCase): + + def test_serialization(self): + + ctx = _fake_context() + serialiser = context.ActionContextSerializer() + dict_ctx = serialiser.serialize_to_dict(ctx) + + self.assertEqual(dict_ctx['security'], vars(ctx.security)) + self.assertEqual(dict_ctx['execution'], vars(ctx.execution)) + + def test_deserialization(self): + + ctx = _fake_context() + serialiser = context.ActionContextSerializer() + dict_ctx = serialiser.serialize_to_dict(ctx) + ctx_2 = serialiser.deserialize_from_dict(dict_ctx) + + self.assertEqual(ctx.security.auth_uri, ctx_2.security.auth_uri) + self.assertEqual( + ctx.execution.workflow_name, + ctx_2.execution.workflow_name)