Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds end-to-end support for perplexity (cross-entropy) computation across the Turbomind and PyTorch backends by introducing an output_ppl/compute_ppl pathway that returns scalar loss/count instead of materializing full logits on CPU.
Changes:
- Adds
compute_pplto TurbomindGenerationConfigand wires it through request/serialization, output processing, and Python bindings. - Implements a logits callback path (Turbomind) and packed-logits loss computation path (PyTorch) to produce
ppl_loss/ppl_count. - Exposes
async_get_ppl()in the async engine and updatespipeline.get_ppl()to use it.
Reviewed changes
Copilot reviewed 16 out of 16 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| src/turbomind/python/bind.cpp | Adds compute_ppl binding and a new optional Python logits callback passed into Turbomind requests. |
| src/turbomind/models/output_processor.cc | Triggers logits generation for compute_ppl and introduces a callback-vs-copy split for logits handling. |
| src/turbomind/engine/request.h | Adds compute_ppl and a per-request logits_cb callback field; serializes compute_ppl. |
| src/turbomind/engine/model_request.h | Plumbs logits_cb through ModelRequest::InputParam. |
| src/turbomind/engine/model_request.cc | Avoids allocating CPU logits output tensor when compute_ppl is enabled; stores callback in request. |
| lmdeploy/turbomind/turbomind.py | Implements Python-side CE accumulation via callback and sets Turbomind compute_ppl. |
| lmdeploy/serve/core/async_engine.py | Adds ppl_loss/ppl_count to outputs and introduces async_get_ppl(). |
| lmdeploy/pytorch/strategies/ar/sampling.py | Adds compute_ppl aggregation into SamplingInputs. |
| lmdeploy/pytorch/messages.py | Adds compute_ppl to SamplingParam and maps from GenerationConfig.output_ppl. |
| lmdeploy/pytorch/engine/model_agent/agent.py | Adds packed-logits CE computation helper and threads ppl outputs through postprocess. |
| lmdeploy/pytorch/engine/logits_process.py | Adds compute_ppl to SamplingInputs. |
| lmdeploy/pytorch/engine/engine_loop.py | Propagates ppl loss/count through engine responses and accumulates across steps. |
| lmdeploy/pytorch/engine/engine_instance.py | Reads and forwards ppl loss/count from engine loop messages. |
| lmdeploy/pytorch/engine/engine.py | Extends InferOutput with ppl loss/count fields. |
| lmdeploy/pipeline.py | Replaces previous logits-materializing PPL logic with async_get_ppl() usage. |
| lmdeploy/messages.py | Adds output_ppl to config and ppl loss/count fields to response/output types. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| py::capsule cap(dlmt, kDlTensorCapsuleName, [](PyObject* obj) { | ||
| DLManagedTensor* p = | ||
| static_cast<DLManagedTensor*>(PyCapsule_GetPointer(obj, kDlTensorCapsuleName)); | ||
| if (p->deleter) { | ||
| p->deleter(p); | ||
| } | ||
| }); |
There was a problem hiding this comment.
The DLPack capsule destructor is unsafe: after torch.from_dlpack(cap) consumes the capsule, PyTorch renames it to "used_dltensor", so PyCapsule_GetPointer(obj, "dltensor") can return null and set a Python error. The current code then dereferences p->deleter, which can segfault. Please follow the existing Tensor::__dlpack__ destructor pattern in this file (handle null from PyCapsule_GetPointer and clear the error, and/or accept the used_dltensor name) to avoid crashes/double-free.
src/turbomind/python/bind.cpp
Outdated
| int64_t shape[2] = {count, vocab_size}; | ||
| DLManagedTensor* dlmt = new DLManagedTensor{}; | ||
| dlmt->dl_tensor.data = data; | ||
| dlmt->dl_tensor.ndim = 2; | ||
| dlmt->dl_tensor.shape = shape; | ||
| dlmt->dl_tensor.strides = nullptr; | ||
| dlmt->dl_tensor.byte_offset = 0; |
There was a problem hiding this comment.
dlmt->dl_tensor.shape is set to a stack array (int64_t shape[2]), but the DLPack consumer may keep the DLManagedTensor alive beyond this callback scope. This leaves shape dangling and can cause memory corruption. Allocate shape in heap-owned storage (e.g., via manager_ctx) and free it in the DLManagedTensor::deleter.
| if (logits_cb) { | ||
| auto py_cb = std::move(*logits_cb); | ||
| param.logits_cb = [py_cb = std::move(py_cb)]( | ||
| void* data, int vocab_size, int begin, int count, ft::DataType dtype) { | ||
| py::gil_scoped_acquire gil; | ||
| int64_t shape[2] = {count, vocab_size}; | ||
| DLManagedTensor* dlmt = new DLManagedTensor{}; | ||
| dlmt->dl_tensor.data = data; | ||
| dlmt->dl_tensor.ndim = 2; | ||
| dlmt->dl_tensor.shape = shape; | ||
| dlmt->dl_tensor.strides = nullptr; | ||
| dlmt->dl_tensor.byte_offset = 0; | ||
|
|
||
| dlmt->dl_tensor.device.device_type = kDLCUDA; | ||
| int device_id = 0; | ||
| cudaGetDevice(&device_id); | ||
| dlmt->dl_tensor.device.device_id = device_id; | ||
|
|
||
| if (dtype == ft::DataType::kFloat16) { | ||
| dlmt->dl_tensor.dtype = {kDLFloat, 16, 1}; | ||
| } | ||
| else if (dtype == ft::DataType::kBfloat16) { | ||
| dlmt->dl_tensor.dtype = {kDLBfloat, 16, 1}; | ||
| } | ||
| else { | ||
| dlmt->dl_tensor.dtype = {kDLFloat, 32, 1}; | ||
| } | ||
|
|
||
| dlmt->manager_ctx = nullptr; | ||
| dlmt->deleter = [](DLManagedTensor* self) { delete self; }; | ||
|
|
||
| py::capsule cap(dlmt, kDlTensorCapsuleName, [](PyObject* obj) { | ||
| DLManagedTensor* p = | ||
| static_cast<DLManagedTensor*>(PyCapsule_GetPointer(obj, kDlTensorCapsuleName)); | ||
| if (p->deleter) { | ||
| p->deleter(p); | ||
| } | ||
| }); | ||
|
|
||
| py::object torch = py::module_::import("torch"); | ||
| py::object from_dlpack = torch.attr("from_dlpack"); | ||
| py::object tensor = from_dlpack(cap); | ||
|
|
||
| py_cb(tensor, vocab_size, begin, count); | ||
| }; |
There was a problem hiding this comment.
The DLPack tensor created here points directly at data owned by the temporary logits tensor in OutputProcessor. If the Python callback stores the returned torch.Tensor (or even uses it asynchronously), it can outlive the underlying buffer and become a use-after-free on GPU memory. To make this safe, either (a) copy logits into an owned buffer and free it from the DLPack deleter, or (b) avoid exposing a borrowing tensor to Python and instead compute the needed reduction (CE loss) on the C++ side.
| break; | ||
|
|
||
| // Invoke PPL callbacks | ||
| if (tp_rank_ == 0) { |
There was a problem hiding this comment.
cudaStreamSynchronize(stream) is called for every logits chunk whenever tp_rank_ == 0, even when there are no PPL callbacks (i.e., ppl_ranges is empty). This introduces a host-side synchronization in the hot path and is a significant performance regression for normal output_logits requests. Please guard the synchronize+loop with if (!ppl_ranges.empty() && tp_rank_ == 0) (and consider a less blocking sync strategy if possible).
| if (tp_rank_ == 0) { | |
| if (!ppl_ranges.empty() && tp_rank_ == 0) { |
| auto& cb = rs[i]->req->logits_cb; | ||
| cb(logits.raw_data(), vocab_size_, chunk.begin(), (int)chunk.size(), logits.dtype()); |
There was a problem hiding this comment.
The PPL callback is invoked with the entire logits chunk pointer and begin/count in packed-token coordinates (chunk.begin()/chunk.size()), ignoring each request's [src,dst] intervals in ppl_ranges. With batching, a chunk can span multiple sequences, so this will at best compute incorrect loss and at worst index beyond the request's input_ids in Python. Please intersect each request's src with the current chunk (similar to OutputLogitsImpl), slice the logits rows accordingly, and pass a per-request begin offset/count to the callback.
| auto& cb = rs[i]->req->logits_cb; | |
| cb(logits.raw_data(), vocab_size_, chunk.begin(), (int)chunk.size(), logits.dtype()); | |
| // Intersect this request's src interval with the current chunk. | |
| if (auto s = src & chunk) { | |
| // Offset into the current logits chunk where this request's data starts. | |
| auto chunk_offset = s.begin() - chunk.begin(); | |
| // Offset within this request's src interval for the callback. | |
| auto req_offset = s.begin() - src.begin(); | |
| auto count = static_cast<int>(s.size()); | |
| // Slice logits rows so the callback only sees this request's tokens. | |
| auto logits_view = logits.slice(static_cast<int>(chunk_offset), count); | |
| auto& cb = rs[i]->req->logits_cb; | |
| cb(logits_view.raw_data(), vocab_size_, static_cast<int>(req_offset), count, logits.dtype()); | |
| } |
| continue | ||
| seq_logits = logits[offset:offset + length - 1] | ||
| seq_targets = input_ids[offset + 1:offset + length] | ||
| print(f'seq_logits: {seq_logits.shape}, seq_targets: {seq_targets.shape}') |
There was a problem hiding this comment.
There is a stray print(...) debug statement in _compute_ppl_from_logits that will spam stdout during inference and can severely impact throughput when called frequently. Please remove it or replace it with a guarded logger.debug if you still need shape diagnostics.
| print(f'seq_logits: {seq_logits.shape}, seq_targets: {seq_targets.shape}') | |
| logger.debug( | |
| 'seq_logits: %s, seq_targets: %s', | |
| seq_logits.shape, | |
| seq_targets.shape, | |
| ) |
| ppl_loss: float = None | ||
| ppl_count: int = None |
There was a problem hiding this comment.
These fields are annotated as non-optional (float/int) but default to None, which is inconsistent with nearby optional fields and can break static type checking. Consider annotating them as float | None and int | None (or provide non-None defaults like 0.0/0).
| ppl_loss: float = None | |
| ppl_count: int = None | |
| ppl_loss: float | None = None | |
| ppl_count: int | None = None |
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
Please describe the motivation of this PR and the goal you want to achieve through this PR.
Modification
Please briefly describe what modification is made in this PR.
BC-breaking (Optional)
Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
Use cases (Optional)
If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
Checklist