.
This commit is contained in:
@@ -3,8 +3,7 @@ import math
|
|||||||
|
|
||||||
from bot.api import telegram_api
|
from bot.api import telegram_api
|
||||||
from bot.config import FEEDBACK_CHAT_ID
|
from bot.config import FEEDBACK_CHAT_ID
|
||||||
from nlp.toxicity import text2toxicity
|
from nlp.toxicity_detector import detector
|
||||||
from nlp.replying import get_toxic_reply
|
|
||||||
from handlers.handle_private import handle_private
|
from handlers.handle_private import handle_private
|
||||||
|
|
||||||
logger = logging.getLogger('handlers.messages_routing')
|
logger = logging.getLogger('handlers.messages_routing')
|
||||||
@@ -31,18 +30,26 @@ async def messages_routing(msg, state):
|
|||||||
if reply_chat_id != FEEDBACK_CHAT_ID:
|
if reply_chat_id != FEEDBACK_CHAT_ID:
|
||||||
await telegram_api("sendMessage", chat_id=reply_chat_id, text=text, reply_to=reply_msg.get("message_id"))
|
await telegram_api("sendMessage", chat_id=reply_chat_id, text=text, reply_to=reply_msg.get("message_id"))
|
||||||
|
|
||||||
|
# TODO: implement text2toxicity with https://huggingface.co/s-nlp/russian_toxicity_classifier
|
||||||
elif bool(text):
|
elif bool(text):
|
||||||
toxic_score = text2toxicity(text)
|
mid = msg.get("message_id")
|
||||||
|
non_toxic_score, toxic_score = detector(text)
|
||||||
logger.info(f'\ntext: {text}\ntoxic: {math.floor(toxic_score*100)}%')
|
logger.info(f'\ntext: {text}\ntoxic: {math.floor(toxic_score*100)}%')
|
||||||
if toxic_score > 0.71:
|
if toxic_score > 0.71:
|
||||||
toxic_reply = get_toxic_reply(toxic_score)
|
if toxic_score > 0.85:
|
||||||
await telegram_api(
|
await telegram_api(
|
||||||
"setMessageReaction",
|
"deletemessage",
|
||||||
chat_id=cid,
|
chat_id=cid,
|
||||||
is_big=True,
|
message_id=mid
|
||||||
message_id=msg.get("message_id"),
|
)
|
||||||
reaction=f'[{{"type":"emoji", "emoji":"{toxic_reply}"}}]'
|
else:
|
||||||
)
|
await telegram_api(
|
||||||
|
"setMessageReaction",
|
||||||
|
chat_id=cid,
|
||||||
|
is_big=True,
|
||||||
|
message_id=mid,
|
||||||
|
reaction=f'[{{"type":"emoji", "emoji":"🙉"}}]'
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
16
main.py
16
main.py
@@ -1,22 +1,24 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
from bot.api import telegram_api
|
|
||||||
from bot.config import FEEDBACK_CHAT_ID
|
|
||||||
from handlers.handle_join_request import handle_join_request, handle_reaction_on_request
|
|
||||||
from handlers.messages_routing import messages_routing
|
from handlers.messages_routing import messages_routing
|
||||||
|
from handlers.handle_join_request import handle_join_request, handle_reaction_on_request
|
||||||
|
from bot.config import BOT_TOKEN, FEEDBACK_CHAT_ID
|
||||||
|
from bot.api import telegram_api
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
logger = logging.getLogger('main')
|
logger = logging.getLogger(__name__)
|
||||||
state = dict()
|
state = dict()
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def start():
|
||||||
logger.info("\tstarted")
|
logger.info("\tstarted")
|
||||||
async with ClientSession() as session:
|
async with ClientSession() as session:
|
||||||
offset = 0 # начальное значение offset
|
offset = 0 # начальное значение offset
|
||||||
while True:
|
while True:
|
||||||
response = await telegram_api("getUpdates", offset=offset, allowed_updates=['message', 'edited_message', 'message_reaction','chat_join_request', 'chat_member'])
|
response = await telegram_api("getUpdates", offset=offset, allowed_updates=['message', 'message_reaction'])
|
||||||
if isinstance(response, dict):
|
if isinstance(response, dict):
|
||||||
result = response.get("result", [])
|
result = response.get("result", [])
|
||||||
for update in result:
|
for update in result:
|
||||||
@@ -47,4 +49,4 @@ async def main():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Запуск асинхронного цикла
|
# Запуск асинхронного цикла
|
||||||
asyncio.run(main())
|
asyncio.run(start())
|
||||||
|
13
nlp/toxycity_detector.py
Normal file
13
nlp/toxycity_detector.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from transformers import BertTokenizer, BertForSequenceClassification
|
||||||
|
|
||||||
|
# load tokenizer and model weights
|
||||||
|
tokenizer = BertTokenizer.from_pretrained('SkolkovoInstitute/russian_toxicity_classifier')
|
||||||
|
model = BertForSequenceClassification.from_pretrained('SkolkovoInstitute/russian_toxicity_classifier')
|
||||||
|
|
||||||
|
|
||||||
|
def detector(text):
|
||||||
|
# prepare the input
|
||||||
|
batch = tokenizer.encode(text, return_tensors='pt')
|
||||||
|
|
||||||
|
# inference
|
||||||
|
model(batch)
|
@@ -1,2 +1,4 @@
|
|||||||
aiohttp
|
aiohttp
|
||||||
redis[hiredis]
|
redis[hiredis]
|
||||||
|
tensorflow
|
||||||
|
transformers
|
Reference in New Issue
Block a user