Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions diskann-providers/src/utils/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ pub fn create_rnd_from_seed_in_tests(seed: u64) -> StandardRng {
}

/// Creates a randomly seeded random number generator.
#[cfg(not(test))]
#[cfg(not(any(test, feature = "testing")))]
#[allow(clippy::disallowed_methods)]
pub fn create_rnd() -> StandardRng {
rand::rngs::StdRng::from_os_rng()
}

/// Creates a pseudo-random number generator from a predefined seed to ensure reproducibility
/// of tests and benchmarks.
#[cfg(test)]
#[cfg(any(test, feature = "testing"))]
pub fn create_rnd() -> StandardRng {
create_rnd_from_seed(DEFAULT_SEED_FOR_TESTS)
}
Expand All @@ -64,7 +64,7 @@ pub fn create_rnd_provider_from_seed_in_tests(seed: u64) -> RandomProvider<Stand
}

/// Creates a random number generator provider.
#[cfg(not(test))]
#[cfg(not(any(test, feature = "testing")))]
pub fn create_rnd_provider() -> RandomProvider<StandardRng> {
RandomProvider {
seed: None,
Expand All @@ -73,7 +73,7 @@ pub fn create_rnd_provider() -> RandomProvider<StandardRng> {
}

/// Creates a pseudo-random number generator provider from a predefined seed to ensure reproducibility of tests and benchmarks.
#[cfg(test)]
#[cfg(any(test, feature = "testing"))]
pub fn create_rnd_provider() -> RandomProvider<StandardRng> {
RandomProvider {
seed: None,
Expand Down
14 changes: 10 additions & 4 deletions diskann-tools/src/bin/relative_contrast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

use clap::Parser;
use diskann_providers::storage::FileStorageProvider;
use diskann_providers::utils::random;
use diskann_tools::utils::{
init_subscriber, relative_contrast::compute_relative_contrast, CMDResult, DataType,
GraphDataF32Vector, GraphDataHalfVector, GraphDataInt8Vector, GraphDataU8Vector,
Expand Down Expand Up @@ -36,39 +37,44 @@ fn main() -> CMDResult<()> {

let args = RelativeContrastArgs::parse();
let storage_provider = FileStorageProvider;
let mut rng = random::create_rnd();

let result = match args.data_type {
DataType::Float => compute_relative_contrast::<GraphDataF32Vector, _>(
DataType::Float => compute_relative_contrast::<GraphDataF32Vector, _, _>(
&storage_provider,
&args.data_file,
&args.query_file,
&args.gt_file,
args.recall_at,
args.search_list,
&mut rng,
),
DataType::Fp16 => compute_relative_contrast::<GraphDataHalfVector, _>(
DataType::Fp16 => compute_relative_contrast::<GraphDataHalfVector, _, _>(
&storage_provider,
&args.data_file,
&args.query_file,
&args.gt_file,
args.recall_at,
args.search_list,
&mut rng,
),
DataType::Uint8 => compute_relative_contrast::<GraphDataU8Vector, _>(
DataType::Uint8 => compute_relative_contrast::<GraphDataU8Vector, _, _>(
&storage_provider,
&args.data_file,
&args.query_file,
&args.gt_file,
args.recall_at,
args.search_list,
&mut rng,
),
DataType::Int8 => compute_relative_contrast::<GraphDataInt8Vector, _>(
DataType::Int8 => compute_relative_contrast::<GraphDataInt8Vector, _, _>(
&storage_provider,
&args.data_file,
&args.query_file,
&args.gt_file,
args.recall_at,
args.search_list,
&mut rng,
),
};

Expand Down
21 changes: 14 additions & 7 deletions diskann-tools/src/utils/relative_contrast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

use diskann::{utils::VectorRepr, ANNError};
use diskann_providers::storage::StorageReadProvider;
use diskann_providers::utils::random;
use diskann_providers::{model::graph::traits::GraphDataType, utils::file_util::load_bin};
use rand::Rng;

Expand All @@ -29,12 +28,12 @@ fn squared_distance<Data: GraphDataType>(
.sum())
}

fn average_squared_distance<Data: GraphDataType>(
fn average_squared_distance<Data: GraphDataType, R: Rng>(
query: &[Data::VectorDataType],
base: &[Vec<Data::VectorDataType>],
num_random_samples: usize,
rng: &mut R,
) -> CMDResult<f32> {
let mut rng = random::create_rnd();
let n = base.len();
let mut sum_dist = 0.0;
for _ in 0..num_random_samples {
Expand All @@ -44,13 +43,18 @@ fn average_squared_distance<Data: GraphDataType>(
Ok(sum_dist / num_random_samples as f32)
}

pub fn compute_relative_contrast<Data: GraphDataType, StorageProvider: StorageReadProvider>(
pub fn compute_relative_contrast<
Data: GraphDataType,
StorageProvider: StorageReadProvider,
R: Rng,
>(
storage_provider: &StorageProvider,
base_file: &str,
query_file: &str,
gt_file: &str,
recall_at: usize,
num_random_samples: usize,
rng: &mut R,
) -> CMDResult<f32> {
// Load base, query, and ground truth data
let (base_flat, nb, dim) = load_bin::<Data::VectorDataType, _>(storage_provider, base_file, 0)?;
Expand All @@ -75,7 +79,7 @@ pub fn compute_relative_contrast<Data: GraphDataType, StorageProvider: StorageRe

for (i, q) in query.iter().enumerate() {
// Compute numerator: average squared distance to random samples
let numerator = average_squared_distance::<Data>(q, &base, num_random_samples)?;
let numerator = average_squared_distance::<Data, R>(q, &base, num_random_samples, rng)?;

// Compute denominator: average squared distance to ground truth neighbors
let mut denominator = 0.0;
Expand Down Expand Up @@ -106,6 +110,7 @@ pub fn compute_relative_contrast<Data: GraphDataType, StorageProvider: StorageRe
#[cfg(test)]
mod relative_contrast_tests {
use diskann_providers::storage::{StorageWriteProvider, VirtualStorageProvider};
use diskann_providers::utils::random;
use diskann_providers::utils::write_metadata;
use diskann_vector::distance::Metric;
use half::f16;
Expand Down Expand Up @@ -178,13 +183,14 @@ mod relative_contrast_tests {

// Run compute_relative_contrast with the generated files
let num_random_samples = 5;
let mean_rc = compute_relative_contrast::<GraphDataHalfVector, _>(
let mean_rc = compute_relative_contrast::<GraphDataHalfVector, _, _>(
&storage_provider,
base_file_path,
query_file_path,
gt_file_path,
recall_at,
num_random_samples,
&mut rng,
)
.unwrap();
println!("Mean relative contrast: {}", mean_rc);
Expand Down Expand Up @@ -253,13 +259,14 @@ mod relative_contrast_tests {

// Run compute_relative_contrast with the generated files
let num_random_samples = 3;
let mean_rc = compute_relative_contrast::<GraphDataHalfVector, _>(
let mean_rc = compute_relative_contrast::<GraphDataHalfVector, _, _>(
&storage_provider,
base_file_path,
query_file_path,
gt_file_path,
recall_at,
num_random_samples,
&mut rng,
)
.unwrap();
println!("Mean relative contrast: {}", mean_rc);
Expand Down