diff --git a/diskann-providers/src/utils/random.rs b/diskann-providers/src/utils/random.rs index 4aad64df4..092af3e2a 100644 --- a/diskann-providers/src/utils/random.rs +++ b/diskann-providers/src/utils/random.rs @@ -32,7 +32,7 @@ pub fn create_rnd_from_seed_in_tests(seed: u64) -> StandardRng { } /// Creates a randomly seeded random number generator. -#[cfg(not(test))] +#[cfg(not(any(test, feature = "testing")))] #[allow(clippy::disallowed_methods)] pub fn create_rnd() -> StandardRng { rand::rngs::StdRng::from_os_rng() @@ -40,7 +40,7 @@ pub fn create_rnd() -> StandardRng { /// Creates a pseudo-random number generator from a predefined seed to ensure reproducibility /// of tests and benchmarks. -#[cfg(test)] +#[cfg(any(test, feature = "testing"))] pub fn create_rnd() -> StandardRng { create_rnd_from_seed(DEFAULT_SEED_FOR_TESTS) } @@ -64,7 +64,7 @@ pub fn create_rnd_provider_from_seed_in_tests(seed: u64) -> RandomProvider RandomProvider { RandomProvider { seed: None, @@ -73,7 +73,7 @@ pub fn create_rnd_provider() -> RandomProvider { } /// Creates a pseudo-random number generator provider from a predefined seed to ensure reproducibility of tests and benchmarks. -#[cfg(test)] +#[cfg(any(test, feature = "testing"))] pub fn create_rnd_provider() -> RandomProvider { RandomProvider { seed: None, diff --git a/diskann-tools/src/bin/relative_contrast.rs b/diskann-tools/src/bin/relative_contrast.rs index 52c83a070..994b1ff6f 100644 --- a/diskann-tools/src/bin/relative_contrast.rs +++ b/diskann-tools/src/bin/relative_contrast.rs @@ -5,6 +5,7 @@ use clap::Parser; use diskann_providers::storage::FileStorageProvider; +use diskann_providers::utils::random; use diskann_tools::utils::{ init_subscriber, relative_contrast::compute_relative_contrast, CMDResult, DataType, GraphDataF32Vector, GraphDataHalfVector, GraphDataInt8Vector, GraphDataU8Vector, @@ -36,39 +37,44 @@ fn main() -> CMDResult<()> { let args = RelativeContrastArgs::parse(); let storage_provider = FileStorageProvider; + let mut rng = random::create_rnd(); let result = match args.data_type { - DataType::Float => compute_relative_contrast::( + DataType::Float => compute_relative_contrast::( &storage_provider, &args.data_file, &args.query_file, &args.gt_file, args.recall_at, args.search_list, + &mut rng, ), - DataType::Fp16 => compute_relative_contrast::( + DataType::Fp16 => compute_relative_contrast::( &storage_provider, &args.data_file, &args.query_file, &args.gt_file, args.recall_at, args.search_list, + &mut rng, ), - DataType::Uint8 => compute_relative_contrast::( + DataType::Uint8 => compute_relative_contrast::( &storage_provider, &args.data_file, &args.query_file, &args.gt_file, args.recall_at, args.search_list, + &mut rng, ), - DataType::Int8 => compute_relative_contrast::( + DataType::Int8 => compute_relative_contrast::( &storage_provider, &args.data_file, &args.query_file, &args.gt_file, args.recall_at, args.search_list, + &mut rng, ), }; diff --git a/diskann-tools/src/utils/relative_contrast.rs b/diskann-tools/src/utils/relative_contrast.rs index 5ed06498c..d10ed0d43 100644 --- a/diskann-tools/src/utils/relative_contrast.rs +++ b/diskann-tools/src/utils/relative_contrast.rs @@ -5,7 +5,6 @@ use diskann::{utils::VectorRepr, ANNError}; use diskann_providers::storage::StorageReadProvider; -use diskann_providers::utils::random; use diskann_providers::{model::graph::traits::GraphDataType, utils::file_util::load_bin}; use rand::Rng; @@ -29,12 +28,12 @@ fn squared_distance( .sum()) } -fn average_squared_distance( +fn average_squared_distance( query: &[Data::VectorDataType], base: &[Vec], num_random_samples: usize, + rng: &mut R, ) -> CMDResult { - let mut rng = random::create_rnd(); let n = base.len(); let mut sum_dist = 0.0; for _ in 0..num_random_samples { @@ -44,13 +43,18 @@ fn average_squared_distance( Ok(sum_dist / num_random_samples as f32) } -pub fn compute_relative_contrast( +pub fn compute_relative_contrast< + Data: GraphDataType, + StorageProvider: StorageReadProvider, + R: Rng, +>( storage_provider: &StorageProvider, base_file: &str, query_file: &str, gt_file: &str, recall_at: usize, num_random_samples: usize, + rng: &mut R, ) -> CMDResult { // Load base, query, and ground truth data let (base_flat, nb, dim) = load_bin::(storage_provider, base_file, 0)?; @@ -75,7 +79,7 @@ pub fn compute_relative_contrast(q, &base, num_random_samples)?; + let numerator = average_squared_distance::(q, &base, num_random_samples, rng)?; // Compute denominator: average squared distance to ground truth neighbors let mut denominator = 0.0; @@ -106,6 +110,7 @@ pub fn compute_relative_contrast( + let mean_rc = compute_relative_contrast::( &storage_provider, base_file_path, query_file_path, gt_file_path, recall_at, num_random_samples, + &mut rng, ) .unwrap(); println!("Mean relative contrast: {}", mean_rc); @@ -253,13 +259,14 @@ mod relative_contrast_tests { // Run compute_relative_contrast with the generated files let num_random_samples = 3; - let mean_rc = compute_relative_contrast::( + let mean_rc = compute_relative_contrast::( &storage_provider, base_file_path, query_file_path, gt_file_path, recall_at, num_random_samples, + &mut rng, ) .unwrap(); println!("Mean relative contrast: {}", mean_rc);