diff --git a/README.md b/README.md index 4bee14b..9a61072 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,101 @@ # CozeWebSocket -Coze 平台 websocket quecpython 接入 + +## Table of Contents + +- [Introduction](#Introduction) +- [Features](#Features) +- [Quick_Start](#Quick_Start) + - [Prerequisites](#Prerequisites) + - [Installation](#Installation) + - [Running_Application](#Running_Application) +- [Contributing](#Contributing) +- [License](#License) +- [Support](#Support) + +## Introduction + +QuecPython introduces the Coze Platform Websocket Access solution. This solution is based on the WebSocket protocol, offering cross-platform compatibility and supporting most QuecPython modules. + +This demo uses an AI development board equipped with the EC800MCNLE module. + +## Features + +- Supports voice interruption/barge-in. +- Supports keyword-based voice wake-up. +- Uses Python for easy secondary development. + +## Quick_Start + +### Prerequisites + +Before getting started, ensure you have the following prerequisites: + +- **Hardware:** + - Contact Quectel to obtain the AI development board and accessories. + - Computer (Windows 7, Windows 10, or Windows 11) + - Speaker + - Any 2-5W speaker will work + - [Purchase link from Quectel Mall](https://www.quecmall.com/goods-detail/2c90800c94028da201948249e9f4012d) +- **Software:** + - Debugging tool [QPYcom](https://images.quectel.com/python/2022/12/QPYcom_V3.6.0.zip) + - QuecPython firmware (beta firmware is available in the `fw` directory of the repository) + - Python text editor (e.g., [VSCode](https://code.visualstudio.com/), [PyCharm](https://www.jetbrains.com/pycharm/download/)) + +### Installation + +1. **Clone the Repository:** + + ```bash + git clone https://github.com/QuecPython/CozeWebSocket.git + ``` + +2. **Flash the Firmware:** + Follow the [instructions](https://python.quectel.com/doc/Application_guide/zh/dev-tools/QPYcom/qpycom-dw.html#%E4%B8%8B%E8%BD%BD%E5%9B%BA%E4%BB%B6) to flash the firmware onto the development board. + +### Running_Application + +1. **Hardware Connection:** + This demo uses the Quectel AI development board. Contact Quectel if needed. Connect the hardware as shown below: + + + + 1. Connect the speaker + 2. Connect the antenna + 3. Insert the battery + +2. Connect to the host computer via Type-C. + +3. **Download the Code to the Device:** + + - Launch the QPYcom debugging tool. + - Connect the data cable to the computer. + - Press the **PWRKEY** button on the development board to power on the device. + - Follow the [instructions](https://developer.quectel.com/doc/quecpython/Getting_started/en/4G/first_python.html#PC与模组间的文件传输) to import all files from the `code` folder into the module's file system, preserving the directory structure. + +4. **Run the Application:** + + - Select the `File` tab. + - Choose the `coze_main.py` script. + - Right-click and select `Run` or use the `Run` shortcut button to execute the script. + +5. **After keyword wake-up, start a conversation. Refer to the runtime log:** + + ![](./media/20260202.png) + +## Contributing + +We welcome contributions to improve this project! Follow these steps to contribute: + +1. Fork this repository. +2. Create a new branch (`git checkout -b feature/your-feature`). +3. Commit your changes (`git commit -m 'Add your feature'`). +4. Push to the branch (`git push origin feature/your-feature`). +5. Open a Pull Request. + +## License + +This project is licensed under the Apache License. See the [LICENSE](https://license/) file for details. + +## Support + +If you have any questions or need support, refer to the [QuecPython Documentation](https://python.quectel.com/doc) or open an issue in this repository. diff --git a/README_zh.md b/README_zh.md new file mode 100644 index 0000000..4a7b548 --- /dev/null +++ b/README_zh.md @@ -0,0 +1,103 @@ +# CozeWebSocket +Coze 平台 websocket quecpython 接入 + +## 目录 + +- [介绍](#介绍) +- [功能特性](#功能特性) +- [快速开始](#快速开始) + - [先决条件](#先决条件) + - [安装](#安装) + - [运行应用程序](#运行应用程序) +- [贡献](#贡献) +- [许可证](#许可证) +- [支持](#支持) + +## 介绍 + +QuecPython 推出了Coze 平台 websocket quecpython 接入解决方案。该方案基于 websocket 协议,具有跨平台特性,可以适用于大部分 QuecPython 模组。 + +本案例采用搭载 EC800MCNLE 模组的 AI 开发板。 + +## 功能特性 + +- 支持语音中断/打断。 +- 支持关键词语音唤醒。 +- 使用 Python 语言,便于二次开发。 + +## 快速开始 + +### 先决条件 + +在开始之前,请确保您具备以下先决条件: + +- **硬件:** + - 联系移远官方获取 AI 开发板及配件。 + - 电脑(Windows 7、Windows 10 或 Windows 11) + - 喇叭 + - 任意 2-5W 功率的喇叭即可 + - [移远商城购买链接](https://www.quecmall.com/goods-detail/2c90800c94028da201948249e9f4012d) + +- **软件:** + - 调试工具 [QPYcom](https://images.quectel.com/python/2022/12/QPYcom_V3.6.0.zip) + - QuecPython 固件(仓库 fw 目录下有 beta 固件) + - Python 文本编辑器(例如,[VSCode](https://code.visualstudio.com/)、[Pycharm](https://www.jetbrains.com/pycharm/download/)) + +### 安装 + +1. **克隆仓库**: + + ```bash + git clone https://github.com/QuecPython/CozeWebSocket.git + ``` + +2. **烧录固件:** + 按照[说明](https://python.quectel.com/doc/Application_guide/zh/dev-tools/QPYcom/qpycom-dw.html#%E4%B8%8B%E8%BD%BD%E5%9B%BA%E4%BB%B6)将固件烧录到开发板上。 + +### 运行应用程序 + +1. **连接硬件:** + 本案例采用移远 AI 开发板,如有需要请联系官方获取。按照下图进行硬件连接: + + + + 1. 连接喇叭 + 2. 连接天线 + 3. 接入电池 + +2. **通过 Tpye-C 连接上位机** + +3. **将代码下载到设备:** + + - 启动 QPYcom 调试工具。 + - 将数据线连接到计算机。 + - 按下开发板上的 **PWRKEY** 按钮启动设备。 + - 按照[说明](https://developer.quectel.com/doc/quecpython/Getting_started/zh/4G/first_python.html#PC与模组间的文件传输)将 `code` 文件夹中的所有文件导入到模块的文件系统中,保留目录结构。 + +4. **运行应用程序:** + + - 选择 `File` 选项卡。 + - 选择 `coze_main.py` 脚本。 + - 右键单击并选择 `Run` 或使用`运行`快捷按钮执行脚本。 + +5. **关键词唤醒后,即可对话, 参考运行日志:** + + ![](./media/20260202.png) + +## 贡献 + +我们欢迎对本项目的改进做出贡献!请按照以下步骤进行贡献: + +1. Fork 此仓库。 +2. 创建一个新分支(`git checkout -b feature/your-feature`)。 +3. 提交您的更改(`git commit -m 'Add your feature'`)。 +4. 推送到分支(`git push origin feature/your-feature`)。 +5. 打开一个 Pull Request。 + +## 许可证 + +本项目使用 Apache 许可证。详细信息请参阅 [LICENSE](LICENSE) 文件。 + +## 支持 + +如果您有任何问题或需要支持,请参阅 [QuecPython 文档](https://python.quectel.com/doc) 或在本仓库中打开一个 issue。 diff --git a/fw/EC800MCNLER06A01M08_AI_WS_OCPU_QPY_BETA0927.zip b/fw/EC800MCNLER06A01M08_AI_WS_OCPU_QPY_BETA0927.zip new file mode 100644 index 0000000..6a81d2f Binary files /dev/null and b/fw/EC800MCNLER06A01M08_AI_WS_OCPU_QPY_BETA0927.zip differ diff --git a/media/20250425131903.jpg b/media/20250425131903.jpg new file mode 100644 index 0000000..bac934a Binary files /dev/null and b/media/20250425131903.jpg differ diff --git a/media/20260202.png b/media/20260202.png new file mode 100644 index 0000000..81e1bf0 Binary files /dev/null and b/media/20260202.png differ diff --git a/src/coze.py b/src/coze.py new file mode 100644 index 0000000..6f73786 --- /dev/null +++ b/src/coze.py @@ -0,0 +1,154 @@ +from usr import uwebsocket +import _thread +from usr import packet #update, append +import ujson +import ubinascii +from usr.media import singleton_media +import utime +from queue import Queue + +class cozews(): + def __init__(self, url, auth, callback=None): + + self.media = singleton_media('pcma', 4) + if self.media is None: + print('media is busy, please stop it first') + return + self.audio_queue = Queue() + + self.url = url + self.headers = {"Authorization": "Bearer " + auth} + + self.ws_recv_task_id = None + self.ws_audio_uplink_handler_id = None + self.ws_audio_downlink_handler_id = None + self.isactive = False + self.volume = 8 + self.callback = callback + + if self.callback: + self.event_queue = Queue() + self.ws_callback_event_id = _thread.start_new_thread(self.ws_server_event_handler, ()) + + def start(self): + if self.media.is_idle() is False: + print('media is busy, please stop it first') + return + self.client = uwebsocket.Client.connect(self.url, self.headers) + msg = ujson.dumps(packet.update) + self.client.send(msg) + + # ws recv task + self.ws_recv_task_id = _thread.start_new_thread(self.ws_recv_task, ()) + + def stop(self): + self.stop_audio_stream() + + if self.ws_recv_task_id: + _thread.stop_thread(self.ws_recv_task_id) + self.ws_recv_task_id = None + + self.client.close() + self.isactive = False + + def ws_audio_uplink_handler(self): + msg = packet.append + + while True: + try: + t1 = utime.ticks_ms() + data = b"".join([self.media.pcma_read() for _ in range(5)]) + t2 = utime.ticks_ms() + if len(data) > 0: + msg['data']['delta'] = ubinascii.b2a_base64(data).strip() + payload = ujson.dumps(msg) + #print('up {}ms/{}'.format(t2 - t1, len(payload))) + self.client.send(payload) + utime.sleep_ms(1) + except Exception as e: + print("Error in ws_audio_uplink_handler: {}".format(e)) + + def ws_audio_downlink_handler(self): + while True: + recv_data = self.audio_queue.get() + start,end = ujson.search(recv_data, 'content') + data = ubinascii.a2b_base64(recv_data[start:end]) + self.media.pcma_write(data) + utime.sleep_ms(1) + + def ws_server_event_handler(self): + while True: + recv_data = self.event_queue.get() + self.callback(self, recv_data) + utime.sleep_ms(1) + + def start_audio_stream(self): + self.media.start() + self.media.set_volume(self.volume) + + self.ws_audio_uplink_handler_id = _thread.start_new_thread(self.ws_audio_uplink_handler, ()) + self.ws_audio_downlink_handler_id = _thread.start_new_thread(self.ws_audio_downlink_handler, ()) + self.isactive = True + + def stop_audio_stream(self): + if self.ws_audio_uplink_handler_id: + _thread.stop_thread(self.ws_audio_uplink_handler_id) + self.ws_audio_uplink_handler_id = None + if self.ws_audio_downlink_handler_id: + _thread.stop_thread(self.ws_audio_downlink_handler_id) + self.ws_audio_downlink_handler_id = None + self.media.stop() + + def ws_recv_task(self): + while True: + try: + recv_data = self.client.recv(4096) + #print('recv_data_{}: {}'.format(len(recv_data), recv_data)) + if recv_data is None or len(recv_data) <= 1: + print('illegal data {}'.format(recv_data)) + continue + if packet.EventType.CONVERSATION_AUDIO_DELTA in recv_data: + self.audio_queue.put(recv_data) + else: + if self.callback: + self.event_queue.put(recv_data) + except Exception as e: + if "EIO" in str(e): + if self.isactive: + self.stop_audio_stream() + self.client.close() + self.isactive = False + msg = '{"event_type": "client.disconnected"}' + if self.callback: + #self.callback(self, msg) + self.event_queue.put(msg) + break + else: + if recv_data is not None: + print('recv error[{}] |{}|'.format(len(recv_data), recv_data)) + print('ws error |{}|'.format(e)) + utime.sleep_ms(1) + + def active(self): + return self.isactive + + def config(self, arg = None, **kwargs): + if arg != None: + if arg == 'volume': + if self.isactive is False: + return self.volume + return self.media.get_volume() + + for key, value in kwargs.items(): + if key == 'volume': + self.volume = value + if self.isactive is False: + continue + self.media.set_volume(value) + + def interrupted(self): + if self.isactive is False: + return + # 打断对话 + msg = ujson.dumps(packet.cancel) + self.client.send(msg) diff --git a/src/coze_main.py b/src/coze_main.py new file mode 100644 index 0000000..554a079 --- /dev/null +++ b/src/coze_main.py @@ -0,0 +1,44 @@ +from usr.coze import cozews +from usr import packet +import ujson + +def callback(coze, msg): + start,end = ujson.search(msg, 'event_type') + event = msg[start:end] + if event == packet.EventType.CHAT_CREATED: + coze.start_audio_stream() + print('connect server success...') + elif event == packet.EventType.DISCONNECTED: + print('server disconnected...') + elif event == packet.EventType.CONVERSATION_AUDIO_TRANSCRIPT_COMPLETED: + start,end = ujson.search(msg, 'content') + print('ASR {}'.format(msg[start:end])) + elif event == packet.EventType.CONVERSATION_MESSAGE_COMPLETED: + start,end = ujson.search(msg, 'content_type') + content_type = msg[start:end] + start,end = ujson.search(msg, 'type') + type = msg[start:end] + if content_type == 'text' and type == 'answer': + start,end = ujson.search(msg, 'content') + print('TTS {}'.format(msg[start:end])) + elif event == packet.EventType.CONVERSATION_CHAT_FAILED: + start,end = ujson.search(msg, 'last_error') + print('failed {}'.format(msg[start:end])) + elif event == packet.EventType.SERVER_ERROR: + start,end = ujson.search(msg, 'msg') + print('error {}'.format(msg[start:end])) + else: + print('unkown event_type: {}'.format(msg['event_type'])) + +#url = "ws://183.201.115.203/v1/chat?bot_id=7511922148273831962" +url = "wss://ws.coze.cn/v1/chat?bot_id=7595096935447724032" + +#auth = "pat_eSuCmnooG6PLDildBu9ghH0OapGEkTg4wxTKNekj9AgXKAajIb0YQpgQ464k5J5x" # Replace with your actual auth token +auth = "pat_bgE5pSWNDM7XnfLi0TGEEyMXX9BcqUmJU3lEFXHWgaWpFbqjgvrh48HjqRoPwj9y" # Replace with your actual auth token + +coze = cozews(url, auth, callback) + +coze.config(volume=11) +coze.start() + +print('config done') diff --git a/src/media.py b/src/media.py new file mode 100644 index 0000000..c34da39 --- /dev/null +++ b/src/media.py @@ -0,0 +1,84 @@ + +import audio +import G711 + +singleton_media_obj = None + +def singleton_media(name, type): + global singleton_media_obj + if singleton_media_obj is None: + singleton_media_obj = media(name, type) + return singleton_media_obj + + if singleton_media_obj.is_idle(): + print('{} is using media'.format(singleton_media_obj.name)) + return None + else: + singleton_media_obj.set_media_config(name, type) + return singleton_media_obj + +class media: + MEDIA_TYPE_AUDIO = 1 + MEDIA_TYPE_PCM = 2 + MEDIA_TYPE_RECORD = 3 + MEDIA_TYPE_PCMA = 4 + + def __init__(self, name, type): + self.name = name + self.type = type + self.pcm = None + self.pcma = None + self.audio = audio.Audio(0) + self.audio.set_pa(29) + def set_media_config(self, name, type): + self.name = name + self.type = type + + def is_idle(self): + if self.pcma: + return False + return True + + def start(self): + if self.type == self.MEDIA_TYPE_PCMA: + self.pcm = audio.Audio.PCM(1, 1, 8000, 2, 1, 5) + self.pcma = G711(self.pcm) + else: + raise('unkown audio type') + + def stop(self): + if self.type == self.MEDIA_TYPE_PCMA: + if self.pcma: + del self.pcma + self.g711 = None + self.pcm.close() + del self.pcm + self.pcm = None + else: + raise('wrong audio type') + self.name = None + self.type = None + + def pcma_read(self): + #read = self.pcma.read(0) + #print('read: {}'.format(read)) + #return read + return self.pcma.read(0) + + def pcma_write(self, payload): + # print('write: {}'.format(payload)) + return self.pcma.write(payload, 0) + + + def set_volume(self, value): + if self.type == self.MEDIA_TYPE_PCMA: + return self.pcm.setVolume(value) + else: + raise('wrong audio type') + + def get_volume(self): + if self.type == self.MEDIA_TYPE_PCMA: + return self.pcm.getVolume() + else: + raise('wrong audio type') + diff --git a/src/packet.py b/src/packet.py new file mode 100644 index 0000000..9745e28 --- /dev/null +++ b/src/packet.py @@ -0,0 +1,155 @@ +class EventType: + ALL = 'realtime.event' # 所有事件 + CONNECTED = 'client.connected' # 客户端已连接 + CONNECTING = 'client.connecting' # 客户端连接中 + INTERRUPTED = 'client.interrupted' # 客户端已中断 + DISCONNECTED = 'client.disconnected' # 客户端已断开 + ERROR = 'client.error' # 客户端发生错误 + + # 音频控制事件 + AUDIO_UNMUTED = 'client.audio.unmuted' # 音频已取消静音 + AUDIO_MUTED = 'client.audio.muted' # 音频已静音 + AUDIO_INPUT_DUMP = 'client.audio.input.dump' # 音频输入数据导出 + + # 设备变更事件 + AUDIO_INPUT_DEVICE_CHANGED = 'client.input.device.changed' # 音频输入设备已改变 + AUDIO_OUTPUT_DEVICE_CHANGED = 'client.output.device.changed' # 音频输出设备已改变 + + # 降噪控制事件 + DENOISER_ENABLED = 'client.denoiser.enabled' # 降噪已启用 + DENOISER_DISABLED = 'client.denoiser.disabled' # 降噪已禁用 + + # 服务端对话事件 + CHAT_CREATED = 'chat.created' # 对话已创建 + CHAT_UPDATED = 'chat.updated' # 对话已更新 + + # 会话状态事件 + CONVERSATION_CHAT_CREATED = 'conversation.chat.created' # 会话对话已创建 + CONVERSATION_CHAT_IN_PROGRESS = 'conversation.chat.in.progress' # 对话进行中 + CONVERSATION_CHAT_COMPLETED = 'conversation.chat.completed' # 对话已完成 + CONVERSATION_CHAT_FAILED = 'conversation.chat.failed' # 对话失败 + CONVERSATION_CHAT_CANCELLED = 'conversation.chat.cancelled' # 对话已取消 + CONVERSATION_CHAT_REQUIRES_ACTION = 'conversation.chat.requires_action' # 对话需要端插件响应 + + # 消息事件 + CONVERSATION_MESSAGE_DELTA = 'conversation.message.delta' # 文本消息增量返回 + CONVERSATION_MESSAGE_COMPLETED = 'conversation.message.completed' # 文本消息完成 + + # 音频事件 + CONVERSATION_AUDIO_DELTA = 'conversation.audio.delta' # 语音消息增量返回 + CONVERSATION_AUDIO_COMPLETED = 'conversation.audio.completed' # 语音回复完成 + + # 语音识别事件 + CONVERSATION_AUDIO_TRANSCRIPT_UPDATE = 'conversation.audio_transcript.update' # 用户语音识别实时字幕更新 + CONVERSATION_AUDIO_TRANSCRIPT_COMPLETED = 'conversation.audio_transcript.completed' # 用户语音识别完成 + + # 语音检测事件 + INPUT_AUDIO_BUFFER_SPEECH_STARTED = 'input_audio_buffer.speech_started' # 检测到用户开始说话 + INPUT_AUDIO_BUFFER_SPEECH_STOPPED = 'input_audio_buffer.speech_stopped' # 检测到用户停止说话 + + # 缓冲区事件 + INPUT_AUDIO_BUFFER_COMPLETED = 'input_audio_buffer.completed' # 语音输入缓冲区提交完成 + INPUT_AUDIO_BUFFER_CLEARED = 'input_audio_buffer.cleared' # 语音输入缓冲区已清除 + + # 其他事件 + SERVER_ERROR = 'error' # 服务端错误 + CONVERSATION_CLEARED = 'conversation.cleared' # 对话上下文已清除 + DUMP_AUDIO = 'dump.audio' # 音频导出 + + +update = { + "id": "event_id_123456", + "event_type": "chat.update", + "data": { + "need_play_prologue": True, + "chat_config": { + "auto_save_history": True, + "user_id": "quecpython_user", + }, + "input_audio": { + "format": "pcm", + "codec": "g711a", + "sample_rate": 8000, + "channel": 1, + "bit_depth": 16 + }, + "output_audio": { + "codec": "g711a", + "pcm_config": { + "sample_rate": 8000, + "frame_size_ms": 100, + "limit_config": { + "period": 1, + "max_frame_num": 11 + }, + }, + "speech_rate": 0, + }, + "turn_detection": { + "type": "server_vad", + "interrupt_config": { + "mode": "keyword_contains", + "keywords": [ + "闭嘴", + "你好扣子" + ] + } + }, + "asr_config":{ + "enable_ddc": True, + "hot_words":[ + "闭嘴", + "你好扣子" + ] + }, + "event_subscriptions": [ + "error", + "conversation.audio_transcript.completed", + "conversation.message.completed", + "conversation.audio.delta", + "conversation.chat.failed", + "conversation.chat.cancelled" + ] + } +} + +interrupt = { + "id": "event_id_123457", + "event_type": "chat.update", + "data": { + "turn_detection": { + "type":"server_vad", + "interrupt_config": { + "mode": "keyword_contains", + "keywords": [ + "闭嘴", + "你好扣子" + ] + } + }, + "asr_config":{ + "hot_words":[ + "闭嘴", + "你好扣子" + ] + } + } +} + + +append = { + "id": "event_id_123458", + "event_type": "input_audio_buffer.append", + "data": { + "delta": "base64EncodedAudioDelta" + } +} + +cancel = { + "id": "event_id_123459", + "event_type": "conversation.chat.cancel" +} + +disconnected = { + "event_type": "client.disconnected" +} diff --git a/src/uwebsocket.py b/src/uwebsocket.py new file mode 100644 index 0000000..a3288f9 --- /dev/null +++ b/src/uwebsocket.py @@ -0,0 +1,186 @@ +import log +import usocket as socket +import ubinascii as binascii +import urandom as random +import log +import ure as re +import ustruct as struct +import urandom as random +import usocket as socket +import websocket +from ucollections import namedtuple +import dataCall + +LOGGER = log.getLogger(__name__) + +URL_RE = re.compile(r'(wss|ws)://([A-Za-z0-9-\.]+)(?:\:([0-9]+))?(/.+)?') +URI = namedtuple('URI', ('protocol', 'hostname', 'port', 'path')) + + +def urlparse(uri): + """Parse ws:// URLs""" + match = URL_RE.match(uri) + if match: + protocol = match.group(1) + host = match.group(2) + port = match.group(3) + path = match.group(4) + + if protocol == 'wss': + if port is None: + port = 443 + elif protocol == 'ws': + if port is None: + port = 80 + else: + raise ValueError('Scheme {} is invalid'.format(protocol)) + + return URI(protocol, host, int(port), path) + + +class NoDataException(Exception): + pass + + +class ConnectionClosed(Exception): + pass + + +class Websocket(object): + """ + Basis of the Websocket protocol. + + This can probably be replaced with the C-based websocket module, but + this one currently supports more options. + """ + is_client = False + + def __init__(self, sock, debug=False): + self.sock = sock + self.ws = websocket.websocket(sock) + self.open = True + self.debug = debug + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + + def settimeout(self, timeout): + self.sock.settimeout(timeout) + + def read_frame(self, sz): + return self.ws.read(sz) + + def write_frame(self, data=b''): + return self.ws.write(data) + + def recv(self, max_size=2048): + """ + Receive data from the websocket. + + This is slightly different from 'websockets' in that it doesn't + fire off a routine to process frames and put the data in a queue. + If you don't call recv() sufficiently often you won't process control + frames. + """ + assert self.open + + return self.read_frame(max_size) + + def send(self, buf): + """Send data to the websocket.""" + + assert self.open + + if isinstance(buf, str): + buf = buf.encode('utf-8') + self.ws.ioctl(9, 1) + elif isinstance(buf, bytes): + self.ws.ioctl(9, 2) + else: + raise TypeError() + return self.write_frame(buf) + + def close(self): + """Close the websocket.""" + if not self.open: + return + + try: + self.ws.ioctl(4) + except Exception as e: + if self.debug: LOGGER.info("websocekt close:%s"%(str(e))) + self._close() + + def _close(self): + if self.debug: LOGGER.info("Connection closed") + self.open = False + self.sock.close() + + +class WebsocketClient(Websocket): + is_client = True + + +class Client(object): + + @staticmethod + def connect(uri, headers=None, debug=False): + """ + Connect a websocket. + :param uri: example ws://172.16.185.123/ + :param headers: k, v of header + :param debug: allow output log + :return: + """ + if not headers: + headers = dict() + if not isinstance(headers, dict): + raise Exception("headers must be dict type but {} you given.".format(type(headers))) + + uri = urlparse(uri) + assert uri + + if debug: LOGGER.info("open connection %s:%s", + uri.hostname, uri.port) + + sock = socket.socket() + addr = socket.getaddrinfo(uri.hostname, uri.port, socket.AF_INET) + sock.connect(addr[0][4]) + + if uri.protocol == 'wss': + import ussl + sock = ussl.wrap_socket(sock) + + def send_header(header, *args): + if debug: LOGGER.info(str(header), *args) + sock.write(header % args + '\r\n') + + # Sec-WebSocket-Key is 16 bytes of random base64 encoded + key = binascii.b2a_base64(bytes(random.getrandbits(8) for _ in range(16)))[:-1] + send_header(b'GET %s HTTP/1.1', uri.path or '/') + send_header(b'Host: %s:%s', 'ws.coze.cn', uri.port) + send_header(b'Connection: Upgrade') + send_header(b'Upgrade: websocket') + send_header(b'Sec-WebSocket-Key: %s', key) + send_header(b'Sec-WebSocket-Version: 13') + send_header(b'Origin: http://{hostname}:{port}'.format( + hostname=uri.hostname, + port=uri.port) + ) + for k, v in headers.items(): + send_header('{}:{}'.format(k, v).encode()) + send_header(b'') + + header = sock.readline()[:-2] + assert header.startswith(b'HTTP/1.1 101 '), header + + # We don't (currently) need these headers + # FIXME: should we check the return key? + while header: + if debug: LOGGER.info(str(header)) + header = sock.readline()[:-2] + + return WebsocketClient(sock, debug)