feat: Add endpoint to activate users

This commit is contained in:
aldrin 2021-10-26 12:35:22 +02:00
parent bba1d7c8aa
commit 8ca0f66777
8 changed files with 159 additions and 68 deletions

View File

@ -1,6 +1,8 @@
from django.conf import settings from django.conf import settings
from ldap3 import Connection, MOCK_SYNC, SAFE_SYNC, Server from ldap3 import Connection, MOCK_SYNC, SAFE_SYNC, Server
_test_connection = None
class LDAPManager: class LDAPManager:
def __init__(self): def __init__(self):
@ -31,6 +33,12 @@ class LDAPManager:
is_valid = self.connection.entries[0]["userPassword"].value == raw_password is_valid = self.connection.entries[0]["userPassword"].value == raw_password
return is_valid return is_valid
def drop_test_connection(self):
global _test_connection
self.connection.unbind()
self.connection = None
_test_connection = None
def _get_connection(self): def _get_connection(self):
server = Server("localhost") server = Server("localhost")
connection = Connection( connection = Connection(
@ -43,8 +51,12 @@ class LDAPManager:
return connection return connection
def _get_test_connection(self): def _get_test_connection(self):
server = Server("testserver") global _test_connection
connection = Connection(server, user="cn=admin,dc=local", password="admin_secret", client_strategy=MOCK_SYNC) if _test_connection is None:
connection.strategy.add_entry("cn=admin,dc=local", {"userPassword": "admin_secret"}) server = Server("testserver")
connection.bind() _test_connection = Connection(
return connection server, user="cn=admin,dc=local", password="admin_secret", client_strategy=MOCK_SYNC
)
_test_connection.strategy.add_entry("cn=admin,dc=local", {"userPassword": "admin_secret"})
_test_connection.bind()
return _test_connection

View File

@ -5,18 +5,13 @@ from django.contrib.auth.validators import UnicodeUsernameValidator
from django.core.mail import send_mail from django.core.mail import send_mail
from django.db import models from django.db import models
from django.utils import timezone from django.utils import timezone
from django.utils.functional import cached_property
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from djeveric.fields import ConfirmationField from djeveric.fields import ConfirmationField
from djeveric.models import ConfirmableModelMixin from djeveric.models import ConfirmableModelMixin
from userausfall import ldap
from userausfall.emails import TrustBridgeConfirmationEmail from userausfall.emails import TrustBridgeConfirmationEmail
from userausfall.ldap import LDAPManager
class MissingUserAttribute(Exception):
"""The user object is missing a required attribute."""
pass
class PasswordMismatch(Exception): class PasswordMismatch(Exception):
@ -89,22 +84,28 @@ class User(PermissionsMixin, AbstractBaseUser):
super().clean() super().clean()
self.email = self.__class__.objects.normalize_email(self.email) self.email = self.__class__.objects.normalize_email(self.email)
def create_ldap_account(self, raw_password):
"""Create the LDAP account which corresponds to this user."""
if not self.check_password(raw_password):
raise PasswordMismatch("The given password does not match the user's password.")
return self._ldap.create_account(self.username, raw_password)
def email_user(self, subject, message, from_email=None, **kwargs): def email_user(self, subject, message, from_email=None, **kwargs):
"""Send an email to this user.""" """Send an email to this user."""
send_mail(subject, message, from_email, [self.email], **kwargs) send_mail(subject, message, from_email, [self.email], **kwargs)
def create_ldap_account(self, raw_password):
"""Create the LDAP account which corresponds to this user."""
if not self.username:
raise MissingUserAttribute("User is missing a username.")
if not self.check_password(raw_password):
raise PasswordMismatch("The given password does not match the user's password.")
return ldap.create_account(self.username, raw_password)
def get_primary_email(self): def get_primary_email(self):
"""Returns the primary email address for this user.""" """Returns the primary email address for this user."""
return f"{self.username}@{settings.USERAUSFALL['PRIMARY_EMAIL_DOMAIN']}" return f"{self.username}@{settings.USERAUSFALL['PRIMARY_EMAIL_DOMAIN']}"
def has_ldap_account(self):
"""Returns True if an ldap account exists for the user's username."""
return self._ldap.has_account(self.username)
@cached_property
def _ldap(self):
return LDAPManager()
class TrustBridge(ConfirmableModelMixin, models.Model): class TrustBridge(ConfirmableModelMixin, models.Model):
is_trusted = ConfirmationField(email_class=TrustBridgeConfirmationEmail) is_trusted = ConfirmationField(email_class=TrustBridgeConfirmationEmail)

View File

@ -10,6 +10,22 @@ class UserSerializer(serializers.HyperlinkedModelSerializer):
fields = ["id", "trust_bridge", "url"] fields = ["id", "trust_bridge", "url"]
class ActivateUserSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = User
fields = ["id", "password", "url"]
def validate(self, data):
if not hasattr(self.instance, "trust_bridge") or not self.instance.trust_bridge.is_trusted:
raise serializers.ValidationError("User has no trusted trust bridge")
return data
def validate_password(self, value):
if not self.instance.check_password(value):
raise serializers.ValidationError("Password does not match the user's password")
return value
class TrustBridgeUserSerializer(serializers.HyperlinkedModelSerializer): class TrustBridgeUserSerializer(serializers.HyperlinkedModelSerializer):
class Meta: class Meta:
model = User model = User

View File

@ -2,7 +2,29 @@ from rest_framework import status
from userausfall.models import User from userausfall.models import User
from userausfall.rest_api.tests.userausfall import UserausfallAPITestCase from userausfall.rest_api.tests.userausfall import UserausfallAPITestCase
from userausfall.rest_api.tests.users import UserMixin
class UserMixin:
user: User
password: str
username: str
def create_user(self):
self.username = f"test{User.objects.count()}"
self.password = "test12345"
self.user = User.objects.create_user(self.username, self.password)
return self.user
def ensure_user_exists(self):
if not hasattr(self, "user"):
self.create_user()
def authenticate_user(self):
self.ensure_user_exists()
if hasattr(self.client, "force_authentication"):
self.client.force_authenticate(user=self.user)
else:
self.client.force_login(user=self.user)
class AuthenticationTestCase(UserMixin, UserausfallAPITestCase): class AuthenticationTestCase(UserMixin, UserausfallAPITestCase):

View File

@ -1,18 +1,25 @@
from django.core import mail from django.core import mail
from rest_framework import status from rest_framework import status
from userausfall.models import TrustBridge from userausfall.models import TrustBridge, User
from userausfall.rest_api.tests.auth import UserMixin
from userausfall.rest_api.tests.userausfall import get_url, UserausfallAPITestCase from userausfall.rest_api.tests.userausfall import get_url, UserausfallAPITestCase
from userausfall.rest_api.tests.users import UserMixin
class TrustBridgeTestCase(UserMixin, UserausfallAPITestCase): class TrustBridgeMixin(UserMixin):
def create_trust_bridge(self): trust_bridge: TrustBridge
trust_giver: User
def create_trust_bridge(self, is_trusted=False):
self.trust_giver = self.create_user() self.trust_giver = self.create_user()
self.create_user() self.create_user()
self.trust_bridge = TrustBridge.objects.create(trust_taker=self.user, trust_giver=self.trust_giver) self.trust_bridge = TrustBridge.objects.create(
trust_taker=self.user, trust_giver=self.trust_giver, is_trusted=is_trusted
)
return self.trust_bridge return self.trust_bridge
class TrustBridgeTestCase(TrustBridgeMixin, UserausfallAPITestCase):
def test_create_trust_bridge(self): def test_create_trust_bridge(self):
"""Create a trust bridge for the current user.""" """Create a trust bridge for the current user."""
url = "/trust-bridges/" url = "/trust-bridges/"

View File

@ -1,33 +1,20 @@
from django.test import override_settings
from rest_framework import status from rest_framework import status
from rest_framework.exceptions import ErrorDetail
from userausfall.models import User from userausfall.ldap import LDAPManager
from userausfall.rest_api.tests.trust_bridges import TrustBridgeMixin
from userausfall.rest_api.tests.userausfall import get_url, UserausfallAPITestCase from userausfall.rest_api.tests.userausfall import get_url, UserausfallAPITestCase
class UserMixin: @override_settings(USERAUSFALL_LDAP_IS_TEST=True)
user: User class UserTestCase(TrustBridgeMixin, UserausfallAPITestCase):
password: str def setUp(self) -> None:
username: str self.ldap = LDAPManager()
def create_user(self): def tearDown(self) -> None:
self.username = f"test{User.objects.count()}" self.ldap.drop_test_connection()
self.password = "test12345"
self.user = User.objects.create_user(self.username, self.password)
return self.user
def ensure_user_exists(self):
if not hasattr(self, "user"):
self.create_user()
def authenticate_user(self):
self.ensure_user_exists()
if hasattr(self.client, "force_authentication"):
self.client.force_authenticate(user=self.user)
else:
self.client.force_login(user=self.user)
class UserTestCase(UserMixin, UserausfallAPITestCase):
def test_retrieve_user(self): def test_retrieve_user(self):
"""Retrieve the details of the current user.""" """Retrieve the details of the current user."""
url = "/users/{pk}/" url = "/users/{pk}/"
@ -42,3 +29,44 @@ class UserTestCase(UserMixin, UserausfallAPITestCase):
"url": get_url(response, "user", self.user), "url": get_url(response, "user", self.user),
}, },
) )
def test_activate_user(self):
"""Create the ldap account for the current user."""
url = "/users/{pk}/activate/"
self.create_trust_bridge(is_trusted=True)
self.authenticate_user()
response = self.client.post(self.get_api_url(url, pk=self.user.pk), {"password": self.password})
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
self.assertTrue(self.user.has_ldap_account())
def test_activate_user_with_invalid_password(self):
"""Create the ldap account for the current user with an invalid password."""
url = "/users/{pk}/activate/"
self.create_trust_bridge(is_trusted=True)
self.authenticate_user()
response = self.client.post(self.get_api_url(url, pk=self.user.pk), {"password": "invalid"})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(
response.data, {"password": [ErrorDetail("Password does not match the user's password", code="invalid")]}
)
def test_activate_user_without_trust_bridge(self):
"""Create the ldap account for the current user without a trust bridge."""
url = "/users/{pk}/activate/"
self.authenticate_user()
response = self.client.post(self.get_api_url(url, pk=self.user.pk), {"password": self.password})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(
response.data, {"non_field_errors": [ErrorDetail("User has no trusted trust bridge", code="invalid")]}
)
def test_activate_user_with_untrusted_trust_bridge(self):
"""Create the ldap account for the current user with an untrusted trust bridge."""
url = "/users/{pk}/activate/"
self.create_trust_bridge(is_trusted=False)
self.authenticate_user()
response = self.client.post(self.get_api_url(url, pk=self.user.pk), {"password": self.password})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(
response.data, {"non_field_errors": [ErrorDetail("User has no trusted trust bridge", code="invalid")]}
)

View File

@ -3,8 +3,8 @@ from rest_framework import mixins, status, viewsets
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.response import Response from rest_framework.response import Response
from userausfall.models import MissingUserAttribute, PasswordMismatch, TrustBridge, User from userausfall.models import TrustBridge, User
from userausfall.rest_api.serializers import TrustBridgeSerializer, UserSerializer from userausfall.rest_api.serializers import ActivateUserSerializer, TrustBridgeSerializer, UserSerializer
from userausfall.views import get_authenticated_user from userausfall.views import get_authenticated_user
@ -18,26 +18,28 @@ class TrustBridgeViewSet(
return self.queryset.filter(trust_taker=get_authenticated_user(self.request)) return self.queryset.filter(trust_taker=get_authenticated_user(self.request))
class UserViewSet(mixins.RetrieveModelMixin, viewsets.GenericViewSet): class ActivateUserMixin:
@action(detail=True, methods=["post"])
def activate(self, request, pk=None):
"""Create the corresponding LDAP account."""
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_activate(instance, serializer)
return Response(status=status.HTTP_204_NO_CONTENT)
def perform_activate(self, instance: User, serializer):
instance.create_ldap_account(serializer.validated_data["password"])
class UserViewSet(ActivateUserMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet):
serializer_class = UserSerializer serializer_class = UserSerializer
def get_queryset(self): def get_queryset(self):
return User.objects.filter(pk=get_authenticated_user(self.request).pk) return User.objects.filter(pk=get_authenticated_user(self.request).pk)
@action(detail=True, methods=["post"]) def get_serializer_class(self):
def activate(self, request, pk=None): if self.action == "activate":
"""Create the corresponding LDAP account.""" return ActivateUserSerializer
user: User = self.get_object()
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
try:
# We prevent untrusted user accounts from being activated via API.
# They might be activated via Admin or programmatically.
if not user.trust_bridge.is_trusted:
raise MissingUserAttribute("User has no trusted trust bridge.")
user.create_ldap_account(serializer.validated_data["password"])
except (MissingUserAttribute, PasswordMismatch) as e:
return Response({"message": str(e)}, status=status.HTTP_400_BAD_REQUEST)
return Response(status=status.HTTP_204_NO_CONTENT)
else: else:
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return super().get_serializer_class()

View File

@ -10,6 +10,9 @@ class LDAPTestCase(TestCase):
self.password = "test12345" self.password = "test12345"
self.ldap = LDAPManager() self.ldap = LDAPManager()
def tearDown(self) -> None:
self.ldap.drop_test_connection()
def test_create_has_account(self): def test_create_has_account(self):
exists = self.ldap.has_account(self.username) exists = self.ldap.has_account(self.username)
self.assertFalse(exists) self.assertFalse(exists)