Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions astrbot/core/provider/sources/anthropic_source.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import base64
import json
from collections.abc import AsyncGenerator
from mimetypes import guess_type

import anthropic
from anthropic import AsyncAnthropic
Expand Down Expand Up @@ -458,6 +457,18 @@ async def text_chat_stream(
async for llm_response in self._query_stream(payloads, func_tool):
yield llm_response

def _detect_image_mime_type(self, data: bytes) -> str:
"""根据图片二进制数据的 magic bytes 检测 MIME 类型"""
if data[:8] == b"\x89PNG\r\n\x1a\n":
return "image/png"
if data[:2] == b"\xff\xd8":
return "image/jpeg"
if data[:6] in (b"GIF87a", b"GIF89a"):
return "image/gif"
if data[:4] == b"RIFF" and data[8:12] == b"WEBP":
return "image/webp"
return "image/jpeg"

async def assemble_context(
self,
text: str,
Expand All @@ -469,22 +480,17 @@ async def assemble_context(
async def resolve_image_url(image_url: str) -> dict | None:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
image_data, mime_type = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
image_data, mime_type = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
image_data, mime_type = await self.encode_image_bs64(image_url)

if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
return None

# Get mime type for the image
mime_type, _ = guess_type(image_url)
if not mime_type:
mime_type = "image/jpeg" # Default to JPEG if can't determine

return {
"type": "image",
"source": {
Expand Down Expand Up @@ -542,14 +548,22 @@ async def resolve_image_url(image_url: str) -> dict | None:
# 否则返回多模态格式
return {"role": "user", "content": content}

async def encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
async def encode_image_bs64(self, image_url: str) -> tuple[str, str]:
"""将图片转换为 base64,同时检测实际 MIME 类型"""
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
raw_base64 = image_url.replace("base64://", "")
try:
image_bytes = base64.b64decode(raw_base64)
mime_type = self._detect_image_mime_type(image_bytes)
except Exception:
mime_type = "image/jpeg"
Comment on lines +554 to +559
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): 用于 MIME 检测的 Base64 解码可以使用 validate=True,以更可靠地捕获非法输入。

在没有使用 validate=True 时,base64.b64decode(raw_base64) 可能会悄悄接受不合法的输入并生成无意义的字节数据,随后 _detect_image_mime_type 可能会对这些数据做出错误判断。使用 validate=True 会在数据非法时抛出异常,从而走到你当前的 except 分支,实现更可预期的行为。

Suggested change
raw_base64 = image_url.replace("base64://", "")
try:
image_bytes = base64.b64decode(raw_base64)
mime_type = self._detect_image_mime_type(image_bytes)
except Exception:
mime_type = "image/jpeg"
raw_base64 = image_url.replace("base64://", "")
try:
# 使用 validate=True 以便在遇到非法 base64 数据时抛出异常
image_bytes = base64.b64decode(raw_base64, validate=True)
mime_type = self._detect_image_mime_type(image_bytes)
except Exception:
mime_type = "image/jpeg"
Original comment in English

suggestion (bug_risk): Base64 decode for MIME detection could use validate=True to catch malformed inputs more reliably.

Without validate=True, base64.b64decode(raw_base64) may silently accept malformed input and produce garbage bytes, which _detect_image_mime_type might then misclassify. Using validate=True would instead raise on invalid data and cleanly route those cases through your existing except branch for more predictable behavior.

Suggested change
raw_base64 = image_url.replace("base64://", "")
try:
image_bytes = base64.b64decode(raw_base64)
mime_type = self._detect_image_mime_type(image_bytes)
except Exception:
mime_type = "image/jpeg"
raw_base64 = image_url.replace("base64://", "")
try:
# 使用 validate=True 以便在遇到非法 base64 数据时抛出异常
image_bytes = base64.b64decode(raw_base64, validate=True)
mime_type = self._detect_image_mime_type(image_bytes)
except Exception:
mime_type = "image/jpeg"

return f"data:{mime_type};base64,{raw_base64}", mime_type
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
return ""
image_bytes = f.read()
mime_type = self._detect_image_mime_type(image_bytes)
image_bs64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:{mime_type};base64,{image_bs64}", mime_type
return "", "image/jpeg"

def get_current_key(self) -> str:
return self.chosen_api_key
Expand Down
Loading