diff --git a/userausfall/ldap.py b/userausfall/ldap.py index 7e2e0b9..4460466 100644 --- a/userausfall/ldap.py +++ b/userausfall/ldap.py @@ -1,44 +1,62 @@ from django.conf import settings -from ldap3 import Connection, Server, SYNC +from ldap3 import Connection, MOCK_SYNC, SAFE_SYNC, Server + +_test_connection = None -def create_account(username, raw_password): - connection = _get_connection() - is_success = connection.add( - f"cn={username},dc=local", - ["simpleSecurityObject", "organizationalRole"], - {"userPassword": raw_password}, - ) - return is_success +class LDAPManager: + def __init__(self): + if not getattr(settings, "USERAUSFALL_LDAP_IS_TEST", False): + self.connection = self._get_connection() + else: + self.connection = self._get_test_connection() + def create_account(self, username, raw_password): + is_success = self.connection.add( + f"cn={username},dc=local", + ["simpleSecurityObject", "organizationalRole"], + {"userPassword": raw_password}, + ) + return is_success -def account_exists(username): - connection = _get_connection() - exists = connection.search(f"cn={username},dc=local", "(objectclass=simpleSecurityObject)") - return exists + def has_account(self, username): + exists = self.connection.search(f"cn={username},dc=local", "(objectclass=simpleSecurityObject)") + return exists + def is_valid_account_data(self, username, raw_password): + is_valid = self.connection.search( + f"cn={username},dc=local", + "(objectclass=simpleSecurityObject)", + attributes=["userPassword"], + ) + if is_valid: + is_valid = self.connection.entries[0]["userPassword"].value == raw_password + return is_valid -def is_valid_account_data(username, raw_password): - connection = _get_connection() - is_valid = connection.search( - f"cn={username},dc=local", - "(objectclass=simpleSecurityObject)", - attributes=["userPassword"], - ) - if is_valid: - is_valid = connection.entries[0]["userPassword"].value == raw_password - return is_valid + def drop_test_connection(self): + global _test_connection + self.connection.unbind() + self.connection = None + _test_connection = None + def _get_connection(self): + server = Server("localhost") + connection = Connection( + server, + settings.USERAUSFALL_LDAP["ADMIN_USER_DN"], + settings.USERAUSFALL_LDAP["ADMIN_USER_PASSWORD"], + client_strategy=SAFE_SYNC, + auto_bind=True, + ) + return connection -def _get_connection(): - server = Server("localhost") - # The SAFE_SYNC client strategy doesn't seem to be present in Buster version of ldap3. We might want to use it as - # soon as it is available (multithreading). - connection = Connection( - server, - settings.USERAUSFALL_LDAP["ADMIN_USER_DN"], - settings.USERAUSFALL_LDAP["ADMIN_USER_PASSWORD"], - client_strategy=SYNC, - auto_bind=True, - ) - return connection + def _get_test_connection(self): + global _test_connection + if _test_connection is None: + server = Server("testserver") + _test_connection = Connection( + server, user="cn=admin,dc=local", password="admin_secret", client_strategy=MOCK_SYNC + ) + _test_connection.strategy.add_entry("cn=admin,dc=local", {"userPassword": "admin_secret"}) + _test_connection.bind() + return _test_connection diff --git a/userausfall/models.py b/userausfall/models.py index 42ebdd3..e8725db 100644 --- a/userausfall/models.py +++ b/userausfall/models.py @@ -5,18 +5,13 @@ from django.contrib.auth.validators import UnicodeUsernameValidator from django.core.mail import send_mail from django.db import models from django.utils import timezone +from django.utils.functional import cached_property from django.utils.translation import gettext_lazy as _ from djeveric.fields import ConfirmationField from djeveric.models import ConfirmableModelMixin -from userausfall import ldap from userausfall.emails import TrustBridgeConfirmationEmail - - -class MissingUserAttribute(Exception): - """The user object is missing a required attribute.""" - - pass +from userausfall.ldap import LDAPManager class PasswordMismatch(Exception): @@ -89,22 +84,28 @@ class User(PermissionsMixin, AbstractBaseUser): super().clean() self.email = self.__class__.objects.normalize_email(self.email) + def create_ldap_account(self, raw_password): + """Create the LDAP account which corresponds to this user.""" + if not self.check_password(raw_password): + raise PasswordMismatch("The given password does not match the user's password.") + return self._ldap.create_account(self.username, raw_password) + def email_user(self, subject, message, from_email=None, **kwargs): """Send an email to this user.""" send_mail(subject, message, from_email, [self.email], **kwargs) - def create_ldap_account(self, raw_password): - """Create the LDAP account which corresponds to this user.""" - if not self.username: - raise MissingUserAttribute("User is missing a username.") - if not self.check_password(raw_password): - raise PasswordMismatch("The given password does not match the user's password.") - return ldap.create_account(self.username, raw_password) - def get_primary_email(self): """Returns the primary email address for this user.""" return f"{self.username}@{settings.USERAUSFALL['PRIMARY_EMAIL_DOMAIN']}" + def has_ldap_account(self): + """Returns True if an ldap account exists for the user's username.""" + return self._ldap.has_account(self.username) + + @cached_property + def _ldap(self): + return LDAPManager() + class TrustBridge(ConfirmableModelMixin, models.Model): is_trusted = ConfirmationField(email_class=TrustBridgeConfirmationEmail) diff --git a/userausfall/rest_api/serializers.py b/userausfall/rest_api/serializers.py index 94225d1..6f15bfe 100644 --- a/userausfall/rest_api/serializers.py +++ b/userausfall/rest_api/serializers.py @@ -7,17 +7,39 @@ from userausfall.views import get_authenticated_user class UserSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = User - fields = ["username"] + fields = ["id", "trust_bridge", "url"] + + +class ActivateUserSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = User + fields = ["id", "password", "url"] + + def validate(self, data): + if not hasattr(self.instance, "trust_bridge") or not self.instance.trust_bridge.is_trusted: + raise serializers.ValidationError("User has no trusted trust bridge") + return data + + def validate_password(self, value): + if not self.instance.check_password(value): + raise serializers.ValidationError("Password does not match the user's password") + return value + + +class TrustBridgeUserSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = User + fields = ["id", "url", "username"] # the UniqueValidator for username prevents successful validation for existing users extra_kwargs = {"username": {"validators": []}} class TrustBridgeSerializer(serializers.HyperlinkedModelSerializer): - trust_giver = UserSerializer() + trust_giver = TrustBridgeUserSerializer() class Meta: model = TrustBridge - fields = ["is_trusted", "trust_giver"] + fields = ["id", "is_trusted", "trust_giver", "url"] read_only_fields = ["is_trusted"] def create(self, validated_data): diff --git a/userausfall/rest_api/tests/__init__.py b/userausfall/rest_api/tests/__init__.py index f7cf7fe..abbee82 100644 --- a/userausfall/rest_api/tests/__init__.py +++ b/userausfall/rest_api/tests/__init__.py @@ -1,2 +1,3 @@ -from .auth import * # noqa: F401, F403 -from .trust_bridges import * # noqa: F401, F403 +from .auth import AuthenticationTestCase # noqa: F401, F403 +from .trust_bridges import TrustBridgeTestCase # noqa: F401, F403 +from .users import UserTestCase # noqa: F401, F403 diff --git a/userausfall/rest_api/tests/auth.py b/userausfall/rest_api/tests/auth.py index 233f23e..9ac7b38 100644 --- a/userausfall/rest_api/tests/auth.py +++ b/userausfall/rest_api/tests/auth.py @@ -27,7 +27,7 @@ class UserMixin: self.client.force_login(user=self.user) -class UserTestCase(UserMixin, UserausfallAPITestCase): +class AuthenticationTestCase(UserMixin, UserausfallAPITestCase): base_url = "/api/auth" def test_signup(self): diff --git a/userausfall/rest_api/tests/trust_bridges.py b/userausfall/rest_api/tests/trust_bridges.py index 8bb4678..71d3a42 100644 --- a/userausfall/rest_api/tests/trust_bridges.py +++ b/userausfall/rest_api/tests/trust_bridges.py @@ -1,17 +1,25 @@ from django.core import mail from rest_framework import status -from userausfall.models import TrustBridge -from userausfall.rest_api.tests import UserausfallAPITestCase, UserMixin +from userausfall.models import TrustBridge, User +from userausfall.rest_api.tests.auth import UserMixin +from userausfall.rest_api.tests.userausfall import get_url, UserausfallAPITestCase -class TrustBridgeTestCase(UserMixin, UserausfallAPITestCase): - def create_trust_bridge(self): +class TrustBridgeMixin(UserMixin): + trust_bridge: TrustBridge + trust_giver: User + + def create_trust_bridge(self, is_trusted=False): self.trust_giver = self.create_user() self.create_user() - self.trust_bridge = TrustBridge.objects.create(trust_taker=self.user, trust_giver=self.trust_giver) + self.trust_bridge = TrustBridge.objects.create( + trust_taker=self.user, trust_giver=self.trust_giver, is_trusted=is_trusted + ) return self.trust_bridge + +class TrustBridgeTestCase(TrustBridgeMixin, UserausfallAPITestCase): def test_create_trust_bridge(self): """Create a trust bridge for the current user.""" url = "/trust-bridges/" @@ -37,13 +45,17 @@ class TrustBridgeTestCase(UserMixin, UserausfallAPITestCase): self.authenticate_user() response = self.client.get(self.get_api_url(url, pk=self.trust_bridge.pk)) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual( + self.assertDictEqual( response.data, { + "id": self.trust_bridge.id, "is_trusted": False, "trust_giver": { + "id": self.trust_giver.id, "username": self.trust_giver.username, + "url": get_url(response, "user", self.trust_giver), }, + "url": get_url(response, "trustbridge", self.trust_bridge), }, ) diff --git a/userausfall/rest_api/tests/userausfall.py b/userausfall/rest_api/tests/userausfall.py index 30fd6ea..32e335c 100644 --- a/userausfall/rest_api/tests/userausfall.py +++ b/userausfall/rest_api/tests/userausfall.py @@ -1,6 +1,11 @@ +from rest_framework.reverse import reverse from rest_framework.test import APITestCase +def get_url(response, basename, instance): + return reverse(f"{basename}-detail", [instance.pk], request=response.wsgi_request) + + class UserausfallAPITestCase(APITestCase): base_url = "/api" diff --git a/userausfall/rest_api/tests/users.py b/userausfall/rest_api/tests/users.py new file mode 100644 index 0000000..8f1b9ee --- /dev/null +++ b/userausfall/rest_api/tests/users.py @@ -0,0 +1,72 @@ +from django.test import override_settings +from rest_framework import status +from rest_framework.exceptions import ErrorDetail + +from userausfall.ldap import LDAPManager +from userausfall.rest_api.tests.trust_bridges import TrustBridgeMixin +from userausfall.rest_api.tests.userausfall import get_url, UserausfallAPITestCase + + +@override_settings(USERAUSFALL_LDAP_IS_TEST=True) +class UserTestCase(TrustBridgeMixin, UserausfallAPITestCase): + def setUp(self) -> None: + self.ldap = LDAPManager() + + def tearDown(self) -> None: + self.ldap.drop_test_connection() + + def test_retrieve_user(self): + """Retrieve the details of the current user.""" + url = "/users/{pk}/" + self.authenticate_user() + response = self.client.get(self.get_api_url(url, pk=self.user.pk)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual( + response.data, + { + "id": self.user.id, + "trust_bridge": None, + "url": get_url(response, "user", self.user), + }, + ) + + def test_activate_user(self): + """Create the ldap account for the current user.""" + url = "/users/{pk}/activate/" + self.create_trust_bridge(is_trusted=True) + self.authenticate_user() + response = self.client.post(self.get_api_url(url, pk=self.user.pk), {"password": self.password}) + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertTrue(self.user.has_ldap_account()) + + def test_activate_user_with_invalid_password(self): + """Create the ldap account for the current user with an invalid password.""" + url = "/users/{pk}/activate/" + self.create_trust_bridge(is_trusted=True) + self.authenticate_user() + response = self.client.post(self.get_api_url(url, pk=self.user.pk), {"password": "invalid"}) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual( + response.data, {"password": [ErrorDetail("Password does not match the user's password", code="invalid")]} + ) + + def test_activate_user_without_trust_bridge(self): + """Create the ldap account for the current user without a trust bridge.""" + url = "/users/{pk}/activate/" + self.authenticate_user() + response = self.client.post(self.get_api_url(url, pk=self.user.pk), {"password": self.password}) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual( + response.data, {"non_field_errors": [ErrorDetail("User has no trusted trust bridge", code="invalid")]} + ) + + def test_activate_user_with_untrusted_trust_bridge(self): + """Create the ldap account for the current user with an untrusted trust bridge.""" + url = "/users/{pk}/activate/" + self.create_trust_bridge(is_trusted=False) + self.authenticate_user() + response = self.client.post(self.get_api_url(url, pk=self.user.pk), {"password": self.password}) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual( + response.data, {"non_field_errors": [ErrorDetail("User has no trusted trust bridge", code="invalid")]} + ) diff --git a/userausfall/rest_api/urls.py b/userausfall/rest_api/urls.py index 7545a24..5128234 100644 --- a/userausfall/rest_api/urls.py +++ b/userausfall/rest_api/urls.py @@ -2,10 +2,11 @@ from django.urls import include, path from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView from rest_framework import routers -from userausfall.rest_api.views import TrustBridgeViewSet +from userausfall.rest_api.views import TrustBridgeViewSet, UserViewSet router = routers.SimpleRouter() -router.register(r"trust-bridges", TrustBridgeViewSet, "trust-bridge") +router.register(r"trust-bridges", TrustBridgeViewSet, "trustbridge") +router.register(r"users", UserViewSet, "user") urlpatterns = [ path("", include(router.urls)), diff --git a/userausfall/rest_api/views.py b/userausfall/rest_api/views.py index f078fa6..e2ec582 100644 --- a/userausfall/rest_api/views.py +++ b/userausfall/rest_api/views.py @@ -3,8 +3,8 @@ from rest_framework import mixins, status, viewsets from rest_framework.decorators import action from rest_framework.response import Response -from userausfall.models import MissingUserAttribute, PasswordMismatch, TrustBridge, User -from userausfall.rest_api.serializers import TrustBridgeSerializer +from userausfall.models import TrustBridge, User +from userausfall.rest_api.serializers import ActivateUserSerializer, TrustBridgeSerializer, UserSerializer from userausfall.views import get_authenticated_user @@ -18,21 +18,28 @@ class TrustBridgeViewSet( return self.queryset.filter(trust_taker=get_authenticated_user(self.request)) -class UserViewSet(viewsets.GenericViewSet): +class ActivateUserMixin: @action(detail=True, methods=["post"]) def activate(self, request, pk=None): """Create the corresponding LDAP account.""" - user: User = self.get_object() - serializer = self.get_serializer(data=request.data) - if serializer.is_valid(): - try: - # We prevent untrusted user accounts from being activated via API. - # They might be activated via Admin or programmatically. - if not user.trust_bridge.is_trusted: - raise MissingUserAttribute("User has no trusted trust bridge.") - user.create_ldap_account(serializer.validated_data["password"]) - except (MissingUserAttribute, PasswordMismatch) as e: - return Response({"message": str(e)}, status=status.HTTP_400_BAD_REQUEST) - return Response(status=status.HTTP_204_NO_CONTENT) + instance = self.get_object() + serializer = self.get_serializer(instance, data=request.data) + serializer.is_valid(raise_exception=True) + self.perform_activate(instance, serializer) + return Response(status=status.HTTP_204_NO_CONTENT) + + def perform_activate(self, instance: User, serializer): + instance.create_ldap_account(serializer.validated_data["password"]) + + +class UserViewSet(ActivateUserMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet): + serializer_class = UserSerializer + + def get_queryset(self): + return User.objects.filter(pk=get_authenticated_user(self.request).pk) + + def get_serializer_class(self): + if self.action == "activate": + return ActivateUserSerializer else: - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + return super().get_serializer_class() diff --git a/userausfall/tests.py b/userausfall/tests.py new file mode 100644 index 0000000..1ca7521 --- /dev/null +++ b/userausfall/tests.py @@ -0,0 +1,30 @@ +from django.test import override_settings, TestCase + +from userausfall.ldap import LDAPManager + + +@override_settings(USERAUSFALL_LDAP_IS_TEST=True) +class LDAPTestCase(TestCase): + def setUp(self) -> None: + self.username = "test" + self.password = "test12345" + self.ldap = LDAPManager() + + def tearDown(self) -> None: + self.ldap.drop_test_connection() + + def test_create_has_account(self): + exists = self.ldap.has_account(self.username) + self.assertFalse(exists) + is_created = self.ldap.create_account(self.username, self.password) + self.assertTrue(is_created) + exists = self.ldap.has_account(self.username) + self.assertTrue(exists) + + def test_create_account_data(self): + is_valid = self.ldap.is_valid_account_data(self.username, self.password) + self.assertFalse(is_valid) + is_created = self.ldap.create_account(self.username, self.password) + self.assertTrue(is_created) + is_valid = self.ldap.is_valid_account_data(self.username, self.password) + self.assertTrue(is_valid)