From 6d94a36414c3327c40a2aa5c2ee2a80d0b6494f5 Mon Sep 17 00:00:00 2001 From: Andrew Louis Date: Thu, 29 Jan 2026 13:20:15 -0500 Subject: [PATCH] fix(oauth): sanitize unicode header values --- src/kimi_cli/auth/oauth.py | 14 +++-- src/kimi_cli/utils/string.py | 23 ++++++++ tests/core/test_oauth_common_headers.py | 14 +++++ tests/utils/test_string_utils.py | 76 +++++++++++++++++++++++++ 4 files changed, 121 insertions(+), 6 deletions(-) create mode 100644 tests/core/test_oauth_common_headers.py create mode 100644 tests/utils/test_string_utils.py diff --git a/src/kimi_cli/auth/oauth.py b/src/kimi_cli/auth/oauth.py index e7801426..adcd2a1d 100644 --- a/src/kimi_cli/auth/oauth.py +++ b/src/kimi_cli/auth/oauth.py @@ -40,6 +40,7 @@ from kimi_cli.share import get_share_dir from kimi_cli.utils.aiohttp import new_client_session from kimi_cli.utils.logging import logger +from kimi_cli.utils.string import sanitize_http_header_value if TYPE_CHECKING: from kimi_cli.soul.agent import Runtime @@ -193,16 +194,17 @@ def get_device_id() -> str: def _common_headers() -> dict[str, str]: - device_name = platform.node() or socket.gethostname() - device_model = _device_model() - return { + device_name_raw = platform.node() or socket.gethostname() + device_model_raw = _device_model() + headers: dict[str, str] = { "X-Msh-Platform": "kimi_cli", "X-Msh-Version": VERSION, - "X-Msh-Device-Name": device_name, - "X-Msh-Device-Model": device_model, - "X-Msh-Os-Version": platform.version(), + "X-Msh-Device-Name": sanitize_http_header_value(device_name_raw, default="device"), + "X-Msh-Device-Model": sanitize_http_header_value(device_model_raw, default="unknown"), + "X-Msh-Os-Version": sanitize_http_header_value(platform.version(), default="unknown"), "X-Msh-Device-Id": get_device_id(), } + return headers def _credentials_dir() -> Path: diff --git a/src/kimi_cli/utils/string.py b/src/kimi_cli/utils/string.py index bd4379bb..23ad110c 100644 --- a/src/kimi_cli/utils/string.py +++ b/src/kimi_cli/utils/string.py @@ -3,8 +3,11 @@ import random import re import string +import unicodedata _NEWLINE_RE = re.compile(r"[\r\n]+") +_CONTROL_CHARS_RE = re.compile(r"[\x00-\x1f\x7f]+") +_WHITESPACE_RE = re.compile(r"\s+") def shorten_middle(text: str, width: int, remove_newline: bool = True) -> str: @@ -20,3 +23,23 @@ def random_string(length: int = 8) -> str: """Generate a random string of fixed length.""" letters = string.ascii_lowercase return "".join(random.choice(letters) for _ in range(length)) + + +def sanitize_http_header_value(value: str, *, default: str = "unknown") -> str: + """Return an ASCII-safe HTTP header value. + + Some HTTP client stacks (and servers) only accept ASCII in header values. + This helper prevents crashes when system strings (e.g., hostname) include + Unicode characters. + """ + cleaned = value.replace("\r", " ").replace("\n", " ") + cleaned = _CONTROL_CHARS_RE.sub(" ", cleaned) + cleaned = _WHITESPACE_RE.sub(" ", cleaned).strip() + + try: + cleaned.encode("ascii") + except UnicodeEncodeError: + normalized = unicodedata.normalize("NFKD", cleaned) + cleaned = normalized.encode("ascii", errors="replace").decode("ascii").strip() + + return cleaned or default diff --git a/tests/core/test_oauth_common_headers.py b/tests/core/test_oauth_common_headers.py new file mode 100644 index 00000000..d82cd1cc --- /dev/null +++ b/tests/core/test_oauth_common_headers.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from kimi_cli.auth import oauth + + +def test_common_headers_unicode_hostname_is_ascii_safe(monkeypatch) -> None: + # Simulate a Unicode hostname like the user's prompt `andrewlouis@🏢`. + monkeypatch.setattr(oauth.platform, "node", lambda: "🏢") + monkeypatch.setattr(oauth.socket, "gethostname", lambda: "🏢") + monkeypatch.setattr(oauth, "get_device_id", lambda: "test-device-id") + + headers = oauth._common_headers() + headers["X-Msh-Device-Name"].encode("ascii") + assert headers["X-Msh-Device-Name"] == "?" diff --git a/tests/utils/test_string_utils.py b/tests/utils/test_string_utils.py new file mode 100644 index 00000000..506fbe20 --- /dev/null +++ b/tests/utils/test_string_utils.py @@ -0,0 +1,76 @@ +"""Tests for string utility functions.""" + +from __future__ import annotations + +import unicodedata + +import pytest + +from kimi_cli.utils.string import sanitize_http_header_value + +# These examples are intentionally explicit for posterity: NFKD is used to decompose +# compatibility characters (e.g. ① -> 1), then we derive an ASCII-safe header value. +# +# Here’s what that looks like in our exact situation and a few related examples: +# +# - andrewlouis@🏢 +# - NFKD: andrewlouis@🏢 (emoji doesn’t decompose) +# - ASCII fallback with ignore: andrewlouis@ (emoji dropped) +# - ASCII fallback with replace: andrewlouis@? (emoji becomes ?) +# - 🏢 +# - NFKD: 🏢 +# - ascii(ignore): `` (empty) +# - ascii(replace): ? +# - José +# - NFKD: José (that last “é” becomes e + a combining accent) +# - ascii(ignore): Jose (accent dropped) +# - ascii(replace): Jose? +# - München +# - NFKD: München (ü → u + combining diaeresis) +# - ascii(ignore): Munchen +# - ①②③ +# - NFKD: 123 (circled numbers become plain digits) + + +@pytest.mark.parametrize( + ("raw", "expected_replace"), + [ + ("andrewlouis@🏢", "andrewlouis@?"), + ("🏢", "?"), + ("José", "Jose?"), + ("München", "Mu?nchen"), + ("①②③", "123"), + ], +) +def test_sanitize_http_header_value_replace_examples(raw: str, expected_replace: str) -> None: + assert sanitize_http_header_value(raw, default="device") == expected_replace + sanitize_http_header_value(raw, default="device").encode("ascii") + + +@pytest.mark.parametrize( + ("raw", "expected_ignore"), + [ + ("andrewlouis@🏢", "andrewlouis@"), + ("🏢", ""), + ("José", "Jose"), + ("München", "Munchen"), + ("①②③", "123"), + ], +) +def test_nfkd_ascii_ignore_examples(raw: str, expected_ignore: str) -> None: + nfkd = unicodedata.normalize("NFKD", raw) + assert nfkd.encode("ascii", errors="ignore").decode("ascii") == expected_ignore + + +def test_sanitize_http_header_value_strips_controls_and_newlines() -> None: + # Separate from the Unicode/NFKD examples: this is about header injection + # hardening and output stability. We intentionally collapse whitespace so + # "\r\n" (turned into two spaces) becomes a single space. + raw = "hi\r\nevil: 1\x00" + assert sanitize_http_header_value(raw, default="device") == "hi evil: 1" + + +def test_sanitize_http_header_value_collapses_internal_whitespace() -> None: + raw = "a\t\tb c" + assert sanitize_http_header_value(raw, default="device") == "a b c" +