435 lines
18 KiB
Python
435 lines
18 KiB
Python
import os
|
||
import asyncio
|
||
import logging
|
||
import json
|
||
from io import BytesIO
|
||
from typing import Dict, List, Any, Optional
|
||
import re
|
||
|
||
from aiogram import Bot, Dispatcher, types, F
|
||
from aiogram.filters import Command
|
||
from aiogram.types import Message
|
||
from aiogram.enums import ParseMode
|
||
from dotenv import load_dotenv
|
||
from ollama import AsyncClient
|
||
import asyncpg
|
||
from aiogram.types import InlineKeyboardMarkup, InlineKeyboardButton
|
||
|
||
# Настройка логирования
|
||
logging.basicConfig(level=logging.INFO)
|
||
|
||
# Загрузка переменных окружения
|
||
load_dotenv()
|
||
|
||
# Инициализация бота и диспетчера
|
||
bot = Bot(token=os.getenv("BOT_TOKEN"))
|
||
dp = Dispatcher()
|
||
|
||
# Инициализация клиента Ollama
|
||
ollama_client = AsyncClient(host=os.getenv("OLLAMA_HOST"))
|
||
model_name = os.getenv("OLLAMA_MODEL")
|
||
|
||
# Соединение с PostgreSQL
|
||
postgres_pool: Optional[asyncpg.Pool] = None
|
||
|
||
# Словарь для хранения контекста диалога для каждого пользователя в памяти (кэш)
|
||
user_messages: Dict[int, List[Dict[str, Any]]] = {}
|
||
# Словарь для хранения задач обновления статуса печати
|
||
typing_tasks: Dict[int, asyncio.Task] = {}
|
||
|
||
|
||
async def init_db():
|
||
"""Инициализация базы данных"""
|
||
global postgres_pool
|
||
|
||
# Создаем пул соединений с базой данных
|
||
postgres_pool = await asyncpg.create_pool(
|
||
host=os.getenv("POSTGRES_HOST"),
|
||
port=os.getenv("POSTGRES_PORT"),
|
||
database=os.getenv("POSTGRES_DB"),
|
||
user=os.getenv("POSTGRES_USER"),
|
||
password=os.getenv("POSTGRES_PASSWORD")
|
||
)
|
||
|
||
# Создаем таблицу, если она не существует
|
||
async with postgres_pool.acquire() as conn:
|
||
await conn.execute('''
|
||
CREATE TABLE IF NOT EXISTS messages (
|
||
id SERIAL PRIMARY KEY,
|
||
user_id BIGINT NOT NULL,
|
||
message_data JSONB NOT NULL,
|
||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
||
)
|
||
''')
|
||
|
||
# Создаем индекс для быстрого поиска по user_id
|
||
await conn.execute('''
|
||
CREATE INDEX IF NOT EXISTS idx_messages_user_id ON messages (user_id)
|
||
''')
|
||
|
||
|
||
async def save_message_to_db(user_id: int, message_data: Dict[str, Any]):
|
||
"""Сохранение сообщения в базу данных"""
|
||
if not postgres_pool:
|
||
logging.error("Пул соединений PostgreSQL не инициализирован")
|
||
return
|
||
|
||
# Преобразуем изображения в строку для хранения
|
||
if "images" in message_data:
|
||
# Обходим проблему хранения бинарных данных - просто удаляем изображения
|
||
# Для реального приложения можно реализовать хранение изображений в отдельной таблице
|
||
message_data_copy = message_data.copy()
|
||
message_data_copy["has_image"] = True
|
||
message_data_copy.pop("images", None)
|
||
else:
|
||
message_data_copy = message_data
|
||
|
||
try:
|
||
async with postgres_pool.acquire() as conn:
|
||
await conn.execute(
|
||
'''
|
||
INSERT INTO messages (user_id, message_data)
|
||
VALUES ($1, $2)
|
||
''',
|
||
user_id, json.dumps(message_data_copy)
|
||
)
|
||
except Exception as e:
|
||
logging.error(f"Ошибка сохранения сообщения в БД: {e}")
|
||
|
||
|
||
async def load_messages_from_db(user_id: int) -> List[Dict[str, Any]]:
|
||
"""Загрузка сообщений из базы данных"""
|
||
if not postgres_pool:
|
||
logging.error("Пул соединений PostgreSQL не инициализирован")
|
||
return []
|
||
|
||
try:
|
||
async with postgres_pool.acquire() as conn:
|
||
rows = await conn.fetch(
|
||
'''
|
||
SELECT message_data
|
||
FROM messages
|
||
WHERE user_id = $1
|
||
ORDER BY created_at ASC
|
||
''',
|
||
user_id
|
||
)
|
||
|
||
# Преобразуем строки JSON обратно в словари
|
||
return [json.loads(row['message_data']) for row in rows]
|
||
except Exception as e:
|
||
logging.error(f"Ошибка загрузки сообщений из БД: {e}")
|
||
return []
|
||
|
||
|
||
async def clear_messages_from_db(user_id: int):
|
||
"""Очистка сообщений пользователя из базы данных"""
|
||
if not postgres_pool:
|
||
logging.error("Пул соединений PostgreSQL не инициализирован")
|
||
return
|
||
|
||
try:
|
||
async with postgres_pool.acquire() as conn:
|
||
await conn.execute(
|
||
'''
|
||
DELETE FROM messages
|
||
WHERE user_id = $1
|
||
''',
|
||
user_id
|
||
)
|
||
except Exception as e:
|
||
logging.error(f"Ошибка очистки сообщений из БД: {e}")
|
||
|
||
|
||
async def keep_typing(chat_id: int):
|
||
"""Функция для поддержания статуса 'печатает...' до окончания генерации ответа"""
|
||
try:
|
||
while True:
|
||
await bot.send_chat_action(chat_id=chat_id, action="typing")
|
||
await asyncio.sleep(4) # Обновляем каждые 4 секунды (статус активен 5 секунд)
|
||
except asyncio.CancelledError:
|
||
# Задача отменена, завершаем работу
|
||
pass
|
||
except Exception as e:
|
||
logging.error(f"Ошибка в keep_typing: {e}")
|
||
|
||
|
||
def escape_markdown(text: str) -> str:
|
||
"""Экранирует специальные символы Markdown, если они не являются частью форматирования"""
|
||
# Регулярное выражение для поиска уже отформатированного текста
|
||
md_patterns = [
|
||
r'`[^`]+`', # Код
|
||
r'\*\*[^*]+\*\*', # Жирный
|
||
r'\*[^*]+\*', # Курсив
|
||
r'_[^_]+_', # Курсив (альтернативный)
|
||
r'__[^_]+__', # Жирный (альтернативный)
|
||
r'\[[^\]]+\]\([^)]+\)', # Ссылки
|
||
]
|
||
|
||
# Собираем все совпадения с шаблонами
|
||
matches = []
|
||
for pattern in md_patterns:
|
||
matches.extend([(m.start(), m.end()) for m in re.finditer(pattern, text)])
|
||
|
||
# Сортируем по началу совпадения
|
||
matches.sort()
|
||
|
||
# Если нет совпадений, просто экранируем все специальные символы
|
||
if not matches:
|
||
return re.sub(r'([_*\[\]()~`>#+-=|{}.!\\])', r'\\\1', text)
|
||
|
||
# Иначе экранируем только те символы, которые не входят в уже отформатированный текст
|
||
result = ""
|
||
last_pos = 0
|
||
|
||
for start, end in matches:
|
||
# Экранируем символы до начала форматирования
|
||
if start > last_pos:
|
||
result += re.sub(r'([_*\[\]()~`>#+-=|{}.!\\])', r'\\\1', text[last_pos:start])
|
||
|
||
# Добавляем отформатированный текст без изменений
|
||
result += text[start:end]
|
||
last_pos = end
|
||
|
||
# Экранируем символы после последнего форматирования
|
||
if last_pos < len(text):
|
||
result += re.sub(r'([_*\[\]()~`>#+-=|{}.!\\])', r'\\\1', text[last_pos:])
|
||
|
||
return result
|
||
|
||
|
||
@dp.message(Command("start"))
|
||
async def cmd_start(message: Message):
|
||
"""Обработчик команды /start"""
|
||
user_id = message.from_user.id
|
||
|
||
# Загружаем контекст из БД при первом запуске
|
||
messages = await load_messages_from_db(user_id)
|
||
if not messages:
|
||
# Если сообщений нет, создаем новый список
|
||
user_messages[user_id] = []
|
||
else:
|
||
# Если есть, сохраняем их в кэше
|
||
user_messages[user_id] = messages
|
||
|
||
welcome_message = (
|
||
"Привет! Я бот на базе Gemma. Отправьте мне сообщение или фото, и я отвечу вам.\n\n"
|
||
"Поддерживается *Markdown* форматирование:\n"
|
||
"• *жирный текст* (звездочки)\n"
|
||
"• _курсив_ (нижнее подчеркивание)\n"
|
||
"• `код` (обратные кавычки)\n"
|
||
"• [ссылки](https://example.com)\n\n"
|
||
"Используйте /clear для очистки контекста диалога."
|
||
)
|
||
|
||
# Создаем инлайн-кнопку для очистки истории
|
||
keyboard = InlineKeyboardMarkup(inline_keyboard=[
|
||
[InlineKeyboardButton(text="🗑 Очистить историю сообщений", callback_data="clear_history")]
|
||
])
|
||
|
||
await message.answer(welcome_message, parse_mode=ParseMode.MARKDOWN, reply_markup=keyboard)
|
||
|
||
|
||
@dp.message(Command("clear"))
|
||
async def cmd_clear(message: Message):
|
||
"""Очистка контекста диалога"""
|
||
user_id = message.from_user.id
|
||
|
||
# Очищаем контекст в памяти
|
||
user_messages[user_id] = []
|
||
|
||
# Очищаем контекст в БД
|
||
await clear_messages_from_db(user_id)
|
||
|
||
await message.answer("Контекст диалога очищен.")
|
||
|
||
|
||
@dp.message(F.photo)
|
||
async def handle_photo(message: Message):
|
||
"""Обработчик фотографий"""
|
||
user_id = message.from_user.id
|
||
|
||
# Инициализация контекста пользователя, если он еще не существует
|
||
if user_id not in user_messages:
|
||
# Пытаемся загрузить контекст из БД
|
||
messages = await load_messages_from_db(user_id)
|
||
user_messages[user_id] = messages or []
|
||
|
||
# Получаем фото наилучшего качества
|
||
photo = message.photo[-1]
|
||
file_info = await bot.get_file(photo.file_id)
|
||
file_content = await bot.download_file(file_info.file_path)
|
||
|
||
# Конвертируем фото в Base64
|
||
image_data = file_content.read()
|
||
image_io = BytesIO(image_data)
|
||
|
||
# Получаем подпись к фото (если есть)
|
||
caption = message.caption or "Проанализируй эту фотографию"
|
||
|
||
# Создаем сообщение с изображением
|
||
user_message = {
|
||
"role": "user",
|
||
"content": caption,
|
||
"images": [image_io.getvalue()]
|
||
}
|
||
|
||
# Добавляем сообщение в контекст
|
||
user_messages[user_id].append(user_message)
|
||
|
||
# Сохраняем сообщение в БД
|
||
await save_message_to_db(user_id, user_message)
|
||
|
||
# Запускаем задачу обновления статуса печати
|
||
typing_tasks[user_id] = asyncio.create_task(keep_typing(user_id))
|
||
|
||
try:
|
||
# Получаем ответ от Gemma
|
||
answer = ""
|
||
async for part in await ollama_client.chat(
|
||
model=model_name,
|
||
messages=user_messages[user_id],
|
||
stream=True
|
||
):
|
||
answer += part['message']['content']
|
||
|
||
# Останавливаем задачу обновления статуса печати
|
||
if user_id in typing_tasks:
|
||
typing_tasks[user_id].cancel()
|
||
del typing_tasks[user_id]
|
||
|
||
# Создаем ответ ассистента
|
||
assistant_message = {"role": "assistant", "content": answer}
|
||
|
||
# Добавляем ответ в историю
|
||
user_messages[user_id].append(assistant_message)
|
||
|
||
# Сохраняем ответ в БД
|
||
await save_message_to_db(user_id, assistant_message)
|
||
|
||
# Создаем инлайн-кнопку для очистки истории
|
||
keyboard = InlineKeyboardMarkup(inline_keyboard=[
|
||
[InlineKeyboardButton(text="🗑 Очистить историю сообщений", callback_data="clear_history")]
|
||
])
|
||
|
||
# Отправляем ответ пользователю с Markdown-форматированием и кнопкой
|
||
safe_answer = escape_markdown(answer)
|
||
await message.answer(safe_answer, parse_mode=ParseMode.MARKDOWN, reply_markup=keyboard)
|
||
|
||
except Exception as e:
|
||
# Останавливаем задачу обновления статуса печати в случае ошибки
|
||
if user_id in typing_tasks:
|
||
typing_tasks[user_id].cancel()
|
||
del typing_tasks[user_id]
|
||
|
||
logging.error(f"Ошибка при обработке фото: {e}")
|
||
await message.answer(f"Произошла ошибка при обработке фото: {e}")
|
||
|
||
|
||
@dp.message(F.text)
|
||
async def handle_text(message: Message):
|
||
"""Обработчик текстовых сообщений"""
|
||
user_id = message.from_user.id
|
||
|
||
# Инициализация контекста пользователя, если он еще не существует
|
||
if user_id not in user_messages:
|
||
# Пытаемся загрузить контекст из БД
|
||
messages = await load_messages_from_db(user_id)
|
||
user_messages[user_id] = messages or []
|
||
|
||
# Создаем сообщение пользователя
|
||
user_message = {"role": "user", "content": message.text}
|
||
|
||
# Добавляем сообщение пользователя в контекст
|
||
user_messages[user_id].append(user_message)
|
||
|
||
# Сохраняем сообщение в БД
|
||
await save_message_to_db(user_id, user_message)
|
||
|
||
# Запускаем задачу обновления статуса печати
|
||
typing_tasks[user_id] = asyncio.create_task(keep_typing(user_id))
|
||
|
||
try:
|
||
# Получаем ответ от Gemma
|
||
answer = ""
|
||
async for part in await ollama_client.chat(
|
||
model=model_name,
|
||
messages=user_messages[user_id],
|
||
stream=True
|
||
):
|
||
answer += part['message']['content']
|
||
|
||
# Останавливаем задачу обновления статуса печати
|
||
if user_id in typing_tasks:
|
||
typing_tasks[user_id].cancel()
|
||
del typing_tasks[user_id]
|
||
|
||
# Создаем ответ ассистента
|
||
assistant_message = {"role": "assistant", "content": answer}
|
||
|
||
# Добавляем ответ в историю
|
||
user_messages[user_id].append(assistant_message)
|
||
|
||
# Сохраняем ответ в БД
|
||
await save_message_to_db(user_id, assistant_message)
|
||
|
||
# Создаем инлайн-кнопку для очистки истории
|
||
keyboard = InlineKeyboardMarkup(inline_keyboard=[
|
||
[InlineKeyboardButton(text="🗑 Очистить историю сообщений", callback_data="clear_history")]
|
||
])
|
||
|
||
# Отправляем ответ пользователю с Markdown-форматированием и кнопкой
|
||
safe_answer = escape_markdown(answer)
|
||
await message.answer(safe_answer, parse_mode=ParseMode.MARKDOWN, reply_markup=keyboard)
|
||
|
||
except Exception as e:
|
||
# Останавливаем задачу обновления статуса печати в случае ошибки
|
||
if user_id in typing_tasks:
|
||
typing_tasks[user_id].cancel()
|
||
del typing_tasks[user_id]
|
||
|
||
logging.error(f"Ошибка при обработке сообщения: {e}")
|
||
await message.answer(f"Произошла ошибка при обработке сообщения: {e}")
|
||
|
||
|
||
@dp.message()
|
||
async def handle_other(message: Message):
|
||
"""Обработчик всех остальных типов сообщений"""
|
||
await message.answer("Я могу обрабатывать только текст и фотографии. Пожалуйста, отправьте текст или одну фотографию.")
|
||
|
||
|
||
# Добавляем обработчик нажатия на кнопку очистки истории
|
||
@dp.callback_query(F.data == "clear_history")
|
||
async def clear_history_callback(callback: types.CallbackQuery):
|
||
"""Обработчик кнопки очистки истории"""
|
||
user_id = callback.from_user.id
|
||
|
||
# Очищаем контекст в памяти
|
||
user_messages[user_id] = []
|
||
|
||
# Очищаем контекст в БД
|
||
await clear_messages_from_db(user_id)
|
||
|
||
# Уведомляем пользователя
|
||
await callback.answer("Контекст диалога очищен!")
|
||
|
||
# Отправляем сообщение в чат
|
||
await callback.message.answer("Контекст диалога очищен. Вы можете начать новый разговор.")
|
||
|
||
|
||
async def main():
|
||
"""Запуск бота"""
|
||
# Инициализируем подключение к базе данных
|
||
await init_db()
|
||
|
||
try:
|
||
# Запускаем бота
|
||
await dp.start_polling(bot)
|
||
finally:
|
||
# Закрываем пул соединений при завершении
|
||
if postgres_pool:
|
||
await postgres_pool.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main()) |