Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions app/alembic/versions/16a9fa2c1f1e_rename_readonly_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def upgrade(container: AsyncContainer) -> None: # noqa: ARG001
return

ro_dir.name = READ_ONLY_GROUP_NAME

ro_dir.create_path(ro_dir.parent, ro_dir.get_dn_prefix())
path = ro_dir.parent.path if ro_dir.parent else []
ro_dir.create_path(path, ro_dir.get_dn_prefix())

session.execute(
update(Attribute)
Expand Down Expand Up @@ -92,7 +92,8 @@ def downgrade(container: AsyncContainer) -> None: # noqa: ARG001

ro_dir.name = "readonly domain controllers"

ro_dir.create_path(ro_dir.parent, ro_dir.get_dn_prefix())
path = ro_dir.parent.path if ro_dir.parent else []
ro_dir.create_path(path, ro_dir.get_dn_prefix())

session.execute(
update(Attribute)
Expand Down
9 changes: 7 additions & 2 deletions app/alembic/versions/71e642808369_add_directory_is_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,13 @@ async def _indicate_system_directories(
if not base_dn_list:
return

for base_dn in base_dn_list:
base_dn.is_system = True
await session.execute(
update(Directory)
.where(
qa(Directory.parent_id).is_(None),
)
.values(is_system=True),
)

await session.flush()

Expand Down
81 changes: 81 additions & 0 deletions app/dtos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Module for dtos."""

import uuid
from dataclasses import dataclass
from datetime import datetime
from typing import ClassVar

from adaptix.conversion import get_converter

from entities import Directory, DistinguishedNamePrefix


@dataclass
class DirectoryDTO:
id: int
name: str
is_system: bool
object_sid: str
object_guid: uuid.UUID
parent_id: int | None
entity_type_id: int | None
object_class: str
rdname: str
created_at: datetime | None
updated_at: datetime | None
depth: int
password_policy_id: int | None
path: list[str]

search_fields: ClassVar[dict[str, str]] = {
"name": "name",
"objectguid": "objectGUID",
"objectsid": "objectSid",
}
ro_fields: ClassVar[set[str]] = {
"uid",
"whencreated",
"lastlogon",
"authtimestamp",
"objectguid",
"objectsid",
"entitytypename",
}

def get_dn_prefix(self) -> DistinguishedNamePrefix:
return {
"organizationalUnit": "ou",
"domain": "dc",
"container": "cn",
}.get(
self.object_class,
"cn",
) # type: ignore

def get_dn(self, dn: str = "cn") -> str:
return f"{dn}={self.name}"

@property
def is_domain(self) -> bool:
return not self.parent_id and self.object_class == "domain"

@property
def host_principal(self) -> str:
return f"host/{self.name}"

@property
def path_dn(self) -> str:
return ",".join(reversed(self.path))

@property
def relative_id(self) -> str:
"""Get RID from objectSid.

Relative Identifier (RID) is the last sub-authority value of a SID.
"""
if "-" in self.object_sid:
return self.object_sid.split("-")[-1]
return ""


_directory_sqla_obj_to_dto = get_converter(Directory, DirectoryDTO)
4 changes: 2 additions & 2 deletions app/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,10 @@ def path_dn(self) -> str:

def create_path(
self,
parent: Directory | None = None,
parent_path: list | None = None,
dn: str = "cn",
) -> None:
pre = parent.path if parent else []
pre = parent_path or []
self.path = pre + [self.get_dn(dn)]
self.depth = len(self.path)
self.rdname = dn
Expand Down
12 changes: 8 additions & 4 deletions app/ldap_protocol/auth/setup_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from sqlalchemy import exists, select
from sqlalchemy.ext.asyncio import AsyncSession

from dtos import DirectoryDTO
from entities import Attribute, Directory, Group, NetworkPolicy, User
from ldap_protocol.ldap_schema.attribute_value_validator import (
AttributeValueValidator,
)
from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO
from ldap_protocol.utils.async_cache import base_directories_cache
from ldap_protocol.utils.helpers import create_object_sid, generate_domain_sid
from ldap_protocol.utils.queries import get_domain_object_class
from password_utils import PasswordUtils
Expand Down Expand Up @@ -113,6 +115,7 @@ async def setup_enviroment(
domain=domain,
parent=domain,
)
base_directories_cache.clear()

except Exception:
import traceback
Expand All @@ -124,21 +127,22 @@ async def create_dir(
self,
data: dict,
is_system: bool,
domain: Directory,
parent: Directory | None = None,
domain: Directory | DirectoryDTO,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Зачем здесь оставлять Directory ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Тут может приходить Directory (app/ldap_protocol/auth/setup_gateway.py::72)

parent: Directory | DirectoryDTO | None = None,
) -> None:
"""Create data recursively."""
dir_ = Directory(
is_system=is_system,
object_class=data["object_class"],
name=data["name"],
parent=parent,
)
dir_.groups = []
dir_.create_path(parent, dir_.get_dn_prefix())
path = parent.path if parent else []
dir_.create_path(path, dir_.get_dn_prefix())

self._session.add(dir_)
await self._session.flush()
dir_.parent_id = parent.id if parent else None
Comment on lines 140 to +145
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parent_id is set after the flush but before the refresh. If parent is None, this sets parent_id to None. However, the refresh on line 144 only includes "id" in the attribute_names, not "parent_id". If the relationship between parent_id and parent needs to be maintained correctly, consider refreshing both or setting parent_id before the initial flush at line 142.

Suggested change
self._session.add(dir_)
await self._session.flush()
dir_.parent_id = parent.id if parent else None
dir_.parent_id = parent.id if parent else None
self._session.add(dir_)
await self._session.flush()

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Почему то в таком случае parent_id не сохраняется

await self._session.refresh(dir_, ["id"])

self._session.add(
Expand Down
2 changes: 1 addition & 1 deletion app/ldap_protocol/ldap_requests/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ async def handle( # noqa: C901
parent=parent,
)

new_dir.create_path(parent, new_dn)
new_dir.create_path(parent.path, new_dn)
ctx.session.add(new_dir)

await ctx.session.flush()
Expand Down
2 changes: 1 addition & 1 deletion app/ldap_protocol/ldap_requests/modify_dn.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ async def handle(
return

directory.parent = parent_dir
directory.create_path(directory.parent, dn=new_dn)
directory.create_path(parent_dir.path, dn=new_dn)

try:
await ctx.session.flush()
Expand Down
3 changes: 2 additions & 1 deletion app/ldap_protocol/ldap_requests/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sqlalchemy.sql.elements import ColumnElement, UnaryExpression
from sqlalchemy.sql.expression import Select

from dtos import DirectoryDTO
from entities import (
Attribute,
AttributeType,
Expand Down Expand Up @@ -367,7 +368,7 @@ def _mutate_query_with_attributes_to_load(

def _build_query(
self,
base_directories: list[Directory],
base_directories: list[DirectoryDTO],
user: UserSchema,
access_manager: AccessManager,
) -> Select[tuple[Directory]]:
Expand Down
3 changes: 2 additions & 1 deletion app/ldap_protocol/roles/role_use_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sqlalchemy import and_, insert, literal, or_, select

from dtos import DirectoryDTO
from entities import AccessControlEntry, AceType, Directory, Role
from enums import AuthorizationRules, RoleConstants, RoleScope
from ldap_protocol.utils.queries import get_base_directories
Expand Down Expand Up @@ -40,7 +41,7 @@ def __init__(

async def inherit_parent_aces(
self,
parent_directory: Directory,
parent_directory: Directory | DirectoryDTO,
directory: Directory,
) -> None:
"""Inherit access control entries from the parent directory.
Expand Down
42 changes: 42 additions & 0 deletions app/ldap_protocol/utils/async_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Async cache implementation."""
import time
from functools import wraps
from typing import Callable, Generic, TypeVar

from dtos import DirectoryDTO

T = TypeVar("T")
DEFAULT_CACHE_TIME = 5 * 60 # 5 minutes


class AsyncTTLCache(Generic[T]):
def __init__(self, ttl: int | None = DEFAULT_CACHE_TIME) -> None:
self._ttl = ttl
self._value: T | None = None
self._expires_at: float | None = None

def clear(self) -> None:
self._value = None
self._expires_at = None

def __call__(self, func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args: tuple, **kwargs: dict) -> T:
if self._value is not None:
if not self._expires_at or self._expires_at > time.monotonic():
return self._value
self.clear()

result = await func(*args, **kwargs)

self._value = result
self._expires_at = (
time.monotonic() + self._ttl if self._ttl else None
)

return result

return wrapper


base_directories_cache = AsyncTTLCache[list[DirectoryDTO]]()
42 changes: 38 additions & 4 deletions app/ldap_protocol/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE
"""

import asyncio
import functools
import hashlib
import random
Expand All @@ -138,9 +139,10 @@
import time
from calendar import timegm
from datetime import datetime
from functools import wraps
from hashlib import blake2b
from operator import attrgetter
from typing import Callable
from typing import Any, Callable, Generic, TypeVar
from zoneinfo import ZoneInfo

from loguru import logger
Expand All @@ -149,6 +151,7 @@
from sqlalchemy.sql.compiler import DDLCompiler
from sqlalchemy.sql.expression import ClauseElement, Executable, Visitable

from dtos import DirectoryDTO
from entities import Directory


Expand Down Expand Up @@ -192,12 +195,18 @@ def validate_attribute(attribute: str) -> bool:
)


def is_dn_in_base_directory(base_directory: Directory, entry: str) -> bool:
def is_dn_in_base_directory(
base_directory: DirectoryDTO,
entry: str,
) -> bool:
"""Check if an entry in a base dn."""
return entry.lower().endswith(base_directory.path_dn.lower())


def dn_is_base_directory(base_directory: Directory, entry: str) -> bool:
def dn_is_base_directory(
base_directory: DirectoryDTO,
entry: str,
) -> bool:
"""Check if an entry is a base dn."""
return base_directory.path_dn.lower() == entry.lower()

Expand Down Expand Up @@ -302,7 +311,7 @@ def string_to_sid(sid_string: str) -> bytes:


def create_object_sid(
domain: Directory,
domain: Directory | DirectoryDTO,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Здесь вроде тоже можно оставить только dto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Вызывается внутри create_dir, там могут быть оба варианта

rid: int,
reserved: bool = False,
) -> str:
Expand Down Expand Up @@ -402,3 +411,28 @@ async def explain_query(
for row in await session.execute(explain(query, analyze=True))
),
)


# def async_cache(ttl: int | None = DEFAULT_CACHE_TIME) -> Callable:
# """Cache for get_base_directories"""
# cache: list[tuple[list[DirectoryDTO], float | None]] = []

# def decorator(func: Callable) -> Callable:
# @wraps(func)
# async def wrapper(*args: tuple, **kwargs: dict) -> list[DirectoryDTO]:
# if cache:
# value, expires_at = cache[0]
# if not expires_at or expires_at > time.monotonic():
# return value
# else:
# cache.clear()

# result = await func(*args, **kwargs)
# expires_at = time.monotonic() + ttl if ttl else None
# cache.append((result, expires_at))

# return result

# return wrapper

# return decorator
Loading
Loading