Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
# Scala Steward: Reformat with scalafmt 3.7.12
333780ab381e31fd80351610657a52a052e96447

# Scala Steward: Reformat with scalafmt 3.8.6
76d0711b5669a39d7800bfa02df199317062de29
2 changes: 1 addition & 1 deletion .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version = 3.8.3
version = 3.8.6
runner.dialect = scala3

align.preset = more
Expand Down
29 changes: 12 additions & 17 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,20 @@ lazy val commonSettings = Seq(
),
libraryDependencies ++= Seq(
"org.scalameta" %% "munit" % munitVersion % Test,
"org.scalameta" %% "munit-scalacheck" % "1.0.0" % Test // TODO: Align with munitVersion, once released
"org.scalameta" %% "munit-scalacheck" % "1.0.0" % Test // TODO: Align with munitVersion, once released
)
)

lazy val syncodia = project
.in(file("core"))
.settings(commonSettings)
.settings(
name := "syncodia",
libraryDependencies ++= Seq(
"org.apache.pekko" %% "pekko-http" % pekkoHttpVersion,
"org.apache.pekko" %% "pekko-stream" % pekkoVersion,
"com.lihaoyi" %% "upickle" % uSerializationVersion,
"com.lihaoyi" %% "ujson" % uSerializationVersion,
"org.scala-lang" % "scala-reflect" % scalaReflectVersion,
"com.knuddels" % "jtokkit" % jTokkitVersion
)
lazy val syncodia = project.in(file("core")).settings(commonSettings).settings(
name := "syncodia",
libraryDependencies ++= Seq(
"org.apache.pekko" %% "pekko-http" % pekkoHttpVersion,
"org.apache.pekko" %% "pekko-stream" % pekkoVersion,
"com.lihaoyi" %% "upickle" % uSerializationVersion,
"com.lihaoyi" %% "ujson" % uSerializationVersion,
"org.scala-lang" % "scala-reflect" % scalaReflectVersion,
"com.knuddels" % "jtokkit" % jTokkitVersion
)
)

lazy val examples = project
.settings(commonSettings)
.dependsOn(syncodia)
lazy val examples = project.settings(commonSettings).dependsOn(syncodia)
10 changes: 3 additions & 7 deletions core/src/main/scala/syncodia/ApiException.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@ import java.io.IOException
import scala.concurrent.duration.DurationInt
import scala.concurrent.duration.FiniteDuration

enum ApiException(
val statusCode: Int,
val message: String,
val maybeRetryAfter: Option[FiniteDuration]
) extends IOException(message):
enum ApiException(val statusCode: Int, val message: String, val maybeRetryAfter: Option[FiniteDuration])
extends IOException(message):

case InvalidAuthenticationException(override val message: String) extends ApiException(401, message, None)

Expand All @@ -34,8 +31,7 @@ enum ApiException(

case RateLimitException(override val message: String) extends ApiException(429, message, Some(1000.millis))

case QuotaExceededException(override val message: String)
extends ApiException(429, message, Some(1000.millis))
case QuotaExceededException(override val message: String) extends ApiException(429, message, Some(1000.millis))

case ServerErrorException(override val message: String) extends ApiException(500, message, Some(100.millis))

Expand Down
14 changes: 4 additions & 10 deletions core/src/main/scala/syncodia/ChatFunction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,15 @@ object ChatFunction:

inline def apply[F](inline f: F, description: String = ""): ChatFunction =
val fs: FunctionSchema =
if description.nonEmpty then functionSchema(f).copy(maybeDescription = Some(description))
else functionSchema(f)
if description.nonEmpty then functionSchema(f).copy(maybeDescription = Some(description)) else functionSchema(f)

def invokeWithParams(parametersJsonString: String): (Any, String) = Try {
val json = ujson.read(parametersJsonString)
fs.readFromJson(json).asInstanceOf[Seq[Any]]
} match
case Failure(exception) =>
exception ->
s"Parameter parsing failed because ${exception.getMessage}"
case Success(parameterValues) =>
Try(Reflection.invoke(f, fs.name, parameterValues*)) match
case Failure(exception) =>
exception ->
fs.prettyFailure(parametersJsonString, exception.getMessage)
case Failure(exception) => exception -> s"Parameter parsing failed because ${exception.getMessage}"
case Success(parameterValues) => Try(Reflection.invoke(f, fs.name, parameterValues*)) match
case Failure(exception) => exception -> fs.prettyFailure(parametersJsonString, exception.getMessage)
case Success(resultValue) =>
val resultAsJson = fs.writeToJson(resultValue)
val prettyResult = fs.prettySuccess(parametersJsonString, resultAsJson.render())
Expand Down
107 changes: 44 additions & 63 deletions core/src/main/scala/syncodia/Syncodia.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

package syncodia

import com.typesafe.config.{ Config, ConfigFactory }
import com.typesafe.config.{Config, ConfigFactory}
import org.apache.pekko.NotUsed
import org.apache.pekko.actor.ActorSystem
import org.apache.pekko.http.scaladsl.*
import org.apache.pekko.http.scaladsl.model.*
import org.apache.pekko.http.scaladsl.model.ContentTypes.`application/json`
import org.apache.pekko.http.scaladsl.model.HttpMethods.POST
import org.apache.pekko.http.scaladsl.model.headers.{ Authorization, OAuth2BearerToken }
import org.apache.pekko.http.scaladsl.model.headers.{Authorization, OAuth2BearerToken}
import org.apache.pekko.http.scaladsl.model.sse.ServerSentEvent
import org.apache.pekko.http.scaladsl.unmarshalling.Unmarshal
import org.apache.pekko.http.scaladsl.unmarshalling.sse.EventStreamUnmarshalling.*
Expand All @@ -36,14 +36,13 @@ import syncodia.openai.protocol.ChatCompletionModel.GPT_35_TURBO
import syncodia.openai.protocol.SerializeJson.*
import ujson.Value.Value

import scala.concurrent.{ ExecutionContext, Future }
import scala.concurrent.duration.{ Duration, DurationInt, FiniteDuration }
import scala.util.{ Failure, Success, Try }
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.{Duration, DurationInt, FiniteDuration}
import scala.util.{Failure, Success, Try}

implicit given string2Message: Conversion[String, Message] = (s: String) => Message(Role.User, s)

implicit given string2Messages: Conversion[String, Seq[Message]] =
(s: String) => Seq(Message(Role.User, s))
implicit given string2Messages: Conversion[String, Seq[Message]] = (s: String) => Seq(Message(Role.User, s))

implicit given message2messages: Conversion[Message, Seq[Message]] = (msg: Message) => Seq(msg)

Expand All @@ -60,8 +59,7 @@ object Syncodia:

val maxBackoffDelay: FiniteDuration = 10.seconds

val defaultPekkoConfig: Config = ConfigFactory
.parseString("""pekko.loglevel = "ERROR"""")
val defaultPekkoConfig: Config = ConfigFactory.parseString("""pekko.loglevel = "ERROR"""")
.withFallback(ConfigFactory.defaultApplication())

def apply(): Syncodia =
Expand All @@ -72,15 +70,13 @@ object Syncodia:

def apply(openAiApiKey: String): Syncodia = new Syncodia(openAiApiKey, None)

def apply(openAiApiKey: String, actorSystem: ActorSystem): Syncodia =
new Syncodia(openAiApiKey, Some(actorSystem))
def apply(openAiApiKey: String, actorSystem: ActorSystem): Syncodia = new Syncodia(openAiApiKey, Some(actorSystem))

end Syncodia

class Syncodia(openAiApiKey: String, maybeProvidedActorSystem: Option[ActorSystem]):

implicit val actorSystem: ActorSystem = maybeProvidedActorSystem
.getOrElse(ActorSystem("default", defaultPekkoConfig))
implicit val actorSystem: ActorSystem = maybeProvidedActorSystem.getOrElse(ActorSystem("default", defaultPekkoConfig))

implicit val executionContext: ExecutionContext = actorSystem.getDispatcher

Expand All @@ -105,38 +101,33 @@ class Syncodia(openAiApiKey: String, maybeProvidedActorSystem: Option[ActorSyste
val tryParse = Try(SerializeJson.read[ChatCompletionResponse](responseString))
tryParse match
case Success(chatCompletionResponse) => chatCompletionResponse
case Failure(e) => throw ParseException(s"Failed to parse response: $responseString", e)
case Failure(e) => throw ParseException(s"Failed to parse response: $responseString", e)
}
}

end executeChatCompletionRequest

private[syncodia] def runApiRequest(r: HttpRequest): Future[HttpResponse] = Http()
.singleRequest(r)
private[syncodia] def runApiRequest(r: HttpRequest): Future[HttpResponse] = Http().singleRequest(r)
.flatMap { response =>
response.status.intValue() match
case code if code >= 200 && code < 300 => Future.successful(response)
case errorCode =>
Unmarshal(response.entity).to[String].flatMap { responseBody =>
case errorCode => Unmarshal(response.entity).to[String].flatMap { responseBody =>
errorCode match
case 401 =>
if responseBody.contains("Invalid Authentication") then
Future.failed(ApiException.InvalidAuthenticationException(responseBody))
else if responseBody.contains("Incorrect API key provided") then
Future.failed(ApiException.IncorrectApiKeyException(responseBody))
else if responseBody
.contains("You must be a member of an organization to use the API")
then Future.failed(ApiException.NoMembershipException(responseBody))
else if responseBody.contains("You must be a member of an organization to use the API") then
Future.failed(ApiException.NoMembershipException(responseBody))
else Future.failed(ApiException.UnhandledException(401, responseBody))
case 429 if responseBody.contains("Rate limit reached") =>
Future.failed(ApiException.RateLimitException(responseBody))
case 429 if responseBody.contains("exceeded your current quota") =>
Future.failed(ApiException.QuotaExceededException(responseBody))
case 500 => Future.failed(ApiException.ServerErrorException(responseBody))
case 503 => Future.failed(ApiException.OverloadedException(responseBody))
case unhandledCode =>
Future
.failed(ApiException.UnhandledException(unhandledCode, responseBody))
case 500 => Future.failed(ApiException.ServerErrorException(responseBody))
case 503 => Future.failed(ApiException.OverloadedException(responseBody))
case unhandledCode => Future.failed(ApiException.UnhandledException(unhandledCode, responseBody))
}
}

Expand Down Expand Up @@ -203,8 +194,7 @@ class Syncodia(openAiApiKey: String, maybeProvidedActorSystem: Option[ActorSyste
case Some(f) => Map(f.name -> f)
case None => Map.empty
case seq: Seq[ChatFunction @unchecked] => seq.map(f => f.name -> f).toMap
val responseFuture =
complete(messages, model, functions, maxTokens, temperature, maxApiRetryAttempts)
val responseFuture = complete(messages, model, functions, maxTokens, temperature, maxApiRetryAttempts)
responseFuture.flatMap { response =>
val maybeMessage = response.choices.headOption.map(_.message)
if printMessages then maybeMessage.foreach(m => println(m.pretty))
Expand All @@ -215,8 +205,7 @@ class Syncodia(openAiApiKey: String, maybeProvidedActorSystem: Option[ActorSyste
val (result, isSuccess) = functionsByName.get(functionName) match
case None => s"No function with name $functionName found" -> false
case Some(chatFunction) =>
val (resultValue, resultString) = chatFunction
.invokeWithParams(functionCall.arguments)
val (resultValue, resultString) = chatFunction.invokeWithParams(functionCall.arguments)
resultString -> !resultValue.isInstanceOf[Throwable]
val resultMessage = Message(Role.Function, result, Some(functionName))
if printMessages then println(resultMessage.pretty)
Expand Down Expand Up @@ -267,20 +256,18 @@ class Syncodia(openAiApiKey: String, maybeProvidedActorSystem: Option[ActorSyste

def recExecute(recMessages: Seq[Message]): Future[ChatCompletionResponse] =
val updatedMaxTokens =
if maxTokens == -1 then -1
else maxTokens - recMessages.drop(messages.length).map(_.tokenCount).sum
val responseFuture =
execute(
recMessages,
model,
functions,
reportFunctionResult = true,
updatedMaxTokens,
temperature,
maxApiRetryAttempts,
maxFunctionCallRetryAttempts,
printMessages
)
if maxTokens == -1 then -1 else maxTokens - recMessages.drop(messages.length).map(_.tokenCount).sum
val responseFuture = execute(
recMessages,
model,
functions,
reportFunctionResult = true,
updatedMaxTokens,
temperature,
maxApiRetryAttempts,
maxFunctionCallRetryAttempts,
printMessages
)
responseFuture.flatMap { response =>
response.choices.headOption match
case Some(choice) if choice.finishReason == "function_call" =>
Expand All @@ -304,27 +291,21 @@ class Syncodia(openAiApiKey: String, maybeProvidedActorSystem: Option[ActorSyste
val body = SerializeJson.write(ccr)
val request = chatCompletionRequestTemplate.withEntity(HttpEntity(`application/json`, body))
val sseSourceFuture: Future[Source[ServerSentEvent, NotUsed]] =
retryWithExponentialBackoff(() => runApiRequest(request), maxRetryAttempts).flatMap { response =>
Unmarshal(response.entity).to[Source[ServerSentEvent, NotUsed]]
}
retryWithExponentialBackoff(() => runApiRequest(request), maxRetryAttempts)
.flatMap(response => Unmarshal(response.entity).to[Source[ServerSentEvent, NotUsed]])
val sseSource: Source[ServerSentEvent, Future[NotUsed]] = Source.futureSource(sseSourceFuture)

sseSource
.takeWhile(sse => sse.data != "[DONE]", inclusive = false)
.map { sse =>
val tryJson = Try(read[ChatCompletionDeltaResponse](sse.data))
tryJson match
case Failure(exception) =>
throw new Exception(
s"""Error when parsing
|${sse.data}
|as a ChatCompletionDeltaResponse: '${exception.getMessage}'""".stripMargin,
exception
)
case Success(parsed) => parsed
}
.asSourceWithContext(identity)
.map(parsed => parsed.completion)
sseSource.takeWhile(sse => sse.data != "[DONE]", inclusive = false).map { sse =>
val tryJson = Try(read[ChatCompletionDeltaResponse](sse.data))
tryJson match
case Failure(exception) => throw new Exception(
s"""Error when parsing
|${sse.data}
|as a ChatCompletionDeltaResponse: '${exception.getMessage}'""".stripMargin,
exception
)
case Success(parsed) => parsed
}.asSourceWithContext(identity).map(parsed => parsed.completion)

end runStreamingChatCompletionRequest

Expand Down
50 changes: 19 additions & 31 deletions core/src/main/scala/syncodia/macros/ExtractSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ object ExtractSchema:
case Block(List(e: Term), _) => extractParamNames(e)
case Inlined(_, _, e) => extractParamNames(e)
case Apply(_, params: List[Ident @unchecked]) => params.map(_.name)
case _ =>
report
.errorAndAbort(s"No paramSchemas found: ${t.show(using Printer.TreeStructure)}", f)
case _ => report.errorAndAbort(s"No paramSchemas found: ${t.show(using Printer.TreeStructure)}", f)
end extractParamNames

val tree = f.asTerm
Expand All @@ -61,12 +59,10 @@ object ExtractSchema:
val allTypeArgs = repr.typeArgs
val paramTypes = allTypeArgs.init
val returnType = allTypeArgs.last
val paramSchemas = paramNames
.zip(paramTypes)
.foldLeft(List[(String, Schema)]()) { case (paramAcc, (name, tpe)) =>
val fieldSchema = extractSchema(tpe)
paramAcc :+ name -> fieldSchema
}
val paramSchemas = paramNames.zip(paramTypes).foldLeft(List[(String, Schema)]()) { case (paramAcc, (name, tpe)) =>
val fieldSchema = extractSchema(tpe)
paramAcc :+ name -> fieldSchema
}
val returnTypeSchema = extractSchema(returnType)

Expr(FunctionSchema(fnName, None, paramSchemas, returnTypeSchema))
Expand Down Expand Up @@ -112,27 +108,20 @@ object ExtractSchema:
case _ if isEnum && ts.flags.is(Flags.Abstract) => // Enum
val children = dealiasedTpe.typeSymbol.children
val childTrees = children.map(_.tree)
val alternatives: Map[String, Option[Schema]] = children
.zip(childTrees)
.collect {
case (c, tpd: Typed) => c.name -> Some(extractSchema(tpd.tpt.tpe))
case (c, vd: ValDef) =>
val childTpe = vd.tpt.tpe
if childTpe == tpe then c.name -> None
else
val childSchema = extractSchema(childTpe)
c.name -> Some(childSchema)
case (c, cd: ClassDef) => c.name -> Some(extractSchema(cd.constructor.returnTpt.tpe))
case (_, other) =>
report.errorAndAbort(
s"Unsupported schema extraction for enum $stringRepresentationOfType:\n$other"
)
}
.toMap
val alternatives: Map[String, Option[Schema]] = children.zip(childTrees).collect {
case (c, tpd: Typed) => c.name -> Some(extractSchema(tpd.tpt.tpe))
case (c, vd: ValDef) =>
val childTpe = vd.tpt.tpe
if childTpe == tpe then c.name -> None
else
val childSchema = extractSchema(childTpe)
c.name -> Some(childSchema)
case (c, cd: ClassDef) => c.name -> Some(extractSchema(cd.constructor.returnTpt.tpe))
case (_, other) => report
.errorAndAbort(s"Unsupported schema extraction for enum $stringRepresentationOfType:\n$other")
}.toMap
SumSchema(className, alternatives)
case _ =>
report
.errorAndAbort(s"Unsupported schema extraction for $stringRepresentationOfType.")
case _ => report.errorAndAbort(s"Unsupported schema extraction for $stringRepresentationOfType.")

end extractSchema

Expand Down Expand Up @@ -193,8 +182,7 @@ object ExtractSchema:
'{ SequenceSchema(${ Expr(s.className) }, ${ Expr(s.elementSchema) }) }

given ToExpr[OptionSchema] with
def apply(s: OptionSchema)(using Quotes): Expr[OptionSchema] =
'{ OptionSchema(${ Expr(s.element) }) }
def apply(s: OptionSchema)(using Quotes): Expr[OptionSchema] = '{ OptionSchema(${ Expr(s.element) }) }

given ToExpr[MapSchema] with
def apply(s: MapSchema)(using Quotes): Expr[MapSchema] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ case class ChatCompletionResponse(
usage: Usage
) derives SerializeJson.ReadWriter:

def maybeContent: Option[String] = choices.headOption.flatMap { choice =>
Option(choice.message.content)
}
def maybeContent: Option[String] = choices.headOption.flatMap(choice => Option(choice.message.content))

def content: String = maybeContent.getOrElse("")
Loading
Loading