Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 72 additions & 29 deletions apps/hash-ai-worker-ts/src/workflows.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
import type {
ActorEntityUuid,
BaseUrl,
DataTypeWithMetadata,
Entity,
EntityId,
EntityTypeWithMetadata,
PropertyTypeWithMetadata,
} from "@blockprotocol/type-system";
import { extractBaseUrl } from "@blockprotocol/type-system";
import { publicUserAccountId } from "@local/hash-backend-utils/public-user-account-id";
import type {
Entity as GraphApiEntity,
EntityQueryCursor,
Filter,
} from "@local/hash-graph-client";
import type { EntityQueryCursor, Filter } from "@local/hash-graph-client";
import type {
CreateEmbeddingsParams,
CreateEmbeddingsReturn,
} from "@local/hash-graph-sdk/embeddings";
import {
deserializeQueryEntitiesResponse,
HashEntity,
} from "@local/hash-graph-sdk/entity";
import { deserializeQueryEntitiesResponse } from "@local/hash-graph-sdk/entity";
import { generateEntityIdFilter } from "@local/hash-isomorphic-utils/graph-queries";
import { systemEntityTypes } from "@local/hash-isomorphic-utils/ontology-type-ids";
import type { ParseTextFromFileParams } from "@local/hash-isomorphic-utils/parse-text-from-file-types";
import {
Expand Down Expand Up @@ -276,9 +273,14 @@ type UpdateEntityEmbeddingsParams = {
authentication: {
actorId: ActorEntityUuid;
};
/**
* Properties to exclude from embedding generation, keyed by entity type base URL.
* Values are arrays of property type base URLs to exclude for that entity type.
*/
embeddingExclusions?: Record<BaseUrl, BaseUrl[]>;
} & (
| {
entities: GraphApiEntity[];
entityIds: EntityId[];
}
| {
filter: Filter;
Expand Down Expand Up @@ -310,26 +312,50 @@ export const updateEntityEmbeddings = async (
total_tokens: 0,
};

// Build filter from entity IDs if provided
let filter: Filter;
if ("entityIds" in params) {
// Early return if no entity IDs provided - avoids ambiguous empty `any` filter
if (params.entityIds.length === 0) {
return usage;
}

// Build a filter matching any of the entity IDs, excluding FlowRun entities
filter = {
all: [
{
any: params.entityIds.map((entityId) =>
generateEntityIdFilter({ entityId, includeArchived: true }),
),
},
{
notEqual: [
{ path: ["type", "versionedUrl"] },
{ parameter: systemEntityTypes.flowRun.entityTypeId },
],
},
],
};
} else {
filter = params.filter;
}

// eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
while (true) {
if ("entities" in params) {
entities = params.entities.map((entity) => new HashEntity(entity));
} else {
const serializedResponse = await graphActivities.queryEntities({
authentication: params.authentication,
request: {
filter: params.filter,
temporalAxes,
includeDrafts: true,
includePermissions: false,
cursor,
limit: 100,
},
});
const response = deserializeQueryEntitiesResponse(serializedResponse);
cursor = response.cursor;
entities = response.entities;
}
const serializedResponse = await graphActivities.queryEntities({
authentication: params.authentication,
request: {
filter,
temporalAxes,
includeDrafts: true,
includePermissions: false,
cursor,
limit: 100,
},
});
const response = deserializeQueryEntitiesResponse(serializedResponse);
cursor = response.cursor;
entities = response.entities;

if (entities.length === 0) {
break;
Expand All @@ -340,6 +366,8 @@ export const updateEntityEmbeddings = async (
* Don't try to create embeddings for `FlowRun` entities, due to the size
* of their property values.
*
* This is a safety fallback - the filter should already exclude these.
*
* @todo: consider having a general approach for declaring which entity/property
* types should be skipped when generating embeddings.
*/
Expand Down Expand Up @@ -375,9 +403,24 @@ export const updateEntityEmbeddings = async (
subgraph,
});

// Filter out protected properties before embedding generation based on config.
const filteredProperties = { ...entity.properties };
if (params.embeddingExclusions) {
for (const entityTypeId of entity.metadata.entityTypeIds) {
const entityTypeBaseUrl = extractBaseUrl(entityTypeId);
const excludedProperties =
params.embeddingExclusions[entityTypeBaseUrl];
if (excludedProperties) {
for (const propertyBaseUrl of excludedProperties) {
delete filteredProperties[propertyBaseUrl];
}
}
}
}

const generatedEmbeddings =
await aiActivities.createEntityEmbeddingsActivity({
entityProperties: entity.properties,
entityProperties: filteredProperties,
propertyTypes,
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use hash_graph_store::{
error::{CheckPermissionError, InsertionError, QueryError, UpdateError},
filter::{
Filter, FilterExpression, FilterExpressionList, Parameter, ParameterList,
protection::{PropertyProtectionFilterConfig, transform_filter},
protection::transform_filter,
},
query::{QueryResult as _, Read},
subgraph::{
Expand Down Expand Up @@ -104,43 +104,6 @@ use crate::store::{
validation::StoreProvider,
};

/// Filters entity properties for embedding generation.
///
/// Removes protected properties from entities based on their types before
/// sending to the embedding service. This prevents sensitive data (e.g., email)
/// from being included in embeddings.
fn filter_entities_for_embedding(
entities: &[Entity],
config: &PropertyProtectionFilterConfig<'_>,
) -> Vec<Entity> {
let exclusions = config.embedding_exclusions();
if exclusions.is_empty() {
return entities.to_vec();
}

entities
.iter()
.cloned()
.map(|mut entity| {
// Collect all properties to exclude based on entity's types
let properties_to_exclude: HashSet<&BaseUrl> = entity
.metadata
.entity_type_ids
.iter()
.filter_map(|type_id| exclusions.get(&type_id.base_url))
.flatten()
.collect();

if !properties_to_exclude.is_empty() {
entity
.properties
.retain(|key, _| !properties_to_exclude.contains(key));
}
entity
})
.collect()
}

impl<C> PostgresStore<C>
where
C: AsClient,
Expand Down Expand Up @@ -1333,10 +1296,16 @@ where
if !self.settings.skip_embedding_creation
&& let Some(temporal_client) = &self.temporal_client
{
let filtered_entities =
filter_entities_for_embedding(&entities, &self.settings.filter_protection);
let entity_ids: Vec<EntityId> = entities
.iter()
.map(|entity| entity.metadata.record_id.entity_id)
.collect();
temporal_client
.start_update_entity_embeddings_workflow(actor_uuid, &filtered_entities)
.start_update_entity_embeddings_workflow(
actor_uuid,
&entity_ids,
self.settings.filter_protection.embedding_exclusions(),
)
.await
.change_context(InsertionError)?;
}
Expand Down Expand Up @@ -2356,10 +2325,16 @@ where
if !self.settings.skip_embedding_creation
&& let Some(temporal_client) = &self.temporal_client
{
let filtered_entities =
filter_entities_for_embedding(&entities, &self.settings.filter_protection);
let entity_ids: Vec<EntityId> = entities
.iter()
.map(|entity| entity.metadata.record_id.entity_id)
.collect();
temporal_client
.start_update_entity_embeddings_workflow(actor_id, &filtered_entities)
.start_update_entity_embeddings_workflow(
actor_id,
&entity_ids,
self.settings.filter_protection.embedding_exclusions(),
)
.await
.change_context(UpdateError)?;
}
Expand Down
36 changes: 22 additions & 14 deletions libs/@local/temporal-client/src/ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ use temporal_sdk_core_protos::{
ENCODING_PAYLOAD_KEY, JSON_ENCODING_VAL, temporal::api::common::v1::Payload,
};
use type_system::{
knowledge::Entity,
ontology::{DataTypeWithMetadata, EntityTypeWithMetadata, PropertyTypeWithMetadata},
knowledge::entity::EntityId,
ontology::{
DataTypeWithMetadata, EntityTypeWithMetadata, PropertyTypeWithMetadata, id::BaseUrl,
},
principal::actor::ActorEntityUuid,
};
use uuid::Uuid;
Expand Down Expand Up @@ -135,42 +137,48 @@ impl TemporalClient {
.await
}

/// Starts a workflow to update the embeddings for the provided entity.
/// Starts a workflow to update the embeddings for the provided entities.
///
/// Returns the run ID of the workflow.
/// The `embedding_exclusions` parameter specifies which properties should be excluded
/// from embedding generation for specific entity types (e.g., email for User entities).
///
/// Returns the run IDs of the workflows.
///
/// # Errors
///
/// Returns an error if the workflow fails to start.
/// Returns an error if any workflow fails to start.
pub async fn start_update_entity_embeddings_workflow(
&self,
actor_id: ActorEntityUuid,
entities: &[Entity],
entity_ids: &[EntityId],
embedding_exclusions: &HashMap<BaseUrl, Vec<BaseUrl>>,
) -> Result<Vec<String>, Report<WorkflowError>> {
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct UpdateEntityEmbeddingsParams<'a> {
authentication: AuthenticationContext,
entities: &'a [Entity],
entity_ids: &'a [EntityId],
embedding_exclusions: &'a HashMap<BaseUrl, Vec<BaseUrl>>,
}

// There is an upper limit on how many bytes can be sent in a single workflow invocation so
// we need to split the entities into chunks.
const CHUNK_SIZE: usize = 100;
// EntityIDs are small (~100 bytes each), but we still chunk to avoid hitting
// Temporal's payload size limits when dealing with very large batches.
const CHUNK_SIZE: usize = 10_000;

#[expect(
clippy::integer_division,
clippy::integer_division_remainder_used,
reason = "The devision is only used to calculate vector capacity and is rounded up."
reason = "The division is only used to calculate vector capacity and is rounded up."
)]
let mut workflow_ids = Vec::with_capacity(entities.len() / CHUNK_SIZE + 1);
for partial_entities in entities.chunks(CHUNK_SIZE) {
let mut workflow_ids = Vec::with_capacity(entity_ids.len() / CHUNK_SIZE + 1);
for chunk in entity_ids.chunks(CHUNK_SIZE) {
workflow_ids.push(
self.start_ai_workflow(
"updateEntityEmbeddings",
&UpdateEntityEmbeddingsParams {
authentication: AuthenticationContext { actor_id },
entities: partial_entities,
entity_ids: chunk,
embedding_exclusions,
},
)
.await?,
Expand Down
Loading