fix-percentage
This commit is contained in:
parent
878da549e0
commit
4a025a5595
|
@ -1,13 +1,30 @@
|
|||
from transformers import BertTokenizer, BertForSequenceClassification
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# load tokenizer and model weights
|
||||
# Load tokenizer and model weights
|
||||
tokenizer = BertTokenizer.from_pretrained('SkolkovoInstitute/russian_toxicity_classifier')
|
||||
model = BertForSequenceClassification.from_pretrained('SkolkovoInstitute/russian_toxicity_classifier')
|
||||
|
||||
|
||||
def detector(text):
|
||||
# prepare the input
|
||||
# Prepare the input
|
||||
batch = tokenizer.encode(text, return_tensors='pt')
|
||||
|
||||
# inference
|
||||
model(batch)
|
||||
# Inference
|
||||
with torch.no_grad():
|
||||
result = model(batch)
|
||||
|
||||
# Get logits
|
||||
logits = result.logits
|
||||
|
||||
# Convert logits to probabilities using softmax
|
||||
probabilities = F.softmax(logits, dim=1)
|
||||
|
||||
return probabilities[0][1].item()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
if len(sys.argv) > 1:
|
||||
p = detector(sys.argv[1])
|
||||
toxicity_percentage = p * 100 # Assuming index 1 is for toxic class
|
||||
print(f"Toxicity Probability: {toxicity_percentage:.2f}%")
|
|
@ -1,4 +1,4 @@
|
|||
aiohttp
|
||||
redis[hiredis]
|
||||
tensorflow
|
||||
torch
|
||||
transformers
|
Loading…
Reference in New Issue
Block a user