claudish/src/middleware/manager.ts

180 lines
4.8 KiB
TypeScript

/**
* MiddlewareManager - Orchestrates model-specific middlewares
*
* Responsibilities:
* - Register middlewares
* - Filter active middlewares by model ID
* - Execute middleware chain in order
* - Handle errors gracefully (log and continue)
*/
import { log, isLoggingEnabled, logStructured } from "../logger.js";
import type {
ModelMiddleware,
RequestContext,
NonStreamingResponseContext,
StreamChunkContext,
} from "./types.js";
export class MiddlewareManager {
private middlewares: ModelMiddleware[] = [];
private initialized = false;
/**
* Register a middleware
* Middlewares execute in registration order
*/
register(middleware: ModelMiddleware): void {
this.middlewares.push(middleware);
if (isLoggingEnabled()) {
logStructured("Middleware Registered", {
name: middleware.name,
total: this.middlewares.length,
});
}
}
/**
* Initialize all middlewares (call onInit hooks)
* Should be called once when server starts
*/
async initialize(): Promise<void> {
if (this.initialized) {
log("[Middleware] Already initialized, skipping");
return;
}
log(`[Middleware] Initializing ${this.middlewares.length} middleware(s)...`);
for (const middleware of this.middlewares) {
if (middleware.onInit) {
try {
await middleware.onInit();
log(`[Middleware] ${middleware.name} initialized`);
} catch (error) {
log(`[Middleware] ERROR: ${middleware.name} initialization failed: ${error}`);
// Continue with other middlewares even if one fails
}
}
}
this.initialized = true;
log("[Middleware] Initialization complete");
}
/**
* Get active middlewares for a specific model
*/
private getActiveMiddlewares(modelId: string): ModelMiddleware[] {
return this.middlewares.filter((m) => m.shouldHandle(modelId));
}
/**
* Execute beforeRequest hooks for all active middlewares
*/
async beforeRequest(context: RequestContext): Promise<void> {
const active = this.getActiveMiddlewares(context.modelId);
if (active.length === 0) {
return; // No middlewares for this model
}
if (isLoggingEnabled()) {
logStructured("Middleware Chain (beforeRequest)", {
modelId: context.modelId,
middlewares: active.map((m) => m.name),
messageCount: context.messages.length,
});
}
for (const middleware of active) {
try {
await middleware.beforeRequest(context);
} catch (error) {
log(`[Middleware] ERROR in ${middleware.name}.beforeRequest: ${error}`);
// Continue with next middleware - don't let one failure break the chain
}
}
}
/**
* Execute afterResponse hooks for non-streaming responses
*/
async afterResponse(context: NonStreamingResponseContext): Promise<void> {
const active = this.getActiveMiddlewares(context.modelId);
if (active.length === 0) {
return;
}
if (isLoggingEnabled()) {
logStructured("Middleware Chain (afterResponse)", {
modelId: context.modelId,
middlewares: active.map((m) => m.name),
});
}
for (const middleware of active) {
if (middleware.afterResponse) {
try {
await middleware.afterResponse(context);
} catch (error) {
log(`[Middleware] ERROR in ${middleware.name}.afterResponse: ${error}`);
}
}
}
}
/**
* Execute afterStreamChunk hooks for each streaming chunk
*/
async afterStreamChunk(context: StreamChunkContext): Promise<void> {
const active = this.getActiveMiddlewares(context.modelId);
if (active.length === 0) {
return;
}
// Only log on first chunk to avoid spam
if (isLoggingEnabled() && !context.metadata.has("_middlewareLogged")) {
logStructured("Middleware Chain (afterStreamChunk)", {
modelId: context.modelId,
middlewares: active.map((m) => m.name),
});
context.metadata.set("_middlewareLogged", true);
}
for (const middleware of active) {
if (middleware.afterStreamChunk) {
try {
await middleware.afterStreamChunk(context);
} catch (error) {
log(`[Middleware] ERROR in ${middleware.name}.afterStreamChunk: ${error}`);
}
}
}
}
/**
* Execute afterStreamComplete hooks after streaming finishes
*/
async afterStreamComplete(modelId: string, metadata: Map<string, any>): Promise<void> {
const active = this.getActiveMiddlewares(modelId);
if (active.length === 0) {
return;
}
for (const middleware of active) {
if (middleware.afterStreamComplete) {
try {
await middleware.afterStreamComplete(metadata);
} catch (error) {
log(`[Middleware] ERROR in ${middleware.name}.afterStreamComplete: ${error}`);
}
}
}
}
}