From 9449f7e665d5a605c62f299c41613e7c9bc437ff Mon Sep 17 00:00:00 2001 From: Robert Date: Fri, 22 Oct 2021 12:38:25 +0200 Subject: [PATCH 1/3] feat: Add user view set --- userausfall/rest_api/serializers.py | 12 ++++-- userausfall/rest_api/tests/__init__.py | 5 ++- userausfall/rest_api/tests/auth.py | 26 +----------- userausfall/rest_api/tests/trust_bridges.py | 9 ++++- userausfall/rest_api/tests/userausfall.py | 5 +++ userausfall/rest_api/tests/users.py | 44 +++++++++++++++++++++ userausfall/rest_api/urls.py | 5 ++- userausfall/rest_api/views.py | 9 ++++- 8 files changed, 80 insertions(+), 35 deletions(-) create mode 100644 userausfall/rest_api/tests/users.py diff --git a/userausfall/rest_api/serializers.py b/userausfall/rest_api/serializers.py index 94225d1..5392b10 100644 --- a/userausfall/rest_api/serializers.py +++ b/userausfall/rest_api/serializers.py @@ -7,17 +7,23 @@ from userausfall.views import get_authenticated_user class UserSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = User - fields = ["username"] + fields = ["id", "trust_bridge", "url"] + + +class TrustBridgeUserSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = User + fields = ["id", "url", "username"] # the UniqueValidator for username prevents successful validation for existing users extra_kwargs = {"username": {"validators": []}} class TrustBridgeSerializer(serializers.HyperlinkedModelSerializer): - trust_giver = UserSerializer() + trust_giver = TrustBridgeUserSerializer() class Meta: model = TrustBridge - fields = ["is_trusted", "trust_giver"] + fields = ["id", "is_trusted", "trust_giver", "url"] read_only_fields = ["is_trusted"] def create(self, validated_data): diff --git a/userausfall/rest_api/tests/__init__.py b/userausfall/rest_api/tests/__init__.py index f7cf7fe..abbee82 100644 --- a/userausfall/rest_api/tests/__init__.py +++ b/userausfall/rest_api/tests/__init__.py @@ -1,2 +1,3 @@ -from .auth import * # noqa: F401, F403 -from .trust_bridges import * # noqa: F401, F403 +from .auth import AuthenticationTestCase # noqa: F401, F403 +from .trust_bridges import TrustBridgeTestCase # noqa: F401, F403 +from .users import UserTestCase # noqa: F401, F403 diff --git a/userausfall/rest_api/tests/auth.py b/userausfall/rest_api/tests/auth.py index 233f23e..b1ca2a6 100644 --- a/userausfall/rest_api/tests/auth.py +++ b/userausfall/rest_api/tests/auth.py @@ -2,32 +2,10 @@ from rest_framework import status from userausfall.models import User from userausfall.rest_api.tests.userausfall import UserausfallAPITestCase +from userausfall.rest_api.tests.users import UserMixin -class UserMixin: - user: User - password: str - username: str - - def create_user(self): - self.username = f"test{User.objects.count()}" - self.password = "test12345" - self.user = User.objects.create_user(self.username, self.password) - return self.user - - def ensure_user_exists(self): - if not hasattr(self, "user"): - self.create_user() - - def authenticate_user(self): - self.ensure_user_exists() - if hasattr(self.client, "force_authentication"): - self.client.force_authenticate(user=self.user) - else: - self.client.force_login(user=self.user) - - -class UserTestCase(UserMixin, UserausfallAPITestCase): +class AuthenticationTestCase(UserMixin, UserausfallAPITestCase): base_url = "/api/auth" def test_signup(self): diff --git a/userausfall/rest_api/tests/trust_bridges.py b/userausfall/rest_api/tests/trust_bridges.py index 8bb4678..0f78599 100644 --- a/userausfall/rest_api/tests/trust_bridges.py +++ b/userausfall/rest_api/tests/trust_bridges.py @@ -2,7 +2,8 @@ from django.core import mail from rest_framework import status from userausfall.models import TrustBridge -from userausfall.rest_api.tests import UserausfallAPITestCase, UserMixin +from userausfall.rest_api.tests.userausfall import get_url, UserausfallAPITestCase +from userausfall.rest_api.tests.users import UserMixin class TrustBridgeTestCase(UserMixin, UserausfallAPITestCase): @@ -37,13 +38,17 @@ class TrustBridgeTestCase(UserMixin, UserausfallAPITestCase): self.authenticate_user() response = self.client.get(self.get_api_url(url, pk=self.trust_bridge.pk)) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual( + self.assertDictEqual( response.data, { + "id": self.trust_bridge.id, "is_trusted": False, "trust_giver": { + "id": self.trust_giver.id, "username": self.trust_giver.username, + "url": get_url(response, "user", self.trust_giver), }, + "url": get_url(response, "trustbridge", self.trust_bridge), }, ) diff --git a/userausfall/rest_api/tests/userausfall.py b/userausfall/rest_api/tests/userausfall.py index 30fd6ea..32e335c 100644 --- a/userausfall/rest_api/tests/userausfall.py +++ b/userausfall/rest_api/tests/userausfall.py @@ -1,6 +1,11 @@ +from rest_framework.reverse import reverse from rest_framework.test import APITestCase +def get_url(response, basename, instance): + return reverse(f"{basename}-detail", [instance.pk], request=response.wsgi_request) + + class UserausfallAPITestCase(APITestCase): base_url = "/api" diff --git a/userausfall/rest_api/tests/users.py b/userausfall/rest_api/tests/users.py new file mode 100644 index 0000000..6df5a5e --- /dev/null +++ b/userausfall/rest_api/tests/users.py @@ -0,0 +1,44 @@ +from rest_framework import status + +from userausfall.models import User +from userausfall.rest_api.tests.userausfall import get_url, UserausfallAPITestCase + + +class UserMixin: + user: User + password: str + username: str + + def create_user(self): + self.username = f"test{User.objects.count()}" + self.password = "test12345" + self.user = User.objects.create_user(self.username, self.password) + return self.user + + def ensure_user_exists(self): + if not hasattr(self, "user"): + self.create_user() + + def authenticate_user(self): + self.ensure_user_exists() + if hasattr(self.client, "force_authentication"): + self.client.force_authenticate(user=self.user) + else: + self.client.force_login(user=self.user) + + +class UserTestCase(UserMixin, UserausfallAPITestCase): + def test_retrieve_user(self): + """Retrieve the details of the current user.""" + url = "/users/{pk}/" + self.authenticate_user() + response = self.client.get(self.get_api_url(url, pk=self.user.pk)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual( + response.data, + { + "id": self.user.id, + "trust_bridge": None, + "url": get_url(response, "user", self.user), + }, + ) diff --git a/userausfall/rest_api/urls.py b/userausfall/rest_api/urls.py index 7545a24..5128234 100644 --- a/userausfall/rest_api/urls.py +++ b/userausfall/rest_api/urls.py @@ -2,10 +2,11 @@ from django.urls import include, path from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView from rest_framework import routers -from userausfall.rest_api.views import TrustBridgeViewSet +from userausfall.rest_api.views import TrustBridgeViewSet, UserViewSet router = routers.SimpleRouter() -router.register(r"trust-bridges", TrustBridgeViewSet, "trust-bridge") +router.register(r"trust-bridges", TrustBridgeViewSet, "trustbridge") +router.register(r"users", UserViewSet, "user") urlpatterns = [ path("", include(router.urls)), diff --git a/userausfall/rest_api/views.py b/userausfall/rest_api/views.py index f078fa6..7ddb5c4 100644 --- a/userausfall/rest_api/views.py +++ b/userausfall/rest_api/views.py @@ -4,7 +4,7 @@ from rest_framework.decorators import action from rest_framework.response import Response from userausfall.models import MissingUserAttribute, PasswordMismatch, TrustBridge, User -from userausfall.rest_api.serializers import TrustBridgeSerializer +from userausfall.rest_api.serializers import TrustBridgeSerializer, UserSerializer from userausfall.views import get_authenticated_user @@ -18,7 +18,12 @@ class TrustBridgeViewSet( return self.queryset.filter(trust_taker=get_authenticated_user(self.request)) -class UserViewSet(viewsets.GenericViewSet): +class UserViewSet(mixins.RetrieveModelMixin, viewsets.GenericViewSet): + serializer_class = UserSerializer + + def get_queryset(self): + return User.objects.filter(pk=get_authenticated_user(self.request).pk) + @action(detail=True, methods=["post"]) def activate(self, request, pk=None): """Create the corresponding LDAP account.""" From bba1d7c8aa1a2de8645ada6764e2ef4821d49d92 Mon Sep 17 00:00:00 2001 From: Robert Date: Tue, 26 Oct 2021 11:11:24 +0200 Subject: [PATCH 2/3] test: Add ldap tests --- userausfall/ldap.py | 78 ++++++++++++++++++++++++-------------------- userausfall/tests.py | 27 +++++++++++++++ 2 files changed, 69 insertions(+), 36 deletions(-) create mode 100644 userausfall/tests.py diff --git a/userausfall/ldap.py b/userausfall/ldap.py index 7e2e0b9..e5cc5be 100644 --- a/userausfall/ldap.py +++ b/userausfall/ldap.py @@ -1,44 +1,50 @@ from django.conf import settings -from ldap3 import Connection, Server, SYNC +from ldap3 import Connection, MOCK_SYNC, SAFE_SYNC, Server -def create_account(username, raw_password): - connection = _get_connection() - is_success = connection.add( - f"cn={username},dc=local", - ["simpleSecurityObject", "organizationalRole"], - {"userPassword": raw_password}, - ) - return is_success +class LDAPManager: + def __init__(self): + if not getattr(settings, "USERAUSFALL_LDAP_IS_TEST", False): + self.connection = self._get_connection() + else: + self.connection = self._get_test_connection() + def create_account(self, username, raw_password): + is_success = self.connection.add( + f"cn={username},dc=local", + ["simpleSecurityObject", "organizationalRole"], + {"userPassword": raw_password}, + ) + return is_success -def account_exists(username): - connection = _get_connection() - exists = connection.search(f"cn={username},dc=local", "(objectclass=simpleSecurityObject)") - return exists + def has_account(self, username): + exists = self.connection.search(f"cn={username},dc=local", "(objectclass=simpleSecurityObject)") + return exists + def is_valid_account_data(self, username, raw_password): + is_valid = self.connection.search( + f"cn={username},dc=local", + "(objectclass=simpleSecurityObject)", + attributes=["userPassword"], + ) + if is_valid: + is_valid = self.connection.entries[0]["userPassword"].value == raw_password + return is_valid -def is_valid_account_data(username, raw_password): - connection = _get_connection() - is_valid = connection.search( - f"cn={username},dc=local", - "(objectclass=simpleSecurityObject)", - attributes=["userPassword"], - ) - if is_valid: - is_valid = connection.entries[0]["userPassword"].value == raw_password - return is_valid + def _get_connection(self): + server = Server("localhost") + connection = Connection( + server, + settings.USERAUSFALL_LDAP["ADMIN_USER_DN"], + settings.USERAUSFALL_LDAP["ADMIN_USER_PASSWORD"], + client_strategy=SAFE_SYNC, + auto_bind=True, + ) + return connection - -def _get_connection(): - server = Server("localhost") - # The SAFE_SYNC client strategy doesn't seem to be present in Buster version of ldap3. We might want to use it as - # soon as it is available (multithreading). - connection = Connection( - server, - settings.USERAUSFALL_LDAP["ADMIN_USER_DN"], - settings.USERAUSFALL_LDAP["ADMIN_USER_PASSWORD"], - client_strategy=SYNC, - auto_bind=True, - ) - return connection + def _get_test_connection(self): + server = Server("testserver") + connection = Connection(server, user="cn=admin,dc=local", password="admin_secret", client_strategy=MOCK_SYNC) + connection.strategy.add_entry("cn=admin,dc=local", {"userPassword": "admin_secret"}) + connection.bind() + return connection diff --git a/userausfall/tests.py b/userausfall/tests.py new file mode 100644 index 0000000..269c892 --- /dev/null +++ b/userausfall/tests.py @@ -0,0 +1,27 @@ +from django.test import override_settings, TestCase + +from userausfall.ldap import LDAPManager + + +@override_settings(USERAUSFALL_LDAP_IS_TEST=True) +class LDAPTestCase(TestCase): + def setUp(self) -> None: + self.username = "test" + self.password = "test12345" + self.ldap = LDAPManager() + + def test_create_has_account(self): + exists = self.ldap.has_account(self.username) + self.assertFalse(exists) + is_created = self.ldap.create_account(self.username, self.password) + self.assertTrue(is_created) + exists = self.ldap.has_account(self.username) + self.assertTrue(exists) + + def test_create_account_data(self): + is_valid = self.ldap.is_valid_account_data(self.username, self.password) + self.assertFalse(is_valid) + is_created = self.ldap.create_account(self.username, self.password) + self.assertTrue(is_created) + is_valid = self.ldap.is_valid_account_data(self.username, self.password) + self.assertTrue(is_valid) From 8ca0f667775b4537e5f53962b2f5c0cffb87a0d3 Mon Sep 17 00:00:00 2001 From: Robert Date: Tue, 26 Oct 2021 12:35:22 +0200 Subject: [PATCH 3/3] feat: Add endpoint to activate users --- userausfall/ldap.py | 22 ++++-- userausfall/models.py | 31 ++++----- userausfall/rest_api/serializers.py | 16 +++++ userausfall/rest_api/tests/auth.py | 24 ++++++- userausfall/rest_api/tests/trust_bridges.py | 17 +++-- userausfall/rest_api/tests/users.py | 74 ++++++++++++++------- userausfall/rest_api/views.py | 40 +++++------ userausfall/tests.py | 3 + 8 files changed, 159 insertions(+), 68 deletions(-) diff --git a/userausfall/ldap.py b/userausfall/ldap.py index e5cc5be..4460466 100644 --- a/userausfall/ldap.py +++ b/userausfall/ldap.py @@ -1,6 +1,8 @@ from django.conf import settings from ldap3 import Connection, MOCK_SYNC, SAFE_SYNC, Server +_test_connection = None + class LDAPManager: def __init__(self): @@ -31,6 +33,12 @@ class LDAPManager: is_valid = self.connection.entries[0]["userPassword"].value == raw_password return is_valid + def drop_test_connection(self): + global _test_connection + self.connection.unbind() + self.connection = None + _test_connection = None + def _get_connection(self): server = Server("localhost") connection = Connection( @@ -43,8 +51,12 @@ class LDAPManager: return connection def _get_test_connection(self): - server = Server("testserver") - connection = Connection(server, user="cn=admin,dc=local", password="admin_secret", client_strategy=MOCK_SYNC) - connection.strategy.add_entry("cn=admin,dc=local", {"userPassword": "admin_secret"}) - connection.bind() - return connection + global _test_connection + if _test_connection is None: + server = Server("testserver") + _test_connection = Connection( + server, user="cn=admin,dc=local", password="admin_secret", client_strategy=MOCK_SYNC + ) + _test_connection.strategy.add_entry("cn=admin,dc=local", {"userPassword": "admin_secret"}) + _test_connection.bind() + return _test_connection diff --git a/userausfall/models.py b/userausfall/models.py index 42ebdd3..e8725db 100644 --- a/userausfall/models.py +++ b/userausfall/models.py @@ -5,18 +5,13 @@ from django.contrib.auth.validators import UnicodeUsernameValidator from django.core.mail import send_mail from django.db import models from django.utils import timezone +from django.utils.functional import cached_property from django.utils.translation import gettext_lazy as _ from djeveric.fields import ConfirmationField from djeveric.models import ConfirmableModelMixin -from userausfall import ldap from userausfall.emails import TrustBridgeConfirmationEmail - - -class MissingUserAttribute(Exception): - """The user object is missing a required attribute.""" - - pass +from userausfall.ldap import LDAPManager class PasswordMismatch(Exception): @@ -89,22 +84,28 @@ class User(PermissionsMixin, AbstractBaseUser): super().clean() self.email = self.__class__.objects.normalize_email(self.email) + def create_ldap_account(self, raw_password): + """Create the LDAP account which corresponds to this user.""" + if not self.check_password(raw_password): + raise PasswordMismatch("The given password does not match the user's password.") + return self._ldap.create_account(self.username, raw_password) + def email_user(self, subject, message, from_email=None, **kwargs): """Send an email to this user.""" send_mail(subject, message, from_email, [self.email], **kwargs) - def create_ldap_account(self, raw_password): - """Create the LDAP account which corresponds to this user.""" - if not self.username: - raise MissingUserAttribute("User is missing a username.") - if not self.check_password(raw_password): - raise PasswordMismatch("The given password does not match the user's password.") - return ldap.create_account(self.username, raw_password) - def get_primary_email(self): """Returns the primary email address for this user.""" return f"{self.username}@{settings.USERAUSFALL['PRIMARY_EMAIL_DOMAIN']}" + def has_ldap_account(self): + """Returns True if an ldap account exists for the user's username.""" + return self._ldap.has_account(self.username) + + @cached_property + def _ldap(self): + return LDAPManager() + class TrustBridge(ConfirmableModelMixin, models.Model): is_trusted = ConfirmationField(email_class=TrustBridgeConfirmationEmail) diff --git a/userausfall/rest_api/serializers.py b/userausfall/rest_api/serializers.py index 5392b10..6f15bfe 100644 --- a/userausfall/rest_api/serializers.py +++ b/userausfall/rest_api/serializers.py @@ -10,6 +10,22 @@ class UserSerializer(serializers.HyperlinkedModelSerializer): fields = ["id", "trust_bridge", "url"] +class ActivateUserSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = User + fields = ["id", "password", "url"] + + def validate(self, data): + if not hasattr(self.instance, "trust_bridge") or not self.instance.trust_bridge.is_trusted: + raise serializers.ValidationError("User has no trusted trust bridge") + return data + + def validate_password(self, value): + if not self.instance.check_password(value): + raise serializers.ValidationError("Password does not match the user's password") + return value + + class TrustBridgeUserSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = User diff --git a/userausfall/rest_api/tests/auth.py b/userausfall/rest_api/tests/auth.py index b1ca2a6..9ac7b38 100644 --- a/userausfall/rest_api/tests/auth.py +++ b/userausfall/rest_api/tests/auth.py @@ -2,7 +2,29 @@ from rest_framework import status from userausfall.models import User from userausfall.rest_api.tests.userausfall import UserausfallAPITestCase -from userausfall.rest_api.tests.users import UserMixin + + +class UserMixin: + user: User + password: str + username: str + + def create_user(self): + self.username = f"test{User.objects.count()}" + self.password = "test12345" + self.user = User.objects.create_user(self.username, self.password) + return self.user + + def ensure_user_exists(self): + if not hasattr(self, "user"): + self.create_user() + + def authenticate_user(self): + self.ensure_user_exists() + if hasattr(self.client, "force_authentication"): + self.client.force_authenticate(user=self.user) + else: + self.client.force_login(user=self.user) class AuthenticationTestCase(UserMixin, UserausfallAPITestCase): diff --git a/userausfall/rest_api/tests/trust_bridges.py b/userausfall/rest_api/tests/trust_bridges.py index 0f78599..71d3a42 100644 --- a/userausfall/rest_api/tests/trust_bridges.py +++ b/userausfall/rest_api/tests/trust_bridges.py @@ -1,18 +1,25 @@ from django.core import mail from rest_framework import status -from userausfall.models import TrustBridge +from userausfall.models import TrustBridge, User +from userausfall.rest_api.tests.auth import UserMixin from userausfall.rest_api.tests.userausfall import get_url, UserausfallAPITestCase -from userausfall.rest_api.tests.users import UserMixin -class TrustBridgeTestCase(UserMixin, UserausfallAPITestCase): - def create_trust_bridge(self): +class TrustBridgeMixin(UserMixin): + trust_bridge: TrustBridge + trust_giver: User + + def create_trust_bridge(self, is_trusted=False): self.trust_giver = self.create_user() self.create_user() - self.trust_bridge = TrustBridge.objects.create(trust_taker=self.user, trust_giver=self.trust_giver) + self.trust_bridge = TrustBridge.objects.create( + trust_taker=self.user, trust_giver=self.trust_giver, is_trusted=is_trusted + ) return self.trust_bridge + +class TrustBridgeTestCase(TrustBridgeMixin, UserausfallAPITestCase): def test_create_trust_bridge(self): """Create a trust bridge for the current user.""" url = "/trust-bridges/" diff --git a/userausfall/rest_api/tests/users.py b/userausfall/rest_api/tests/users.py index 6df5a5e..8f1b9ee 100644 --- a/userausfall/rest_api/tests/users.py +++ b/userausfall/rest_api/tests/users.py @@ -1,33 +1,20 @@ +from django.test import override_settings from rest_framework import status +from rest_framework.exceptions import ErrorDetail -from userausfall.models import User +from userausfall.ldap import LDAPManager +from userausfall.rest_api.tests.trust_bridges import TrustBridgeMixin from userausfall.rest_api.tests.userausfall import get_url, UserausfallAPITestCase -class UserMixin: - user: User - password: str - username: str +@override_settings(USERAUSFALL_LDAP_IS_TEST=True) +class UserTestCase(TrustBridgeMixin, UserausfallAPITestCase): + def setUp(self) -> None: + self.ldap = LDAPManager() - def create_user(self): - self.username = f"test{User.objects.count()}" - self.password = "test12345" - self.user = User.objects.create_user(self.username, self.password) - return self.user + def tearDown(self) -> None: + self.ldap.drop_test_connection() - def ensure_user_exists(self): - if not hasattr(self, "user"): - self.create_user() - - def authenticate_user(self): - self.ensure_user_exists() - if hasattr(self.client, "force_authentication"): - self.client.force_authenticate(user=self.user) - else: - self.client.force_login(user=self.user) - - -class UserTestCase(UserMixin, UserausfallAPITestCase): def test_retrieve_user(self): """Retrieve the details of the current user.""" url = "/users/{pk}/" @@ -42,3 +29,44 @@ class UserTestCase(UserMixin, UserausfallAPITestCase): "url": get_url(response, "user", self.user), }, ) + + def test_activate_user(self): + """Create the ldap account for the current user.""" + url = "/users/{pk}/activate/" + self.create_trust_bridge(is_trusted=True) + self.authenticate_user() + response = self.client.post(self.get_api_url(url, pk=self.user.pk), {"password": self.password}) + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertTrue(self.user.has_ldap_account()) + + def test_activate_user_with_invalid_password(self): + """Create the ldap account for the current user with an invalid password.""" + url = "/users/{pk}/activate/" + self.create_trust_bridge(is_trusted=True) + self.authenticate_user() + response = self.client.post(self.get_api_url(url, pk=self.user.pk), {"password": "invalid"}) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual( + response.data, {"password": [ErrorDetail("Password does not match the user's password", code="invalid")]} + ) + + def test_activate_user_without_trust_bridge(self): + """Create the ldap account for the current user without a trust bridge.""" + url = "/users/{pk}/activate/" + self.authenticate_user() + response = self.client.post(self.get_api_url(url, pk=self.user.pk), {"password": self.password}) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual( + response.data, {"non_field_errors": [ErrorDetail("User has no trusted trust bridge", code="invalid")]} + ) + + def test_activate_user_with_untrusted_trust_bridge(self): + """Create the ldap account for the current user with an untrusted trust bridge.""" + url = "/users/{pk}/activate/" + self.create_trust_bridge(is_trusted=False) + self.authenticate_user() + response = self.client.post(self.get_api_url(url, pk=self.user.pk), {"password": self.password}) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual( + response.data, {"non_field_errors": [ErrorDetail("User has no trusted trust bridge", code="invalid")]} + ) diff --git a/userausfall/rest_api/views.py b/userausfall/rest_api/views.py index 7ddb5c4..e2ec582 100644 --- a/userausfall/rest_api/views.py +++ b/userausfall/rest_api/views.py @@ -3,8 +3,8 @@ from rest_framework import mixins, status, viewsets from rest_framework.decorators import action from rest_framework.response import Response -from userausfall.models import MissingUserAttribute, PasswordMismatch, TrustBridge, User -from userausfall.rest_api.serializers import TrustBridgeSerializer, UserSerializer +from userausfall.models import TrustBridge, User +from userausfall.rest_api.serializers import ActivateUserSerializer, TrustBridgeSerializer, UserSerializer from userausfall.views import get_authenticated_user @@ -18,26 +18,28 @@ class TrustBridgeViewSet( return self.queryset.filter(trust_taker=get_authenticated_user(self.request)) -class UserViewSet(mixins.RetrieveModelMixin, viewsets.GenericViewSet): +class ActivateUserMixin: + @action(detail=True, methods=["post"]) + def activate(self, request, pk=None): + """Create the corresponding LDAP account.""" + instance = self.get_object() + serializer = self.get_serializer(instance, data=request.data) + serializer.is_valid(raise_exception=True) + self.perform_activate(instance, serializer) + return Response(status=status.HTTP_204_NO_CONTENT) + + def perform_activate(self, instance: User, serializer): + instance.create_ldap_account(serializer.validated_data["password"]) + + +class UserViewSet(ActivateUserMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet): serializer_class = UserSerializer def get_queryset(self): return User.objects.filter(pk=get_authenticated_user(self.request).pk) - @action(detail=True, methods=["post"]) - def activate(self, request, pk=None): - """Create the corresponding LDAP account.""" - user: User = self.get_object() - serializer = self.get_serializer(data=request.data) - if serializer.is_valid(): - try: - # We prevent untrusted user accounts from being activated via API. - # They might be activated via Admin or programmatically. - if not user.trust_bridge.is_trusted: - raise MissingUserAttribute("User has no trusted trust bridge.") - user.create_ldap_account(serializer.validated_data["password"]) - except (MissingUserAttribute, PasswordMismatch) as e: - return Response({"message": str(e)}, status=status.HTTP_400_BAD_REQUEST) - return Response(status=status.HTTP_204_NO_CONTENT) + def get_serializer_class(self): + if self.action == "activate": + return ActivateUserSerializer else: - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + return super().get_serializer_class() diff --git a/userausfall/tests.py b/userausfall/tests.py index 269c892..1ca7521 100644 --- a/userausfall/tests.py +++ b/userausfall/tests.py @@ -10,6 +10,9 @@ class LDAPTestCase(TestCase): self.password = "test12345" self.ldap = LDAPManager() + def tearDown(self) -> None: + self.ldap.drop_test_connection() + def test_create_has_account(self): exists = self.ldap.has_account(self.username) self.assertFalse(exists)