Skip to main content

Router Implementation Details

This document provides detailed insights into the core routing algorithms, classification logic, and implementation specifics of the Semantic Router.

Classification Pipeline​

Multi-Stage Classification Architecture​

The Semantic Router employs a multi-stage classification pipeline that combines several specialized models:

Implementation Details​

Category Classification Logic​

type CategoryClassifier struct {
model *ModernBERTModel
tokenizer *ModernBERTTokenizer
labelMapping map[int]string
confidenceThreshold float64
}

func (cc *CategoryClassifier) ClassifyIntent(query string) (*Classification, error) {
// Tokenize input
tokens := cc.tokenizer.Tokenize(query)

// Run inference
logits, err := cc.model.Forward(tokens)
if err != nil {
return nil, err
}

// Apply softmax to get probabilities
probabilities := softmax(logits)

// Find best classification
maxIdx, maxProb := argmax(probabilities)
category := cc.labelMapping[maxIdx]

return &Classification{
Category: category,
Confidence: maxProb,
Probabilities: probabilities,
ProcessingTime: time.Since(start),
}, nil
}

// Routing decision logic
func (r *OpenAIRouter) makeRoutingDecision(classification *Classification) *RoutingDecision {
// High confidence - use specialized model
if classification.Confidence > 0.85 {
return &RoutingDecision{
SelectedModel: r.getSpecializedModel(classification.Category),
Reason: "High confidence specialized routing",
Confidence: classification.Confidence,
}
}

// Medium confidence - use category-appropriate model with fallback
if classification.Confidence > 0.6 {
return &RoutingDecision{
SelectedModel: r.getCategoryModel(classification.Category),
FallbackModel: r.Config.DefaultModel,
Reason: "Medium confidence routing with fallback",
Confidence: classification.Confidence,
}
}

// Low confidence - use general model
return &RoutingDecision{
SelectedModel: r.Config.DefaultModel,
Reason: "Low confidence, using general model",
Confidence: classification.Confidence,
}
}

Semantic Caching Implementation​

Cache Architecture​

type SemanticCache struct {
entries []CacheEntry
mu sync.RWMutex
similarityThreshold float32
maxEntries int
ttlSeconds int
enabled bool
}

type CacheEntry struct {
RequestBody []byte
ResponseBody []byte
Model string
Query string
Embedding []float32
Timestamp time.Time
}

// FindSimilar looks for a similar request in the cache
func (c *SemanticCache) FindSimilar(model string, query string) ([]byte, bool, error) {
if !c.enabled {
return nil, false, nil
}

// Generate embedding for the query
queryEmbedding, err := candle_binding.GetEmbedding(query, 512)
if err != nil {
return nil, false, fmt.Errorf("failed to generate embedding: %w", err)
}

c.mu.RLock()
defer c.mu.RUnlock()

// Cleanup expired entries
c.cleanupExpiredEntriesReadOnly()

type SimilarityResult struct {
Entry CacheEntry
Similarity float32
}

// Only compare with entries that have responses
results := make([]SimilarityResult, 0, len(c.entries))
for _, entry := range c.entries {
if entry.ResponseBody == nil {
continue // Skip entries without responses
}

// Only compare with entries with the same model
if entry.Model != model {
continue
}

// Calculate similarity using dot product
var dotProduct float32
for i := 0; i < len(queryEmbedding) && i < len(entry.Embedding); i++ {
dotProduct += queryEmbedding[i] * entry.Embedding[i]
}

results = append(results, SimilarityResult{
Entry: entry,
Similarity: dotProduct,
})
}

// No results found
if len(results) == 0 {
return nil, false, nil
}

// Sort by similarity (highest first)
sort.Slice(results, func(i, j int) bool {
return results[i].Similarity > results[j].Similarity
})

// Check if the best match exceeds the threshold
if results[0].Similarity >= c.similarityThreshold {
return results[0].Entry.ResponseBody, true, nil
}

return nil, false, nil
}

Tools Auto-Selection​

Tool Relevance Algorithm​

type ToolsSelector struct {
toolsDB *tools.ToolsDatabase
relevanceModel *RelevanceModel
maxTools int
confidenceThreshold float64
}

func (ts *ToolsSelector) SelectRelevantTools(
query string,
availableTools []Tool,
) []Tool {
var selectedTools []Tool

// Score each tool for relevance
for _, tool := range availableTools {
relevanceScore := ts.calculateRelevance(query, tool)

if relevanceScore > ts.confidenceThreshold {
tool.RelevanceScore = relevanceScore
selectedTools = append(selectedTools, tool)
}
}

// Sort by relevance score
sort.Slice(selectedTools, func(i, j int) bool {
return selectedTools[i].RelevanceScore > selectedTools[j].RelevanceScore
})

// Limit number of tools
if len(selectedTools) > ts.maxTools {
selectedTools = selectedTools[:ts.maxTools]
}

return selectedTools
}

func (ts *ToolsSelector) calculateRelevance(query string, tool Tool) float64 {
// Combine multiple relevance signals
keywordScore := ts.calculateKeywordRelevance(query, tool)
semanticScore := ts.calculateSemanticRelevance(query, tool)
categoryScore := ts.calculateCategoryRelevance(query, tool)

// Weighted combination
return 0.4*keywordScore + 0.4*semanticScore + 0.2*categoryScore
}

Security Implementation​

PII Detection Pipeline​

type PIIDetector struct {
tokenClassifier *ModernBERTTokenClassifier
piiPatterns map[string]*regexp.Regexp
confidence float64
}

func (pd *PIIDetector) DetectPII(text string) (*PIIDetectionResult, error) {
result := &PIIDetectionResult{
HasPII: false,
Entities: []PIIEntity{},
}

// Token-level classification with ModernBERT
tokens := pd.tokenClassifier.Tokenize(text)
predictions, err := pd.tokenClassifier.Predict(tokens)
if err != nil {
return nil, err
}

// Extract PII entities
entities := pd.extractEntities(tokens, predictions)

// Additional pattern-based detection for high-precision
patternEntities := pd.detectWithPatterns(text)

// Combine results
allEntities := append(entities, patternEntities...)

if len(allEntities) > 0 {
result.HasPII = true
result.Entities = allEntities
}

return result, nil
}

Jailbreak Detection​

type JailbreakGuard struct {
classifier *ModernBERTBinaryClassifier
patterns []JailbreakPattern
riskThreshold float64
}

func (jg *JailbreakGuard) AssessRisk(query string) (*SecurityAssessment, error) {
// ML-based detection
mlScore, err := jg.classifier.PredictRisk(query)
if err != nil {
return nil, err
}

// Pattern-based detection
patternScore := jg.calculatePatternScore(query)

// Combined risk score
overallRisk := 0.7*mlScore + 0.3*patternScore

return &SecurityAssessment{
RiskScore: overallRisk,
IsJailbreak: overallRisk > jg.riskThreshold,
MLScore: mlScore,
PatternScore: patternScore,
Reasoning: jg.explainDecision(overallRisk, mlScore, patternScore),
}, nil
}

Performance Optimizations​

Model Loading and Caching​

type ModelManager struct {
models map[string]*LoadedModel
modelLock sync.RWMutex
warmupPool sync.Pool
}

// Lazy loading with warming
func (mm *ModelManager) GetModel(modelName string) (*LoadedModel, error) {
mm.modelLock.RLock()
if model, exists := mm.models[modelName]; exists {
mm.modelLock.RUnlock()
return model, nil
}
mm.modelLock.RUnlock()

// Upgrade to write lock
mm.modelLock.Lock()
defer mm.modelLock.Unlock()

// Double-check pattern
if model, exists := mm.models[modelName]; exists {
return model, nil
}

// Load model
model, err := mm.loadModel(modelName)
if err != nil {
return nil, err
}

// Warm up model
go mm.warmupModel(model)

mm.models[modelName] = model
return model, nil
}

Batch Processing​

type BatchProcessor struct {
batchSize int
batchTimeout time.Duration
pendingBatch []ProcessingRequest
batchMutex sync.Mutex
flushTimer *time.Timer
}

func (bp *BatchProcessor) ProcessRequest(req ProcessingRequest) {
bp.batchMutex.Lock()
defer bp.batchMutex.Unlock()

bp.pendingBatch = append(bp.pendingBatch, req)

// Flush if batch is full
if len(bp.pendingBatch) >= bp.batchSize {
bp.flushBatch()
return
}

// Set timer for timeout-based flushing
if bp.flushTimer == nil {
bp.flushTimer = time.AfterFunc(bp.batchTimeout, bp.flushBatch)
}
}

func (bp *BatchProcessor) flushBatch() {
if len(bp.pendingBatch) == 0 {
return
}

// Process entire batch together for better GPU utilization
results := bp.classifier.ProcessBatch(bp.pendingBatch)

// Distribute results back to individual requests
for i, result := range results {
bp.pendingBatch[i].ResultChannel <- result
}

// Reset batch
bp.pendingBatch = bp.pendingBatch[:0]
if bp.flushTimer != nil {
bp.flushTimer.Stop()
bp.flushTimer = nil
}
}

Monitoring and Observability​

Request Tracing​

type RequestTracer struct {
spans map[string]*Span
mutex sync.RWMutex
}

func (rt *RequestTracer) StartSpan(requestID, operation string) *Span {
span := &Span{
RequestID: requestID,
Operation: operation,
StartTime: time.Now(),
Tags: make(map[string]interface{}),
}

rt.mutex.Lock()
rt.spans[requestID+":"+operation] = span
rt.mutex.Unlock()

return span
}

func (rt *RequestTracer) FinishSpan(span *Span) {
span.EndTime = time.Now()
span.Duration = span.EndTime.Sub(span.StartTime)

// Log detailed timing information
log.WithFields(log.Fields{
"request_id": span.RequestID,
"operation": span.Operation,
"duration": span.Duration.Milliseconds(),
"tags": span.Tags,
}).Info("Operation completed")

rt.mutex.Lock()
delete(rt.spans, span.RequestID+":"+span.Operation)
rt.mutex.Unlock()
}

Performance Metrics​

// Detailed performance tracking
type PerformanceTracker struct {
classificationLatency prometheus.Histogram
cacheHitRatio prometheus.Gauge
securityCheckLatency prometheus.Histogram
routingAccuracy prometheus.Gauge
}

func (pt *PerformanceTracker) RecordClassification(
category string,
confidence float64,
duration time.Duration,
) {
pt.classificationLatency.Observe(duration.Seconds())

// Track accuracy by category
accuracyMetric := pt.routingAccuracy.WithLabelValues(category)
accuracyMetric.Set(confidence)
}

This implementation provides the foundation for intelligent, secure, and performant LLM routing. The next section covers Model Training, detailing how the classification models are developed and optimized.