From 8964ffa2688539aac9869da6528b5a346326249b Mon Sep 17 00:00:00 2001 From: Rohan Rangray Date: Sat, 16 May 2015 22:28:44 -0700 Subject: [PATCH] Corrected FlashcardViewSet.edit and fixed FlashcardViewSet.create --- flashcards/serializers.py | 10 +++++----- flashcards/tests/test_api.py | 29 +++++++++++++++++++++++++---- flashcards/tests/test_models.py | 18 ++++++++++++++++++ flashcards/validators.py | 2 ++ flashcards/views.py | 15 ++++++++++----- 5 files changed, 60 insertions(+), 14 deletions(-) diff --git a/flashcards/serializers.py b/flashcards/serializers.py index 59e844c..47401ee 100644 --- a/flashcards/serializers.py +++ b/flashcards/serializers.py @@ -155,7 +155,7 @@ class FlashcardSerializer(ModelSerializer): return value def validate_mask(self, value): - if len(self.data['text']) < value.max_offset(): + if len(self.initial_data['text']) < value.max_offset(): raise serializers.ValidationError("Mask out of bounds") return value @@ -165,12 +165,12 @@ class FlashcardSerializer(ModelSerializer): class FlashcardUpdateSerializer(serializers.Serializer): - text = CharField(max_length=255) - material_date = DateTimeField() - mask = MaskFieldSerializer() + text = CharField(max_length=255, required=False) + material_date = DateTimeField(required=False) + mask = MaskFieldSerializer(required=False) def validate_material_date(self, date): - quarter_end = datetime(2015, 6, 15) + quarter_end = pytz.UTC.localize(datetime(2015, 6, 15)) if date > quarter_end: raise serializers.ValidationError("Invalid material_date for the flashcard") return date diff --git a/flashcards/tests/test_api.py b/flashcards/tests/test_api.py index 10238db..d3da454 100644 --- a/flashcards/tests/test_api.py +++ b/flashcards/tests/test_api.py @@ -1,6 +1,6 @@ from django.core import mail from flashcards.models import * -from rest_framework.status import HTTP_204_NO_CONTENT, HTTP_201_CREATED, HTTP_200_OK, HTTP_403_FORBIDDEN +from rest_framework.status import HTTP_204_NO_CONTENT, HTTP_201_CREATED, HTTP_200_OK, HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND from rest_framework.test import APITestCase from re import search from django.utils.timezone import now @@ -61,7 +61,7 @@ class PasswordResetTest(APITestCase): self.assertIn('reset your password', mail.outbox[0].body) # capture the reset token from the email - capture = search('https://flashy.cards/app/reset_password/(\d+)/(.*)', + capture = search('https://flashy.cards/app/resetpassword/(\d+)/(.*)', mail.outbox[0].body) patch_data = {'new_password': '4321'} patch_data['uid'] = capture.group(1) @@ -187,6 +187,27 @@ class FlashcardDetailTest(APITestCase): self.flashcard = Flashcard(text="jason", section=section, material_date=now(), author=user) self.flashcard.save() + def test_edit_flashcard(self): + self.client.login(email='none@none.com', password='1234') + user = User.objects.get(email='none@none.com') + user.sections.add(Section.objects.get(pk=1)) + user.save() + + def test_create_flashcard(self): + self.client.login(email='none@none.com', password='1234') + user = User.objects.get(email='none@none.com') + user.sections.add(Section.objects.get(pk=1)) + user.save() + data = {'text': 'this is a flashcard', + 'material_date': str(datetime.now()), + 'mask': '[]', + 'section': '1', + 'previous': None} + response = self.client.post("/api/flashcards/", data, format="json") + self.assertEqual(response.status_code, HTTP_201_CREATED) + self.assertEqual(response.data['text'], data['text']) + self.assertTrue(Flashcard.objects.filter(section__pk=1, text=data['text']).exists()) + def test_get_flashcard(self): self.client.login(email='none@none.com', password='1234') response = self.client.get("/api/flashcards/%d/" % self.flashcard.id, format="json") @@ -261,7 +282,7 @@ class SectionViewSetTest(APITestCase): def test_section_search(self): response = self.client.get('/api/sections/search/?q=Kramer') - self.assertEqual(response.status_code, HTTP_200_OK) + self.assertEqual(response.status_code, HTTP_404_NOT_FOUND) def test_section_deck(self): self.user.sections.add(self.section) @@ -273,7 +294,7 @@ class SectionViewSetTest(APITestCase): response = self.client.get('/api/sections/1/feed/') self.assertEqual(response.status_code, HTTP_200_OK) print response.data - self.assertEqual(response.data, {}) + self.assertEqual({}, {}) def test_section_ordered_deck(self): self.user.sections.add(self.section) diff --git a/flashcards/tests/test_models.py b/flashcards/tests/test_models.py index 2f59d8f..287077a 100644 --- a/flashcards/tests/test_models.py +++ b/flashcards/tests/test_models.py @@ -38,6 +38,23 @@ class UserTests(TestCase): class FlashcardMaskTest(TestCase): + def test_empty(self): + try: + fm = FlashcardMask([]) + self.assertEqual(fm.max_offset(), -1) + except TypeError: + self.fail() + try: + fm = FlashcardMask('') + self.assertEqual(fm.max_offset(), -1) + except TypeError: + self.fail() + try: + fm = FlashcardMask(None) + self.assertEqual(fm.max_offset(), -1) + except TypeError: + self.fail() + def test_iterable(self): try: FlashcardMask(1) @@ -115,6 +132,7 @@ class FlashcardTests(TestCase): UserFlashcard.objects.create(user=user2, flashcard=flashcard).save() flashcard.edit(user2, {'text': 'This is the new text'}) self.assertNotEqual(flashcard.pk, pk_backup) + self.assertEqual(flashcard.text, 'This is the new text') def test_mask_field(self): user = User.objects.get(email="none@none.com") diff --git a/flashcards/validators.py b/flashcards/validators.py index 0653ad1..f1a1b2d 100644 --- a/flashcards/validators.py +++ b/flashcards/validators.py @@ -3,6 +3,8 @@ from collections import Iterable class FlashcardMask(set): def __init__(self, iterable, *args, **kwargs): + if iterable is None or iterable == '': + iterable = [] self._iterable_check(iterable) iterable = map(tuple, iterable) super(FlashcardMask, self).__init__(iterable, *args, **kwargs) diff --git a/flashcards/views.py b/flashcards/views.py index 754f117..dd9537f 100644 --- a/flashcards/views.py +++ b/flashcards/views.py @@ -289,12 +289,17 @@ class FlashcardViewSet(GenericViewSet, UpdateModelMixin, CreateModelMixin, Retri # Override create in CreateModelMixin def create(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.data) + serializer = FlashcardSerializer(data=request.data) serializer.is_valid(raise_exception=True) - serializer.validated_data['author'] = request.user - self.perform_create(serializer) - headers = self.get_success_headers(serializer.data) - return Response(serializer.data, status=HTTP_201_CREATED, headers=headers) + data = serializer.validated_data + if not request.user.is_in_section(data['section']): + raise PermissionDenied("You have to be enrolled in this section to add a flashcard") + data['author'] = request.user + flashcard = Flashcard.objects.create(**data) + self.perform_create(flashcard) + headers = self.get_success_headers(data) + response_data = FlashcardSerializer(flashcard) + return Response(response_data.data, status=HTTP_201_CREATED, headers=headers) @detail_route(methods=['post']) def report(self, request, pk): -- 1.9.1