diff --git a/src/shiftings/accounts/models/user.py b/src/shiftings/accounts/models/user.py index 86d149c..4f610d7 100644 --- a/src/shiftings/accounts/models/user.py +++ b/src/shiftings/accounts/models/user.py @@ -60,10 +60,11 @@ def events(self) -> QuerySet[Event]: @property def shift_count(self) -> int: from shiftings.shifts.models import Shift - total = Shift.objects.filter(participants__user=self).count() - for claimed_user in self.claimed_org_dummy_users.all(): - total += Shift.objects.filter(participants__user=claimed_user).count() - return total + # Collect all user IDs: self + claimed dummy users + user_ids = [self.pk] + user_ids += list(self.claimed_org_dummy_users.values_list('pk', flat=True)) + # Single query for all shifts where any of these users is a participant + return Shift.objects.filter(participants__user__in=user_ids).count() def get_absolute_url(self): return reverse('user_profile') diff --git a/src/shiftings/accounts/views/auth.py b/src/shiftings/accounts/views/auth.py index 10d9ba2..daf2d98 100644 --- a/src/shiftings/accounts/views/auth.py +++ b/src/shiftings/accounts/views/auth.py @@ -15,6 +15,7 @@ from django.urls import reverse from django.utils.decorators import method_decorator from django.utils.translation import gettext_lazy as _ +from django.utils.http import url_has_allowed_host_and_scheme from django.views.decorators.cache import never_cache from django.views.generic import RedirectView @@ -81,7 +82,12 @@ def get(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: if user is None: messages.error(request, _('Error while creating the user instance!')) return HttpResponseRedirect(settings.LOGIN_URL) + redirect_to = request.GET.get(REDIRECT_FIELD_NAME) + if redirect_to and not url_has_allowed_host_and_scheme(url=redirect_to, allowed_hosts={request.get_host()}): + messages.error(request, _('The redirect url is not safe!')) + return HttpResponseRedirect(settings.LOGIN_URL) + login_user(self.request, user) messages.success(request, _('Successfully logged in!')) return HttpResponseRedirect(redirect_to or settings.LOGIN_REDIRECT_URL) diff --git a/src/shiftings/organizations/backends.py b/src/shiftings/organizations/backends.py index 4c533e7..5cefc67 100644 --- a/src/shiftings/organizations/backends.py +++ b/src/shiftings/organizations/backends.py @@ -21,13 +21,23 @@ def has_perm(self, user_obj: User, perm: str, obj: Optional[Model] = None) -> bo def _collect_permissions(self, user_obj: User, obj: Organization) -> set[str]: if user_obj.is_superuser: perms = Permission.objects.all() - else: - perms = Permission.objects.none() - for member in obj.members.all(): - if member.is_member(user_obj): - perms |= member.type.permissions.all() - perms = perms.values_list('content_type__app_label', 'codename').order_by() - return {f'{ct}.{name}' for ct, name in perms} + perms = perms.values_list('content_type__app_label', 'codename').order_by() + return {f'{ct}.{name}' for ct, name in perms} + + permissions: set[str] = set() + + members = ( + obj.members + .select_related('type') + .prefetch_related('type__permissions__content_type') + ) + + for member in members: + permissions.update( + f"{perm.content_type.app_label}.{perm.codename}" + for perm in member.type.permissions.all() + ) + return permissions def get_all_permissions(self, user_obj: User, obj: Optional[Organization] = None) -> set[str]: if not user_obj.is_active or user_obj.is_anonymous: diff --git a/src/shiftings/settings.py b/src/shiftings/settings.py index 65a21e3..14ff043 100644 --- a/src/shiftings/settings.py +++ b/src/shiftings/settings.py @@ -10,12 +10,11 @@ https://docs.djangoproject.com/en/4.0/ref/settings/ """ import os -from pathlib import Path # Build paths inside the project like this: BASE_DIR / 'subdir'. from django.urls import reverse_lazy -BASE_DIR = Path(__file__).resolve().parent.parent +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # Quick-start development settings - unsuitable for production # See https://docs.djangoproject.com/en/4.0/howto/deployment/checklist/ @@ -115,7 +114,7 @@ DATABASES = { 'default': { 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': BASE_DIR.parent / 'test_db' / 'db.sqlite3', + 'NAME': os.path.join(os.path.dirname(BASE_DIR), 'test_db', 'db.sqlite3'), } } @@ -175,12 +174,12 @@ ['en', 'English'], ] USE_I18N = True -LOCALE_PATHS = [ - BASE_DIR / 'locale', -] +LOCALE_PATHS : tuple[str] = (os.path.join(BASE_DIR, 'locale'),) -TIME_ZONE = 'UTC' -USE_TZ = False +# Time zone settings +# https://docs.djangoproject.com/en/6.0/topics/i18n/timezones/ +TIME_ZONE : str = 'UTC' +USE_TZ : bool = False # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/4.0/howto/static-files/ diff --git a/src/shiftings/shifts/forms/shift.py b/src/shiftings/shifts/forms/shift.py index 147e851..c0df930 100644 --- a/src/shiftings/shifts/forms/shift.py +++ b/src/shiftings/shifts/forms/shift.py @@ -40,7 +40,7 @@ def clean(self) -> Dict[str, Any]: ## TODO: raise form error if not valid, but first implement proper error display in template max_length = timedelta(minutes=settings.MAX_SHIFT_LENGTH_MINUTES) - if end - start > max_length: + if start and end and end - start > max_length: self.add_error('end', ValidationError(_('Shift is too long, can at most be {max} long').format( max=localize_timedelta(max_length) ))) diff --git a/src/shiftings/shifts/models/recurring.py b/src/shiftings/shifts/models/recurring.py index df62e63..c4d1384 100644 --- a/src/shiftings/shifts/models/recurring.py +++ b/src/shiftings/shifts/models/recurring.py @@ -151,7 +151,7 @@ def create_shifts(self, _date: date) -> None: if _date in holidays.country_holidays(holiday.get('country'), holiday.get('region')): if self.holiday_handling is ProblemHandling.Cancel: return None - elif self.weekend_handling is ProblemHandling.Warn and self.holiday_warning is not None: + elif self.holiday_handling is ProblemHandling.Warn and self.holiday_warning is not None: holiday_warning = self.holiday_warning shifts = self.template.get_shift_objs(_date, weekend_warning, holiday_warning) diff --git a/src/shiftings/shifts/models/shift.py b/src/shiftings/shifts/models/shift.py index d9aabd3..1a55097 100644 --- a/src/shiftings/shifts/models/shift.py +++ b/src/shiftings/shifts/models/shift.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import cached_property from typing import Optional, TYPE_CHECKING, Union from django.contrib.contenttypes.fields import GenericRelation @@ -77,31 +78,31 @@ def time_display(self) -> str: start=self.start.strftime('%H:%M'), end_time=self.end.strftime('%H:%M'), ) + @cached_property + def participant_count(self) -> int: + return self.participants.all().count() + @property def is_full(self) -> bool: - return self.max_users != 0 and self.participants.all().count() >= self.max_users + return self.max_users != 0 and self.participant_count >= self.max_users @property def participants_missing(self) -> int: - if self.max_users == 0: - return 0 - return max(self.max_users - self.participants.all().count(), 0) + return max(self.max_users - self.participant_count, 0) @property def has_required(self) -> bool: - return self.participants.all().count() >= self.required_users + return self.participant_count >= self.required_users @property def required_participants_missing(self) -> int: - if self.required_users == 0: - return 0 - return max(self.required_users - self.participants.all().count(), 0) + return max(self.required_users - self.participant_count, 0) @property def confirmed_participants(self) -> Optional[int]: if not self.organization.confirm_participation_active: return None - return self.participants.all().count() - self.participants.filter(confirmed=True).count() + return self.participants.filter(confirmed=True).count() @property def email(self) -> str: diff --git a/src/shiftings/shifts/utils/time_frame.py b/src/shiftings/shifts/utils/time_frame.py index 7696844..a273df6 100644 --- a/src/shiftings/shifts/utils/time_frame.py +++ b/src/shiftings/shifts/utils/time_frame.py @@ -69,7 +69,7 @@ def _matches_every_nth_weekday(shift: RecurringShift, _date: date) -> bool: return True if _date == shift.first_occurrence: return True - weeks = (_date - shift.first_occurrence).days / 7 + weeks = (_date - shift.first_occurrence).days // 7 return weeks % shift.ordinal == 0 @staticmethod diff --git a/src/shiftings/utils/views/base.py b/src/shiftings/utils/views/base.py index e32ff0b..dfe8068 100644 --- a/src/shiftings/utils/views/base.py +++ b/src/shiftings/utils/views/base.py @@ -6,7 +6,7 @@ from django.core.exceptions import ImproperlyConfigured, PermissionDenied from django.db.models import Model from django.http import Http404, HttpRequest, HttpResponse, HttpResponseNotFound, HttpResponseRedirect -from django.utils.http import urlencode +from django.utils.http import url_has_allowed_host_and_scheme, urlencode from django.utils.translation import gettext_lazy as _ from django.views.generic.base import ContextMixin @@ -93,9 +93,13 @@ def get_success_url(self) -> str: @property def success(self) -> HttpResponse: if 'success_url' in self.request.POST: - return HttpResponseRedirect(str(self.request.POST['success_url'])) + success_url : str = self.request.POST['success_url'] + if url_has_allowed_host_and_scheme(url=success_url, allowed_hosts={self.request.get_host()}): + return HttpResponseRedirect(success_url) + else: + raise Http403(_('The provided success url is not allowed. This might be a configuration error.')) return HttpResponseRedirect(self.get_success_url()) - + def get_fail_url(self) -> Optional[str]: return self.fail_url