LangGraph Advanced Patterns: Production Implementation Guide

Cover Image for LangGraph Advanced Patterns: Production Implementation Guide
AI & Machine Learning5 min read

Take your LangGraph skills to the next level with advanced patterns for production environments. Learn about human-in-the-loop workflows, distributed execution, robust checkpointing, and performance optimization.

Human-in-the-Loop Workflows

Interactive Decision Points

from langgraph.graph import StateGraph, END
from langgraph.checkpoint.sqlite import SqliteSaver
from typing import TypedDict, Optional
import asyncio

class HumanInLoopGraph:
    """Graph with human intervention points"""

    class State(TypedDict):
        task: str
        ai_proposal: Optional[str]
        human_feedback: Optional[str]
        human_approval: Optional[bool]
        final_output: Optional[str]
        requires_human_input: bool

    def __init__(self, checkpoint_dir: str = "./checkpoints"):
        self.graph = StateGraph(self.State)
        self.checkpointer = SqliteSaver.from_conn_string(checkpoint_dir)
        self.pending_approvals = {}
        self._build_graph()

    def _build_graph(self):
        """Build graph with human checkpoints"""

        # Add nodes
        self.graph.add_node("analyze", self.analyze_task)
        self.graph.add_node("generate_proposal", self.generate_proposal)
        self.graph.add_node("human_review", self.human_review)
        self.graph.add_node("execute", self.execute_task)
        self.graph.add_node("finalize", self.finalize_output)

        # Set entry point
        self.graph.set_entry_point("analyze")

        # Define flow
        self.graph.add_edge("analyze", "generate_proposal")

        # Conditional routing after proposal
        self.graph.add_conditional_edges(
            "generate_proposal",
            self.check_human_required,
            {
                "needs_human": "human_review",
                "auto_approve": "execute"
            }
        )

        # Human review outcomes
        self.graph.add_conditional_edges(
            "human_review",
            self.process_human_feedback,
            {
                "approved": "execute",
                "rejected": "generate_proposal",
                "modified": "execute"
            }
        )

        self.graph.add_edge("execute", "finalize")
        self.graph.add_edge("finalize", END)

        # Compile with checkpointing
        self.app = self.graph.compile(checkpointer=self.checkpointer)

    async def human_review(self, state: dict) -> dict:
        """Pause for human review"""

        # Generate review request
        review_id = str(uuid.uuid4())
        self.pending_approvals[review_id] = {
            "state": state,
            "timestamp": datetime.now(),
            "status": "pending"
        }

        # Wait for human response (with timeout)
        timeout = 3600  # 1 hour timeout
        start_time = time.time()

        while time.time() - start_time < timeout:
            if review_id in self.pending_approvals:
                approval_data = self.pending_approvals[review_id]

                if approval_data["status"] != "pending":
                    return {
                        **state,
                        "human_feedback": approval_data.get("feedback"),
                        "human_approval": approval_data.get("approved", False)
                    }

            await asyncio.sleep(5)

        # Timeout - use default action
        return {
            **state,
            "human_feedback": "Timeout - auto-rejected",
            "human_approval": False
        }

Checkpointing and Recovery

Robust State Persistence

from langgraph.checkpoint.postgres import PostgresSaver
import psycopg2

class ProductionCheckpointer:
    """Production-grade checkpointing system"""

    def __init__(self, postgres_url: str):
        self.checkpointer = PostgresSaver.from_conn_string(postgres_url)
        self.backup_checkpointer = SqliteSaver.from_conn_string("./backup.db")

    async def save_checkpoint(
        self,
        thread_id: str,
        checkpoint: Checkpoint,
        metadata: Optional[Dict[str, Any]] = None
    ):
        """Save checkpoint with backup"""

        # Primary save
        try:
            await self.checkpointer.aput(
                thread_id=thread_id,
                checkpoint=checkpoint,
                metadata=metadata or {}
            )
        except Exception as e:
            logger.error(f"Primary checkpoint failed: {e}")

            # Fallback to backup
            await self.backup_checkpointer.aput(
                thread_id=thread_id,
                checkpoint=checkpoint,
                metadata=metadata or {}
            )

        # Async replication to S3
        await self._replicate_to_s3(thread_id, checkpoint)

class RecoverableGraph:
    """Graph with automatic recovery"""

    def __init__(self, graph: StateGraph):
        self.graph = graph
        self.checkpointer = ProductionCheckpointer(
            "postgresql://user:pass@localhost/checkpoints"
        )
        self.app = graph.compile(
            checkpointer=self.checkpointer
        )

    async def run_with_recovery(
        self,
        input_data: dict,
        thread_id: str
    ) -> dict:
        """Run graph with automatic recovery"""

        # Check for existing checkpoint
        checkpoint = await self.checkpointer.load_checkpoint(thread_id)

        if checkpoint:
            logger.info(f"Resuming from checkpoint: {checkpoint.id}")

            # Resume from checkpoint
            result = await self.app.ainvoke(
                input_data,
                config={
                    "configurable": {
                        "thread_id": thread_id,
                        "checkpoint_id": checkpoint.id
                    }
                }
            )
        else:
            # Start fresh
            result = await self.app.ainvoke(
                input_data,
                config={"configurable": {"thread_id": thread_id}}
            )

        return result

Distributed Execution

Scaling LangGraph Across Workers

from celery import Celery
import ray
from typing import List, Dict, Any

class DistributedGraphExecutor:
    """Execute graph nodes across distributed workers"""

    def __init__(self, redis_url: str = "redis://localhost:6379"):
        # Initialize Celery
        self.celery_app = Celery('langgraph', broker=redis_url)

        # Configure queues for different node types
        self.celery_app.conf.task_routes = {
            'heavy_compute_node': {'queue': 'gpu_workers'},
            'llm_node': {'queue': 'llm_workers'},
            'io_node': {'queue': 'io_workers'},
        }

        # Initialize Ray
        ray.init()

    def create_distributed_node(self, node_func: callable, node_type: str):
        """Wrap node for distributed execution"""

        @self.celery_app.task(name=f"{node_type}_node")
        def distributed_node(state: dict) -> dict:
            return node_func(state)

        return distributed_node

    @ray.remote
    class ParallelNodeExecutor:
        """Execute multiple nodes in parallel using Ray"""

        async def execute_parallel_nodes(
            self,
            state: dict,
            nodes: List[callable]
        ) -> List[dict]:
            """Execute nodes in parallel"""

            # Create Ray tasks
            futures = []
            for node in nodes:
                future = self.execute_node.remote(state, node)
                futures.append(future)

            # Wait for all to complete
            results = await ray.get(futures)

            return results

Performance Optimization

Optimizing Graph Execution

class GraphOptimizer:
    """Optimize graph performance"""

    def __init__(self):
        self.metrics = defaultdict(list)

    def add_performance_monitoring(self, graph: StateGraph):
        """Add performance monitoring to nodes"""

        original_nodes = graph.nodes.copy()

        for node_name, node_func in original_nodes.items():
            # Wrap with monitoring
            monitored_node = self._create_monitored_node(node_name, node_func)
            graph.nodes[node_name] = monitored_node

    def _create_monitored_node(self, name: str, func: callable):
        """Create monitored version of node"""

        @wraps(func)
        async def monitored(state: dict) -> dict:
            start_time = time.perf_counter()
            start_memory = psutil.Process().memory_info().rss / 1024 / 1024

            try:
                result = await func(state)

                # Record metrics
                elapsed = time.perf_counter() - start_time
                memory_used = psutil.Process().memory_info().rss / 1024 / 1024 - start_memory

                self.metrics[name].append({
                    "execution_time": elapsed,
                    "memory_delta": memory_used,
                    "timestamp": datetime.now(),
                    "success": True
                })

                return result

            except Exception as e:
                self.metrics[name].append({
                    "execution_time": time.perf_counter() - start_time,
                    "error": str(e),
                    "timestamp": datetime.now(),
                    "success": False
                })
                raise

        return monitored

    def cache_expensive_operations(self):
        """Add caching to expensive operations"""

        from aiocache import Cache

        cache = Cache(Cache.REDIS)

        def cached_node(key_func: callable):
            """Decorator for caching node results"""

            def decorator(node_func: callable):
                @wraps(node_func)
                async def wrapper(state: dict) -> dict:
                    # Generate cache key
                    cache_key = key_func(state)

                    # Try cache
                    cached_result = await cache.get(cache_key)
                    if cached_result:
                        return cached_result

                    # Execute and cache
                    result = await node_func(state)
                    await cache.set(cache_key, result, ttl=3600)

                    return result

                return wrapper

            return decorator

        return cached_node

Testing and Debugging

Comprehensive Testing Framework

import pytest
from unittest.mock import Mock, AsyncMock

class GraphTestFramework:
    """Testing framework for LangGraph"""

    @pytest.fixture
    def test_graph(self):
        """Create test graph"""

        graph = StateGraph(GraphState)

        # Add test nodes
        graph.add_node("node1", lambda x: {**x, "processed": True})
        graph.add_node("node2", lambda x: {**x, "result": "test"})

        graph.set_entry_point("node1")
        graph.add_edge("node1", "node2")
        graph.add_edge("node2", END)

        return graph.compile()

    @pytest.mark.asyncio
    async def test_graph_execution(self, test_graph):
        """Test basic graph execution"""

        result = await test_graph.ainvoke({"input": "test"})

        assert result["processed"] is True
        assert result["result"] == "test"

    @pytest.mark.asyncio
    async def test_conditional_routing(self):
        """Test conditional edges"""

        graph = StateGraph(GraphState)

        # Add nodes
        graph.add_node("router", lambda x: x)
        graph.add_node("path_a", lambda x: {**x, "path": "a"})
        graph.add_node("path_b", lambda x: {**x, "path": "b"})

        # Add conditional routing
        graph.add_conditional_edges(
            "router",
            lambda x: "a" if x.get("condition") else "b",
            {"a": "path_a", "b": "path_b"}
        )

        app = graph.compile()

        # Test both paths
        result_a = await app.ainvoke({"condition": True})
        assert result_a["path"] == "a"

        result_b = await app.ainvoke({"condition": False})
        assert result_b["path"] == "b"

class GraphDebugger:
    """Debug and trace graph execution"""

    def __init__(self, graph: StateGraph):
        self.graph = graph
        self.execution_trace = []

    def trace_execution(self):
        """Trace all node executions"""

        for node_name in self.graph.nodes:
            self._wrap_node_with_trace(node_name)

    def generate_debug_report(self) -> dict:
        """Generate comprehensive debug report"""

        return {
            "total_nodes_executed": len(
                [e for e in self.execution_trace if e["event"] == "enter"]
            ),
            "errors": [
                e for e in self.execution_trace if e["event"] == "error"
            ],
            "execution_path": [
                e["node"] for e in self.execution_trace if e["event"] == "enter"
            ],
            "state_evolution": [
                {"node": e["node"], "state": e["state"]}
                for e in self.execution_trace if e["event"] == "exit"
            ]
        }

Real-World Implementation

Complete Customer Support System

class CustomerSupportGraph:
    """Production customer support system"""

    def __init__(self):
        self.graph = self._build_graph()
        self.knowledge_base = self._load_knowledge_base()

    def _build_graph(self) -> StateGraph:
        """Build customer support workflow"""

        class SupportState(TypedDict):
            customer_query: str
            intent: Optional[str]
            sentiment: Optional[str]
            knowledge_results: List[dict]
            suggested_response: Optional[str]
            requires_escalation: bool
            resolution: Optional[str]

        graph = StateGraph(SupportState)

        # Add nodes
        graph.add_node("classify_intent", self.classify_intent)
        graph.add_node("analyze_sentiment", self.analyze_sentiment)
        graph.add_node("search_knowledge", self.search_knowledge_base)
        graph.add_node("generate_response", self.generate_response)
        graph.add_node("quality_check", self.quality_check)
        graph.add_node("escalate", self.escalate_to_human)
        graph.add_node("resolve", self.resolve_ticket)

        # Entry point
        graph.set_entry_point("classify_intent")

        # Quality check routing
        graph.add_conditional_edges(
            "quality_check",
            self.route_quality_check,
            {
                "approve": "resolve",
                "escalate": "escalate",
                "regenerate": "generate_response"
            }
        )

        return graph.compile()

    async def process_ticket(self, customer_query: str) -> dict:
        """Process customer support ticket"""

        result = await self.graph.ainvoke(
            {"customer_query": customer_query},
            config={"configurable": {"thread_id": str(uuid.uuid4())}}
        )

        return result

Summary

This advanced guide covered:

  • Human-in-the-loop patterns with WebSocket integration
  • Robust checkpointing with multi-tier storage
  • Distributed execution using Celery and Ray
  • Performance optimization techniques
  • Comprehensive testing and debugging
  • Real-world customer support implementation

LangGraph's advanced features enable building production-grade AI systems that are scalable, maintainable, and reliable.


Series Navigation

This is Part 4 of the LangChain Series.

Previous: ← Part 3 - LangGraph: Building Stateful Applications Next: Part 5 - LangSmith: Production Monitoring →

Complete Series:


Tags: #LangGraph #Production #HumanInLoop #Distributed #Testing