"""
Progress tracking for the analytic continuation pipeline.
Provides real-time progress updates via Server-Sent Events (SSE)
and progress state management for UI display.
"""
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any, AsyncGenerator
import asyncio
import json
from datetime import datetime
from enum import Enum
from .logging_config import TaskStatus, TaskProgress, PipelineLogger, get_logger
# Pipeline stages for the analytic continuation workflow
PIPELINE_STAGES = [
("precheck", "Pre-Check Contour", "Quick validation of curve topology"),
("validate_input", "Validate Input", "Checking curve and function validity"),
("load_curve", "Load Curve", "Loading and preprocessing Jordan curve"),
("fit_laurent", "Fit Laurent Map", "Fitting Laurent series to curve"),
(
"analyze_complexity",
"Analyze Complexity",
"Computing Cesàro/Whewell forms for cost estimation",
),
("check_holomorphic", "Check Holomorphic", "Verifying function is holomorphic on annulus"),
("prepare_render", "Prepare Render", "Extracting coefficients for rendering"),
("render", "Render", "Generating domain coloring visualization"),
]
[docs]
@dataclass
class StageInfo:
"""Information about a pipeline stage."""
id: str
name: str
description: str
status: TaskStatus = TaskStatus.PENDING
progress: float = 0.0
message: str = ""
substeps_total: int = 0
substeps_done: int = 0
started_at: Optional[str] = None
completed_at: Optional[str] = None
error: Optional[str] = None
[docs]
def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"description": self.description,
"status": self.status.value,
"progress": self.progress,
"message": self.message,
"substeps_total": self.substeps_total,
"substeps_done": self.substeps_done,
"started_at": self.started_at,
"completed_at": self.completed_at,
"error": self.error,
}
[docs]
class ProgressTracker:
"""
Tracks progress through the analytic continuation pipeline.
Provides:
- Stage-by-stage progress tracking
- Real-time updates via callbacks or SSE
- Checklist-style UI output
"""
[docs]
def __init__(self, session_id: Optional[str] = None):
self.session_id = session_id
self.logger = get_logger()
self.stages: Dict[str, StageInfo] = {}
self._subscribers: List[asyncio.Queue] = []
self._lock = asyncio.Lock()
# Initialize stages
for stage_id, name, desc in PIPELINE_STAGES:
self.stages[stage_id] = StageInfo(
id=stage_id,
name=name,
description=desc,
)
[docs]
async def subscribe(self) -> AsyncGenerator[str, None]:
"""Subscribe to progress updates via SSE."""
queue: asyncio.Queue = asyncio.Queue()
self._subscribers.append(queue)
try:
# Send initial state
yield self._format_sse_event("init", self.get_state())
while True:
try:
event = await asyncio.wait_for(queue.get(), timeout=30.0)
yield event
except asyncio.TimeoutError:
# Send keepalive
yield ": keepalive\n\n"
finally:
self._subscribers.remove(queue)
def _format_sse_event(self, event_type: str, data: Any) -> str:
"""Format data as SSE event."""
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
async def _broadcast(self, event_type: str, data: Any):
"""Broadcast update to all subscribers."""
event = self._format_sse_event(event_type, data)
for queue in self._subscribers:
await queue.put(event)
[docs]
def get_state(self) -> Dict[str, Any]:
"""Get current progress state for UI."""
stages_list = []
overall_progress = 0.0
total_weight = len(self.stages)
for stage_id, _, _ in PIPELINE_STAGES:
stage = self.stages[stage_id]
stages_list.append(stage.to_dict())
if stage.status == TaskStatus.COMPLETED:
overall_progress += 1.0
elif stage.status == TaskStatus.IN_PROGRESS:
overall_progress += stage.progress
return {
"session_id": self.session_id,
"overall_progress": overall_progress / total_weight if total_weight > 0 else 0.0,
"stages": stages_list,
"current_stage": self._get_current_stage_id(),
}
def _get_current_stage_id(self) -> Optional[str]:
"""Get the ID of the currently running stage."""
for stage_id, _, _ in PIPELINE_STAGES:
stage = self.stages[stage_id]
if stage.status == TaskStatus.IN_PROGRESS:
return stage_id
return None
[docs]
async def start_stage(
self,
stage_id: str,
substeps_total: int = 0,
message: str = "",
):
"""Start a pipeline stage."""
if stage_id not in self.stages:
return
stage = self.stages[stage_id]
stage.status = TaskStatus.IN_PROGRESS
stage.started_at = datetime.utcnow().isoformat()
stage.substeps_total = substeps_total
stage.substeps_done = 0
stage.message = message or stage.description
stage.progress = 0.0
# Log to pipeline logger
if self.session_id:
self.logger.start_task(stage_id, stage.name, self.session_id)
await self._broadcast(
"stage_start",
{
"stage_id": stage_id,
"stage": stage.to_dict(),
"overall": self.get_state(),
},
)
[docs]
async def update_stage(
self,
stage_id: str,
progress: Optional[float] = None,
message: Optional[str] = None,
substeps_done: Optional[int] = None,
):
"""Update stage progress."""
if stage_id not in self.stages:
return
stage = self.stages[stage_id]
if progress is not None:
stage.progress = min(1.0, max(0.0, progress))
if message is not None:
stage.message = message
if substeps_done is not None:
stage.substeps_done = substeps_done
if stage.substeps_total > 0:
stage.progress = substeps_done / stage.substeps_total
# Log update
if self.session_id:
self.logger.update_task(
stage_id,
progress=stage.progress,
message=stage.message,
session_id=self.session_id,
)
await self._broadcast(
"stage_update",
{
"stage_id": stage_id,
"stage": stage.to_dict(),
"overall": self.get_state(),
},
)
[docs]
async def complete_stage(
self,
stage_id: str,
success: bool = True,
error: Optional[str] = None,
message: Optional[str] = None,
):
"""Complete a pipeline stage."""
if stage_id not in self.stages:
return
stage = self.stages[stage_id]
stage.status = TaskStatus.COMPLETED if success else TaskStatus.FAILED
stage.completed_at = datetime.utcnow().isoformat()
stage.progress = 1.0 if success else stage.progress
stage.error = error
if message:
stage.message = message
# Log completion
if self.session_id:
self.logger.complete_task(
stage_id,
success=success,
error=error,
session_id=self.session_id,
)
await self._broadcast(
"stage_complete",
{
"stage_id": stage_id,
"success": success,
"stage": stage.to_dict(),
"overall": self.get_state(),
},
)
[docs]
async def skip_stage(self, stage_id: str, reason: str = ""):
"""Skip a stage (e.g., when using cached results)."""
if stage_id not in self.stages:
return
stage = self.stages[stage_id]
stage.status = TaskStatus.SKIPPED
stage.message = reason or "Skipped"
stage.progress = 1.0
await self._broadcast(
"stage_skip",
{
"stage_id": stage_id,
"reason": reason,
"stage": stage.to_dict(),
"overall": self.get_state(),
},
)
[docs]
def sync_start_stage(self, stage_id: str, substeps_total: int = 0, message: str = ""):
"""Synchronous version of start_stage for non-async contexts."""
if stage_id not in self.stages:
return
stage = self.stages[stage_id]
stage.status = TaskStatus.IN_PROGRESS
stage.started_at = datetime.utcnow().isoformat()
stage.substeps_total = substeps_total
stage.substeps_done = 0
stage.message = message or stage.description
stage.progress = 0.0
if self.session_id:
self.logger.start_task(stage_id, stage.name, self.session_id)
[docs]
def sync_update_stage(
self,
stage_id: str,
progress: Optional[float] = None,
message: Optional[str] = None,
substeps_done: Optional[int] = None,
):
"""Synchronous version of update_stage."""
if stage_id not in self.stages:
return
stage = self.stages[stage_id]
if progress is not None:
stage.progress = min(1.0, max(0.0, progress))
if message is not None:
stage.message = message
if substeps_done is not None:
stage.substeps_done = substeps_done
if stage.substeps_total > 0:
stage.progress = substeps_done / stage.substeps_total
if self.session_id:
self.logger.update_task(
stage_id,
progress=stage.progress,
message=stage.message,
session_id=self.session_id,
)
[docs]
def sync_complete_stage(
self,
stage_id: str,
success: bool = True,
error: Optional[str] = None,
message: Optional[str] = None,
):
"""Synchronous version of complete_stage."""
if stage_id not in self.stages:
return
stage = self.stages[stage_id]
stage.status = TaskStatus.COMPLETED if success else TaskStatus.FAILED
stage.completed_at = datetime.utcnow().isoformat()
stage.progress = 1.0 if success else stage.progress
stage.error = error
if message:
stage.message = message
if self.session_id:
self.logger.complete_task(
stage_id,
success=success,
error=error,
session_id=self.session_id,
)