From 7c135f7034077ea20e2941046666f47d880f4ec9 Mon Sep 17 00:00:00 2001 From: JOLIMAITRE Matthieu Date: Thu, 13 Jun 2024 03:35:36 +0200 Subject: [PATCH] init --- .gitignore | 3 + README.md | 5 ++ __init__.py | 14 +++++ requirements.txt | 3 + setup.sh | 10 ++++ src/online_image_node/__init__.py | 12 ++++ src/online_image_node/node.py | 97 +++++++++++++++++++++++++++++++ 7 files changed, 144 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 __init__.py create mode 100644 requirements.txt create mode 100755 setup.sh create mode 100644 src/online_image_node/__init__.py create mode 100644 src/online_image_node/node.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..717f053 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__ + +/venv diff --git a/README.md b/README.md new file mode 100644 index 0000000..fadfa12 --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +# Online Image Node + +A custom node for Comfyui that search for an image online with a seed input for reruns. + +Find in category 'Feur'. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..50b7767 --- /dev/null +++ b/__init__.py @@ -0,0 +1,14 @@ +import sys +import pathlib + + +project_dir = pathlib.Path(__file__).parent.resolve() +src_dir = project_dir / "src" +sys.path.insert(1, str(src_dir)) + + +from online_image_node import ( + NODE_CLASS_MAPPINGS, + NODE_DISPLAY_NAME_MAPPINGS, + OnlineImageNode +) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0ff5efd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +duckduckgo_search==5.2.1 +numpy==1.24.2 +Pillow==10.3.0 \ No newline at end of file diff --git a/setup.sh b/setup.sh new file mode 100755 index 0000000..d59d79f --- /dev/null +++ b/setup.sh @@ -0,0 +1,10 @@ +#!/bin/sh +set -e +cd "$(dirname "$(realpath "$0")")" + + +python -m venv venv +. venv/bin/activate + + +pip install -r requirements.txt diff --git a/src/online_image_node/__init__.py b/src/online_image_node/__init__.py new file mode 100644 index 0000000..825a2a2 --- /dev/null +++ b/src/online_image_node/__init__.py @@ -0,0 +1,12 @@ +from .node import OnlineImageNode + + +NODE_CLASS_MAPPINGS = { + "Online Image Search": OnlineImageNode +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "Online Image Search": "Online Image Search" +} + +# WEB_DIRECTORY = "./web" \ No newline at end of file diff --git a/src/online_image_node/node.py b/src/online_image_node/node.py new file mode 100644 index 0000000..f9e30c0 --- /dev/null +++ b/src/online_image_node/node.py @@ -0,0 +1,97 @@ +from dataclasses import dataclass +from pathlib import Path + +import numpy +import torch +from duckduckgo_search import DDGS +from fastdownload import download_url +from PIL import Image, ImageOps + + +@dataclass +class Query: + phrase: str + seed: int + pass + + +class OnlineImageNode: + CATEGORY = "Feur" + + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "search_phrase": ( + "STRING", + { + "default": "Feur", + "multiline": True + } + ), + "seed": ( + "NUMBER", + { + "default": 19 + } + ) + } + } + + OUTPUT_NODE = True + RETURN_TYPES = ["IMAGE"] + RETURN_NAMES = ["image"] + + + FUNCTION = "run" + def run(self, search_phrase: str, seed: int): + query = Query(search_phrase, seed) + path = self.download(query) + (image, _mask) = self.load_image(path) + return [image] + + + def download(self, query: Query): + with DDGS() as ddg: + results: list = ddg.images(query.phrase, max_results=50) + result = results[query.seed % len(results)] + filename = f"{query.phrase}_{query.seed}{Path(result['image']).suffix}" + dl_path = Path("/tmp") / filename + download_url(result["image"], dl_path) + return dl_path + + + def load_image(self, img_path: Path | str) -> tuple[torch.Tensor, torch.Tensor]: + """This code is taken from the LoadImage node code in ComfyUI. Maybe it's better to call that function directly?""" + img = Image.open(img_path) + + # If the image has exif data, rotate it to the correct orientation and remove the exif data. + img_raw = ImageOps.exif_transpose(img) + assert img_raw is not None + + # If in 32-bit mode, normalize the image appropriately. + if img_raw.mode == "I": + img_raw = img.point(lambda i: i * (1 / 255)) + + # If image is rgba, create mask. + if "A" in img_raw.getbands(): + mask = numpy.array(img_raw.getchannel("A")).astype(numpy.float32) / 255.0 + mask = 1.0 - torch.from_numpy(mask) + else: + # otherwise create a blank mask. + mask = torch.zeros( + (img_raw.height, img_raw.width), dtype=torch.float32, device="cpu" + ) + mask = mask.unsqueeze(0) # Add a batch dimension to mask + + # Convert the image to RGB. + rgb_image = img_raw.convert("RGB") + # Normalize the image's rgb values to {x | x ∈ float32, 0 <= x <= 1} + rgb_image = numpy.array(rgb_image).astype(numpy.float32) / 255.0 + # Convert the image to a tensor (torch.from_numpy gives a tensor with the format of [H, W, C]) + rgb_image = torch.from_numpy(rgb_image)[ + None, + ] # Add a batch dimension, new format is [B, H, W, C] + + return (rgb_image, mask)