import logging import threading from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional from flask import ( Flask, abort, jsonify, render_template_string, request, send_from_directory, ) from werkzeug.utils import secure_filename from api_services import ApiService, DEFAULT_LLM_PROMPT_TEMPLATE, check_for_updates from config_manager import ConfigManager from version import CURRENT_VERSION DEFAULT_OUTPUT_DIR_NAME = "output_reports" def _ensure_directory(path: Path) -> Path: path.mkdir(parents=True, exist_ok=True) return path def _as_bool(value: Any, default: bool) -> bool: if value is None: return default if isinstance(value, bool): return value if isinstance(value, str): return value.strip().lower() in {"1", "true", "yes", "on"} return bool(value) def _usage_snapshot(raw: Optional[Dict[str, Any]]) -> Dict[str, int]: raw = raw or {} return { "prompt_tokens": int(raw.get("prompt_tokens", 0) or 0), "completion_tokens": int(raw.get("completion_tokens", 0) or 0), } def create_app(config_manager: ConfigManager) -> Flask: """Create and configure the Flask web application.""" app = Flask(__name__) app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024 # 100 MB payload ceiling api_service = ApiService(config_manager) config_lock = threading.Lock() update_state: Dict[str, Optional[str]] = {"latest": None, "checked": None} update_lock = threading.Lock() run_states: Dict[str, Dict[str, Any]] = {} run_states_lock = threading.Lock() def get_output_root() -> Path: configured = config_manager.get("OutputDirectory") base_path = Path(configured) if configured else Path(DEFAULT_OUTPUT_DIR_NAME) if not base_path.is_absolute(): base_path = Path.cwd() / base_path return _ensure_directory(base_path) def relative_to_output(path: Path) -> str: root = get_output_root().resolve() resolved = path.resolve() try: relative = resolved.relative_to(root) except ValueError as exc: # pragma: no cover - safety guard raise ValueError("Requested path is outside of the output directory") from exc return relative.as_posix() def start_update_check(force: bool = False) -> None: if not _as_bool(config_manager.get("AutoUpdateCheck", True), True): return with update_lock: already_checked = update_state["checked"] if already_checked and not force: return def _worker() -> None: latest = check_for_updates(CURRENT_VERSION) timestamp = datetime.now().isoformat(timespec="seconds") with update_lock: update_state["latest"] = latest update_state["checked"] = timestamp threading.Thread(target=_worker, daemon=True).start() def _execute_run( run_id: str, saved_files: List[Dict[str, Any]], topic: str, run_dir: Path, max_workers: int, save_markdown: bool, ) -> None: aggregate = {"vlm_in": 0, "vlm_out": 0, "llm_in": 0, "llm_out": 0} failures = 0 with run_states_lock: state = run_states.get(run_id) if state: state["status"] = "running" def process_single(file_info: Dict[str, Any]) -> Dict[str, Any]: saved_path: Path = file_info["path"] logs: List[str] = [f"开始处理: {file_info['original']}"] markdown_path: Optional[Path] = None html_path: Optional[Path] = None vlm_usage = {"prompt_tokens": 0, "completion_tokens": 0} llm_usage = {"prompt_tokens": 0, "completion_tokens": 0} error: Optional[str] = None rendered_html_path: Optional[str] = None report_markdown_path = saved_path.parent / f"{saved_path.stem}_report.md" try: final_report, raw_vlm_usage, raw_llm_usage, rendered_html_path = api_service.process_essay_image( str(saved_path), topic, ) vlm_usage = _usage_snapshot(raw_vlm_usage) llm_usage = _usage_snapshot(raw_llm_usage) if save_markdown: markdown_path = report_markdown_path markdown_path.write_text(final_report, encoding="utf-8") logs.append(f"已生成 Markdown: {markdown_path.name}") render_html = _as_bool(config_manager.get("RenderMarkdown", True), True) if rendered_html_path: html_path = Path(rendered_html_path) logs.append(f"已生成 HTML: {html_path.name}") if not save_markdown and render_html and report_markdown_path.exists(): report_markdown_path.unlink(missing_ok=True) logs.append("已删除 Markdown(仅保留 HTML)") elif not save_markdown and report_markdown_path.exists(): report_markdown_path.unlink(missing_ok=True) with config_lock: config_manager.update_token_usage( vlm_usage["prompt_tokens"], vlm_usage["completion_tokens"], llm_usage["prompt_tokens"], llm_usage["completion_tokens"], ) config_manager.save() except Exception as exc: # pylint: disable=broad-except logging.exception("文件处理失败: %s", saved_path) error = str(exc) logs.append(f"处理失败: {error}") saved_rel = relative_to_output(saved_path) markdown_rel = relative_to_output(markdown_path) if markdown_path else None if rendered_html_path: html_rel = relative_to_output(Path(rendered_html_path)) else: html_rel = relative_to_output(html_path) if html_path else None return { "index": file_info["index"], "original": file_info["original"], "saved": saved_rel, "markdown": markdown_rel, "html": html_rel, "vlm_usage": vlm_usage, "llm_usage": llm_usage, "logs": logs, "error": error, } try: with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [executor.submit(process_single, info) for info in saved_files] for future in as_completed(futures): result = future.result() if not result["error"]: aggregate["vlm_in"] += result["vlm_usage"]["prompt_tokens"] aggregate["vlm_out"] += result["vlm_usage"]["completion_tokens"] aggregate["llm_in"] += result["llm_usage"]["prompt_tokens"] aggregate["llm_out"] += result["llm_usage"]["completion_tokens"] else: failures += 1 with run_states_lock: state = run_states.get(run_id) if not state: continue state["completed"] = state.get("completed", 0) + 1 state.setdefault("results", {})[result["index"]] = result state["aggregate"] = aggregate.copy() if result["error"]: state.setdefault("errors", []).append( {"index": result["index"], "message": result["error"]} ) except Exception as exc: # pylint: disable=broad-except logging.exception("批处理任务失败: %s", run_id) with run_states_lock: state = run_states.get(run_id) if state: state["status"] = "failed" state["error"] = str(exc) state["aggregate"] = aggregate state["finished_at"] = datetime.now().isoformat(timespec="seconds") return total = len(saved_files) if total == 0: status = "empty" elif failures == 0: status = "ok" elif failures == total: status = "failed" else: status = "partial" with run_states_lock: state = run_states.get(run_id) if state: state["status"] = status state["aggregate"] = aggregate state["completed"] = total state["finished_at"] = datetime.now().isoformat(timespec="seconds") @app.get("/api/config") def read_config(): usage = { "vlm_input": int(config_manager.get("UsageVlmInput", 0) or 0), "vlm_output": int(config_manager.get("UsageVlmOutput", 0) or 0), "llm_input": int(config_manager.get("UsageLlmInput", 0) or 0), "llm_output": int(config_manager.get("UsageLlmOutput", 0) or 0), } with update_lock: latest_version = update_state["latest"] checked_at = update_state["checked"] has_vlm_key = bool(config_manager.get("VlmApiKey")) has_llm_key = bool(config_manager.get("LlmApiKey")) data = { "VlmUrl": config_manager.get("VlmUrl", ""), "VlmApiKey": "", "HasVlmApiKey": has_vlm_key, "VlmModel": config_manager.get("VlmModel", ""), "LlmUrl": config_manager.get("LlmUrl", ""), "LlmApiKey": "", "HasLlmApiKey": has_llm_key, "LlmModel": config_manager.get("LlmModel", ""), "SensitivityFactor": config_manager.get("SensitivityFactor", "1.0"), "MaxWorkers": config_manager.get("MaxWorkers", 4), "MaxRetries": config_manager.get("MaxRetries", 3), "RetryDelay": config_manager.get("RetryDelay", 5), "SaveMarkdown": _as_bool(config_manager.get("SaveMarkdown", True), True), "RenderMarkdown": _as_bool(config_manager.get("RenderMarkdown", True), True), "AutoUpdateCheck": _as_bool(config_manager.get("AutoUpdateCheck", True), True), "LlmPromptTemplate": config_manager.get("LlmPromptTemplate") or DEFAULT_LLM_PROMPT_TEMPLATE, "OutputDirectory": str(config_manager.get("OutputDirectory", DEFAULT_OUTPUT_DIR_NAME)), "Usage": usage, "CurrentVersion": CURRENT_VERSION, "LatestVersion": latest_version, "CheckedAt": checked_at, } return jsonify(data) @app.post("/api/config") def update_config(): payload = request.get_json(silent=True) or {} string_fields = [ "VlmUrl", "VlmModel", "LlmUrl", "LlmModel", "OutputDirectory", ] sensitive_fields = [ "VlmApiKey", "LlmApiKey", ] int_fields = ["MaxWorkers", "MaxRetries", "RetryDelay"] bool_fields = ["SaveMarkdown", "RenderMarkdown", "AutoUpdateCheck"] updates: Dict[str, Any] = {} for key in string_fields: if key in payload: value = (payload.get(key) or "").strip() if key == "OutputDirectory" and not value: return jsonify({"error": "输出目录不能为空"}), 400 updates[key] = value for key in sensitive_fields: if key in payload: value = (payload.get(key) or "").strip() if value: updates[key] = value if payload.get("ClearVlmApiKey"): updates["VlmApiKey"] = "" if payload.get("ClearLlmApiKey"): updates["LlmApiKey"] = "" for key in int_fields: if key in payload and payload[key] not in (None, ""): try: updates[key] = int(payload[key]) except (TypeError, ValueError): return jsonify({"error": f"{key} 需要是整数"}), 400 if "SensitivityFactor" in payload and payload["SensitivityFactor"] not in (None, ""): try: updates["SensitivityFactor"] = float(payload["SensitivityFactor"]) except (TypeError, ValueError): return jsonify({"error": "SensitivityFactor 需要是数字"}), 400 for key in bool_fields: if key in payload: updates[key] = bool(payload[key]) prompt_template = payload.get("LlmPromptTemplate") if prompt_template is not None: normalized = str(prompt_template).strip() if not normalized: updates["LlmPromptTemplate"] = None elif normalized == DEFAULT_LLM_PROMPT_TEMPLATE.strip(): updates["LlmPromptTemplate"] = None else: updates["LlmPromptTemplate"] = normalized with config_lock: for key, value in updates.items(): if key == "LlmPromptTemplate" and value is None: config_manager.config.pop(key, None) elif key in ("VlmApiKey", "LlmApiKey") and value == "": config_manager.config.pop(key, None) else: config_manager.set(key, value) config_manager.save() if "OutputDirectory" in updates and updates["OutputDirectory"]: get_output_root() start_update_check(force=True) return jsonify({"status": "ok"}) @app.post("/api/process") def process_files(): topic = (request.form.get("topic") or "").strip() if not topic: return jsonify({"error": "请输入作文题目"}), 400 uploads = request.files.getlist("files") if not uploads: return jsonify({"error": "请至少选择一张图片"}), 400 run_id = datetime.now().strftime("%Y%m%d-%H%M%S") output_root = get_output_root() run_dir = _ensure_directory(output_root / run_id) saved_files: List[Dict[str, Any]] = [] used_names = set() for index, upload in enumerate(uploads): original_name = upload.filename or f"upload_{index + 1}.png" safe_name = secure_filename(original_name) or f"upload_{index + 1}.png" if safe_name in used_names: stem = Path(safe_name).stem suffix = Path(safe_name).suffix or ".png" counter = 1 candidate = f"{stem}_{counter}{suffix}" while candidate in used_names: counter += 1 candidate = f"{stem}_{counter}{suffix}" safe_name = candidate used_names.add(safe_name) saved_path = run_dir / safe_name upload.save(saved_path) saved_files.append( { "index": index, "original": original_name, "name": safe_name, "path": saved_path, } ) try: max_workers = int(config_manager.get("MaxWorkers", 4)) or 1 except (TypeError, ValueError): max_workers = 4 save_markdown = _as_bool(config_manager.get("SaveMarkdown", True), True) run_path = relative_to_output(run_dir) run_state = { "run_id": run_id, "status": "queued", "total": len(saved_files), "completed": 0, "aggregate": {"vlm_in": 0, "vlm_out": 0, "llm_in": 0, "llm_out": 0}, "results": {}, "errors": [], "run_path": run_path, "created_at": datetime.now().isoformat(timespec="seconds"), } with run_states_lock: run_states[run_id] = run_state worker = threading.Thread( target=_execute_run, args=(run_id, saved_files, topic, run_dir, max_workers, save_markdown), daemon=True, ) worker.start() return jsonify( { "status": "queued", "run_id": run_id, "total": len(saved_files), "run_path": run_path, } ) @app.get("/api/run-status/") def run_status(run_id: str): with run_states_lock: state = run_states.get(run_id) if not state: abort(404) results_dict = state.get("results", {}) results = [results_dict[index] for index in sorted(results_dict.keys())] aggregate = dict(state.get("aggregate", {"vlm_in": 0, "vlm_out": 0, "llm_in": 0, "llm_out": 0})) response = { "run_id": run_id, "status": state.get("status", "unknown"), "total": state.get("total", 0), "completed": state.get("completed", 0), "aggregate": aggregate, "results": results, "run_path": state.get("run_path"), "error": state.get("error"), "errors": state.get("errors", []), } return jsonify(response) @app.get("/outputs/") def serve_outputs(requested_path: str): output_root = get_output_root().resolve() target_path = (output_root / requested_path).resolve() try: target_path.relative_to(output_root) except ValueError: abort(404) if not target_path.exists() or target_path.is_dir(): abort(404) relative = target_path.relative_to(output_root).as_posix() return send_from_directory(str(output_root), relative) @app.get("/api/update-status") def update_status(): with update_lock: return jsonify( { "current": CURRENT_VERSION, "latest": update_state["latest"], "checked": update_state["checked"], } ) @app.post("/api/update-check") def trigger_update_check(): start_update_check(force=True) return jsonify({"status": "checking"}) @app.get("/") def index(): html = ''' AI 作文批改助手 · Web

AI 作文批改助手 · Web

版本 {{ current_version }}

加载用量中...

批改任务

批改结果

服务设置

关于与更新

当前版本:{{ current_version }}
正在获取最新版本信息...
使用提示
  • 默认使用 output_reports/时间戳 保存批改文件,可在设置中修改。
  • 可单独保存 Markdown 或 HTML,也可保留二者。
  • Prompt 模板支持完全自定义,请保留参数占位符以确保正常传值。
''' return render_template_string( html, current_version=CURRENT_VERSION, default_output_dir=DEFAULT_OUTPUT_DIR_NAME, ) start_update_check() return app