fix-percentage
This commit is contained in:
parent
878da549e0
commit
4a025a5595
|
@ -1,13 +1,30 @@
|
||||||
from transformers import BertTokenizer, BertForSequenceClassification
|
from transformers import BertTokenizer, BertForSequenceClassification
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
# load tokenizer and model weights
|
# Load tokenizer and model weights
|
||||||
tokenizer = BertTokenizer.from_pretrained('SkolkovoInstitute/russian_toxicity_classifier')
|
tokenizer = BertTokenizer.from_pretrained('SkolkovoInstitute/russian_toxicity_classifier')
|
||||||
model = BertForSequenceClassification.from_pretrained('SkolkovoInstitute/russian_toxicity_classifier')
|
model = BertForSequenceClassification.from_pretrained('SkolkovoInstitute/russian_toxicity_classifier')
|
||||||
|
|
||||||
|
|
||||||
def detector(text):
|
def detector(text):
|
||||||
# prepare the input
|
# Prepare the input
|
||||||
batch = tokenizer.encode(text, return_tensors='pt')
|
batch = tokenizer.encode(text, return_tensors='pt')
|
||||||
|
|
||||||
# inference
|
# Inference
|
||||||
model(batch)
|
with torch.no_grad():
|
||||||
|
result = model(batch)
|
||||||
|
|
||||||
|
# Get logits
|
||||||
|
logits = result.logits
|
||||||
|
|
||||||
|
# Convert logits to probabilities using softmax
|
||||||
|
probabilities = F.softmax(logits, dim=1)
|
||||||
|
|
||||||
|
return probabilities[0][1].item()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
p = detector(sys.argv[1])
|
||||||
|
toxicity_percentage = p * 100 # Assuming index 1 is for toxic class
|
||||||
|
print(f"Toxicity Probability: {toxicity_percentage:.2f}%")
|
|
@ -1,4 +1,4 @@
|
||||||
aiohttp
|
aiohttp
|
||||||
redis[hiredis]
|
redis[hiredis]
|
||||||
tensorflow
|
torch
|
||||||
transformers
|
transformers
|
Loading…
Reference in New Issue
Block a user