gemma-bot/bot.py

435 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import asyncio
import logging
import json
from io import BytesIO
from typing import Dict, List, Any, Optional
import re
from aiogram import Bot, Dispatcher, types, F
from aiogram.filters import Command
from aiogram.types import Message
from aiogram.enums import ParseMode
from dotenv import load_dotenv
from ollama import AsyncClient
import asyncpg
from aiogram.types import InlineKeyboardMarkup, InlineKeyboardButton
# Настройка логирования
logging.basicConfig(level=logging.INFO)
# Загрузка переменных окружения
load_dotenv()
# Инициализация бота и диспетчера
bot = Bot(token=os.getenv("BOT_TOKEN"))
dp = Dispatcher()
# Инициализация клиента Ollama
ollama_client = AsyncClient(host=os.getenv("OLLAMA_HOST"))
model_name = os.getenv("OLLAMA_MODEL")
# Соединение с PostgreSQL
postgres_pool: Optional[asyncpg.Pool] = None
# Словарь для хранения контекста диалога для каждого пользователя в памяти (кэш)
user_messages: Dict[int, List[Dict[str, Any]]] = {}
# Словарь для хранения задач обновления статуса печати
typing_tasks: Dict[int, asyncio.Task] = {}
async def init_db():
"""Инициализация базы данных"""
global postgres_pool
# Создаем пул соединений с базой данных
postgres_pool = await asyncpg.create_pool(
host=os.getenv("POSTGRES_HOST"),
port=os.getenv("POSTGRES_PORT"),
database=os.getenv("POSTGRES_DB"),
user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASSWORD")
)
# Создаем таблицу, если она не существует
async with postgres_pool.acquire() as conn:
await conn.execute('''
CREATE TABLE IF NOT EXISTS messages (
id SERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
message_data JSONB NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
)
''')
# Создаем индекс для быстрого поиска по user_id
await conn.execute('''
CREATE INDEX IF NOT EXISTS idx_messages_user_id ON messages (user_id)
''')
async def save_message_to_db(user_id: int, message_data: Dict[str, Any]):
"""Сохранение сообщения в базу данных"""
if not postgres_pool:
logging.error("Пул соединений PostgreSQL не инициализирован")
return
# Преобразуем изображения в строку для хранения
if "images" in message_data:
# Обходим проблему хранения бинарных данных - просто удаляем изображения
# Для реального приложения можно реализовать хранение изображений в отдельной таблице
message_data_copy = message_data.copy()
message_data_copy["has_image"] = True
message_data_copy.pop("images", None)
else:
message_data_copy = message_data
try:
async with postgres_pool.acquire() as conn:
await conn.execute(
'''
INSERT INTO messages (user_id, message_data)
VALUES ($1, $2)
''',
user_id, json.dumps(message_data_copy)
)
except Exception as e:
logging.error(f"Ошибка сохранения сообщения в БД: {e}")
async def load_messages_from_db(user_id: int) -> List[Dict[str, Any]]:
"""Загрузка сообщений из базы данных"""
if not postgres_pool:
logging.error("Пул соединений PostgreSQL не инициализирован")
return []
try:
async with postgres_pool.acquire() as conn:
rows = await conn.fetch(
'''
SELECT message_data
FROM messages
WHERE user_id = $1
ORDER BY created_at ASC
''',
user_id
)
# Преобразуем строки JSON обратно в словари
return [json.loads(row['message_data']) for row in rows]
except Exception as e:
logging.error(f"Ошибка загрузки сообщений из БД: {e}")
return []
async def clear_messages_from_db(user_id: int):
"""Очистка сообщений пользователя из базы данных"""
if not postgres_pool:
logging.error("Пул соединений PostgreSQL не инициализирован")
return
try:
async with postgres_pool.acquire() as conn:
await conn.execute(
'''
DELETE FROM messages
WHERE user_id = $1
''',
user_id
)
except Exception as e:
logging.error(f"Ошибка очистки сообщений из БД: {e}")
async def keep_typing(chat_id: int):
"""Функция для поддержания статуса 'печатает...' до окончания генерации ответа"""
try:
while True:
await bot.send_chat_action(chat_id=chat_id, action="typing")
await asyncio.sleep(4) # Обновляем каждые 4 секунды (статус активен 5 секунд)
except asyncio.CancelledError:
# Задача отменена, завершаем работу
pass
except Exception as e:
logging.error(f"Ошибка в keep_typing: {e}")
def escape_markdown(text: str) -> str:
"""Экранирует специальные символы Markdown, если они не являются частью форматирования"""
# Регулярное выражение для поиска уже отформатированного текста
md_patterns = [
r'`[^`]+`', # Код
r'\*\*[^*]+\*\*', # Жирный
r'\*[^*]+\*', # Курсив
r'_[^_]+_', # Курсив (альтернативный)
r'__[^_]+__', # Жирный (альтернативный)
r'\[[^\]]+\]\([^)]+\)', # Ссылки
]
# Собираем все совпадения с шаблонами
matches = []
for pattern in md_patterns:
matches.extend([(m.start(), m.end()) for m in re.finditer(pattern, text)])
# Сортируем по началу совпадения
matches.sort()
# Если нет совпадений, просто экранируем все специальные символы
if not matches:
return re.sub(r'([_*\[\]()~`>#+-=|{}.!\\])', r'\\\1', text)
# Иначе экранируем только те символы, которые не входят в уже отформатированный текст
result = ""
last_pos = 0
for start, end in matches:
# Экранируем символы до начала форматирования
if start > last_pos:
result += re.sub(r'([_*\[\]()~`>#+-=|{}.!\\])', r'\\\1', text[last_pos:start])
# Добавляем отформатированный текст без изменений
result += text[start:end]
last_pos = end
# Экранируем символы после последнего форматирования
if last_pos < len(text):
result += re.sub(r'([_*\[\]()~`>#+-=|{}.!\\])', r'\\\1', text[last_pos:])
return result
@dp.message(Command("start"))
async def cmd_start(message: Message):
"""Обработчик команды /start"""
user_id = message.from_user.id
# Загружаем контекст из БД при первом запуске
messages = await load_messages_from_db(user_id)
if not messages:
# Если сообщений нет, создаем новый список
user_messages[user_id] = []
else:
# Если есть, сохраняем их в кэше
user_messages[user_id] = messages
welcome_message = (
"Привет! Я бот на базе Gemma. Отправьте мне сообщение или фото, и я отвечу вам.\n\n"
"Поддерживается *Markdown* форматирование:\n"
"• *жирный текст* (звездочки)\n"
"• _курсив_ (нижнее подчеркивание)\n"
"• `код` (обратные кавычки)\n"
"• [ссылки](https://example.com)\n\n"
"Используйте /clear для очистки контекста диалога."
)
# Создаем инлайн-кнопку для очистки истории
keyboard = InlineKeyboardMarkup(inline_keyboard=[
[InlineKeyboardButton(text="🗑 Очистить историю сообщений", callback_data="clear_history")]
])
await message.answer(welcome_message, parse_mode=ParseMode.MARKDOWN, reply_markup=keyboard)
@dp.message(Command("clear"))
async def cmd_clear(message: Message):
"""Очистка контекста диалога"""
user_id = message.from_user.id
# Очищаем контекст в памяти
user_messages[user_id] = []
# Очищаем контекст в БД
await clear_messages_from_db(user_id)
await message.answer("Контекст диалога очищен.")
@dp.message(F.photo)
async def handle_photo(message: Message):
"""Обработчик фотографий"""
user_id = message.from_user.id
# Инициализация контекста пользователя, если он еще не существует
if user_id not in user_messages:
# Пытаемся загрузить контекст из БД
messages = await load_messages_from_db(user_id)
user_messages[user_id] = messages or []
# Получаем фото наилучшего качества
photo = message.photo[-1]
file_info = await bot.get_file(photo.file_id)
file_content = await bot.download_file(file_info.file_path)
# Конвертируем фото в Base64
image_data = file_content.read()
image_io = BytesIO(image_data)
# Получаем подпись к фото (если есть)
caption = message.caption or "Проанализируй эту фотографию"
# Создаем сообщение с изображением
user_message = {
"role": "user",
"content": caption,
"images": [image_io.getvalue()]
}
# Добавляем сообщение в контекст
user_messages[user_id].append(user_message)
# Сохраняем сообщение в БД
await save_message_to_db(user_id, user_message)
# Запускаем задачу обновления статуса печати
typing_tasks[user_id] = asyncio.create_task(keep_typing(user_id))
try:
# Получаем ответ от Gemma
answer = ""
async for part in await ollama_client.chat(
model=model_name,
messages=user_messages[user_id],
stream=True
):
answer += part['message']['content']
# Останавливаем задачу обновления статуса печати
if user_id in typing_tasks:
typing_tasks[user_id].cancel()
del typing_tasks[user_id]
# Создаем ответ ассистента
assistant_message = {"role": "assistant", "content": answer}
# Добавляем ответ в историю
user_messages[user_id].append(assistant_message)
# Сохраняем ответ в БД
await save_message_to_db(user_id, assistant_message)
# Создаем инлайн-кнопку для очистки истории
keyboard = InlineKeyboardMarkup(inline_keyboard=[
[InlineKeyboardButton(text="🗑 Очистить историю сообщений", callback_data="clear_history")]
])
# Отправляем ответ пользователю с Markdown-форматированием и кнопкой
safe_answer = escape_markdown(answer)
await message.answer(safe_answer, parse_mode=ParseMode.MARKDOWN, reply_markup=keyboard)
except Exception as e:
# Останавливаем задачу обновления статуса печати в случае ошибки
if user_id in typing_tasks:
typing_tasks[user_id].cancel()
del typing_tasks[user_id]
logging.error(f"Ошибка при обработке фото: {e}")
await message.answer(f"Произошла ошибка при обработке фото: {e}")
@dp.message(F.text)
async def handle_text(message: Message):
"""Обработчик текстовых сообщений"""
user_id = message.from_user.id
# Инициализация контекста пользователя, если он еще не существует
if user_id not in user_messages:
# Пытаемся загрузить контекст из БД
messages = await load_messages_from_db(user_id)
user_messages[user_id] = messages or []
# Создаем сообщение пользователя
user_message = {"role": "user", "content": message.text}
# Добавляем сообщение пользователя в контекст
user_messages[user_id].append(user_message)
# Сохраняем сообщение в БД
await save_message_to_db(user_id, user_message)
# Запускаем задачу обновления статуса печати
typing_tasks[user_id] = asyncio.create_task(keep_typing(user_id))
try:
# Получаем ответ от Gemma
answer = ""
async for part in await ollama_client.chat(
model=model_name,
messages=user_messages[user_id],
stream=True
):
answer += part['message']['content']
# Останавливаем задачу обновления статуса печати
if user_id in typing_tasks:
typing_tasks[user_id].cancel()
del typing_tasks[user_id]
# Создаем ответ ассистента
assistant_message = {"role": "assistant", "content": answer}
# Добавляем ответ в историю
user_messages[user_id].append(assistant_message)
# Сохраняем ответ в БД
await save_message_to_db(user_id, assistant_message)
# Создаем инлайн-кнопку для очистки истории
keyboard = InlineKeyboardMarkup(inline_keyboard=[
[InlineKeyboardButton(text="🗑 Очистить историю сообщений", callback_data="clear_history")]
])
# Отправляем ответ пользователю с Markdown-форматированием и кнопкой
safe_answer = escape_markdown(answer)
await message.answer(safe_answer, parse_mode=ParseMode.MARKDOWN, reply_markup=keyboard)
except Exception as e:
# Останавливаем задачу обновления статуса печати в случае ошибки
if user_id in typing_tasks:
typing_tasks[user_id].cancel()
del typing_tasks[user_id]
logging.error(f"Ошибка при обработке сообщения: {e}")
await message.answer(f"Произошла ошибка при обработке сообщения: {e}")
@dp.message()
async def handle_other(message: Message):
"""Обработчик всех остальных типов сообщений"""
await message.answer("Я могу обрабатывать только текст и фотографии. Пожалуйста, отправьте текст или одну фотографию.")
# Добавляем обработчик нажатия на кнопку очистки истории
@dp.callback_query(F.data == "clear_history")
async def clear_history_callback(callback: types.CallbackQuery):
"""Обработчик кнопки очистки истории"""
user_id = callback.from_user.id
# Очищаем контекст в памяти
user_messages[user_id] = []
# Очищаем контекст в БД
await clear_messages_from_db(user_id)
# Уведомляем пользователя
await callback.answer("Контекст диалога очищен!")
# Отправляем сообщение в чат
await callback.message.answer("Контекст диалога очищен. Вы можете начать новый разговор.")
async def main():
"""Запуск бота"""
# Инициализируем подключение к базе данных
await init_db()
try:
# Запускаем бота
await dp.start_polling(bot)
finally:
# Закрываем пул соединений при завершении
if postgres_pool:
await postgres_pool.close()
if __name__ == "__main__":
asyncio.run(main())