diff --git a/resolvers/draft.py b/resolvers/draft.py index 523de3d3..bd7c587b 100644 --- a/resolvers/draft.py +++ b/resolvers/draft.py @@ -1,6 +1,5 @@ import time -from sqlalchemy import select from sqlalchemy.sql import and_ from cache.cache import ( diff --git a/resolvers/editor.py b/resolvers/editor.py index 1175dbe2..c7634b35 100644 --- a/resolvers/editor.py +++ b/resolvers/editor.py @@ -12,16 +12,14 @@ from cache.cache import ( invalidate_shouts_cache, ) from orm.author import Author -from orm.draft import Draft from orm.shout import Shout, ShoutAuthor, ShoutTopic from orm.topic import Topic -from resolvers.draft import create_draft, publish_draft from resolvers.follower import follow, unfollow from resolvers.stat import get_with_stat from services.auth import login_required from services.db import local_session from services.notify import notify_shout -from services.schema import mutation, query +from services.schema import query from services.search import search_service from utils.logger import root_logger as logger @@ -582,6 +580,9 @@ async def update_shout(_, info, shout_id: int, shout_input=None, publish=False): else [] ) + # Add main_topic to the shout dictionary + shout_dict["main_topic"] = get_main_topic_slug(shout_with_relations.topics) + shout_dict["authors"] = ( [ {"id": author.id, "name": author.name, "slug": author.slug} @@ -641,3 +642,19 @@ async def delete_shout(_, info, shout_id: int): return {"error": None} else: return {"error": "access denied"} + + +def get_main_topic_slug(topics): + """Get the slug of the main topic from a list of topics. + + Args: + topics: List of ShoutTopic objects + + Returns: + str: Slug of the main topic, or None if no main topic found + """ + if not topics: + return None + + main_topic = next((t for t in topics.reverse() if t.main), None) + return main_topic.topic.slug if main_topic else { "slug": "notopic", "title": "no topic", "id": 0 } diff --git a/tests/conftest.py b/tests/conftest.py index 7bd7f135..0eec04fb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,6 @@ from starlette.testclient import TestClient from main import app from services.db import Base from services.redis import redis -from settings import DB_URL # Use SQLite for testing TEST_DB_URL = "sqlite:///test.db" diff --git a/tests/test_reactions.py b/tests/test_reactions.py index 9b73e001..71622858 100644 --- a/tests/test_reactions.py +++ b/tests/test_reactions.py @@ -3,7 +3,7 @@ from datetime import datetime import pytest from orm.author import Author -from orm.reaction import Reaction, ReactionKind +from orm.reaction import ReactionKind from orm.shout import Shout diff --git a/tests/test_validations.py b/tests/test_validations.py index 39fa7e24..95fd75e4 100644 --- a/tests/test_validations.py +++ b/tests/test_validations.py @@ -6,9 +6,7 @@ from pydantic import ValidationError from auth.validations import ( AuthInput, AuthResponse, - OAuthInput, TokenPayload, - UserLoginInput, UserRegistrationInput, ) diff --git a/utils/logger.py b/utils/logger.py index 3607e84d..b49263d4 100644 --- a/utils/logger.py +++ b/utils/logger.py @@ -105,7 +105,7 @@ root_logger.setLevel(logging.DEBUG) root_logger.addHandler(stream) root_logger.addFilter(filter) -ignore_logs = ["_trace", "httpx", "_client", "_trace.atrace", "aiohttp", "_client"] +ignore_logs = ["_trace", "httpx", "_client", "atrace", "aiohttp", "_client"] for lgr in ignore_logs: loggr = logging.getLogger(lgr) loggr.setLevel(logging.INFO)