From f98318d5981cc6b3edbb34c4a8af90fc1a4027ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=A5=E4=BB=99?= Date: Tue, 27 Jan 2026 14:23:57 +0800 Subject: [PATCH] (feat:model/cosyvoice):use new protocol to build websocket connection --- .../protocol/AudioWebsocketCallback.java | 18 + .../audio/protocol/AudioWebsocketRequest.java | 157 +++++ .../audio/ttsv2/SpeechSynthesizerV2.java | 544 ++++++++++++++++++ .../TestTtsV2SpeechSynthesizerV2.java | 160 ++++++ 4 files changed, 879 insertions(+) create mode 100644 src/main/java/com/alibaba/dashscope/audio/protocol/AudioWebsocketCallback.java create mode 100644 src/main/java/com/alibaba/dashscope/audio/protocol/AudioWebsocketRequest.java create mode 100644 src/main/java/com/alibaba/dashscope/audio/ttsv2/SpeechSynthesizerV2.java create mode 100644 src/test/java/com/alibaba/dashscope/TestTtsV2SpeechSynthesizerV2.java diff --git a/src/main/java/com/alibaba/dashscope/audio/protocol/AudioWebsocketCallback.java b/src/main/java/com/alibaba/dashscope/audio/protocol/AudioWebsocketCallback.java new file mode 100644 index 0000000..93dcce2 --- /dev/null +++ b/src/main/java/com/alibaba/dashscope/audio/protocol/AudioWebsocketCallback.java @@ -0,0 +1,18 @@ +package com.alibaba.dashscope.audio.protocol; + +import java.nio.ByteBuffer; +import okhttp3.WebSocket; + +/** @author songsong.shao */ +public interface AudioWebsocketCallback { + + void onOpen(); + + void onMessage(WebSocket webSocket, String text); + + void onMessage(WebSocket webSocket, ByteBuffer buffer); + + void onError(WebSocket webSocket, Throwable t); + + void onClose(int code, String reason); +} diff --git a/src/main/java/com/alibaba/dashscope/audio/protocol/AudioWebsocketRequest.java b/src/main/java/com/alibaba/dashscope/audio/protocol/AudioWebsocketRequest.java new file mode 100644 index 0000000..ec76f5f --- /dev/null +++ b/src/main/java/com/alibaba/dashscope/audio/protocol/AudioWebsocketRequest.java @@ -0,0 +1,157 @@ +package com.alibaba.dashscope.audio.protocol; + +import com.alibaba.dashscope.exception.NoApiKeyException; +import com.alibaba.dashscope.protocol.DashScopeHeaders; +import com.alibaba.dashscope.protocol.okhttp.OkHttpClientFactory; +import com.alibaba.dashscope.utils.ApiKey; +import com.alibaba.dashscope.utils.Constants; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import lombok.extern.slf4j.Slf4j; +import okhttp3.*; +import okio.ByteString; + +/** @author songsong.shao */ +@Slf4j +public class AudioWebsocketRequest extends WebSocketListener { + + private OkHttpClient client; + private WebSocket websocktetClient; + private AtomicBoolean isOpen = new AtomicBoolean(false); + private AtomicReference connectLatch = new AtomicReference<>(null); + private AtomicBoolean isClosed = new AtomicBoolean(false); + private AudioWebsocketCallback callback; + private Integer connectTimeout = 5000; + + public boolean isOpen() { + return isOpen.get(); + } + + public boolean isClosed() { + return isClosed.get(); + } + + public void checkStatus() { + if (this.isClosed.get()) { + throw new RuntimeException("Websocket is already closed!"); + } + } + + public void connect( + String apiKey, + String workspace, + Map customHeaders, + String baseWebSocketUrl, + AudioWebsocketCallback callback) + throws NoApiKeyException, InterruptedException, RuntimeException { + Request request = + buildConnectionRequest( + ApiKey.getApiKey(apiKey), false, workspace, customHeaders, baseWebSocketUrl); + this.callback = callback; + client = OkHttpClientFactory.getOkHttpClient(); + websocktetClient = client.newWebSocket(request, this); + connectLatch.set(new CountDownLatch(1)); + boolean result = connectLatch.get().await(connectTimeout, TimeUnit.MILLISECONDS); + if (!result) { + throw new RuntimeException( + "TimeoutError: waiting for websocket connect more than" + connectTimeout + " ms."); + } + } + + private Request buildConnectionRequest( + String apiKey, + boolean isSecurityCheck, + String workspace, + Map customHeaders, + String baseWebSocketUrl) + throws NoApiKeyException { + // build the request builder. + Request.Builder bd = new Request.Builder(); + bd.headers( + Headers.of( + DashScopeHeaders.buildWebSocketHeaders( + apiKey, isSecurityCheck, workspace, customHeaders))); + String url = Constants.baseWebsocketApiUrl; + if (baseWebSocketUrl != null) { + url = baseWebSocketUrl; + } + Request request = bd.url(url).build(); + return request; + } + + private void sendMessage(String message, boolean enableLog) { + checkStatus(); + if (enableLog) { + log.debug("send message: " + message); + } + if (!websocktetClient.send(message)) { + log.warn("Failed to enqueue websocket text message for sending."); + } + } + + public void close() { + this.close(1000, "bye"); + } + + public void close(int code, String reason) { + checkStatus(); + websocktetClient.close(code, reason); + isClosed.set(true); + } + + public void sendTextMessage(String message) { + checkStatus(); + this.sendMessage(message, true); + } + + public void sendBinaryMessage(ByteString rawData) { + checkStatus(); + if (!websocktetClient.send(rawData)) { + log.warn("Failed to enqueue websocket binary message for sending."); + } + } + + @Override + public void onOpen(WebSocket webSocket, Response response) { + isOpen.set(true); + connectLatch.get().countDown(); + log.debug("WebSocket opened"); + callback.onOpen(); + } + + @Override + public void onMessage(WebSocket webSocket, String text) { + callback.onMessage(webSocket, text); + } + + @Override + public void onMessage(WebSocket webSocket, ByteString bytes) { + log.debug("Received binary message"); + callback.onMessage(webSocket, bytes.asByteBuffer()); + } + + @Override + public void onClosed(WebSocket webSocket, int code, String reason) { + isOpen.set(false); + isClosed.set(true); + connectLatch.get().countDown(); + log.debug("WebSocket closed"); + callback.onClose(code, reason); + } + + @Override + public void onFailure(WebSocket webSocket, Throwable t, Response response) { + log.error("WebSocket failed: " + t.getMessage()); + if (connectLatch.get() != null) { + connectLatch.get().countDown(); + } + if (callback != null) { + callback.onError(webSocket, t); + } else { + throw new RuntimeException(t); + } + } +} diff --git a/src/main/java/com/alibaba/dashscope/audio/ttsv2/SpeechSynthesizerV2.java b/src/main/java/com/alibaba/dashscope/audio/ttsv2/SpeechSynthesizerV2.java new file mode 100644 index 0000000..335a9ec --- /dev/null +++ b/src/main/java/com/alibaba/dashscope/audio/ttsv2/SpeechSynthesizerV2.java @@ -0,0 +1,544 @@ +// Copyright (c) Alibaba, Inc. and its affiliates. + +package com.alibaba.dashscope.audio.ttsv2; + +import com.alibaba.dashscope.audio.protocol.AudioWebsocketCallback; +import com.alibaba.dashscope.audio.protocol.AudioWebsocketRequest; +import com.alibaba.dashscope.audio.tts.SpeechSynthesisResult; +import com.alibaba.dashscope.audio.tts.SpeechSynthesisUsage; +import com.alibaba.dashscope.audio.tts.timestamp.Sentence; +import com.alibaba.dashscope.common.*; +import com.alibaba.dashscope.exception.ApiException; +import com.alibaba.dashscope.exception.InputRequiredException; +import com.alibaba.dashscope.exception.NoApiKeyException; +import com.alibaba.dashscope.protocol.*; +import com.alibaba.dashscope.utils.Constants; +import com.alibaba.dashscope.utils.JsonUtils; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; +import java.util.*; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; +import okhttp3.WebSocket; + +/** @author songsong.sss */ +@Slf4j +public final class SpeechSynthesizerV2 implements AudioWebsocketCallback { + private SpeechSynthesisState state = SpeechSynthesisState.IDLE; + private ResultCallback callback; + + private AtomicReference stopLatch = new AtomicReference<>(null); + + private SpeechSynthesisParam parameters; + + private String preRequestId = null; + private boolean isFirst = true; + private AtomicBoolean canceled = new AtomicBoolean(false); + private boolean asyncCall = false; + private ByteArrayOutputStream audioStream; + private long startStreamTimeStamp = -1; + private long firstPackageTimeStamp = -1; + private double recvAudioLength = 0; + @Getter @Setter private long startedTimeout = 5000; + @Getter @Setter private long firstAudioTimeout = -1; + private AtomicReference startLatch = new AtomicReference<>(null); + private AudioWebsocketRequest websocketRequest; + private String websocketUrl = Constants.baseWebsocketApiUrl; + private JsonObject bailianHeader = new JsonObject(); + private static final String HEADER_ACTION = "action"; + private static final String TASK_ID = "task_id"; + private static final Integer DEFAULT_COMPLETE_TIMEOUT = 60 * 1000; + private String taskId; + private boolean enableSsml = false; + + /** + * CosyVoice Speech Synthesis SDK + * + * @param param Configuration for speech synthesis, including voice type, volume, etc. + * @param callback In non-streaming output scenarios, this can be set to null + * @param baseUrl Base URL + * @param connectionOptions Connection options + */ + public SpeechSynthesizerV2( + SpeechSynthesisParam param, + ResultCallback callback, + String baseUrl, + ConnectionOptions connectionOptions) { + if (baseUrl != null) { + this.websocketUrl = baseUrl; + } + this.parameters = param; + this.callback = callback; + this.asyncCall = this.callback != null; + this.taskId = UUID.randomUUID().toString(); + } + + /** + * CosyVoice Speech Synthesis SDK + * + * @param baseUrl Base URL + * @param connectionOptions Connection options + */ + public SpeechSynthesizerV2(String baseUrl, ConnectionOptions connectionOptions) { + this(null, null, baseUrl, connectionOptions); + } + + /** CosyVoice Speech Synthesis SDK */ + public SpeechSynthesizerV2() { + this(null, null, null, null); + } + + public void updateParamAndCallback( + SpeechSynthesisParam param, ResultCallback callback) { + this.parameters = param; + this.callback = callback; + this.canceled.set(false); + + // reset inner params + this.stopLatch = new AtomicReference<>(null); + this.startLatch = new AtomicReference<>(null); + this.firstAudioTimeout = -1; + this.isFirst = true; + this.audioStream = new ByteArrayOutputStream(); + + this.asyncCall = this.callback != null; + this.taskId = UUID.randomUUID().toString(); + } + + /** + * CosyVoice Speech Synthesis SDK + * + * @param param Configuration for speech synthesis, including voice type, volume, etc. + * @param callback In non-streaming output scenarios, this can be set to null + * @param baseUrl Base URL + */ + public SpeechSynthesizerV2( + SpeechSynthesisParam param, ResultCallback callback, String baseUrl) { + this(param, callback, baseUrl, null); + } + + /** + * CosyVoice Speech Synthesis SDK + * + * @param param Configuration for speech synthesis, including voice type, volume, etc. + * @param callback In non-streaming output scenarios, this can be set to null + */ + public SpeechSynthesizerV2( + SpeechSynthesisParam param, ResultCallback callback) { + this(param, callback, null, null); + } + + public String getLastRequestId() { + return preRequestId; + } + + private void checkConnectStatus() { + websocketRequest.checkStatus(); + } + + public void connect() throws NoApiKeyException, InterruptedException { + startStreamTimeStamp = System.currentTimeMillis(); + this.audioStream = new ByteArrayOutputStream(); + this.canceled.set(false); + if (websocketRequest != null && websocketRequest.isOpen()) { + websocketRequest.close(); + } + + websocketRequest = new AudioWebsocketRequest(); + websocketRequest.connect( + parameters.getApiKey(), + parameters.getWorkspace(), + parameters.getHeaders(), + websocketUrl, + this); + } + + public void close() { + if (websocketRequest != null && websocketRequest.isOpen()) { + try { + websocketRequest.close(); + } catch (Exception e) { + log.warn("Failed to close websocket connection: " + e.getMessage()); + } + } + } + + private void sendTaskMessage(String action, JsonObject input) { + JsonObject wsMessage = new JsonObject(); + + bailianHeader.addProperty(HEADER_ACTION, action); + bailianHeader.addProperty(TASK_ID, taskId); + + JsonObject payload = new JsonObject(); + if ("run-task".equals(action)) { + payload.addProperty("task_group", "audio"); + payload.addProperty("task", "tts"); + payload.addProperty("function", "SpeechSynthesizer"); + payload.addProperty("model", this.parameters.getModel()); + JsonObject parameters = JsonUtils.toJsonObject(this.parameters.getParameters()); + if (enableSsml) { + parameters.addProperty("enable_ssml", true); + } + payload.add("parameters", parameters); + + payload.add("input", input != null ? input : new JsonObject()); + } else { + payload.add("input", input != null ? input : new JsonObject()); + } + + wsMessage.add("header", JsonUtils.toJsonObject(bailianHeader)); + wsMessage.add("payload", JsonUtils.toJsonObject(payload)); + log.debug("sendTaskMessage: {}", wsMessage.toString()); + websocketRequest.sendTextMessage(wsMessage.toString()); + } + + public void startSynthesizer(boolean enableSsml) throws InterruptedException { + bailianHeader.addProperty("streaming", "duplex"); + this.enableSsml = enableSsml; + sendTaskMessage("run-task", new JsonObject()); + } + + public void sendText(String text) { + JsonObject input = new JsonObject(); + input.addProperty("text", text); + sendTaskMessage("continue-task", input); + } + + public void stopSynthesizer() { + sendTaskMessage("finish-task", new JsonObject()); + } + + @Override + public void onOpen() { + log.info("WebSocket connection opened"); + if (callback != null) { + callback.onOpen(null); + } + } + + @Override + public void onMessage(WebSocket webSocket, String text) { + log.debug("Received text message: " + text); + try { + JsonObject messageObj = JsonParser.parseString(text).getAsJsonObject(); + if (messageObj.has("header")) { + JsonObject header = messageObj.getAsJsonObject("header"); + if (header.has("event")) { + String event = header.get("event").getAsString(); + + switch (event) { + case "task-started": + handleTaskStarted(messageObj); + break; + case "task-finished": + handleTaskFinished(messageObj); + break; + case "task-failed": + handleTaskFailed(messageObj); + break; + case "result-generated": + handleResultGenerated(messageObj); + break; + default: + log.warn("Unknown event: " + event); + break; + } + } + } + } catch (Exception e) { + log.error("Error processing text message: " + e.getMessage(), e); + } + } + + @Override + public void onMessage(WebSocket webSocket, ByteBuffer bytes) { + log.debug("Received binary message, size: {}", bytes.remaining()); + try { + ByteBuffer audioFrame = ByteBuffer.allocate(bytes.remaining()); + audioFrame.put(bytes); + audioFrame.flip(); + + if (callback != null) { + SpeechSynthesisResult result = new SpeechSynthesisResult(); + result.setAudioFrame(audioFrame); + callback.onEvent(result); + } else { + // Use atomic reference compare-and-swap for thread-safe accumulation + accumulateAudioData(audioFrame); + } + + // Update received audio length + recvAudioLength += bytes.remaining(); + + } catch (Exception e) { + log.error("Error processing binary message", e); + if (callback != null) { + callback.onError(e); + } + } + } + + /** + * Accumulates audio data to audioStream. Reuses existing buffer when possible to minimize + * allocations. + */ + private void accumulateAudioData(ByteBuffer frame) throws Exception { + if (audioStream == null) { + audioStream = new ByteArrayOutputStream(); + } + byte[] buffer = new byte[frame.remaining()]; + frame.get(buffer); + audioStream.write(buffer, 0, buffer.length); + } + + @Override + public void onError(WebSocket webSocket, Throwable t) { + if (callback != null) { + // callback error first + callback.onError(new ApiException(t)); + } + + CountDownLatch startLatch = this.startLatch.get(); + if (startLatch != null && startLatch.getCount() > 0) { + startLatch.countDown(); + } + + CountDownLatch stopLatch = this.stopLatch.get(); + if (stopLatch != null && stopLatch.getCount() > 0) { + stopLatch.countDown(); + } + + if (audioStream != null) { + audioStream.reset(); + } + } + + @Override + public void onClose(int code, String reason) { + log.warn("WebSocket connection closed: " + reason + " (" + code + ")"); + } + + private void handleTaskStarted(JsonObject message) { + log.info("Task started"); + state = SpeechSynthesisState.TTS_STARTED; + firstPackageTimeStamp = -1; + if (startLatch.get() != null) { + startLatch.get().countDown(); + } + } + + private void handleTaskFinished(JsonObject message) { + log.info("Task finished"); + if (stopLatch.get() != null) { + stopLatch.get().countDown(); + } + if (callback != null) { + callback.onComplete(); + } + if (audioStream != null) { + audioStream.reset(); // 重置 ByteArrayOutputStream,清空数据但保留缓冲区 + } + // Reset for reuse + isFirst = true; + } + + private void handleTaskFailed(JsonObject message) { + log.error("Task failed: " + message.toString()); + if (callback != null) { + String errorMessage = "Unknown error"; + if (message.has("header") && message.getAsJsonObject("header").has("error_message")) { + errorMessage = message.getAsJsonObject("header").get("error_message").getAsString(); + } + + // Create a Status object for the ApiException + com.alibaba.dashscope.common.Status status = + com.alibaba.dashscope.common.Status.builder() + .statusCode(-1) + .code("TASK_FAILED") + .message(errorMessage) + .build(); + callback.onError(new ApiException(status)); + + if (stopLatch.get() != null) { + stopLatch.get().countDown(); + } + } + } + + private void handleResultGenerated(JsonObject message) { + log.debug("Result generated: " + message.toString()); + if (callback == null) { + return; + } + SpeechSynthesisResult result = new SpeechSynthesisResult(); + if (message.has("header")) { + JsonObject header = message.getAsJsonObject("header"); + if (header.has("task_id")) { + preRequestId = header.get("task_id").getAsString(); + result.setRequestId(preRequestId); + } + } + if (message.has("payload")) { + JsonObject payload = message.getAsJsonObject("payload"); + if (payload != null && payload.has("output")) { + JsonObject output = payload.getAsJsonObject("output"); + result.setOutput(output); + if (output != null && output.has("sentence")) { + result.setTimestamp( + JsonUtils.fromJsonObject(output.getAsJsonObject("sentence"), Sentence.class)); + } + } + if (payload != null && payload.has("usage")) { + result.setUsage( + JsonUtils.fromJsonObject(payload.getAsJsonObject("usage"), SpeechSynthesisUsage.class)); + } + } + callback.onEvent(result); + } + + /** First Package Delay is the time between start sending text and receive first audio package */ + public long getFirstPackageDelay() { + return this.firstPackageTimeStamp - this.startStreamTimeStamp; + } + + private void startStream(boolean enableSsml) throws NoApiKeyException, InterruptedException { + if (websocketRequest == null || !websocketRequest.isOpen()) { + // if websocket is not open, then connect + connect(); + } else { + startStreamTimeStamp = System.currentTimeMillis(); + } + + checkConnectStatus(); // check websocket connection, if socket is closed. + startLatch = new AtomicReference<>(new CountDownLatch(1)); + startSynthesizer(enableSsml); + boolean startResult = startLatch.get().await(startedTimeout, TimeUnit.MILLISECONDS); + if (!startResult) { + throw new RuntimeException( + "TimeoutError: waiting for task started more than" + startedTimeout + " ms."); + } + } + + private void submitText(String text) { + if (text == null || text.isEmpty()) { + throw new ApiException( + new InputRequiredException("Parameter invalid: text is null or empty")); + } + synchronized (this) { + if (state != SpeechSynthesisState.TTS_STARTED) { + throw new ApiException( + new InputRequiredException( + "State invalid: expect stream input tts state is started but " + state.getValue())); + } + sendText(text); + } + } + + private void startStream() throws NoApiKeyException, InterruptedException { + startStream(false); + } + + public void streamingComplete(long completeTimeoutMillis) { + log.debug("streamingComplete with timeout: " + completeTimeoutMillis); + synchronized (this) { + if (state != SpeechSynthesisState.TTS_STARTED) { + throw new ApiException( + new RuntimeException( + "State invalid: expect stream input tts state is started but " + state.getValue())); + } + } + stopLatch = new AtomicReference<>(new CountDownLatch(1)); + stopSynthesizer(); + + if (stopLatch.get() != null) { + try { + if (completeTimeoutMillis > 0) { + log.debug("start waiting for stopLatch"); + if (!stopLatch.get().await(completeTimeoutMillis, TimeUnit.MILLISECONDS)) { + throw new RuntimeException("TimeoutError: waiting for streaming complete"); + } + } else { + log.debug("start waiting for stopLatch"); + stopLatch.get().await(); + } + log.debug("stopLatch is done"); + } catch (InterruptedException ignored) { + log.error("Interrupted while waiting for streaming complete"); + } + } + } + + public void streamingComplete() { + streamingComplete(DEFAULT_COMPLETE_TIMEOUT); + } + + public void asyncStreamingComplete() { + synchronized (this) { + if (state != SpeechSynthesisState.TTS_STARTED) { + throw new ApiException( + new RuntimeException( + "State invalid: expect stream input tts state is started but " + state.getValue())); + } + stopSynthesizer(); + } + } + + public void streamingCancel() { + canceled.set(true); + synchronized (this) { + if (state != SpeechSynthesisState.TTS_STARTED) { + return; + } + stopSynthesizer(); + } + } + + public void streamingCall(String text) { + if (isFirst) { + isFirst = false; + try { + this.startStream(false); + this.submitText(text); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); // 恢复中断状态 + log.error("Interrupted while waiting for streaming complete", e); + throw new ApiException(e); + } catch (NoApiKeyException e) { + throw new ApiException(e); + } + } else { + this.submitText(text); + } + } + + public ByteBuffer call(String text, long timeoutMillis) throws RuntimeException { + try { + this.startStream(true); + this.submitText(text); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); // 恢复中断状态 + log.error("Interrupted while waiting for streaming complete", e); + throw new ApiException(e); + } catch (NoApiKeyException e) { + throw new ApiException(e); + } + if (this.asyncCall) { + this.asyncStreamingComplete(); + return null; + } else { + this.streamingComplete(timeoutMillis); + return ByteBuffer.wrap(audioStream.toByteArray()); + } + } + + public ByteBuffer call(String text) { + return call(text, 0); + } +} diff --git a/src/test/java/com/alibaba/dashscope/TestTtsV2SpeechSynthesizerV2.java b/src/test/java/com/alibaba/dashscope/TestTtsV2SpeechSynthesizerV2.java new file mode 100644 index 0000000..7e677a4 --- /dev/null +++ b/src/test/java/com/alibaba/dashscope/TestTtsV2SpeechSynthesizerV2.java @@ -0,0 +1,160 @@ +// Copyright (c) Alibaba, Inc. and its affiliates. + +package com.alibaba.dashscope; + +import static org.junit.Assert.assertEquals; + +import com.alibaba.dashscope.audio.tts.SpeechSynthesisResult; +import com.alibaba.dashscope.audio.ttsv2.SpeechSynthesisAudioFormat; +import com.alibaba.dashscope.audio.ttsv2.SpeechSynthesisParam; +import com.alibaba.dashscope.audio.ttsv2.SpeechSynthesizerV2; +import com.alibaba.dashscope.common.ResultCallback; +import com.alibaba.dashscope.utils.Constants; +import com.alibaba.dashscope.utils.JsonUtils; +import com.google.gson.JsonObject; +import java.io.IOException; +import java.util.ArrayList; +import lombok.extern.slf4j.Slf4j; +import okhttp3.Response; +import okhttp3.WebSocket; +import okhttp3.WebSocketListener; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okio.ByteString; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.api.parallel.ExecutionMode; + +@Execution(ExecutionMode.SAME_THREAD) +@Slf4j +public class TestTtsV2SpeechSynthesizerV2 { + private static ArrayList audioBuffer; + private static ResultCallback callback = + new ResultCallback() { + @Override + public void onEvent(SpeechSynthesisResult message) { + System.out.println("onEvent:" + message); + if (message.getAudioFrame() != null) { + for (byte b : message.getAudioFrame().array()) { + audioBuffer.add(b); + } + } + } + + @Override + public void onComplete() { + // System.out.println("onComplete"); + } + + @Override + public void onError(Exception e) {} + }; + private static MockWebServer mockServer; + + @BeforeAll + public static void before() throws IOException { + audioBuffer = new ArrayList<>(); + mockServer = new MockWebServer(); + mockServer.start(); + MockResponse response = + new MockResponse() + .withWebSocketUpgrade( + new WebSocketListener() { + String task_id = ""; + + @Override + public void onOpen(WebSocket webSocket, Response response) { + System.out.println("Mock Server onOpen"); + System.out.println( + "Mock Server request header:" + response.request().headers()); + System.out.println("Mock Server response header:" + response.headers()); + System.out.println("Mock Server response:" + response); + } + + @Override + public void onMessage(WebSocket webSocket, String string) { + System.out.println("mock server recv: " + string); + JsonObject req = JsonUtils.parse(string); + if (task_id == "") { + task_id = req.get("header").getAsJsonObject().get("task_id").getAsString(); + } + if (string.contains("run-task")) { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + webSocket.send( + "{'header': {'task_id': '" + + task_id + + "', 'event': 'task-started', 'attributes': {}}, 'payload': {}}"); + } else if (string.contains("finish-task")) { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + webSocket.send( + "{'header': {'task_id': '" + + task_id + + "', 'event': 'task-finished', 'attributes': {}}, 'payload': {'output': None, 'usage': {'characters': 7}}}"); + webSocket.close(1000, "close by server"); + } else if (string.contains("continue-task")) { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + byte[] binary = new byte[] {0x01, 0x01, 0x01}; + webSocket.send(new ByteString(binary)); + } + } + }); + mockServer.enqueue(response); + } + + @AfterAll + public static void after() throws IOException { + System.out.println("Mock Server is closed"); + mockServer.close(); + } + + @Test + public void testStreamingCall() { + System.out.println("############ Start Test Streaming Call ############"); + int port = mockServer.getPort(); + Constants.baseWebsocketApiUrl = String.format("http://127.0.0.1:%s", port); + + // 获取 URL + String url = mockServer.url("/binary").toString(); + + // 在真实世界中,你会在这里做 HTTP 请求,并得到响应 + System.out.println("Mock Server is running at: " + url); + SpeechSynthesisParam param = + SpeechSynthesisParam.builder() + .apiKey("1234") + .model("cosyvoice-v1") + .voice("longxiaochun") + .format(SpeechSynthesisAudioFormat.MP3_16000HZ_MONO_128KBPS) + .build(); + SpeechSynthesizerV2 synthesizer = new SpeechSynthesizerV2(param, callback); + synthesizer.setStartedTimeout(1000); + synthesizer.setFirstAudioTimeout(2000); + for (int i = 0; i < 3; i++) { + synthesizer.streamingCall("今天天气怎么样?"); + } + try { + synthesizer.streamingComplete(); + synthesizer.close(); + } catch (Exception e) { + throw new RuntimeException(e); + } + assertEquals(audioBuffer.size(), 9); + for (int i = 0; i < 9; i++) { + assertEquals((byte) audioBuffer.get(i), (byte) 0x01); + } + System.out.println("############ Start Test Streaming Call Done ############"); + } +}