nospaces-recheck-patch

This commit is contained in:
Untone 2024-09-28 10:14:38 +03:00
parent 32ce2e17c4
commit a9e37b7a4e
2 changed files with 14 additions and 6 deletions

View File

@ -1,6 +1,5 @@
import aiohttp import aiohttp
import aiofiles import aiofiles
import os
import json import json
from urllib.parse import urlencode from urllib.parse import urlencode
from bot.config import BOT_TOKEN from bot.config import BOT_TOKEN
@ -36,11 +35,10 @@ async def telegram_api(endpoint: str, json_data=None, **kwargs):
async def download_file(file_id): async def download_file(file_id):
"""Asynchronously download a file from Telegram and yield the temporary file path.""" """Asynchronously download a file from Telegram and yield the temporary file path."""
download_url = f"{api_base}/{file_path}"
# Get the file path of the file using the telegram_api method # Get the file path of the file using the telegram_api method
file = await telegram_api("getFile", file_id=file_id) file = await telegram_api("getFile", file_id=file_id)
file_path = file["result"]["file_path"] file_path = file["result"]["file_path"]
download_url = f"{api_base}/{file_path}"
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(download_url) as response: async with session.get(download_url) as response:

View File

@ -2,7 +2,7 @@ import logging
import math import math
from state.redis import redis from state.redis import redis
from state.scan import get_average_pattern from state.scan import get_average_pattern
from bot.api import telegram_api from bot.api import telegram_api, download_file
from bot.config import FEEDBACK_CHAT_ID from bot.config import FEEDBACK_CHAT_ID
from handlers.handle_private import handle_private from handlers.handle_private import handle_private
from nlp.toxicity_detector import detector from nlp.toxicity_detector import detector
@ -17,7 +17,7 @@ async def messages_routing(msg, state):
uid = msg["from"]["id"] uid = msg["from"]["id"]
text = msg.get("caption", msg.get("text", "")) text = msg.get("caption", msg.get("text", ""))
for photo in msg.get("photo", []) for photo in msg.get("photo", []):
file_id = photo.get("file_id") file_id = photo.get("file_id")
if file_id: if file_id:
async for temp_file_path in download_file(file_id): async for temp_file_path in download_file(file_id):
@ -106,8 +106,18 @@ async def messages_routing(msg, state):
except Exception: except Exception:
pass pass
else: else:
toxic_score = detector(normalize(text)) normalized_text = normalize(text)
toxic_score = detector(normalized_text)
toxic_perc = math.floor(toxic_score * 100) toxic_perc = math.floor(toxic_score * 100)
if toxic_perc > 49:
logger.info('re-check this one...')
nospaces_text = text.replace(' ', '')
nospaces_text_score = detector(nospaces_text)
logger.info(f'no spaces text toxic: {nospaces_text_score}')
if nospaces_text_score > toxic_score:
toxic_score = nospaces_text + 10
await redis.set(f"toxic:{cid}", mid) await redis.set(f"toxic:{cid}", mid)
await redis.set(f"toxic:{cid}:{uid}:{mid}", toxic_perc, ex=60 * 60 * 24 * 3) await redis.set(f"toxic:{cid}:{uid}:{mid}", toxic_perc, ex=60 * 60 * 24 * 3)
logger.info(f"\ntext: {text}\ntoxic: {toxic_perc}%") logger.info(f"\ntext: {text}\ntoxic: {toxic_perc}%")