Skip to content
9 changes: 5 additions & 4 deletions src/shiftings/accounts/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ def events(self) -> QuerySet[Event]:
@property
def shift_count(self) -> int:
from shiftings.shifts.models import Shift
total = Shift.objects.filter(participants__user=self).count()
for claimed_user in self.claimed_org_dummy_users.all():
total += Shift.objects.filter(participants__user=claimed_user).count()
return total
# Collect all user IDs: self + claimed dummy users
user_ids = [self.pk]
user_ids += list(self.claimed_org_dummy_users.values_list('pk', flat=True))
# Single query for all shifts where any of these users is a participant
return Shift.objects.filter(participants__user__in=user_ids).count()

def get_absolute_url(self):
return reverse('user_profile')
6 changes: 6 additions & 0 deletions src/shiftings/accounts/views/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from django.urls import reverse
from django.utils.decorators import method_decorator
from django.utils.translation import gettext_lazy as _
from django.utils.http import url_has_allowed_host_and_scheme
from django.views.decorators.cache import never_cache
from django.views.generic import RedirectView

Expand Down Expand Up @@ -81,7 +82,12 @@ def get(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
if user is None:
messages.error(request, _('Error while creating the user instance!'))
return HttpResponseRedirect(settings.LOGIN_URL)

redirect_to = request.GET.get(REDIRECT_FIELD_NAME)
if redirect_to and not url_has_allowed_host_and_scheme(url=redirect_to, allowed_hosts={request.get_host()}):
messages.error(request, _('The redirect url is not safe!'))
return HttpResponseRedirect(settings.LOGIN_URL)

login_user(self.request, user)
messages.success(request, _('Successfully logged in!'))
return HttpResponseRedirect(redirect_to or settings.LOGIN_REDIRECT_URL)
Expand Down
24 changes: 17 additions & 7 deletions src/shiftings/organizations/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,23 @@ def has_perm(self, user_obj: User, perm: str, obj: Optional[Model] = None) -> bo
def _collect_permissions(self, user_obj: User, obj: Organization) -> set[str]:
if user_obj.is_superuser:
perms = Permission.objects.all()
else:
perms = Permission.objects.none()
for member in obj.members.all():
if member.is_member(user_obj):
perms |= member.type.permissions.all()
perms = perms.values_list('content_type__app_label', 'codename').order_by()
return {f'{ct}.{name}' for ct, name in perms}
perms = perms.values_list('content_type__app_label', 'codename').order_by()
return {f'{ct}.{name}' for ct, name in perms}

permissions: set[str] = set()

members = (
obj.members
.select_related('type')
.prefetch_related('type__permissions__content_type')
)

for member in members:
permissions.update(
f"{perm.content_type.app_label}.{perm.codename}"
for perm in member.type.permissions.all()
)
return permissions

def get_all_permissions(self, user_obj: User, obj: Optional[Organization] = None) -> set[str]:
if not user_obj.is_active or user_obj.is_anonymous:
Expand Down
15 changes: 7 additions & 8 deletions src/shiftings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
https://docs.djangoproject.com/en/4.0/ref/settings/
"""
import os
from pathlib import Path

# Build paths inside the project like this: BASE_DIR / 'subdir'.
from django.urls import reverse_lazy

BASE_DIR = Path(__file__).resolve().parent.parent
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/4.0/howto/deployment/checklist/
Expand Down Expand Up @@ -115,7 +114,7 @@
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': BASE_DIR.parent / 'test_db' / 'db.sqlite3',
'NAME': os.path.join(os.path.dirname(BASE_DIR), 'test_db', 'db.sqlite3'),
}
}

Expand Down Expand Up @@ -175,12 +174,12 @@
['en', 'English'],
]
USE_I18N = True
LOCALE_PATHS = [
BASE_DIR / 'locale',
]
LOCALE_PATHS : tuple[str] = (os.path.join(BASE_DIR, 'locale'),)

TIME_ZONE = 'UTC'
USE_TZ = False
# Time zone settings
# https://docs.djangoproject.com/en/6.0/topics/i18n/timezones/
TIME_ZONE : str = 'UTC'
USE_TZ : bool = False

# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/4.0/howto/static-files/
Expand Down
2 changes: 1 addition & 1 deletion src/shiftings/shifts/forms/shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def clean(self) -> Dict[str, Any]:

## TODO: raise form error if not valid, but first implement proper error display in template
max_length = timedelta(minutes=settings.MAX_SHIFT_LENGTH_MINUTES)
if end - start > max_length:
if start and end and end - start > max_length:
self.add_error('end', ValidationError(_('Shift is too long, can at most be {max} long').format(
max=localize_timedelta(max_length)
)))
Expand Down
2 changes: 1 addition & 1 deletion src/shiftings/shifts/models/recurring.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def create_shifts(self, _date: date) -> None:
if _date in holidays.country_holidays(holiday.get('country'), holiday.get('region')):
if self.holiday_handling is ProblemHandling.Cancel:
return None
elif self.weekend_handling is ProblemHandling.Warn and self.holiday_warning is not None:
elif self.holiday_handling is ProblemHandling.Warn and self.holiday_warning is not None:
holiday_warning = self.holiday_warning

shifts = self.template.get_shift_objs(_date, weekend_warning, holiday_warning)
Expand Down
19 changes: 10 additions & 9 deletions src/shiftings/shifts/models/shift.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import cached_property
from typing import Optional, TYPE_CHECKING, Union

from django.contrib.contenttypes.fields import GenericRelation
Expand Down Expand Up @@ -77,31 +78,31 @@ def time_display(self) -> str:
start=self.start.strftime('%H:%M'),
end_time=self.end.strftime('%H:%M'), )

@cached_property
def participant_count(self) -> int:
return self.participants.all().count()

@property
def is_full(self) -> bool:
return self.max_users != 0 and self.participants.all().count() >= self.max_users
return self.max_users != 0 and self.participant_count >= self.max_users

@property
def participants_missing(self) -> int:
if self.max_users == 0:
return 0
return max(self.max_users - self.participants.all().count(), 0)
return max(self.max_users - self.participant_count, 0)

@property
def has_required(self) -> bool:
return self.participants.all().count() >= self.required_users
return self.participant_count >= self.required_users

@property
def required_participants_missing(self) -> int:
if self.required_users == 0:
return 0
return max(self.required_users - self.participants.all().count(), 0)
return max(self.required_users - self.participant_count, 0)

@property
def confirmed_participants(self) -> Optional[int]:
if not self.organization.confirm_participation_active:
return None
return self.participants.all().count() - self.participants.filter(confirmed=True).count()
return self.participants.filter(confirmed=True).count()

@property
def email(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion src/shiftings/shifts/utils/time_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _matches_every_nth_weekday(shift: RecurringShift, _date: date) -> bool:
return True
if _date == shift.first_occurrence:
return True
weeks = (_date - shift.first_occurrence).days / 7
weeks = (_date - shift.first_occurrence).days // 7
return weeks % shift.ordinal == 0

@staticmethod
Expand Down
10 changes: 7 additions & 3 deletions src/shiftings/utils/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from django.core.exceptions import ImproperlyConfigured, PermissionDenied
from django.db.models import Model
from django.http import Http404, HttpRequest, HttpResponse, HttpResponseNotFound, HttpResponseRedirect
from django.utils.http import urlencode
from django.utils.http import url_has_allowed_host_and_scheme, urlencode
from django.utils.translation import gettext_lazy as _
from django.views.generic.base import ContextMixin

Expand Down Expand Up @@ -93,9 +93,13 @@ def get_success_url(self) -> str:
@property
def success(self) -> HttpResponse:
if 'success_url' in self.request.POST:
return HttpResponseRedirect(str(self.request.POST['success_url']))
success_url : str = self.request.POST['success_url']
if url_has_allowed_host_and_scheme(url=success_url, allowed_hosts={self.request.get_host()}):
return HttpResponseRedirect(success_url)
else:
raise Http403(_('The provided success url is not allowed. This might be a configuration error.'))
return HttpResponseRedirect(self.get_success_url())

def get_fail_url(self) -> Optional[str]:
return self.fail_url

Expand Down