mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Make LLaVA Onevision node work with 0 images, and other minor improvements.
This commit is contained in:
committed by
psychedelicious
parent
e9714fe476
commit
41de112932
@@ -14,7 +14,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
class LlavaOnevisionVllmInvocation(BaseInvocation):
|
||||
"""Run a LLaVA OneVision VLLM model."""
|
||||
|
||||
image: list[ImageField] | ImageField | None = InputField(description="Input image.")
|
||||
images: list[ImageField] | ImageField | None = InputField(default=None, description="Input image.")
|
||||
prompt: str = InputField(
|
||||
default="",
|
||||
description="Input text prompt.",
|
||||
@@ -27,10 +27,10 @@ class LlavaOnevisionVllmInvocation(BaseInvocation):
|
||||
# )
|
||||
|
||||
def _get_images(self, context: InvocationContext) -> list[Image]:
|
||||
if self.image is None:
|
||||
if self.images is None:
|
||||
return []
|
||||
|
||||
image_fields = self.image if isinstance(self.image, list) else [self.image]
|
||||
image_fields = self.images if isinstance(self.images, list) else [self.images]
|
||||
return [context.images.get_pil(image_field.image_name, "RGB") for image_field in image_fields]
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -38,8 +38,8 @@ class LlavaOnevisionModel(RawModel):
|
||||
|
||||
conversation = [{"role": "user", "content": content}]
|
||||
prompt = self._processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
inputs = self._processor(images=images, text=prompt, return_tensors="pt").to(device=device, dtype=dtype)
|
||||
output = self._vllm_model.generate(**inputs, max_new_tokens=200, do_sample=False)
|
||||
inputs = self._processor(images=images or None, text=prompt, return_tensors="pt").to(device=device, dtype=dtype)
|
||||
output = self._vllm_model.generate(**inputs, max_new_tokens=400, do_sample=False)
|
||||
output_str: str = self._processor.decode(output[0][2:], skip_special_tokens=True)
|
||||
# The output_str will include the prompt, so we extract the response.
|
||||
response = output_str.split("assistant\n", 1)[1].strip()
|
||||
|
||||
Reference in New Issue
Block a user