Skip to content

Commit 90fcafc

Browse files
committed
bindings-java : disable flash attention by default
This commit disables flash-attention for the Java binding test so that the testFullTranscribe test passes. Without this change the test was failing because the expected output mismatches after the flash-attention change: ```console <And so my fellow Americans ask not what your country can do for you ask what you can do for your country.> but was: <and so my fellow Americans ask not what your country can do for you ask what you can do for your country> ``` An alternative would also be to update the expected output but it felt better to keep the same expected output and disable flash-attention and not just change the expected output to match the new behavior.
1 parent c610672 commit 90fcafc

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public WhisperContextParams() {
2020
/** Use GPU for inference (default = true) */
2121
public CBool use_gpu;
2222

23-
/** Use flash attention (default = false) */
23+
/** Use flash attention (default = true) */
2424
public CBool flash_attn;
2525

2626
/** CUDA device to use (default = 0) */

bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import io.github.ggerganov.whispercpp.bean.WhisperSegment;
66
import io.github.ggerganov.whispercpp.params.CBool;
7+
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
78
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
89
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
910
import org.junit.jupiter.api.BeforeAll;
@@ -25,7 +26,9 @@ static void init() throws FileNotFoundException {
2526
//String modelName = "../../models/ggml-tiny.bin";
2627
String modelName = "../../models/ggml-tiny.en.bin";
2728
try {
28-
whisper.initContext(modelName);
29+
WhisperContextParams.ByValue contextParams = whisper.getContextDefaultParams();
30+
contextParams.useFlashAttn(false); // Disable flash attention
31+
whisper.initContext(modelName, contextParams);
2932
//whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
3033
//whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
3134
modelInitialised = true;

0 commit comments

Comments
 (0)