Files
yanting/report-notebooklm-api/tests/test_public_api.py
T

118 lines
4.6 KiB
Python

from __future__ import annotations
import os
os.environ["RNB_DATABASE_URL"] = "sqlite+aiosqlite:///./test_seed.db"
os.environ["RNB_REDIS_URL"] = "redis://test-redis.invalid/0"
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy import select
from app.db import Base, SessionLocal, engine
from app.main import app
from app.models import AudioAsset, DisplayModule, Institution, Report
from scripts.import_seed_content import import_seed
PREFIX = "/api/report-notebooklm/v1"
@pytest.fixture(autouse=True)
async def seeded_db():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
async with SessionLocal() as session:
await import_seed(session)
yield
@pytest.fixture
async def client():
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
async def test_seed_counts_match_phase1_shape():
async with SessionLocal() as session:
assert len((await session.execute(select(Institution))).scalars().all()) == 18
assert len((await session.execute(select(Report))).scalars().all()) == 27
assert len((await session.execute(select(AudioAsset))).scalars().all()) == 15
assert len((await session.execute(select(DisplayModule))).scalars().all()) >= 120
async def test_health_and_recommended_feed(client: AsyncClient):
health = await client.get(f"{PREFIX}/health")
assert health.status_code == 200
assert health.json() == {"status": "ok"}
feed = await client.get(f"{PREFIX}/feed/recommended")
assert feed.status_code == 200
body = feed.json()
assert body["items"]
assert body["items"][0]["report_id"] == "rep_bis_notebooklm_sample"
assert "display_version" not in body["items"][0]
assert body["items"][0]["cache_version"].startswith("rep_")
async def test_report_detail_hides_internal_fields_and_review_modules(client: AsyncClient):
response = await client.get(f"{PREFIX}/reports/rep_ssga_gold")
assert response.status_code == 200
body = response.json()
assert body["report_id"] == "rep_ssga_gold"
assert "display_version" not in body
module_types = [module["type"] for module in body["modules"]]
assert "study_guide" in module_types
assert "institution" not in module_types
assert "faq" not in module_types
assert "infographic" not in module_types
assert all(module["has_detail_page"] for module in body["modules"])
assert module_types[-1] == "source_compliance"
key_data = next(module for module in body["modules"] if module["type"] == "key_data")
assert key_data["render_mode"] == "card_plus_page"
assert key_data["content"] is None
assert key_data["preview"]["row_count"] == 10
assert key_data["content_ref"].startswith("rnb/modules/")
async def test_module_endpoint_returns_full_content(client: AsyncClient):
detail = (await client.get(f"{PREFIX}/reports/rep_ssga_gold")).json()
key_data = next(module for module in detail["modules"] if module["type"] == "key_data")
response = await client.get(f"{PREFIX}/reports/rep_ssga_gold/modules/{key_data['module_id']}")
assert response.status_code == 200
body = response.json()
assert body["module_id"] == key_data["module_id"]
assert "rows" in body["content"]
assert body["cache_version"] == "rep_ssga_gold:v1"
async def test_boundary_reports(client: AsyncClient):
listen = (await client.get(f"{PREFIX}/listen")).json()
listen_report_ids = {item["report_id"] for item in listen["items"]}
assert "rep_ing_gold" not in listen_report_ids
assert "rep_pas_silver" not in listen_report_ids
hidden = await client.get(f"{PREFIX}/reports/rep_wisdomtree_outlook")
assert hidden.status_code == 404
gray = (await client.get(f"{PREFIX}/reports/rep_pas_silver")).json()
compliance = next(module for module in gray["modules"] if module["type"] == "source_compliance")
assert compliance["content"]["source_url"] is None
assert "灰度" in compliance["content"]["source_note"]
async def test_institutions_and_listen(client: AsyncClient):
institutions = await client.get(f"{PREFIX}/institutions")
assert institutions.status_code == 200
assert len(institutions.json()["items"]) == 18
inst = await client.get(f"{PREFIX}/institutions/inst_ssga")
assert inst.status_code == 200
assert inst.json()["latest_report"]["report_id"] == "rep_ssga_gold"
listen = await client.get(f"{PREFIX}/listen")
assert listen.status_code == 200
assert listen.json()["items"][0]["audio_id"].startswith("aud_")