diff --git a/docs/format.md b/docs/format.md index 2402406..b405bd3 100644 --- a/docs/format.md +++ b/docs/format.md @@ -107,8 +107,12 @@ convo = Conversation.from_messages( tokens = encoding.render_conversation_for_completion(convo, Role.ASSISTANT) # After receiving a token response -# Do not pass in the stop token -parsed_response = encoding.parse_messages_from_completion_tokens(new_tokens, Role.ASSISTANT) +# Do not pass in the stop token. Set strict=False to tolerate malformed headers. +parsed_response = encoding.parse_messages_from_completion_tokens( + new_tokens, + Role.ASSISTANT, + strict=True, +) ``` Additionally the openai_harmony library also includes a StreamableParser for parsing and decoding as the model is generating new tokens. This can be helpful for example to stream output and handle unicode characters during decoding. @@ -269,7 +273,7 @@ If you are not using function tool calling your developer message would just loo Where `{instructions}` is replaced with your “system prompt”. -For defining function calling tools, [check out the dedicated section](#function-calling). +For defining function calling tools, [check out the dedicated section](#function-calling). For defining an output format to be used in structured outputs, [check out this section of the guide](#structured-output). ### Reasoning @@ -301,7 +305,7 @@ And the actual answer is: 2 + 2 = 4 ``` -**Important:** +**Important:** The model has not been trained to the same safety standards in the chain-of-thought as it has for final output. We recommend not to show the chain-of-thought to your users as they might contain harmful content. [Learn more in the model card](https://openai.com/index/gpt-oss-model-card/). #### Handling reasoning output in subsequent sampling diff --git a/docs/python.md b/docs/python.md index 22225d9..9b7917d 100644 --- a/docs/python.md +++ b/docs/python.md @@ -107,12 +107,14 @@ Methods: - `render_conversation_for_training(conversation, config=None)` – render a conversation for training. - `render_conversation(conversation, config=None)` – render a conversation without appending a new role. - `render(message)` – render a single message into tokens. -- `parse_messages_from_completion_tokens(tokens, role=None)` – parse tokens back into `Message` objects. +- `parse_messages_from_completion_tokens(tokens, role=None, strict=True)` – parse tokens back into `Message` objects (set `strict=False` to enable permissive parsing). - `decode_utf8(tokens)` – decode tokens with the underlying tokenizer. - `stop_tokens()` / `stop_tokens_for_assistant_actions()` – lists of stop tokens. +Use `strict=False` when you need the parser to recover from malformed model output that omits markers such as `<|message|>`. + ### `StreamableParser` -Incremental parser built on top of an encoding. Construct with `StreamableParser(encoding, role)` and feed tokens via `process(token)`. Inspect state via properties like `current_content`, `current_role`, `tokens` and `state`. +Incremental parser built on top of an encoding. Construct with `StreamableParser(encoding, role)` and feed tokens via `process(token)`. Inspect state via properties like `current_content`, `current_role`, `tokens` and `state`. Pass `strict=False` to enable permissive parsing (mirrors `ParseOptions { strict: false }` on the Rust side). ### `load_harmony_encoding(name)` Return a `HarmonyEncoding` by name. Accepts either the string name or a value from the `HarmonyEncodingName` enum (`HARMONY_GPT_OSS`). diff --git a/docs/rust.md b/docs/rust.md index 9a6f2ee..d4084f2 100644 --- a/docs/rust.md +++ b/docs/rust.md @@ -88,12 +88,15 @@ Important methods: - `render_conversation_for_training(conversation, config)` – render a conversation for training data. - `render_conversation(conversation, config)` – render a conversation without appending a new role. - `render(message)` – render a single message into tokens. -- `parse_messages_from_completion_tokens(tokens, role)` – parse a list of tokens back into messages. +- `parse_messages_from_completion_tokens(tokens, role)` – parse a list of tokens back into messages using strict validation. +- `parse_messages_from_completion_tokens_with_options(tokens, role, options)` – parse tokens with custom `ParseOptions` (e.g. to disable strict validation). - `stop_tokens()` and `stop_tokens_for_assistant_actions()` – sets of stop tokens for sampling. +`ParseOptions` currently exposes a single field, `strict`, which defaults to `true`. Set it to `false` when you need to recover from malformed model output in downstream systems. + ### `StreamableParser` -Incremental parser that consumes tokens one by one. Create with `StreamableParser::new(encoding, role)` and feed tokens via `process`. Access information via getters like `current_content`, `current_role`, `messages`, `tokens` and `state_json`. +Incremental parser that consumes tokens one by one. Create with `StreamableParser::new(encoding, role)` and feed tokens via `process`. Access information via getters like `current_content`, `current_role`, `messages`, `tokens` and `state_json`. Use `StreamableParser::new_with_options(encoding, role, options)` when you need to override defaults such as `ParseOptions { strict: false }`. ## registry module diff --git a/python/openai_harmony/__init__.py b/python/openai_harmony/__init__.py index 33afbd7..c8bce8f 100644 --- a/python/openai_harmony/__init__.py +++ b/python/openai_harmony/__init__.py @@ -520,10 +520,14 @@ def render( # -- Parsing ------------------------------------------------------- def parse_messages_from_completion_tokens( - self, tokens: Sequence[int], role: Optional[Role] | None = None + self, + tokens: Sequence[int], + role: Optional[Role] | None = None, + *, + strict: bool = True, ) -> List[Message]: raw_json: str = self._inner.parse_messages_from_completion_tokens( - list(tokens), None if role is None else str(role.value) + list(tokens), None if role is None else str(role.value), strict ) return [Message.from_dict(m) for m in json.loads(raw_json)] @@ -619,9 +623,15 @@ class StreamState(Enum): class StreamableParser: """Incremental parser over completion tokens.""" - def __init__(self, encoding: HarmonyEncoding, role: Role | None): + def __init__( + self, + encoding: HarmonyEncoding, + role: Role | None, + *, + strict: bool = True, + ) -> None: role_str = str(role.value) if role is not None else None - self._inner = _PyStreamableParser(encoding._inner, role_str) + self._inner = _PyStreamableParser(encoding._inner, role_str, strict) def process(self, token: int) -> "StreamableParser": self._inner.process(token) diff --git a/src/encoding.rs b/src/encoding.rs index 6a9305b..1999372 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -369,15 +369,16 @@ impl HarmonyEncoding { Ok(()) } - pub fn parse_messages_from_completion_tokens( + pub fn parse_messages_from_completion_tokens_with_options( &self, tokens: I, role: Option, + options: ParseOptions, ) -> anyhow::Result> where I: IntoIterator, { - let mut parser = StreamableParser::new(self.clone(), role)?; + let mut parser = StreamableParser::new_with_options(self.clone(), role, options)?; for token in tokens { parser.process(token)?; } @@ -385,6 +386,21 @@ impl HarmonyEncoding { Ok(parser.into_messages()) } + pub fn parse_messages_from_completion_tokens( + &self, + tokens: I, + role: Option, + ) -> anyhow::Result> + where + I: IntoIterator, + { + self.parse_messages_from_completion_tokens_with_options( + tokens, + role, + ParseOptions::default(), + ) + } + /// Helper to convert a JSON schema (OpenAPI style) to a TypeScript type definition. fn json_schema_to_typescript(schema: &serde_json::Value, indent: &str) -> String { // Helper to check if this schema is an enum @@ -1019,6 +1035,17 @@ impl Render for HarmonyEncoding { } } +#[derive(Clone, Copy, Debug)] +pub struct ParseOptions { + pub strict: bool, +} + +impl Default for ParseOptions { + fn default() -> Self { + Self { strict: true } + } +} + /// Incremental parser that can consume tokens one by one. /// /// It keeps track of all tokens seen so far, exposes all fully parsed messages @@ -1032,6 +1059,7 @@ pub struct StreamableParser { stop_tokens: HashSet, last_content_delta: Option, undecoded_tokens: Vec, + options: ParseOptions, } #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] @@ -1049,6 +1077,15 @@ pub enum StreamState { impl StreamableParser { /// Create a new streaming parser starting with the given role. pub fn new(encoding: HarmonyEncoding, role: Option) -> anyhow::Result { + Self::new_with_options(encoding, role, ParseOptions::default()) + } + + /// Create a new streaming parser with explicit options. + pub fn new_with_options( + encoding: HarmonyEncoding, + role: Option, + options: ParseOptions, + ) -> anyhow::Result { let stop_tokens = encoding.stop_tokens()?; let (state, next_role) = match role { Some(role) => ( @@ -1068,6 +1105,7 @@ impl StreamableParser { stop_tokens, last_content_delta: None, undecoded_tokens: Vec::new(), + options, }) } @@ -1123,6 +1161,34 @@ impl StreamableParser { content_tokens: Vec::new(), }; } + Some(token) if !self.options.strict && self.stop_tokens.contains(&token) => { + // Encountered a stop token while in Header state. This means we have + // accumulated header tokens but never saw a <|message|> token, so the + // message is malformed. If we have a role, parse header metadata and + // treat remaining tokens as content. + if let Some(role) = next_role_clone { + if !header_tokens.is_empty() { + let decoded = + self.encoding.tokenizer().decode_utf8(header_tokens)?; + let (header, remaining_content) = + self.parse_header_from_string(decoded, Some(role), false)?; + + // Use remaining content if present, otherwise empty string + let text = remaining_content.unwrap_or_default(); + let message = Message { + author: header.author.clone(), + recipient: header.recipient.clone(), + channel: header.channel.clone(), + content_type: header.content_type.clone(), + content: vec![Content::Text(TextContent { text })], + }; + self.messages.push(message); + } + } + // Transition to ExpectStart to wait for the next message + self.state = StreamState::ExpectStart; + self.next_role = None; + } Some(token) => { header_tokens.push(token); } @@ -1194,17 +1260,18 @@ impl StreamableParser { Ok(self) } - fn parse_header_from_tokens( + /// Helper to parse header metadata from a decoded string. + /// Returns the parsed header and any remaining content after extracting header parts. + /// + /// If `parse_recipient_and_type` is true, tries to parse recipient and content_type from + /// whitespace-separated tokens (normal header parsing). If false, treats all remaining + /// text after extracting channel as content (for malformed messages). + fn parse_header_from_string( &self, - header_tokens: &[Rank], + mut header_string: String, role: Option, - ) -> anyhow::Result { - let mut header_string = self - .encoding - .tokenizer() - .decode_utf8(header_tokens) - .context("could not decode header")?; - + parse_recipient_and_type: bool, + ) -> anyhow::Result<(ParsedHeader, Option)> { let mut channel: Option = None; if let Some(channel_marker) = self.encoding.mapped_format_token(FormattingToken::Channel) { if let Some(idx) = header_string.find(channel_marker) { @@ -1280,10 +1347,9 @@ impl StreamableParser { let mut recipient: Option = None; let mut content_type: Option = None; + let remaining_content: Option; - if !parts.is_empty() { - // Determine whether the last token is a content-type or part of the - // recipient specification. + if parse_recipient_and_type && !parts.is_empty() { let num_parts = parts.len(); // SAFETY: we know that there is at least one part remaining, because of is_empty check above let last_part = parts.pop().unwrap(); @@ -1308,12 +1374,21 @@ impl StreamableParser { }; } } + + // Any remaining parts are content (not header metadata) + remaining_content = if !parts.is_empty() { + Some(parts.join(" ")) + } else { + None + }; + } else { + // Treat all remaining parts as content when not parsing recipient and content type + remaining_content = if !parts.is_empty() { + Some(parts.join(" ")) + } else { + None + }; } - anyhow::ensure!( - parts.is_empty(), - "unexpected tokens remaining in message header: {:?}", - parts - ); let author = if role == Role::Tool { let name = role_str_opt; @@ -1321,12 +1396,39 @@ impl StreamableParser { } else { Author { role, name: None } }; - Ok(ParsedHeader { - author, - recipient, - channel, - content_type, - }) + Ok(( + ParsedHeader { + author, + recipient, + channel, + content_type, + }, + remaining_content, + )) + } + + fn parse_header_from_tokens( + &self, + header_tokens: &[Rank], + role: Option, + ) -> anyhow::Result { + let header_string = self + .encoding + .tokenizer() + .decode_utf8(header_tokens) + .context("could not decode header")?; + + let (header, remaining_content) = + self.parse_header_from_string(header_string, role, true)?; + + if remaining_content.is_some() { + anyhow::bail!( + "unexpected tokens remaining in message header: {:?}", + remaining_content + ); + } + + Ok(header) } /// Return the textual content of the current message so far. diff --git a/src/lib.rs b/src/lib.rs index acd572a..6c2fcfb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,7 @@ mod registry; mod tiktoken; pub mod tiktoken_ext; -pub use encoding::{HarmonyEncoding, StreamableParser}; +pub use encoding::{HarmonyEncoding, ParseOptions, StreamableParser}; pub use registry::load_harmony_encoding; pub use registry::HarmonyEncodingName; diff --git a/src/py_module.rs b/src/py_module.rs index c5c7b0a..345a887 100644 --- a/src/py_module.rs +++ b/src/py_module.rs @@ -27,7 +27,7 @@ create_exception!(openai_harmony, HarmonyError, PyRuntimeError); use crate::{ chat::{Message, Role, ToolNamespaceConfig}, - encoding::{HarmonyEncoding, StreamableParser}, + encoding::{HarmonyEncoding, ParseOptions, StreamableParser}, load_harmony_encoding, HarmonyEncodingName, }; @@ -212,6 +212,7 @@ impl PyHarmonyEncoding { &self, tokens: Vec, role: Option<&str>, + strict: Option, ) -> PyResult { let role_parsed = if let Some(r) = role { Some(Role::try_from(r).map_err(|_| { @@ -221,9 +222,13 @@ impl PyHarmonyEncoding { None }; + let options = ParseOptions { + strict: strict.unwrap_or(true), + }; + let messages: Vec = self .inner - .parse_messages_from_completion_tokens(tokens, role_parsed) + .parse_messages_from_completion_tokens_with_options(tokens, role_parsed, options) .map_err(|e| PyErr::new::(e.to_string()))?; serde_json::to_string(&messages).map_err(|e| { @@ -297,7 +302,11 @@ impl PyHarmonyEncoding { #[pymethods] impl PyStreamableParser { #[new] - fn new(encoding: &PyHarmonyEncoding, role: Option<&str>) -> PyResult { + fn new( + encoding: &PyHarmonyEncoding, + role: Option<&str>, + strict: Option, + ) -> PyResult { let parsed_role = role .map(|r| { Role::try_from(r).map_err(|_| { @@ -305,8 +314,12 @@ impl PyStreamableParser { }) }) .transpose()?; - let inner = StreamableParser::new(encoding.inner.clone(), parsed_role) - .map_err(|e| PyErr::new::(e.to_string()))?; + let options = ParseOptions { + strict: strict.unwrap_or(true), + }; + let inner = + StreamableParser::new_with_options(encoding.inner.clone(), parsed_role, options) + .map_err(|e| PyErr::new::(e.to_string()))?; Ok(Self { inner }) } diff --git a/src/tests.rs b/src/tests.rs index d072d73..922be79 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -7,7 +7,7 @@ use crate::{ }, load_harmony_encoding, tiktoken::{CoreBPE, Rank}, - HarmonyEncodingName, StreamableParser, + HarmonyEncodingName, ParseOptions, StreamableParser, }; use pretty_assertions::{assert_eq, Comparison}; use serde_json::json; @@ -674,6 +674,36 @@ fn test_streamable_parser_tool_call_with_constrain_adjacent() { ); } +#[test] +fn test_missing_message_token_requires_non_strict_mode() { + let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); + let malformed = "<|channel|>commentary Hello<|end|>"; + let tokens = encoding.tokenizer().encode_with_special_tokens(malformed); + + // Strict mode should continue to error on malformed headers. + let strict_result = encoding + .parse_messages_from_completion_tokens(tokens.iter().copied(), Some(Role::Assistant)); + assert!( + strict_result.is_err(), + "expected strict parser to reject malformed header" + ); + + // Non-strict mode should recover and return the accumulated message content. + let parsed = encoding + .parse_messages_from_completion_tokens_with_options( + tokens.iter().copied(), + Some(Role::Assistant), + ParseOptions { strict: false }, + ) + .expect("non-strict parser should recover from malformed header"); + + assert_eq!(parsed.len(), 1); + assert_eq!( + parsed[0], + Message::from_role_and_content(Role::Assistant, "Hello").with_channel("commentary") + ); +} + #[test] fn test_tool_call_with_constrain_marker_adjacent() { let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); diff --git a/src/wasm_module.rs b/src/wasm_module.rs index 1b96a4f..0cbe281 100644 --- a/src/wasm_module.rs +++ b/src/wasm_module.rs @@ -2,7 +2,7 @@ use wasm_bindgen::prelude::*; use crate::{ chat::{Message, Role, ToolNamespaceConfig}, - encoding::{HarmonyEncoding, StreamableParser}, + encoding::{HarmonyEncoding, ParseOptions, StreamableParser}, load_harmony_encoding as inner_load_harmony_encoding, HarmonyEncodingName, }; @@ -166,6 +166,7 @@ impl JsHarmonyEncoding { &self, tokens: Vec, role: Option, + strict: Option, ) -> Result { let role_parsed = if let Some(r) = role { Some( @@ -175,9 +176,12 @@ impl JsHarmonyEncoding { } else { None }; + let options = ParseOptions { + strict: strict.unwrap_or(true), + }; let messages: Vec = self .inner - .parse_messages_from_completion_tokens(tokens, role_parsed) + .parse_messages_from_completion_tokens_with_options(tokens, role_parsed, options) .map_err(|e| JsValue::from_str(&e.to_string()))?; serde_json::to_string(&messages) .map_err(|e| JsValue::from_str(&format!("failed to serialise messages to JSON: {e}"))) @@ -253,11 +257,19 @@ pub struct JsStreamableParser { #[wasm_bindgen] impl JsStreamableParser { #[wasm_bindgen(constructor)] - pub fn new(encoding: &JsHarmonyEncoding, role: &str) -> Result { + pub fn new( + encoding: &JsHarmonyEncoding, + role: &str, + strict: Option, + ) -> Result { let parsed_role = Role::try_from(role) .map_err(|_| JsValue::from_str(&format!("unknown role: {role}")))?; - let inner = StreamableParser::new(encoding.inner.clone(), Some(parsed_role)) - .map_err(|e| JsValue::from_str(&e.to_string()))?; + let options = ParseOptions { + strict: strict.unwrap_or(true), + }; + let inner = + StreamableParser::new_with_options(encoding.inner.clone(), Some(parsed_role), options) + .map_err(|e| JsValue::from_str(&e.to_string()))?; Ok(Self { inner }) } diff --git a/tests/test_harmony.py b/tests/test_harmony.py index dd34e81..761bcef 100644 --- a/tests/test_harmony.py +++ b/tests/test_harmony.py @@ -981,3 +981,110 @@ def test_streamable_parser_tool_call_with_constrain_adjacent(): ] assert parser.messages == expected + + +@pytest.mark.parametrize("strict, expect_error", [(False, False), (True, True)]) +def test_streamable_parser_missing_message_token(strict: bool, expect_error: bool): + encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + text = ( + "I must refuse<|end|>" + "<|start|>assistant<|channel|>analysis<|message|>We must refuse<|end|>" + "<|start|>assistant<|channel|>final<|message|>I'm sorry, but I can't help with that.<|return|>" + ) + tokens = encoding.encode(text, allowed_special="all") + parser = StreamableParser(encoding, Role.ASSISTANT, strict=strict) + + if expect_error: + with pytest.raises(HarmonyError, match="unexpected tokens remaining in message header"): + for token in tokens: + parser.process(token) + return + + for token in tokens: + parser.process(token) + + expected = [ + Message.from_role_and_content(Role.ASSISTANT, "I must refuse"), + Message.from_role_and_content(Role.ASSISTANT, "We must refuse").with_channel( + "analysis" + ), + Message.from_role_and_content( + Role.ASSISTANT, "I'm sorry, but I can't help with that." + ).with_channel("final"), + ] + assert parser.messages == expected + + +@pytest.mark.parametrize("strict, expect_error", [(False, False), (True, True)]) +def test_streamable_parser_missing_message_token_other_initial_headers( + strict: bool, expect_error: bool +): + encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + text = ( + "<|channel|>analysis I must refuse<|end|>" + "<|start|>assistant<|channel|>analysis<|message|>We must refuse<|end|>" + "<|start|>assistant<|channel|>final<|message|>I'm sorry, but I can't help with that.<|return|>" + ) + tokens = encoding.encode(text, allowed_special="all") + parser = StreamableParser(encoding, Role.ASSISTANT, strict=strict) + + if expect_error: + with pytest.raises(HarmonyError, match="unexpected tokens remaining in message header"): + for token in tokens: + parser.process(token) + return + + for token in tokens: + parser.process(token) + + expected = [ + Message.from_role_and_content(Role.ASSISTANT, "I must refuse").with_channel( + "analysis" + ), + Message.from_role_and_content(Role.ASSISTANT, "We must refuse").with_channel( + "analysis" + ), + Message.from_role_and_content( + Role.ASSISTANT, "I'm sorry, but I can't help with that." + ).with_channel("final"), + ] + assert parser.messages == expected + + +@pytest.mark.parametrize("strict, expect_error", [(False, False), (True, True)]) +def test_streamable_parser_missing_message_token_tool_call( + strict: bool, expect_error: bool +): + encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + text = ( + "... Let's use the tool.<|end|>" + "<|start|>assistant to=functions.get_weather<|channel|>commentary json" + '<|message|>{"location": "Tokyo"}<|call|>' + ) + tokens = encoding.encode(text, allowed_special="all") + parser = StreamableParser(encoding, Role.ASSISTANT, strict=strict) + + if expect_error: + with pytest.raises(HarmonyError, match="unexpected tokens remaining in message header"): + for token in tokens: + parser.process(token) + return + + for token in tokens: + parser.process(token) + + expected = [ + Message.from_role_and_content( + Role.ASSISTANT, "... Let's use the tool." + ), + Message.from_role_and_content( + Role.ASSISTANT, '{"location": "Tokyo"}' + ) + .with_channel("commentary") + .with_recipient("functions.get_weather") + .with_content_type("json"), + ] + assert parser.messages == expected