diff --git a/flashcards/fields.py b/flashcards/fields.py index 4569e74..2fe841b 100644 --- a/flashcards/fields.py +++ b/flashcards/fields.py @@ -1,4 +1,5 @@ from django.db import models +from validators import FlashcardMask, OverlapIntervalException class MaskField(models.Field): @@ -37,16 +38,14 @@ class MaskField(models.Field): return ','.join(['-'.join(map(str, i)) for i in value]) def to_python(self, value): - if value is None or isinstance(value, set): + if value is None: return value - return MaskField._parse_mask(value) + return sorted(list(FlashcardMask(value))) def get_prep_value(self, value): if value is None: return value - if not isinstance(value, set) or not all([isinstance(interval, tuple) for interval in value]): - raise ValueError("Invalid value for MaskField attribute") - return self.__class__._parse_mask(sorted(value)) + return sorted(list(FlashcardMask(value))) def get_prep_lookup(self, lookup_type, value): raise TypeError("Lookup not supported for MaskField") @@ -65,15 +64,9 @@ class MaskField(models.Field): @classmethod def _varchar_parse_mask(cls, value): - intervals = [] - ranges = value.split(',') - for interval in ranges: - _range = interval.split('-') - if len(_range) != 2 or not all(map(unicode.isdigit, _range)): - raise ValueError("Invalid range format.") - intervals.append(tuple(_range)) - return set([tuple(i) for i in cls._parse_mask(sorted(intervals))]) + mask = [tuple(map(int, i.split('-'))) for i in value.split(',')] + return FlashcardMask(mask) @classmethod def _psql_parse_mask(cls, value): - return set([tuple(i) for i in cls._parse_mask(sorted(value))]) + return FlashcardMask(value) diff --git a/flashcards/tests/test_models.py b/flashcards/tests/test_models.py index 6a64b06..5028cc9 100644 --- a/flashcards/tests/test_models.py +++ b/flashcards/tests/test_models.py @@ -2,6 +2,7 @@ from datetime import datetime from django.test import TestCase from flashcards.models import User, Section, Flashcard +from flashcards.validators import OverlapIntervalException class RegistrationTests(TestCase): @@ -68,5 +69,5 @@ class FlashcardTests(TestCase): previous=None, mask={(10,34), (0, 14)}) self.fail() - except ValueError: + except OverlapIntervalException: self.assertTrue(True) diff --git a/flashcards/validators.py b/flashcards/validators.py new file mode 100644 index 0000000..81d8417 --- /dev/null +++ b/flashcards/validators.py @@ -0,0 +1,36 @@ +__author__ = 'rray' + +from collections import Iterable + + +class FlashcardMask(set): + def __init__(self, *args, **kwargs): + super(FlashcardMask, self).__init__(*args, **kwargs) + self._iterable_check() + self._interval_check() + self._overlap_check() + + def _iterable_check(self): + if not all([isinstance(i, Iterable) for i in self]): + raise TypeError("Interval not a valid iterable") + + def _interval_check(self): + if not all([len(i) == 2 for i in self]): + raise TypeError("Intervals must have exactly 2 elements, begin and end") + + def _overlap_check(self): + p_beg, p_end = -1, -1 + for interval in sorted(self): + beg, end = map(int, interval) + if not (0 <= beg <= 255) or not (0 <= end <= 255) or not (beg <= end) or not (beg > p_end): + raise OverlapIntervalException((beg, end), "Invalid interval offsets in the mask") + p_beg, p_end = beg, end + + +class OverlapIntervalException(Exception): + def __init__(self, interval, reason): + self.interval = interval + self.reason = reason + + def __str__(self): + return repr(self.reason) + ': ' + repr(self.interval) diff --git a/flashcards/views.py b/flashcards/views.py index 355f3d0..b7c89f6 100644 --- a/flashcards/views.py +++ b/flashcards/views.py @@ -273,7 +273,7 @@ class FlashcardViewSet(GenericViewSet, CreateModelMixin, RetrieveModelMixin): :return: A 204 response upon success. """ user = request.user - flashcard = Flashcard.objects.get(pk=pk) + flashcard = self.get_object() user_card, created = UserFlashcard.objects.get_or_create(user=user, flashcard=flashcard) user_card.save() return Response(status=HTTP_204_NO_CONTENT)