From 33f8a47a8b87874c0ce2e7080fdb43079d2e147b Mon Sep 17 00:00:00 2001 From: Andrew Buss Date: Fri, 29 May 2015 00:09:49 -0700 Subject: [PATCH] enforce 24 hour grace period on email verification --- flashcards/api.py | 17 ++++++++++++++++- flashcards/views.py | 13 +++++++------ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/flashcards/api.py b/flashcards/api.py index c32f45a..0ff5fa9 100644 --- a/flashcards/api.py +++ b/flashcards/api.py @@ -1,9 +1,11 @@ +from django.utils.timezone import now from flashcards.models import Flashcard, UserFlashcardQuiz +from rest_framework.exceptions import PermissionDenied from rest_framework.pagination import PageNumberPagination from rest_framework.permissions import BasePermission +mock_no_params = lambda x: None -mock_no_params = lambda x:None class StandardResultsSetPagination(PageNumberPagination): page_size = 40 @@ -36,3 +38,16 @@ class IsFlashcardReviewer(BasePermission): return True assert type(obj) is UserFlashcardQuiz return request.user == obj.user_flashcard.user + + +class IsAuthenticatedAndConfirmed(BasePermission): + """ + Allows access only to authenticated users who have verified their email address, with a 24 hour grace period + """ + + def has_permission(self, request, view): + if not (request.user and request.user.is_authenticated()): return False + if request.user.confirmed_email: return True + if (now() - request.user.date_joined).days > 0: + raise PermissionDenied('Please verify your email before continuing') + return True diff --git a/flashcards/views.py b/flashcards/views.py index da4cbb6..1eaa83e 100644 --- a/flashcards/views.py +++ b/flashcards/views.py @@ -2,7 +2,8 @@ import django from django.contrib import auth from django.db import IntegrityError from django.shortcuts import get_object_or_404 -from flashcards.api import StandardResultsSetPagination, IsEnrolledInAssociatedSection, IsFlashcardReviewer +from flashcards.api import StandardResultsSetPagination, IsEnrolledInAssociatedSection, IsFlashcardReviewer, \ + IsAuthenticatedAndConfirmed from flashcards.models import Section, User, Flashcard, FlashcardHide, UserFlashcard, UserFlashcardQuiz from flashcards.notifications import notify_new_card from flashcards.serializers import SectionSerializer, UserUpdateSerializer, RegistrationSerializer, UserSerializer, \ @@ -27,7 +28,7 @@ class SectionViewSet(ReadOnlyModelViewSet): queryset = Section.objects.all() serializer_class = DeepSectionSerializer pagination_class = StandardResultsSetPagination - permission_classes = [IsAuthenticated] + permission_classes = [IsAuthenticatedAndConfirmed] @detail_route(methods=['GET']) def flashcards(self, request, pk): @@ -116,7 +117,7 @@ class SectionViewSet(ReadOnlyModelViewSet): class UserSectionListView(ListAPIView): serializer_class = DeepSectionSerializer - permission_classes = [IsAuthenticated] + permission_classes = [IsAuthenticatedAndConfirmed] def get_queryset(self): return self.request.user.sections.all() @@ -126,7 +127,7 @@ class UserSectionListView(ListAPIView): class UserDetail(GenericAPIView): serializer_class = UserSerializer - permission_classes = [IsAuthenticated] + permission_classes = [IsAuthenticatedAndConfirmed] def patch(self, request, format=None): """ @@ -258,7 +259,7 @@ def reset_password(request, format=None): class FlashcardViewSet(GenericViewSet, CreateModelMixin, RetrieveModelMixin): queryset = Flashcard.objects.all() serializer_class = FlashcardSerializer - permission_classes = [IsAuthenticated, IsEnrolledInAssociatedSection] + permission_classes = [IsAuthenticatedAndConfirmed, IsEnrolledInAssociatedSection] # Override create in CreateModelMixin def create(self, request, *args, **kwargs): serializer = FlashcardSerializer(data=request.data) @@ -340,7 +341,7 @@ class FlashcardViewSet(GenericViewSet, CreateModelMixin, RetrieveModelMixin): class UserFlashcardQuizViewSet(GenericViewSet, CreateModelMixin, UpdateModelMixin): - permission_classes = [IsAuthenticated, IsFlashcardReviewer] + permission_classes = [IsAuthenticatedAndConfirmed, IsFlashcardReviewer] queryset = UserFlashcardQuiz.objects.all() def get_serializer_class(self): -- 1.9.1