ai test.py, sqllite removed
This commit is contained in:
parent
e56b083b7f
commit
cbb64af17f
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -58,8 +58,6 @@ coverage.xml
|
|||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
|
@ -141,7 +139,6 @@ migration/content/**/*.md
|
|||
.obsidian
|
||||
|
||||
*.zip
|
||||
*.sqlite3
|
||||
*.rdb
|
||||
.DS_Store
|
||||
/dump
|
||||
|
|
|
@ -5,5 +5,5 @@ ADD nginx.conf.sigil ./
|
|||
RUN /usr/local/bin/python -m pip install --upgrade pip
|
||||
WORKDIR /usr/src/app
|
||||
COPY requirements.txt ./
|
||||
RUN set -ex && pip install -r requirements.txt
|
||||
RUN pip install -r requirements.txt
|
||||
COPY . .
|
||||
|
|
75
ai/preprocess.py
Normal file
75
ai/preprocess.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
import re
|
||||
import nltk
|
||||
from bs4 import BeautifulSoup
|
||||
from nltk.corpus import stopwords
|
||||
from pymystem3 import Mystem
|
||||
from string import punctuation
|
||||
from transformers import BertTokenizer
|
||||
|
||||
nltk.download("stopwords")
|
||||
|
||||
|
||||
def get_clear_text(text):
|
||||
soup = BeautifulSoup(text, 'html.parser')
|
||||
|
||||
# extract the plain text from the HTML document without tags
|
||||
clear_text = ''
|
||||
for tag in soup.find_all():
|
||||
clear_text += tag.string or ''
|
||||
|
||||
clear_text = re.sub(pattern='[\u202F\u00A0\n]+', repl=' ', string=clear_text)
|
||||
|
||||
# only words
|
||||
clear_text = re.sub(pattern='[^A-ZА-ЯЁ -]', repl='', string=clear_text, flags=re.IGNORECASE)
|
||||
|
||||
clear_text = re.sub(pattern='\s+', repl=' ', string=clear_text)
|
||||
|
||||
clear_text = clear_text.lower()
|
||||
|
||||
mystem = Mystem()
|
||||
russian_stopwords = stopwords.words("russian")
|
||||
|
||||
tokens = mystem.lemmatize(clear_text)
|
||||
tokens = [token for token in tokens if token not in russian_stopwords \
|
||||
and token != " " \
|
||||
and token.strip() not in punctuation]
|
||||
|
||||
clear_text = " ".join(tokens)
|
||||
|
||||
return clear_text
|
||||
|
||||
|
||||
# if __name__ == '__main__':
|
||||
#
|
||||
# # initialize the tokenizer with the pre-trained BERT model and vocabulary
|
||||
# tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
|
||||
#
|
||||
# # split each text into smaller segments of maximum length 512
|
||||
# max_length = 512
|
||||
# segmented_texts = []
|
||||
# for text in [clear_text1, clear_text2]:
|
||||
# segmented_text = []
|
||||
# for i in range(0, len(text), max_length):
|
||||
# segment = text[i:i+max_length]
|
||||
# segmented_text.append(segment)
|
||||
# segmented_texts.append(segmented_text)
|
||||
#
|
||||
# # tokenize each segment using the BERT tokenizer
|
||||
# tokenized_texts = []
|
||||
# for segmented_text in segmented_texts:
|
||||
# tokenized_text = []
|
||||
# for segment in segmented_text:
|
||||
# segment_tokens = tokenizer.tokenize(segment)
|
||||
# segment_tokens = ['[CLS]'] + segment_tokens + ['[SEP]']
|
||||
# tokenized_text.append(segment_tokens)
|
||||
# tokenized_texts.append(tokenized_text)
|
||||
#
|
||||
# input_ids = []
|
||||
# for tokenized_text in tokenized_texts:
|
||||
# input_id = []
|
||||
# for segment_tokens in tokenized_text:
|
||||
# segment_id = tokenizer.convert_tokens_to_ids(segment_tokens)
|
||||
# input_id.append(segment_id)
|
||||
# input_ids.append(input_id)
|
||||
#
|
||||
# print(input_ids)
|
|
@ -7,12 +7,9 @@ from sqlalchemy.sql.schema import Table
|
|||
|
||||
from settings import DB_URL
|
||||
|
||||
if DB_URL.startswith("sqlite"):
|
||||
engine = create_engine(DB_URL)
|
||||
else:
|
||||
engine = create_engine(
|
||||
DB_URL, echo=False, pool_size=10, max_overflow=20
|
||||
)
|
||||
engine = create_engine(
|
||||
DB_URL, echo=False, pool_size=10, max_overflow=20
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ pyjwt>=2.6.0
|
|||
starlette~=0.23.1
|
||||
sqlalchemy>=1.4.41
|
||||
graphql-core>=3.0.3
|
||||
gql
|
||||
gql~=3.4.0
|
||||
uvicorn>=0.18.3
|
||||
pydantic>=1.10.2
|
||||
passlib~=1.7.4
|
||||
|
@ -29,3 +29,6 @@ lxml
|
|||
sentry-sdk>=1.14.0
|
||||
# sse_starlette
|
||||
graphql-ws
|
||||
nltk~=3.8.1
|
||||
pymystem3~=0.2.0
|
||||
transformers~=4.28.1
|
||||
|
|
|
@ -4,7 +4,7 @@ PORT = 8080
|
|||
|
||||
DB_URL = (
|
||||
environ.get("DATABASE_URL") or environ.get("DB_URL") or
|
||||
"postgresql://postgres@localhost:5432/discoursio" or "sqlite:///db.sqlite3"
|
||||
"postgresql://postgres@localhost:5432/discoursio"
|
||||
)
|
||||
JWT_ALGORITHM = "HS256"
|
||||
JWT_SECRET_KEY = environ.get("JWT_SECRET_KEY") or "8f1bd7696ffb482d8486dfbc6e7d16dd-secret-key"
|
||||
|
|
11
test.py
Normal file
11
test.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
from sqlalchemy import select
|
||||
from ai.preprocess import get_clear_text
|
||||
from base.orm import local_session
|
||||
from orm import Shout
|
||||
|
||||
if __name__ == "__main__":
|
||||
with local_session() as session:
|
||||
q = select(Shout)
|
||||
for [shout] in session.execute(q):
|
||||
clear_shout_body = get_clear_text(shout.body)
|
||||
print(clear_shout_body)
|
Loading…
Reference in New Issue
Block a user