Skip to content

Commit adfaefd

Browse files
authored
Merge pull request #21043 from hvitved/rust/type-inference-trait-bounds-overlap
Rust: Fix candidate receiver type calculation for trait bounds
2 parents 802c465 + eb56cbd commit adfaefd

File tree

6 files changed

+6665
-6551
lines changed

6 files changed

+6665
-6551
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 87 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,20 +1572,18 @@ private module MethodResolution {
15721572
}
15731573

15741574
/**
1575-
* Same as `getACandidateReceiverTypeAt`, but with traits substituted in for types
1576-
* with trait bounds.
1575+
* Same as `getACandidateReceiverTypeAt`, but excludes pseudo types `!` and `unknown`.
15771576
*/
15781577
pragma[nomagic]
1579-
Type getACandidateReceiverTypeAtSubstituteLookupTraits(
1580-
string derefChain, boolean borrow, TypePath path
1581-
) {
1582-
result = substituteLookupTraits(this.getACandidateReceiverTypeAt(derefChain, borrow, path))
1578+
Type getANonPseudoCandidateReceiverTypeAt(string derefChain, boolean borrow, TypePath path) {
1579+
result = this.getACandidateReceiverTypeAt(derefChain, borrow, path) and
1580+
result != TNeverType() and
1581+
result != TUnknownType()
15831582
}
15841583

15851584
pragma[nomagic]
15861585
private Type getComplexStrippedType(string derefChain, boolean borrow, TypePath strippedTypePath) {
1587-
result =
1588-
this.getACandidateReceiverTypeAtSubstituteLookupTraits(derefChain, borrow, strippedTypePath) and
1586+
result = this.getANonPseudoCandidateReceiverTypeAt(derefChain, borrow, strippedTypePath) and
15891587
isComplexRootStripped(strippedTypePath, result)
15901588
}
15911589

@@ -1624,23 +1622,58 @@ private module MethodResolution {
16241622
)
16251623
}
16261624

1625+
// forex using recursion
1626+
pragma[nomagic]
1627+
private predicate hasNoCompatibleTargetNoBorrowToIndex(
1628+
string derefChain, TypePath strippedTypePath, Type strippedType, int n
1629+
) {
1630+
(
1631+
this.supportsAutoDerefAndBorrow()
1632+
or
1633+
// needed for the `hasNoCompatibleTarget` check in
1634+
// `ReceiverSatisfiesBlanketLikeConstraintInput::hasBlanketCandidate`
1635+
derefChain = ""
1636+
) and
1637+
strippedType = this.getComplexStrippedType(derefChain, false, strippedTypePath) and
1638+
n = -1
1639+
or
1640+
this.hasNoCompatibleTargetNoBorrowToIndex(derefChain, strippedTypePath, strippedType, n - 1) and
1641+
exists(Type t | t = getNthLookupType(strippedType, n) |
1642+
this.hasNoCompatibleTargetCheck(derefChain, false, strippedTypePath, t)
1643+
)
1644+
}
1645+
16271646
/**
16281647
* Holds if the candidate receiver type represented by `derefChain` does not
16291648
* have a matching method target.
16301649
*/
16311650
pragma[nomagic]
16321651
predicate hasNoCompatibleTargetNoBorrow(string derefChain) {
1652+
exists(Type strippedType |
1653+
this.hasNoCompatibleTargetNoBorrowToIndex(derefChain, _, strippedType,
1654+
getLastLookupTypeIndex(strippedType))
1655+
)
1656+
}
1657+
1658+
// forex using recursion
1659+
pragma[nomagic]
1660+
private predicate hasNoCompatibleNonBlanketTargetNoBorrowToIndex(
1661+
string derefChain, TypePath strippedTypePath, Type strippedType, int n
1662+
) {
16331663
(
16341664
this.supportsAutoDerefAndBorrow()
16351665
or
16361666
// needed for the `hasNoCompatibleTarget` check in
16371667
// `ReceiverSatisfiesBlanketLikeConstraintInput::hasBlanketCandidate`
16381668
derefChain = ""
16391669
) and
1640-
exists(TypePath strippedTypePath, Type strippedType |
1641-
not derefChain.matches("%.ref") and // no need to try a borrow if the last thing we did was a deref
1642-
strippedType = this.getComplexStrippedType(derefChain, false, strippedTypePath) and
1643-
this.hasNoCompatibleTargetCheck(derefChain, false, strippedTypePath, strippedType)
1670+
strippedType = this.getComplexStrippedType(derefChain, false, strippedTypePath) and
1671+
n = -1
1672+
or
1673+
this.hasNoCompatibleNonBlanketTargetNoBorrowToIndex(derefChain, strippedTypePath,
1674+
strippedType, n - 1) and
1675+
exists(Type t | t = getNthLookupType(strippedType, n) |
1676+
this.hasNoCompatibleNonBlanketTargetCheck(derefChain, false, strippedTypePath, t)
16441677
)
16451678
}
16461679

@@ -1650,17 +1683,24 @@ private module MethodResolution {
16501683
*/
16511684
pragma[nomagic]
16521685
predicate hasNoCompatibleNonBlanketTargetNoBorrow(string derefChain) {
1653-
(
1654-
this.supportsAutoDerefAndBorrow()
1655-
or
1656-
// needed for the `hasNoCompatibleTarget` check in
1657-
// `ReceiverSatisfiesBlanketLikeConstraintInput::hasBlanketCandidate`
1658-
derefChain = ""
1659-
) and
1660-
exists(TypePath strippedTypePath, Type strippedType |
1661-
not derefChain.matches("%.ref") and // no need to try a borrow if the last thing we did was a deref
1662-
strippedType = this.getComplexStrippedType(derefChain, false, strippedTypePath) and
1663-
this.hasNoCompatibleNonBlanketTargetCheck(derefChain, false, strippedTypePath, strippedType)
1686+
exists(Type strippedType |
1687+
this.hasNoCompatibleNonBlanketTargetNoBorrowToIndex(derefChain, _, strippedType,
1688+
getLastLookupTypeIndex(strippedType))
1689+
)
1690+
}
1691+
1692+
// forex using recursion
1693+
pragma[nomagic]
1694+
private predicate hasNoCompatibleTargetBorrowToIndex(
1695+
string derefChain, TypePath strippedTypePath, Type strippedType, int n
1696+
) {
1697+
this.hasNoCompatibleTargetNoBorrow(derefChain) and
1698+
strippedType = this.getComplexStrippedType(derefChain, true, strippedTypePath) and
1699+
n = -1
1700+
or
1701+
this.hasNoCompatibleTargetBorrowToIndex(derefChain, strippedTypePath, strippedType, n - 1) and
1702+
exists(Type t | t = getNthLookupType(strippedType, n) |
1703+
this.hasNoCompatibleNonBlanketLikeTargetCheck(derefChain, true, strippedTypePath, t)
16641704
)
16651705
}
16661706

@@ -1670,11 +1710,25 @@ private module MethodResolution {
16701710
*/
16711711
pragma[nomagic]
16721712
predicate hasNoCompatibleTargetBorrow(string derefChain) {
1673-
exists(TypePath strippedTypePath, Type strippedType |
1674-
this.hasNoCompatibleTargetNoBorrow(derefChain) and
1675-
strippedType = this.getComplexStrippedType(derefChain, true, strippedTypePath) and
1676-
this.hasNoCompatibleNonBlanketLikeTargetCheck(derefChain, true, strippedTypePath,
1677-
strippedType)
1713+
exists(Type strippedType |
1714+
this.hasNoCompatibleTargetBorrowToIndex(derefChain, _, strippedType,
1715+
getLastLookupTypeIndex(strippedType))
1716+
)
1717+
}
1718+
1719+
// forex using recursion
1720+
pragma[nomagic]
1721+
private predicate hasNoCompatibleNonBlanketTargetBorrowToIndex(
1722+
string derefChain, TypePath strippedTypePath, Type strippedType, int n
1723+
) {
1724+
this.hasNoCompatibleTargetNoBorrow(derefChain) and
1725+
strippedType = this.getComplexStrippedType(derefChain, true, strippedTypePath) and
1726+
n = -1
1727+
or
1728+
this.hasNoCompatibleNonBlanketTargetBorrowToIndex(derefChain, strippedTypePath, strippedType,
1729+
n - 1) and
1730+
exists(Type t | t = getNthLookupType(strippedType, n) |
1731+
this.hasNoCompatibleNonBlanketTargetCheck(derefChain, true, strippedTypePath, t)
16781732
)
16791733
}
16801734

@@ -1684,10 +1738,9 @@ private module MethodResolution {
16841738
*/
16851739
pragma[nomagic]
16861740
predicate hasNoCompatibleNonBlanketTargetBorrow(string derefChain) {
1687-
exists(TypePath strippedTypePath, Type strippedType |
1688-
this.hasNoCompatibleTargetNoBorrow(derefChain) and
1689-
strippedType = this.getComplexStrippedType(derefChain, true, strippedTypePath) and
1690-
this.hasNoCompatibleNonBlanketTargetCheck(derefChain, true, strippedTypePath, strippedType)
1741+
exists(Type strippedType |
1742+
this.hasNoCompatibleNonBlanketTargetBorrowToIndex(derefChain, _, strippedType,
1743+
getLastLookupTypeIndex(strippedType))
16911744
)
16921745
}
16931746

@@ -1905,9 +1958,8 @@ private module MethodResolution {
19051958
MethodCall getMethodCall() { result = mc_ }
19061959

19071960
Type getTypeAt(TypePath path) {
1908-
result = mc_.getACandidateReceiverTypeAtSubstituteLookupTraits(derefChain, borrow, path) and
1909-
not result = TNeverType() and
1910-
not result = TUnknownType()
1961+
result =
1962+
substituteLookupTraits(mc_.getANonPseudoCandidateReceiverTypeAt(derefChain, borrow, path))
19111963
}
19121964

19131965
pragma[nomagic]

rust/ql/lib/codeql/rust/internal/typeinference/FunctionType.qll

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ class AssocFunctionType extends MkAssocFunctionType {
194194
Location getLocation() { result = this.getTypeMention().getLocation() }
195195
}
196196

197+
pragma[nomagic]
197198
private Trait getALookupTrait(Type t) {
198199
result = t.(TypeParamTypeParameter).getTypeParam().(TypeParamItemNode).resolveABound()
199200
or
@@ -208,14 +209,38 @@ private Trait getALookupTrait(Type t) {
208209
* Gets the type obtained by substituting in relevant traits in which to do function
209210
* lookup, or `t` itself when no such trait exist.
210211
*/
211-
bindingset[t]
212+
pragma[nomagic]
212213
Type substituteLookupTraits(Type t) {
213214
not exists(getALookupTrait(t)) and
214215
result = t
215216
or
216217
result = TTrait(getALookupTrait(t))
217218
}
218219

220+
/**
221+
* Gets the `n`th `substituteLookupTraits` type for `t`, per some arbitrary order.
222+
*/
223+
pragma[nomagic]
224+
Type getNthLookupType(Type t, int n) {
225+
not exists(getALookupTrait(t)) and
226+
result = t and
227+
n = 0
228+
or
229+
result =
230+
TTrait(rank[n + 1](Trait trait, int i |
231+
trait = getALookupTrait(t) and
232+
i = idOfTypeParameterAstNode(trait)
233+
|
234+
trait order by i
235+
))
236+
}
237+
238+
/**
239+
* Gets the index of the last `substituteLookupTraits` type for `t`.
240+
*/
241+
pragma[nomagic]
242+
int getLastLookupTypeIndex(Type t) { result = max(int n | exists(getNthLookupType(t, n))) }
243+
219244
/**
220245
* A wrapper around `IsInstantiationOf` which ensures to substitute in lookup
221246
* traits when checking whether argument types are instantiations of function

rust/ql/test/library-tests/type-inference/CONSISTENCY/PathResolutionConsistency.expected

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,23 @@ multipleResolvedTargets
1313
| dyn_type.rs:90:10:90:13 | * ... |
1414
| invalid/main.rs:69:13:69:17 | * ... |
1515
| invalid/main.rs:76:13:76:17 | * ... |
16-
| main.rs:1077:14:1077:18 | * ... |
17-
| main.rs:1159:26:1159:30 | * ... |
18-
| main.rs:1503:14:1503:21 | * ... |
19-
| main.rs:1503:16:1503:20 | * ... |
20-
| main.rs:1508:14:1508:18 | * ... |
21-
| main.rs:1539:27:1539:29 | * ... |
22-
| main.rs:1653:17:1653:24 | * ... |
23-
| main.rs:1653:18:1653:24 | * ... |
24-
| main.rs:1791:17:1791:21 | * ... |
25-
| main.rs:1806:28:1806:32 | * ... |
26-
| main.rs:2439:13:2439:18 | * ... |
27-
| main.rs:2633:13:2633:31 | ...::from(...) |
28-
| main.rs:2634:13:2634:31 | ...::from(...) |
29-
| main.rs:2635:13:2635:31 | ...::from(...) |
30-
| main.rs:2641:13:2641:31 | ...::from(...) |
31-
| main.rs:2642:13:2642:31 | ...::from(...) |
32-
| main.rs:2643:13:2643:31 | ...::from(...) |
33-
| main.rs:3072:13:3072:17 | x.f() |
16+
| main.rs:1092:14:1092:18 | * ... |
17+
| main.rs:1174:26:1174:30 | * ... |
18+
| main.rs:1518:14:1518:21 | * ... |
19+
| main.rs:1518:16:1518:20 | * ... |
20+
| main.rs:1523:14:1523:18 | * ... |
21+
| main.rs:1554:27:1554:29 | * ... |
22+
| main.rs:1668:17:1668:24 | * ... |
23+
| main.rs:1668:18:1668:24 | * ... |
24+
| main.rs:1806:17:1806:21 | * ... |
25+
| main.rs:1821:28:1821:32 | * ... |
26+
| main.rs:2454:13:2454:18 | * ... |
27+
| main.rs:2648:13:2648:31 | ...::from(...) |
28+
| main.rs:2649:13:2649:31 | ...::from(...) |
29+
| main.rs:2650:13:2650:31 | ...::from(...) |
30+
| main.rs:2656:13:2656:31 | ...::from(...) |
31+
| main.rs:2657:13:2657:31 | ...::from(...) |
32+
| main.rs:2658:13:2658:31 | ...::from(...) |
33+
| main.rs:3087:13:3087:17 | x.f() |
3434
| pattern_matching.rs:273:13:273:27 | * ... |
3535
| pattern_matching.rs:273:14:273:27 | * ... |

rust/ql/test/library-tests/type-inference/blanket_impl.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ mod blanket_like_impl {
236236
impl MyTrait2 for &&S1 {
237237
// MyTrait2RefRefS1::m2
238238
fn m2(self) {
239-
self.m1() // $ MISSING: target=S1::m1
239+
self.m1() // $ target=S1::m1
240240
}
241241
}
242242

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,21 @@ mod function_trait_bounds {
827827
}
828828
}
829829

830+
trait MyTrait2 {
831+
// MyTrait2::m2
832+
fn m2(self);
833+
}
834+
835+
trait MyTrait3 {
836+
// MyTrait3::m2
837+
fn m2(&self);
838+
}
839+
840+
fn bound_overlap<T: MyTrait2 + MyTrait3>(x: T, y: &T) {
841+
x.m2(); // $ target=MyTrait2::m2
842+
y.m2(); // $ target=MyTrait3::m2
843+
}
844+
830845
pub fn f() {
831846
let x = MyThing { a: S1 };
832847
let y = MyThing { a: S2 };

0 commit comments

Comments
 (0)