-
Notifications
You must be signed in to change notification settings - Fork 213
Add Runway Gen-4 text-to-image node for ComfyUI integration #82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| import os | ||
| import requests | ||
| from io import BytesIO | ||
| from PIL import Image | ||
| import sys | ||
| import os | ||
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))) | ||
|
|
||
|
|
||
| from nodes import CLIPTextEncode | ||
|
|
||
| class RunwayText2Image(CLIPTextEncode): | ||
| """ | ||
| RunwayText2Image Node | ||
|
|
||
| Description: | ||
| This node sends a prompt to Runway’s /v1/text_to_image endpoint to generate an image. | ||
| It can be used in a ComfyUI graph after text-encoding nodes like CLIPTextEncode. | ||
|
|
||
| Parameters: | ||
| - prompt (str): Text prompt for image generation (default: ""). | ||
| - poll_timeout (int): Max time to wait for response in seconds (default: 60, min: 5, max: 300). | ||
|
|
||
| Returns: | ||
| - (IMAGE,): A tuple containing a PIL.Image in RGB format. | ||
|
|
||
| Requirements: | ||
| - Requires the RUNWAY_API_KEY environment variable to be set. | ||
| If absent, a RuntimeError is raised with a clear message. | ||
|
|
||
| Notes: | ||
| - You can adjust the poll_timeout parameter to shorten or extend how long the node waits for a response from Runway’s API. | ||
|
|
||
| API: | ||
| - POST https://api.dev.runwayml.com/v1/text_to_image | ||
| """ | ||
|
|
||
| @classmethod | ||
| def INPUT_TYPES(cls): | ||
| return { | ||
| "required": { | ||
| "prompt": ("STRING", {"default": ""}), | ||
| "poll_timeout": ("INT", {"default": 60, "min": 5, "max": 300}), | ||
| } | ||
| } | ||
|
|
||
| RETURN_TYPES = ("IMAGE",) | ||
| FUNCTION = "generate_image" | ||
|
|
||
| def generate_image(self, prompt, poll_timeout): | ||
| api_key = os.getenv("RUNWAY_API_KEY") | ||
| if not api_key: | ||
| raise RuntimeError("Missing environment variable: RUNWAY_API_KEY") | ||
|
|
||
| headers = {"Authorization": f"Bearer {api_key}"} | ||
| payload = {"prompt": prompt} | ||
|
|
||
| try: | ||
| response = requests.post( | ||
| "https://api.dev.runwayml.com/v1/text_to_image", | ||
| json=payload, | ||
| headers=headers, | ||
| timeout=poll_timeout | ||
| ) | ||
| response.raise_for_status() | ||
| image_url = response.json().get("image_url") | ||
| if not image_url: | ||
| raise ValueError("No image URL returned from Runway API") | ||
|
|
||
| image_data = requests.get(image_url).content | ||
| image = Image.open(BytesIO(image_data)).convert("RGB") | ||
|
|
||
| return (image,) | ||
|
|
||
| except Exception as e: | ||
| raise RuntimeError(f"RunwayText2Image generation failed: {e}") | ||
|
|
||
|
Comment on lines
+50
to
+77
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add timeout and error handling for image download. The image download request (line 70) lacks timeout and error handling, which could cause the node to hang indefinitely or fail silently. Apply this diff to improve error handling: image_url = response.json().get("image_url")
if not image_url:
raise ValueError("No image URL returned from Runway API")
- image_data = requests.get(image_url).content
+ image_response = requests.get(image_url, timeout=30)
+ image_response.raise_for_status()
+ image_data = image_response.content
image = Image.open(BytesIO(image_data)).convert("RGB")Additionally, consider more specific exception handling instead of catching all exceptions on line 75. 🤖 Prompt for AI Agents |
||
| NODE_CLASS_MAPPINGS = { | ||
| "RunwayText2Image": RunwayText2Image, | ||
| } | ||
|
|
||
| NODE_DISPLAY_NAME_MAPPINGS = { | ||
| "RunwayText2Image": "Runway Text-to-Image", | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| import pytest | ||
| from unittest import mock | ||
| from io import BytesIO | ||
| from PIL import Image | ||
| import sys | ||
| import os | ||
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../custom_nodes'))) | ||
| from runway_text2img import RunwayText2Image | ||
|
|
||
|
|
||
|
|
||
| @pytest.fixture | ||
| def dummy_image_bytes(): | ||
| """Returns PNG bytes of a 1x1 black image.""" | ||
| img = Image.new("RGB", (1, 1)) | ||
| buffer = BytesIO() | ||
| img.save(buffer, format="PNG") | ||
| return buffer.getvalue() | ||
|
|
||
| @mock.patch("runway_text2img.requests.get") | ||
| @mock.patch("runway_text2img.requests.post") | ||
| @mock.patch("runway_text2img.os.getenv") | ||
| def test_runway_text2img_node_success(mock_getenv, mock_post, mock_get, dummy_image_bytes): | ||
| # Mock environment variable | ||
| mock_getenv.return_value = "fake_api_key" | ||
|
|
||
| # Mock POST response | ||
| mock_post.return_value.status_code = 200 | ||
| mock_post.return_value.json.return_value = {"image_url": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTwBchNsxAEthMtT_uv1MInGKEi4A0W2b1mx1flcpNOoUMkiy0CCnLfKF55jqIiRB9Mx-Y&usqp=CAU"} | ||
|
|
||
| # Mock GET image download | ||
| mock_get.return_value.content = dummy_image_bytes | ||
|
|
||
| # Instantiate node | ||
| node = RunwayText2Image() | ||
|
|
||
| # Run node | ||
| outputs = node.generate_image(prompt="test prompt", poll_timeout=10) | ||
|
|
||
| # Validate output | ||
| assert isinstance(outputs, tuple) | ||
| assert isinstance(outputs[0], Image.Image) | ||
| assert outputs[0].size == (1, 1) | ||
|
|
||
| @mock.patch("runway_text2img.os.getenv") | ||
| def test_runway_text2img_missing_api_key(mock_getenv): | ||
| mock_getenv.return_value = None | ||
| node = RunwayText2Image() | ||
|
|
||
| with pytest.raises(RuntimeError, match="RUNWAY_API_KEY"): | ||
| node.generate_image(prompt="test", poll_timeout=10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix duplicate import and simplify path handling.
The
osmodule is imported twice (lines 1 and 6), and the sys.path manipulation is unnecessarily complex.Apply this diff to fix the duplicate import:
import os import requests from io import BytesIO from PIL import Image import sys -import os sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..')))Consider whether the sys.path manipulation is necessary - using relative imports or proper package structure might be cleaner.
📝 Committable suggestion
🧰 Tools
🪛 Ruff (0.12.2)
6-6: Redefinition of unused
osfrom line 1Remove definition:
os(F811)
🤖 Prompt for AI Agents