diff --git a/bluesky_httpserver/app.py b/bluesky_httpserver/app.py index f09acb3..9a8420a 100644 --- a/bluesky_httpserver/app.py +++ b/bluesky_httpserver/app.py @@ -15,7 +15,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi -from .authentication import Mode +from .authentication import ExternalAuthenticator, InternalAuthenticator from .console_output import CollectPublishedConsoleOutput, ConsoleOutputStream, SystemInfoStream from .core import PatchedStreamingResponse from .database.core import purge_expired @@ -179,12 +179,11 @@ def build_app(authentication=None, api_access=None, resource_access=None, server for spec in authentication["providers"]: provider = spec["provider"] authenticator = spec["authenticator"] - mode = authenticator.mode - if mode == Mode.password: + if isinstance(authenticator, InternalAuthenticator): authentication_router.post(f"/provider/{provider}/token")( build_handle_credentials_route(authenticator, provider) ) - elif mode == Mode.external: + elif isinstance(authenticator, ExternalAuthenticator): authentication_router.get(f"/provider/{provider}/code")( build_auth_code_route(authenticator, provider) ) @@ -192,7 +191,7 @@ def build_app(authentication=None, api_access=None, resource_access=None, server build_auth_code_route(authenticator, provider) ) else: - raise ValueError(f"unknown authentication mode {mode}") + raise ValueError(f"unknown authenticator type {type(authenticator)}") for custom_router in getattr(authenticator, "include_routers", []): authentication_router.include_router(custom_router, prefix=f"/provider/{provider}") diff --git a/bluesky_httpserver/authentication.py b/bluesky_httpserver/authentication.py index 673ca29..c7ff0a7 100644 --- a/bluesky_httpserver/authentication.py +++ b/bluesky_httpserver/authentication.py @@ -31,6 +31,11 @@ from pydantic_settings import BaseSettings from . import schemas +from .authentication.authenticator_base import ( + ExternalAuthenticator, + InternalAuthenticator, + UserSessionState, +) from .authorization._defaults import _DEFAULT_ANONYMOUS_PROVIDER_NAME from .core import json_or_msgpack from .database import orm @@ -54,12 +59,6 @@ def utcnow(): "UTC now with second resolution" return datetime.utcnow().replace(microsecond=0) - -class Mode(enum.Enum): - password = "password" - external = "external" - - class Token(BaseModel): access_token: str token_type: str @@ -421,7 +420,8 @@ async def auth_code( api_access_manager=Depends(get_api_access_manager), ): request.state.endpoint = "auth" - username = await authenticator.authenticate(request) + user_session_state = await authenticator.authenticate(request) + username = user_session_state.user_name if user_session_state else None if username and api_access_manager.is_user_known(username): scopes = api_access_manager.get_user_scopes(username) @@ -450,7 +450,8 @@ async def handle_credentials( api_access_manager=Depends(get_api_access_manager), ): request.state.endpoint = "auth" - username = await authenticator.authenticate(username=form_data.username, password=form_data.password) + user_session_state = await authenticator.authenticate(username=form_data.username, password=form_data.password) + username = user_session_state.user_name if user_session_state else None err_msg = None if not username: diff --git a/bluesky_httpserver/authentication/__init__.py b/bluesky_httpserver/authentication/__init__.py new file mode 100644 index 0000000..58c758f --- /dev/null +++ b/bluesky_httpserver/authentication/__init__.py @@ -0,0 +1,11 @@ +from .authenticator_base import ( + ExternalAuthenticator, + InternalAuthenticator, + UserSessionState, +) + +__all__ = [ + "ExternalAuthenticator", + "InternalAuthenticator", + "UserSessionState", +] diff --git a/bluesky_httpserver/authentication/authenticator_base.py b/bluesky_httpserver/authentication/authenticator_base.py new file mode 100644 index 0000000..7a2cff3 --- /dev/null +++ b/bluesky_httpserver/authentication/authenticator_base.py @@ -0,0 +1,39 @@ +from abc import ABC +from dataclasses import dataclass +from typing import Optional + +from fastapi import Request + + +@dataclass +class UserSessionState: + """Data transfer class to communicate custom session state information.""" + + user_name: str + state: dict = None + + +class InternalAuthenticator(ABC): + """ + Base class for authenticators that use username/password credentials. + + Subclasses must implement the authenticate method which takes a username + and password and returns a UserSessionState on success or None on failure. + """ + + async def authenticate( + self, username: str, password: str + ) -> Optional[UserSessionState]: + raise NotImplementedError + + +class ExternalAuthenticator(ABC): + """ + Base class for authenticators that use external identity providers. + + Subclasses must implement the authenticate method which takes a FastAPI + Request object and returns a UserSessionState on success or None on failure. + """ + + async def authenticate(self, request: Request) -> Optional[UserSessionState]: + raise NotImplementedError diff --git a/bluesky_httpserver/authenticators.py b/bluesky_httpserver/authenticators.py index 61c2da4..3b439f4 100644 --- a/bluesky_httpserver/authenticators.py +++ b/bluesky_httpserver/authenticators.py @@ -1,21 +1,32 @@ import asyncio +import base64 import functools import logging import re import secrets from collections.abc import Iterable +from datetime import timedelta +from typing import Any, List, Mapping, Optional, cast +import httpx +from cachetools import TTLCache, cached from fastapi import APIRouter, Request -from jose import JWTError, jwk, jwt +from fastapi.security import OAuth2, OAuth2AuthorizationCodeBearer +from jose import JWTError, jwt +from pydantic import Secret from starlette.responses import RedirectResponse -from .authentication import Mode -from .utils import modules_available +from .authentication import ( + ExternalAuthenticator, + InternalAuthenticator, + UserSessionState, +) +from .utils import get_root_url, modules_available logger = logging.getLogger(__name__) -class DummyAuthenticator: +class DummyAuthenticator(InternalAuthenticator): """ For test and demo purposes only! @@ -23,26 +34,20 @@ class DummyAuthenticator: """ - mode = Mode.password + def __init__(self, confirmation_message: str = ""): + self.confirmation_message = confirmation_message - async def authenticate(self, username: str, password: str): - return username + async def authenticate(self, username: str, password: str) -> UserSessionState: + return UserSessionState(username, {}) -class DictionaryAuthenticator: +class DictionaryAuthenticator(InternalAuthenticator): """ For test and demo purposes only! Check passwords from a dictionary of usernames mapped to passwords. - - Parameters - ---------- - - users_to_passwords: dict(str, str) - Mapping of usernames to passwords. """ - mode = Mode.password configuration_schema = """ $schema": http://json-schema.org/draft-07/schema# type: object @@ -50,25 +55,32 @@ class DictionaryAuthenticator: properties: users_to_password: type: object - description: | - Mapping usernames to password. Environment variable expansion should be - used to avoid placing passwords directly in configuration. + description: | + Mapping usernames to password. Environment variable expansion should be + used to avoid placing passwords directly in configuration. + confirmation_message: + type: string + description: May be displayed by client after successful login. """ - def __init__(self, users_to_passwords): + def __init__( + self, users_to_passwords: Mapping[str, str], confirmation_message: str = "" + ): self._users_to_passwords = users_to_passwords + self.confirmation_message = confirmation_message - async def authenticate(self, username: str, password: str): + async def authenticate( + self, username: str, password: str + ) -> Optional[UserSessionState]: true_password = self._users_to_passwords.get(username) if not true_password: # Username is not valid. - return + return None if secrets.compare_digest(true_password, password): - return username + return UserSessionState(username, {}) -class PAMAuthenticator: - mode = Mode.password +class PAMAuthenticator(InternalAuthenticator): configuration_schema = """ $schema": http://json-schema.org/draft-07/schema# type: object @@ -77,90 +89,149 @@ class PAMAuthenticator: service: type: string description: PAM service. Default is 'login'. + confirmation_message: + type: string + description: May be displayed by client after successful login. """ - def __init__(self, service="login"): + def __init__(self, service: str = "login", confirmation_message: str = ""): if not modules_available("pamela"): - raise ModuleNotFoundError("This PAMAuthenticator requires the module 'pamela' to be installed.") + raise ModuleNotFoundError( + "This PAMAuthenticator requires the module 'pamela' to be installed." + ) self.service = service + self.confirmation_message = confirmation_message # TODO Try to open a PAM session. - async def authenticate(self, username: str, password: str): + async def authenticate( + self, username: str, password: str + ) -> Optional[UserSessionState]: import pamela try: pamela.authenticate(username, password, service=self.service) + return UserSessionState(username, {}) except pamela.PAMError: # Authentication failed. - return - else: - return username + return None -class OIDCAuthenticator: - mode = Mode.external +class OIDCAuthenticator(ExternalAuthenticator): configuration_schema = """ $schema": http://json-schema.org/draft-07/schema# type: object additionalProperties: false properties: + audience: + type: string client_id: type: string client_secret: type: string - redirect_uri: + well_known_uri: type: string - token_uri: + confirmation_message: type: string - authorization_endpoint: + redirect_on_success: + type: string + redirect_on_failure: type: string - public_keys: - type: array - item: - type: object - properties: - - alg: - type: string - - e - type: string - - kid - type: string - - kty - type: string - - n - type: string - - use - type: string - required: - - alg - - e - - kid - - kty - - n - - use """ def __init__( self, - client_id, - client_secret, - redirect_uri, - public_keys, - token_uri, - authorization_endpoint, - confirmation_message, + audience: str, + client_id: str, + client_secret: str, + well_known_uri: str, + confirmation_message: str = "", + redirect_on_success: Optional[str] = None, + redirect_on_failure: Optional[str] = None, ): - self.client_id = client_id - self.client_secret = client_secret + self._audience = audience + self._client_id = client_id + self._client_secret = Secret(client_secret) + self._well_known_url = well_known_uri self.confirmation_message = confirmation_message - self.redirect_uri = redirect_uri - self.public_keys = public_keys - self.token_uri = token_uri - self.authorization_endpoint = authorization_endpoint.format(client_id=client_id, redirect_uri=redirect_uri) - - async def authenticate(self, request): - code = request.query_params["code"] - response = await exchange_code(self.token_uri, code, self.client_id, self.client_secret, self.redirect_uri) + self.redirect_on_success = redirect_on_success + self.redirect_on_failure = redirect_on_failure + + @functools.cached_property + def _config_from_oidc_url(self) -> dict[str, Any]: + response: httpx.Response = httpx.get(self._well_known_url) + response.raise_for_status() + return response.json() + + @functools.cached_property + def client_id(self) -> str: + return self._client_id + + @functools.cached_property + def id_token_signing_alg_values_supported(self) -> list[str]: + return cast( + list[str], + self._config_from_oidc_url.get("id_token_signing_alg_values_supported"), + ) + + @functools.cached_property + def issuer(self) -> str: + return cast(str, self._config_from_oidc_url.get("issuer")) + + @functools.cached_property + def jwks_uri(self) -> str: + return cast(str, self._config_from_oidc_url.get("jwks_uri")) + + @functools.cached_property + def token_endpoint(self) -> str: + return cast(str, self._config_from_oidc_url.get("token_endpoint")) + + @functools.cached_property + def authorization_endpoint(self) -> httpx.URL: + return httpx.URL( + cast(str, self._config_from_oidc_url.get("authorization_endpoint")) + ) + + @functools.cached_property + def device_authorization_endpoint(self) -> str: + return cast( + str, self._config_from_oidc_url.get("device_authorization_endpoint") + ) + + @functools.cached_property + def end_session_endpoint(self) -> str: + return cast(str, self._config_from_oidc_url.get("end_session_endpoint")) + + @cached(TTLCache(maxsize=1, ttl=timedelta(days=7).total_seconds())) + def keys(self) -> List[str]: + return httpx.get(self.jwks_uri).raise_for_status().json().get("keys", []) + + def decode_token(self, token: str) -> dict[str, Any]: + return jwt.decode( + token, + key=self.keys(), + algorithms=self.id_token_signing_alg_values_supported, + audience=self._audience, + issuer=self.issuer, + ) + + async def authenticate(self, request: Request) -> Optional[UserSessionState]: + code = request.query_params.get("code") + if not code: + logger.warning( + "Authentication failed: No authorization code parameter provided." + ) + return None + # A proxy in the middle may make the request into something like + # 'http://localhost:8000/...' so we fix the first part but keep + # the original URI path. + redirect_uri = f"{get_root_url(request)}{request.url.path}" + response = await exchange_code( + self.token_endpoint, + code, + self._client_id, + self._client_secret.get_secret_value(), + redirect_uri, + ) response_body = response.json() if response.is_error: logger.error("Authentication error: %r", response_body) @@ -168,63 +239,84 @@ async def authenticate(self, request): response_body = response.json() id_token = response_body["id_token"] access_token = response_body["access_token"] - # Match the kid in id_token to a key in the list of public_keys. - key = find_key(id_token, self.public_keys) try: - verified_body = jwt.decode(id_token, key, access_token=access_token, audience=self.client_id) + verified_body = self.decode_token(access_token) except JWTError: logger.exception( "Authentication error. Unverified token: %r", jwt.get_unverified_claims(id_token), ) return None - return verified_body["sub"] + return UserSessionState(verified_body["sub"], {}) -class KeyNotFoundError(Exception): - pass - - -def find_key(token, keys): - """ - Find a key from the configured keys based on the kid claim of the token - - Parameters - ---------- - token : token to search for the kid from - keys: list of keys - - Raises - ------ - KeyNotFoundError: - returned if the token does not have a kid claim - - Returns - ------ - key: found key object - """ +class ProxiedOIDCAuthenticator(OIDCAuthenticator): + configuration_schema = """ +$schema": http://json-schema.org/draft-07/schema# +type: object +additionalProperties: false +properties: + audience: + type: string + client_id: + type: string + well_known_uri: + type: string + scopes: + type: array + items: + type: string + description: | + Optional list of OAuth2 scopes to request. If provided, authorization + should be enforced by an external policy agent (for example ExternalPolicyDecisionPoint) + rather than by this authenticator. + device_flow_client_id: + type: string + confirmation_message: + type: string +""" - unverified = jwt.get_unverified_header(token) - kid = unverified.get("kid") - if not kid: - raise KeyNotFoundError("No 'kid' in token") + def __init__( + self, + audience: str, + client_id: str, + well_known_uri: str, + device_flow_client_id: str, + scopes: Optional[List[str]] = None, + confirmation_message: str = "", + ): + super().__init__( + audience=audience, + client_id=client_id, + client_secret="", + well_known_uri=well_known_uri, + confirmation_message=confirmation_message, + ) + self.scopes = scopes + self.device_flow_client_id = device_flow_client_id + self._oidc_bearer = OAuth2AuthorizationCodeBearer( + authorizationUrl=str(self.authorization_endpoint), + tokenUrl=self.token_endpoint, + ) - for key in keys: - if key["kid"] == kid: - return jwk.construct(key) - return KeyNotFoundError(f"Token specifies {kid} but we have {[k['kid'] for k in keys]}") + @property + def oauth2_schema(self) -> OAuth2: + return self._oidc_bearer -async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect_uri): +async def exchange_code( + token_uri: str, + auth_code: str, + client_id: str, + client_secret: str, + redirect_uri: str, +) -> httpx.Response: """Method that talks to an IdP to exchange a code for an access_token and/or id_token Args: token_url ([type]): [description] auth_code ([type]): [description] """ - if not modules_available("httpx"): - raise ModuleNotFoundError("This authenticator requires 'httpx'. (pip install httpx)") - import httpx - + auth_value = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode() response = httpx.post( url=token_uri, data={ @@ -234,18 +326,18 @@ async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect "code": auth_code, "client_secret": client_secret, }, + headers={"Authorization": f"Basic {auth_value}"}, ) return response -class SAMLAuthenticator: - mode = Mode.external +class SAMLAuthenticator(ExternalAuthenticator): def __init__( self, saml_settings, # See EXAMPLE_SAML_SETTINGS below. - attribute_name, # which SAML attribute to use as 'id' for Idenity - confirmation_message=None, + attribute_name: str, # which SAML attribute to use as 'id' for Identity + confirmation_message: str = "", ): self.saml_settings = saml_settings self.attribute_name = attribute_name @@ -258,30 +350,26 @@ def __init__( # The PyPI package name is 'python3-saml' # but it imports as 'onelogin'. # https://github.com/onelogin/python3-saml - raise ModuleNotFoundError("This SAMLAuthenticator requires 'python3-saml' to be installed.") + raise ModuleNotFoundError( + "This SAMLAuthenticator requires 'python3-saml' to be installed." + ) from onelogin.saml2.auth import OneLogin_Saml2_Auth @router.get("/login") - async def saml_login(request: Request): + async def saml_login(request: Request) -> RedirectResponse: req = await prepare_saml_from_fastapi_request(request) auth = OneLogin_Saml2_Auth(req, self.saml_settings) - # saml_settings = auth.get_settings() - # metadata = saml_settings.get_sp_metadata() - # errors = saml_settings.validate_metadata(metadata) - # if len(errors) == 0: - # print(metadata) - # else: - # print("Error found on Metadata: %s" % (', '.join(errors))) callback_url = auth.login() - response = RedirectResponse(url=callback_url) - return response + return RedirectResponse(url=callback_url) self.include_routers = [router] - async def authenticate(self, request): + async def authenticate(self, request: Request) -> Optional[UserSessionState]: if not modules_available("onelogin"): - raise ModuleNotFoundError("This SAMLAuthenticator requires the module 'oneline' to be installed.") + raise ModuleNotFoundError( + "This SAMLAuthenticator requires the module 'oneline' to be installed." + ) from onelogin.saml2.auth import OneLogin_Saml2_Auth req = await prepare_saml_from_fastapi_request(request, True) @@ -290,26 +378,27 @@ async def authenticate(self, request): errors = auth.get_errors() # This method receives an array with the errors if errors: raise Exception( - "Error when processing SAML Response: %s %s" % (", ".join(errors), auth.get_last_error_reason()) + "Error when processing SAML Response: %s %s" + % (", ".join(errors), auth.get_last_error_reason()) ) if auth.is_authenticated(): # Return a string that the Identity can use as id. attribute_as_list = auth.get_attributes()[self.attribute_name] # Confused in what situation this would have more than one item.... assert len(attribute_as_list) == 1 - return attribute_as_list[0] + return UserSessionState(attribute_as_list[0], {}) else: return None -async def prepare_saml_from_fastapi_request(request, debug=False): +async def prepare_saml_from_fastapi_request(request: Request) -> Mapping[str, str]: form_data = await request.form() rv = { "http_host": request.client.host, "server_port": request.url.port, "script_name": request.url.path, "post_data": {}, - "get_data": {}, + "get_data": {} # Advanced request options # "https": "", # "request_uri": "", @@ -328,7 +417,7 @@ async def prepare_saml_from_fastapi_request(request, debug=False): return rv -class LDAPAuthenticator: +class LDAPAuthenticator(InternalAuthenticator): """ LDAP authenticator. The authenticator code is based on https://github.com/jupyterhub/ldapauthenticator @@ -472,6 +561,8 @@ class LDAPAuthenticator: This can be useful in an heterogeneous environment, when supplying a UNIX username to authenticate against AD. + confirmation_message: str + May be displayed by client after successful login. Examples -------- @@ -510,8 +601,6 @@ class LDAPAuthenticator: id: user02 """ - mode = Mode.password - def __init__( self, server_address, @@ -536,6 +625,7 @@ def __init__( attributes=None, auth_state_attributes=None, use_lookup_dn_username=True, + confirmation_message="", ): self.use_ssl = use_ssl self.use_tls = use_tls @@ -554,7 +644,9 @@ def __init__( self.escape_userdn = escape_userdn self.search_filter = search_filter self.attributes = attributes if attributes else [] - self.auth_state_attributes = auth_state_attributes if auth_state_attributes else [] + self.auth_state_attributes = ( + auth_state_attributes if auth_state_attributes else [] + ) self.use_lookup_dn_username = use_lookup_dn_username if isinstance(server_address, str): @@ -567,10 +659,15 @@ def __init__( f"type(server_address)={type(server_address)}" ) if not server_address_list: - raise ValueError("No servers are specified: 'server_address' is an empty list") + raise ValueError( + "No servers are specified: 'server_address' is an empty list" + ) self.server_address_list = server_address_list - self.server_port = server_port if server_port is not None else self._server_port_default() + self.server_port = ( + server_port if server_port is not None else self._server_port_default() + ) + self.confirmation_message = confirmation_message def _server_port_default(self): if self.use_ssl: @@ -623,8 +720,15 @@ async def resolve_username(self, username_supplied_by_user): response = conn.response if len(response) == 0 or "attributes" not in response[0].keys(): - msg = "No entry found for user '{username}' " "when looking up attribute '{attribute}'" - logger.warning(msg.format(username=username_supplied_by_user, attribute=self.user_attribute)) + msg = ( + "No entry found for user '{username}' " + "when looking up attribute '{attribute}'" + ) + logger.warning( + msg.format( + username=username_supplied_by_user, attribute=self.user_attribute + ) + ) return (None, None) user_dn = response[0]["attributes"][self.lookup_dn_user_dn_attribute] @@ -655,7 +759,7 @@ async def resolve_username(self, username_supplied_by_user): def get_connection(self, userdn, password): import ldap3 - # NOTE: setting 'acitve=False' essentially disables exclusion of inactive servers from the pool. + # NOTE: setting 'active=False' essentially disables exclusion of inactive servers from the pool. # It probably does not matter if the pool contains only one server, but it could have implications # when there are multiple servers in the pool. It is not clear what those implications are. # But using the default 'activate=True' results in the thread being blocked indefinitely @@ -675,14 +779,23 @@ def get_connection(self, userdn, password): server_port = self.server_port server = ldap3.Server( - server_addr, port=server_port, use_ssl=self.use_ssl, connect_timeout=self.connect_timeout + server_addr, + port=server_port, + use_ssl=self.use_ssl, + connect_timeout=self.connect_timeout, ) server_pool.add(server) - auto_bind_no_ssl = ldap3.AUTO_BIND_TLS_BEFORE_BIND if self.use_tls else ldap3.AUTO_BIND_NO_TLS + auto_bind_no_ssl = ( + ldap3.AUTO_BIND_TLS_BEFORE_BIND if self.use_tls else ldap3.AUTO_BIND_NO_TLS + ) auto_bind = ldap3.AUTO_BIND_NO_TLS if self.use_ssl else auto_bind_no_ssl conn = ldap3.Connection( - server_pool, user=userdn, password=password, auto_bind=auto_bind, receive_timeout=self.receive_timeout + server_pool, + user=userdn, + password=password, + auto_bind=auto_bind, + receive_timeout=self.receive_timeout, ) return conn @@ -690,14 +803,19 @@ async def get_user_attributes(self, conn, userdn): attrs = {} if self.auth_state_attributes: search_func = functools.partial( - conn.search, userdn, "(objectClass=*)", attributes=self.auth_state_attributes + conn.search, + userdn, + "(objectClass=*)", + attributes=self.auth_state_attributes, ) found = await asyncio.get_running_loop().run_in_executor(None, search_func) if found: attrs = conn.entries[0].entry_attributes_as_dict return attrs - async def authenticate(self, username: str, password: str): + async def authenticate( + self, username: str, password: str + ) -> Optional[UserSessionState]: import ldap3 username_saved = username # Save the user name passed as a parameter @@ -723,7 +841,9 @@ async def authenticate(self, username: str, password: str): # sanity check if not self.lookup_dn and not bind_dn_template: - logger.warning("Login not allowed, please configure 'lookup_dn' or 'bind_dn_template'.") + logger.warning( + "Login not allowed, please configure 'lookup_dn' or 'bind_dn_template'." + ) return None if self.lookup_dn: @@ -761,7 +881,9 @@ async def authenticate(self, username: str, password: str): if conn.bound: is_bound = True else: - is_bound = await asyncio.get_running_loop().run_in_executor(None, conn.bind) + is_bound = await asyncio.get_running_loop().run_in_executor( + None, conn.bind + ) msg = msg.format(username=username, userdn=userdn, is_bound=is_bound) logger.debug(msg) @@ -774,7 +896,9 @@ async def authenticate(self, username: str, password: str): return None if self.search_filter: - search_filter = self.search_filter.format(userattr=self.user_attribute, username=username) + search_filter = self.search_filter.format( + userattr=self.user_attribute, username=username + ) search_func = functools.partial( conn.search, @@ -788,18 +912,33 @@ async def authenticate(self, username: str, password: str): n_users = len(conn.response) if n_users == 0: msg = "User with '{userattr}={username}' not found in directory" - logger.warning(msg.format(userattr=self.user_attribute, username=username)) + logger.warning( + msg.format(userattr=self.user_attribute, username=username) + ) return None if n_users > 1: - msg = "Duplicate users found! " "{n_users} users found with '{userattr}={username}'" - logger.warning(msg.format(userattr=self.user_attribute, username=username, n_users=n_users)) + msg = ( + "Duplicate users found! " + "{n_users} users found with '{userattr}={username}'" + ) + logger.warning( + msg.format( + userattr=self.user_attribute, username=username, n_users=n_users + ) + ) return None if self.allowed_groups: logger.debug("username:%s Using dn %s", username, userdn) found = False for group in self.allowed_groups: - group_filter = "(|" "(member={userdn})" "(uniqueMember={userdn})" "(memberUid={uid})" ")" + group_filter = ( + "(|" + "(member={userdn})" + "(uniqueMember={userdn})" + "(memberUid={uid})" + ")" + ) group_filter = group_filter.format(userdn=userdn, uid=username) group_attributes = ["member", "uniqueMember", "memberUid"] @@ -810,7 +949,9 @@ async def authenticate(self, username: str, password: str): search_filter=group_filter, attributes=group_attributes, ) - found = await asyncio.get_running_loop().run_in_executor(None, search_func) + found = await asyncio.get_running_loop().run_in_executor( + None, search_func + ) if found: break @@ -826,5 +967,6 @@ async def authenticate(self, username: str, password: str): user_info = await self.get_user_attributes(conn, userdn) if user_info: logger.debug("username:%s attributes:%s", username, user_info) - return {"name": username, "auth_state": user_info} - return username + # this path might never have been worked out...is it ever hit? + return UserSessionState(username, user_info) + return UserSessionState(username, {})