diff --git a/tests/test_update.py b/tests/test_update.py index 106f6c1b4..ca9e62791 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -97,6 +97,24 @@ async def test_bulk_update_pk_uuid(db): assert (await UUIDFields.get(pk=objs[1].pk)).data == objs[1].data +@pytest.mark.asyncio +async def test_bulk_update_foreign_key(db): + tournament1 = await Tournament.create(name="t1") + tournament2 = await Tournament.create(name="t2") + events = [ + await Event.create(name="e1", tournament=tournament1), + await Event.create(name="e2", tournament=tournament1), + ] + events[0].tournament = tournament2 + events[1].tournament = tournament2 + rows_affected = await Event.bulk_update(events, fields=["tournament"]) + assert rows_affected == 2 + e1 = await Event.get(pk=events[0].pk).select_related("tournament") + e2 = await Event.get(pk=events[1].pk).select_related("tournament") + assert e1.tournament.pk == tournament2.pk + assert e2.tournament.pk == tournament2.pk + + @pytest.mark.asyncio async def test_bulk_renamed_pk_source_field(db): objs = [ diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 0cfd2fc52..ae050dce3 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1959,25 +1959,39 @@ def _make_queries(self) -> list[tuple[str, list[Any]]]: for field in self.fields: case = Case() pk_list = [] + field_obj = self.model._meta.fields_map[field] + is_fk = isinstance(field_obj, (ForeignKeyFieldInstance, OneToOneFieldInstance)) + if is_fk: + fk_field = field_obj.source_field + underlying_field_obj = self.model._meta.fields_map[fk_field] + db_column = underlying_field_obj.source_field + else: + underlying_field_obj = field_obj + db_column = self.model._meta.fields_db_projection[field] for obj in objects_item: pk_value = self.model._meta.fields_map[pk_attr].to_db_value(obj.pk, None) - field_obj = obj._meta.fields_map[field] - field_value = field_obj.to_db_value(getattr(obj, field), obj) - case.when( - pk == pk_value, - ( - Cast( - self.query._wrapper_cls(field_value), - field_obj.get_for_dialect( - self._db.schema_generator.DIALECT, "SQL_TYPE" - ), - ) - if self._db.schema_generator.DIALECT == "postgres" - else self.query._wrapper_cls(field_value) - ), - ) + if is_fk: + related_obj = getattr(obj, field) + self.model._validate_relation_type(field, related_obj) + field_value = underlying_field_obj.to_db_value( + getattr(related_obj, field_obj.to_field_instance.model_field_name), + None, + ) + else: + field_value = underlying_field_obj.to_db_value(getattr(obj, field), obj) + value_expr: Term + if self._db.schema_generator.DIALECT == "postgres": + value_expr = Cast( + self.query._wrapper_cls(field_value), + underlying_field_obj.get_for_dialect( + self._db.schema_generator.DIALECT, "SQL_TYPE" + ), + ) + else: + value_expr = self.query._wrapper_cls(field_value) + case.when(pk == pk_value, value_expr) pk_list.append(pk_value) - query = query.set(field, case) + query = query.set(db_column, case) query = query.where(pk.isin(pk_list)) self._queries.append(query) return [query.get_parameterized_sql() for query in self._queries]