Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.alibaba.dashscope.audio.protocol;

import java.nio.ByteBuffer;
import okhttp3.WebSocket;

/** @author songsong.shao */
public interface AudioWebsocketCallback {

void onOpen();

void onMessage(WebSocket webSocket, String text);

void onMessage(WebSocket webSocket, ByteBuffer buffer);

void onError(WebSocket webSocket, Throwable t);

void onClose(int code, String reason);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package com.alibaba.dashscope.audio.protocol;

import com.alibaba.dashscope.exception.NoApiKeyException;
import com.alibaba.dashscope.protocol.DashScopeHeaders;
import com.alibaba.dashscope.protocol.okhttp.OkHttpClientFactory;
import com.alibaba.dashscope.utils.ApiKey;
import com.alibaba.dashscope.utils.Constants;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import okio.ByteString;

/** @author songsong.shao */
@Slf4j
public class AudioWebsocketRequest extends WebSocketListener {

private OkHttpClient client;
private WebSocket websocktetClient;
private AtomicBoolean isOpen = new AtomicBoolean(false);
private AtomicReference<CountDownLatch> connectLatch = new AtomicReference<>(null);
private AtomicBoolean isClosed = new AtomicBoolean(false);
private AudioWebsocketCallback callback;
private Integer connectTimeout = 5000;

public boolean isOpen() {
return isOpen.get();
}

public boolean isClosed() {
return isClosed.get();
}

public void checkStatus() {
if (this.isClosed.get()) {
throw new RuntimeException("Websocket is already closed!");
}
}

public void connect(
String apiKey,
String workspace,
Map<String, String> customHeaders,
String baseWebSocketUrl,
AudioWebsocketCallback callback)
throws NoApiKeyException, InterruptedException, RuntimeException {
Request request =
buildConnectionRequest(
ApiKey.getApiKey(apiKey), false, workspace, customHeaders, baseWebSocketUrl);
this.callback = callback;
client = OkHttpClientFactory.getOkHttpClient();
websocktetClient = client.newWebSocket(request, this);
connectLatch.set(new CountDownLatch(1));
boolean result = connectLatch.get().await(connectTimeout, TimeUnit.MILLISECONDS);
if (!result) {
throw new RuntimeException(
"TimeoutError: waiting for websocket connect more than" + connectTimeout + " ms.");
}
}

private Request buildConnectionRequest(
String apiKey,
boolean isSecurityCheck,
String workspace,
Map<String, String> customHeaders,
String baseWebSocketUrl)
throws NoApiKeyException {
// build the request builder.
Request.Builder bd = new Request.Builder();
bd.headers(
Headers.of(
DashScopeHeaders.buildWebSocketHeaders(
apiKey, isSecurityCheck, workspace, customHeaders)));
String url = Constants.baseWebsocketApiUrl;
if (baseWebSocketUrl != null) {
url = baseWebSocketUrl;
}
Request request = bd.url(url).build();
return request;
}

private void sendMessage(String message, boolean enableLog) {
checkStatus();
if (enableLog) {
log.debug("send message: " + message);
}
if (!websocktetClient.send(message)) {
log.warn("Failed to enqueue websocket text message for sending.");
}
}

public void close() {
this.close(1000, "bye");
}

public void close(int code, String reason) {
checkStatus();
websocktetClient.close(code, reason);
isClosed.set(true);
}

public void sendTextMessage(String message) {
checkStatus();
this.sendMessage(message, true);
}

public void sendBinaryMessage(ByteString rawData) {
checkStatus();
if (!websocktetClient.send(rawData)) {
log.warn("Failed to enqueue websocket binary message for sending.");
}
}

@Override
public void onOpen(WebSocket webSocket, Response response) {
isOpen.set(true);
connectLatch.get().countDown();
log.debug("WebSocket opened");
callback.onOpen();
}

@Override
public void onMessage(WebSocket webSocket, String text) {
callback.onMessage(webSocket, text);
}

@Override
public void onMessage(WebSocket webSocket, ByteString bytes) {
log.debug("Received binary message");
callback.onMessage(webSocket, bytes.asByteBuffer());
}

@Override
public void onClosed(WebSocket webSocket, int code, String reason) {
isOpen.set(false);
isClosed.set(true);
connectLatch.get().countDown();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The call to connectLatch.get() could result in a NullPointerException if onClosed is called before connect() has initialized the latch. It's safer to add a null check, similar to what's done in onFailure.

Suggested change
connectLatch.get().countDown();
if (connectLatch.get() != null) {
connectLatch.get().countDown();
}

log.debug("WebSocket closed");
callback.onClose(code, reason);
}

@Override
public void onFailure(WebSocket webSocket, Throwable t, Response response) {
log.error("WebSocket failed: " + t.getMessage());
if (connectLatch.get() != null) {
connectLatch.get().countDown();
}
if (callback != null) {
callback.onError(webSocket, t);
} else {
throw new RuntimeException(t);
}
}
}
Loading