Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 106 additions & 38 deletions src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ void PQFlashIndex<T, LabelT>::setup_thread_data(uint64_t nthreads, uint64_t visi
{
#pragma omp critical
{
SSDThreadData<T> *data = new SSDThreadData<T>(this->_aligned_dim, visited_reserve, max_degree, max_filters_per_query);
SSDThreadData<T> *data =
new SSDThreadData<T>(this->_aligned_dim, visited_reserve, max_degree, max_filters_per_query);
this->reader->register_thread();
data->ctx = this->reader->get_ctx();
this->_thread_data.push(data);
Expand Down Expand Up @@ -571,7 +572,8 @@ void PQFlashIndex<T, LabelT>::generate_random_labels(std::vector<LabelT> &labels
}

template <typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::load_label_map(std::basic_istream<char> &map_reader, std::unordered_map<std::string, LabelT>& string_to_int_map)
void PQFlashIndex<T, LabelT>::load_label_map(std::basic_istream<char> &map_reader,
std::unordered_map<std::string, LabelT> &string_to_int_map)
{
std::string line, token;
LabelT token_as_num;
Expand Down Expand Up @@ -687,7 +689,6 @@ bool PQFlashIndex<T, LabelT>::point_has_any_label(uint32_t point_id, const std::
return ret_val;
}


template <typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::parse_label_file(std::basic_istream<char> &infile, size_t &num_points_labels)
{
Expand Down Expand Up @@ -778,7 +779,8 @@ template <typename T, typename LabelT> void PQFlashIndex<T, LabelT>::set_univers
}

template <typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::load_label_medoid_map(const std::string& labels_to_medoids_filepath, std::istream& medoid_stream)
void PQFlashIndex<T, LabelT>::load_label_medoid_map(const std::string &labels_to_medoids_filepath,
std::istream &medoid_stream)
{
std::string line, token;

Expand Down Expand Up @@ -840,11 +842,10 @@ void PQFlashIndex<T, LabelT>::load_dummy_map(const std::string &dummy_map_filepa
}
catch (std::system_error &e)
{
throw FileException (dummy_map_filepath, e, __FUNCSIG__, __FILE__, __LINE__);
throw FileException(dummy_map_filepath, e, __FUNCSIG__, __FILE__, __LINE__);
}
}


template <typename T, typename LabelT>
#ifdef EXEC_ENV_OLS
bool PQFlashIndex<T, LabelT>::use_filter_support(MemoryMappedFiles &files)
Expand Down Expand Up @@ -980,7 +981,6 @@ template <typename T, typename LabelT> void PQFlashIndex<T, LabelT>::load_labels
ss << "Note: Filter support is enabled but " << dummy_map_file << " file cannot be opened" << std::endl;
diskann::cerr << ss.str();
}

}
else
{
Expand Down Expand Up @@ -1107,11 +1107,8 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
// bytes are needed to store the header and read in that many using our
// 'standard' aligned file reader approach.
reader->open(_disk_index_file);
this->setup_thread_data(
num_threads,
defaults::VISITED_RESERVE,
defaults::MAX_GRAPH_DEGREE,
(use_filter_support(files)? defaults::MAX_FILTERS_PER_QUERY : 0));
this->setup_thread_data(num_threads, defaults::VISITED_RESERVE, defaults::MAX_GRAPH_DEGREE,
(use_filter_support(files) ? defaults::MAX_FILTERS_PER_QUERY : 0));
this->_max_nthreads = num_threads;

char *bytes = getHeaderBytes();
Expand Down Expand Up @@ -1145,6 +1142,77 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
READ_U64(index_metadata, _nnodes_per_sector);
_max_degree = ((_max_node_len - _disk_bytes_per_point) / sizeof(uint32_t)) - 1;

// Early validation: read first node from disk and validate neighbor data.
// If data type is wrong, _disk_bytes_per_point will be incorrect, causing
// us to read garbage for neighbor count and neighbor IDs.
// Disk layout: [sector 0: metadata] [sector 1+: node data]
// Node layout: [vector data: _disk_bytes_per_point bytes] [neighbor count: 4 bytes] [neighbor IDs: 4 bytes each]
if (!_use_disk_index_pq && disk_nnodes > 0)
{
#ifdef EXEC_ENV_OLS
// In OLS environment, index_metadata is a ContentBuf with only the header (sector 0).
// We must use the reader (which is already open) to read sector 1 containing the first node.
uint64_t num_sectors_for_node = _nnodes_per_sector > 0 ? 1 : DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN);
char *first_sector_buf = nullptr;
alloc_aligned((void **)&first_sector_buf, num_sectors_for_node * defaults::SECTOR_LEN, defaults::SECTOR_LEN);

std::vector<AlignedRead> read_reqs;
read_reqs.emplace_back(defaults::SECTOR_LEN, num_sectors_for_node * defaults::SECTOR_LEN, first_sector_buf);

// We need a temporary IOContext for this read
IOContext tmp_ctx = reader->get_ctx();
reader->read(read_reqs, tmp_ctx);

// First node starts at offset 0 within sector 1
char *first_node_buf = first_sector_buf;
#else
// In non-OLS environment, we can seek and read directly from the file stream
uint64_t num_sectors_for_node = _nnodes_per_sector > 0 ? 1 : DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN);
std::vector<char> first_node_buf_vec(num_sectors_for_node * defaults::SECTOR_LEN);
index_metadata.seekg(defaults::SECTOR_LEN, std::ios::beg);
index_metadata.read(first_node_buf_vec.data(), num_sectors_for_node * defaults::SECTOR_LEN);
char *first_node_buf = first_node_buf_vec.data();
#endif

// Get neighbor count (located after vector data)
uint32_t *nhood_ptr = reinterpret_cast<uint32_t *>(first_node_buf + _disk_bytes_per_point);
uint32_t num_neighbors = *nhood_ptr;

// Validate neighbor count is reasonable
if (num_neighbors > _max_degree)
{
#ifdef EXEC_ENV_OLS
aligned_free(first_sector_buf);
#endif
std::stringstream stream;
stream << "Data type mismatch detected: first node has neighbor count " << num_neighbors
<< " which exceeds max_degree " << _max_degree << ". "
<< "Please ensure --data_type matches the type used when building the index.";
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}

// Validate each neighbor ID is within valid range [0, disk_nnodes)
uint32_t *neighbors = nhood_ptr + 1;
for (uint32_t i = 0; i < num_neighbors; i++)
{
if (neighbors[i] >= disk_nnodes)
{
#ifdef EXEC_ENV_OLS
aligned_free(first_sector_buf);
#endif
std::stringstream stream;
stream << "Data type mismatch detected: first node has invalid neighbor ID " << neighbors[i]
<< " (max valid ID is " << (disk_nnodes - 1) << "). "
<< "Please ensure --data_type matches the type used when building the index.";
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}
}

#ifdef EXEC_ENV_OLS
aligned_free(first_sector_buf);
#endif
}

if (_max_degree > defaults::MAX_GRAPH_DEGREE)
{
std::stringstream stream;
Expand Down Expand Up @@ -1180,11 +1248,11 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
READ_U64(index_metadata, this->_nvecs_per_sector);
}

#ifdef EXEC_ENV_OLS
load_labels(files, _disk_index_file);
#else
load_labels(_disk_index_file);
#endif
#ifdef EXEC_ENV_OLS
load_labels(files, _disk_index_file);
#else
load_labels(_disk_index_file);
#endif

diskann::cout << "Disk-Index File Meta-data: ";
diskann::cout << "# nodes per sector: " << _nnodes_per_sector;
Expand All @@ -1201,11 +1269,8 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
// open AlignedFileReader handle to index_file
std::string index_fname(_disk_index_file);
reader->open(index_fname);
this->setup_thread_data(
num_threads,
defaults::VISITED_RESERVE,
defaults::MAX_GRAPH_DEGREE,
(use_filter_support()? defaults::MAX_FILTERS_PER_QUERY : 0));
this->setup_thread_data(num_threads, defaults::VISITED_RESERVE, defaults::MAX_GRAPH_DEGREE,
(use_filter_support() ? defaults::MAX_FILTERS_PER_QUERY : 0));
this->_max_nthreads = num_threads;

#endif
Expand Down Expand Up @@ -1466,16 +1531,18 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
NeighborPriorityQueue &retset = query_scratch->retset;
std::vector<Neighbor> &full_retset = query_scratch->full_retset;
tsl::robin_set<location_t> full_retset_ids;
if (use_filters) {
if (use_filters)
{
uint64_t size_to_reserve = std::max(l_search, (std::min((uint64_t)filter_label_count, this->_max_degree) + 1));
retset.reserve(size_to_reserve);
full_retset.reserve(4096);
full_retset.reserve(4096);
full_retset_ids.reserve(4096);
} else {
}
else
{
retset.reserve(l_search + 1);
}


uint32_t best_medoid = 0;
uint32_t cur_list_size = 0;
float best_dist = (std::numeric_limits<float>::max)();
Expand All @@ -1495,7 +1562,9 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
retset.insert(Neighbor(best_medoid, dist_scratch[0]));
visited.insert(best_medoid);
cur_list_size = 1;
} else {
}
else
{
std::vector<location_t> filter_specific_medoids;
filter_specific_medoids.reserve(filter_label_count);
location_t ctr = 0;
Expand All @@ -1513,12 +1582,12 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
for (ctr = 0; ctr < filter_specific_medoids.size(); ctr++)
{
retset.insert(Neighbor(filter_specific_medoids[ctr], dist_scratch[ctr]));
//retset[ctr].id = filter_specific_medoids[ctr];
//retset[ctr].distance = dist_scratch[ctr];
//retset[ctr].expanded = false;
// retset[ctr].id = filter_specific_medoids[ctr];
// retset[ctr].distance = dist_scratch[ctr];
// retset[ctr].expanded = false;
visited.insert(filter_specific_medoids[ctr]);
}
cur_list_size = (uint32_t) filter_specific_medoids.size();
cur_list_size = (uint32_t)filter_specific_medoids.size();
}

uint32_t cmps = 0;
Expand All @@ -1535,10 +1604,10 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
std::vector<std::pair<uint32_t, std::pair<uint32_t, uint32_t *>>> cached_nhoods;
cached_nhoods.reserve(2 * beam_width);

//if we are doing multi-filter search we don't want to restrict the number of IOs
//at present. Must revisit this decision later.
// if we are doing multi-filter search we don't want to restrict the number of IOs
// at present. Must revisit this decision later.
uint32_t max_ios_for_query = use_filters || (io_limit == 0) ? std::numeric_limits<uint32_t>::max() : io_limit;
const std::vector<LabelT>& label_ids = filter_labels; //avoid renaming.
const std::vector<LabelT> &label_ids = filter_labels; // avoid renaming.
std::vector<LabelT> lbl_vec;

retset.sort();
Expand All @@ -1554,7 +1623,6 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
// find new beam
uint32_t num_seen = 0;


for (const auto &lbl : label_ids)
{ // assuming that number of OR labels is
// less than max frontier size allowed
Expand Down Expand Up @@ -1583,7 +1651,8 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
retset[lbl_marker].expanded = true;
if (this->_count_visited_nodes)
{
reinterpret_cast<std::atomic<uint32_t> &>(this->_node_visit_counter[retset[lbl_marker].id].second)
reinterpret_cast<std::atomic<uint32_t> &>(
this->_node_visit_counter[retset[lbl_marker].id].second)
.fetch_add(1);
}
break;
Expand Down Expand Up @@ -1686,7 +1755,6 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
full_retset.push_back(Neighbor((unsigned)cached_nhood.first, cur_expanded_dist));
}


uint64_t nnbrs = cached_nhood.second.first;
uint32_t *node_nbrs = cached_nhood.second.second;

Expand Down Expand Up @@ -1768,7 +1836,7 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
{
full_retset.push_back(Neighbor(frontier_nhood.first, cur_expanded_dist));
}

uint32_t *node_nbrs = (node_buf + 1);
// compute node_nbrs <-> query dist in PQ space
cpu_timer.reset();
Expand Down
Loading