This commit is contained in:
JOLIMAITRE Matthieu 2024-06-13 03:35:36 +02:00
commit 7c135f7034
7 changed files with 144 additions and 0 deletions

3
.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
__pycache__
/venv

5
README.md Normal file
View file

@ -0,0 +1,5 @@
# Online Image Node
A custom node for Comfyui that search for an image online with a seed input for reruns.
Find in category 'Feur'.

14
__init__.py Normal file
View file

@ -0,0 +1,14 @@
import sys
import pathlib
project_dir = pathlib.Path(__file__).parent.resolve()
src_dir = project_dir / "src"
sys.path.insert(1, str(src_dir))
from online_image_node import (
NODE_CLASS_MAPPINGS,
NODE_DISPLAY_NAME_MAPPINGS,
OnlineImageNode
)

3
requirements.txt Normal file
View file

@ -0,0 +1,3 @@
duckduckgo_search==5.2.1
numpy==1.24.2
Pillow==10.3.0

10
setup.sh Executable file
View file

@ -0,0 +1,10 @@
#!/bin/sh
set -e
cd "$(dirname "$(realpath "$0")")"
python -m venv venv
. venv/bin/activate
pip install -r requirements.txt

View file

@ -0,0 +1,12 @@
from .node import OnlineImageNode
NODE_CLASS_MAPPINGS = {
"Online Image Search": OnlineImageNode
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Online Image Search": "Online Image Search"
}
# WEB_DIRECTORY = "./web"

View file

@ -0,0 +1,97 @@
from dataclasses import dataclass
from pathlib import Path
import numpy
import torch
from duckduckgo_search import DDGS
from fastdownload import download_url
from PIL import Image, ImageOps
@dataclass
class Query:
phrase: str
seed: int
pass
class OnlineImageNode:
CATEGORY = "Feur"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"search_phrase": (
"STRING",
{
"default": "Feur",
"multiline": True
}
),
"seed": (
"NUMBER",
{
"default": 19
}
)
}
}
OUTPUT_NODE = True
RETURN_TYPES = ["IMAGE"]
RETURN_NAMES = ["image"]
FUNCTION = "run"
def run(self, search_phrase: str, seed: int):
query = Query(search_phrase, seed)
path = self.download(query)
(image, _mask) = self.load_image(path)
return [image]
def download(self, query: Query):
with DDGS() as ddg:
results: list = ddg.images(query.phrase, max_results=50)
result = results[query.seed % len(results)]
filename = f"{query.phrase}_{query.seed}{Path(result['image']).suffix}"
dl_path = Path("/tmp") / filename
download_url(result["image"], dl_path)
return dl_path
def load_image(self, img_path: Path | str) -> tuple[torch.Tensor, torch.Tensor]:
"""This code is taken from the LoadImage node code in ComfyUI. Maybe it's better to call that function directly?"""
img = Image.open(img_path)
# If the image has exif data, rotate it to the correct orientation and remove the exif data.
img_raw = ImageOps.exif_transpose(img)
assert img_raw is not None
# If in 32-bit mode, normalize the image appropriately.
if img_raw.mode == "I":
img_raw = img.point(lambda i: i * (1 / 255))
# If image is rgba, create mask.
if "A" in img_raw.getbands():
mask = numpy.array(img_raw.getchannel("A")).astype(numpy.float32) / 255.0
mask = 1.0 - torch.from_numpy(mask)
else:
# otherwise create a blank mask.
mask = torch.zeros(
(img_raw.height, img_raw.width), dtype=torch.float32, device="cpu"
)
mask = mask.unsqueeze(0) # Add a batch dimension to mask
# Convert the image to RGB.
rgb_image = img_raw.convert("RGB")
# Normalize the image's rgb values to {x | x ∈ float32, 0 <= x <= 1}
rgb_image = numpy.array(rgb_image).astype(numpy.float32) / 255.0
# Convert the image to a tensor (torch.from_numpy gives a tensor with the format of [H, W, C])
rgb_image = torch.from_numpy(rgb_image)[
None,
] # Add a batch dimension, new format is [B, H, W, C]
return (rgb_image, mask)