Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/models/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,13 @@ namespace ctranslate2 {

std::vector<std::vector<std::pair<dim_t, dim_t>>> alignments;

if (variable_num_frames) {
if (std::all_of(num_frames.begin(), num_frames.end(), [](size_t size) { return size == 0; })) {
// A window shorter than the encoder stride (num_frames < 2) has no
// frames left to align against: running the attention post-processing
// on zero-size tensors is at best undefined. Return empty alignments.
alignments.resize(batch_size);

} else if (variable_num_frames) {
const StorageView frame_sizes({batch_size},
std::vector<int32_t>(num_frames.begin(), num_frames.end()),
device);
Expand All @@ -524,6 +530,11 @@ namespace ctranslate2 {
alignments.reserve(batch_size);

for (dim_t b = 0; b < batch_size; ++b) {
if (num_frames[b] == 0) {
alignments.emplace_back();
continue;
}

// Retrieve attention probs for batch and remove padding.
StorageView batch_id({1}, int32_t(b), device);
StorageView attention_probs(dtype, device);
Expand Down
19 changes: 13 additions & 6 deletions src/ops/median_filter_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,23 @@ namespace ctranslate2 {
void MedianFilter::compute(const StorageView& input,
const dim_t axis_size,
StorageView& output) const {
const auto* src = input.data<T>();
auto* dst = output.data<T>();


const dim_t depth = axis_size;
const dim_t batch_size = input.size() / depth;
const dim_t rank = _width / 2;

if (depth <= rank)
// Guard before dividing by depth: a zero-size axis (e.g. Whisper align()
// with num_frames < 2, halved to 0 by the encoder stride) would make
// input.size() / depth an integer division by zero — a native crash,
// not a catchable exception. Pass the input through like the GPU path.
if (depth <= rank) {
if (&output != &input)
output.copy_from(input);
return;
}

const auto* src = input.data<T>();
auto* dst = output.data<T>();

const dim_t batch_size = input.size() / depth;

cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) {
StorageView window_storage({_width}, DataType::FLOAT32);
Expand Down
9 changes: 7 additions & 2 deletions src/ops/median_filter_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,14 @@ namespace ctranslate2 {
const dim_t axis_size,
StorageView& output) const {
const int depth = static_cast<int>(axis_size);
const int rows = static_cast<int>(input.size() / depth);
const int width = static_cast<int>(_width);
const int rank = width / 2;

// Host-side guards and fallbacks.
// Host-side guards and fallbacks. The depth guard must run before
// input.size() / depth below: a zero-size axis (e.g. Whisper align()
// with num_frames < 2, halved to 0 by the encoder stride) would make it
// an integer division by zero — a native crash, not a catchable
// exception (0xC0000094 on Windows, SIGFPE on Linux).
if (width <= 1) {
if (&output != &input)
output.copy_from(input);
Expand All @@ -71,6 +74,8 @@ namespace ctranslate2 {
return;
}

const int rows = static_cast<int>(input.size() / depth);

// Grid configuration
const int total = rows * depth;
int blocks = (total + num_threads - 1) / num_threads;
Expand Down
21 changes: 21 additions & 0 deletions tests/ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,27 @@ TEST_P(OpDeviceFPTest, MedianFilter) {
expect_storage_eq(y.to_float32(), expected, error);
}

TEST_P(OpDeviceFPTest, MedianFilterShortAxis) {
// An axis shorter than the filter rank (including a zero-size axis, as
// produced by Whisper align() when num_frames < 2) must pass the input
// through instead of crashing on an integer division by zero.
const Device device = GetParam().device;
const DataType dtype = GetParam().dtype;
const float error = GetParam().error;
{
StorageView x({2, 0}, std::vector<float>{}, device);
StorageView y(dtype, device);
ops::MedianFilter(5)(x.to(dtype), y);
EXPECT_EQ(y.size(), 0);
}
{
StorageView x({2, 2}, std::vector<float>{0.1f, 0.2f, 0.3f, 0.4f}, device);
StorageView y(dtype, device);
ops::MedianFilter(5)(x.to(dtype), y);
expect_storage_eq(y.to_float32(), x, error);
}
}

TEST_P(OpDeviceTest, Add) {
Device device = GetParam();
StorageView a({4}, std::vector<float>{1, 2, 3, 4}, device);
Expand Down
Loading