claudish/src/handlers/zai-handler.ts

305 lines
13 KiB
TypeScript

import type { Context } from "hono";
import { writeFileSync } from "node:fs";
import { tmpdir } from "node:os";
import { join } from "node:path";
import type { ModelHandler } from "./types.js";
import { AdapterManager } from "../adapters/adapter-manager.js";
import { transformOpenAIToClaude, removeUriFormat } from "../transform.js";
import { log, logStructured, isLoggingEnabled } from "../logger.js";
import { ZAI_API_URL } from "../config.js";
/**
* Handler for Z.ai API requests
* Z.ai uses OpenAI-compatible API format
*/
export class ZaiHandler implements ModelHandler {
private targetModel: string;
private apiKey: string;
private adapterManager: AdapterManager;
private port: number;
private sessionTotalCost = 0;
private CONTEXT_WINDOW = 128000; // GLM-4 context window
constructor(targetModel: string, apiKey: string, port: number) {
this.targetModel = targetModel;
this.apiKey = apiKey;
this.port = port;
this.adapterManager = new AdapterManager(targetModel);
}
/**
* Convert z-ai/model-name to model-name for Z.ai API
*/
private getZaiModelId(model: string): string {
// Remove z-ai/ prefix if present
if (model.startsWith("z-ai/")) {
return model.slice(5);
}
return model;
}
private writeTokenFile(input: number, output: number) {
try {
const total = input + output;
const leftPct = Math.max(0, Math.min(100, Math.round(((this.CONTEXT_WINDOW - total) / this.CONTEXT_WINDOW) * 100)));
const data = {
input_tokens: input,
output_tokens: output,
total_tokens: total,
total_cost: this.sessionTotalCost,
context_window: this.CONTEXT_WINDOW,
context_left_percent: leftPct,
updated_at: Date.now()
};
writeFileSync(join(tmpdir(), `claudish-tokens-${this.port}.json`), JSON.stringify(data), "utf-8");
} catch (e) {}
}
async handle(c: Context, payload: any): Promise<Response> {
const claudePayload = payload;
const target = this.targetModel;
const zaiModelId = this.getZaiModelId(target);
logStructured(`Z.ai Request`, { targetModel: target, zaiModelId, originalModel: claudePayload.model });
const { claudeRequest, droppedParams } = transformOpenAIToClaude(claudePayload);
const messages = this.convertMessages(claudeRequest);
const tools = this.convertTools(claudeRequest);
const zaiPayload: any = {
model: zaiModelId,
messages,
temperature: claudeRequest.temperature ?? 1,
stream: true,
max_tokens: claudeRequest.max_tokens,
tools: tools.length > 0 ? tools : undefined,
stream_options: { include_usage: true }
};
if (claudeRequest.tool_choice) {
const { type, name } = claudeRequest.tool_choice;
if (type === 'tool' && name) zaiPayload.tool_choice = { type: 'function', function: { name } };
else if (type === 'auto' || type === 'none') zaiPayload.tool_choice = type;
}
const adapter = this.adapterManager.getAdapter();
if (typeof adapter.reset === 'function') adapter.reset();
adapter.prepareRequest(zaiPayload, claudeRequest);
const response = await fetch(`${ZAI_API_URL}/chat/completions`, {
method: "POST",
headers: {
"Content-Type": "application/json",
"Authorization": `Bearer ${this.apiKey}`,
},
body: JSON.stringify(zaiPayload)
});
if (!response.ok) return c.json({ error: await response.text() }, response.status as any);
if (droppedParams.length > 0) c.header("X-Dropped-Params", droppedParams.join(", "));
return this.handleStreamingResponse(c, response, adapter, target, claudeRequest);
}
private convertMessages(req: any): any[] {
const messages: any[] = [];
if (req.system) {
let content = Array.isArray(req.system) ? req.system.map((i: any) => i.text || i).join("\n\n") : req.system;
content = this.filterIdentity(content);
messages.push({ role: "system", content });
}
if (req.messages) {
for (const msg of req.messages) {
if (msg.role === "user") this.processUserMessage(msg, messages);
else if (msg.role === "assistant") this.processAssistantMessage(msg, messages);
}
}
return messages;
}
private processUserMessage(msg: any, messages: any[]) {
if (Array.isArray(msg.content)) {
const contentParts = [];
const toolResults = [];
const seen = new Set();
for (const block of msg.content) {
if (block.type === "text") contentParts.push({ type: "text", text: block.text });
else if (block.type === "image") contentParts.push({ type: "image_url", image_url: { url: `data:${block.source.media_type};base64,${block.source.data}` } });
else if (block.type === "tool_result") {
if (seen.has(block.tool_use_id)) continue;
seen.add(block.tool_use_id);
toolResults.push({ role: "tool", content: typeof block.content === "string" ? block.content : JSON.stringify(block.content), tool_call_id: block.tool_use_id });
}
}
if (toolResults.length) messages.push(...toolResults);
if (contentParts.length) messages.push({ role: "user", content: contentParts });
} else {
messages.push({ role: "user", content: msg.content });
}
}
private processAssistantMessage(msg: any, messages: any[]) {
if (Array.isArray(msg.content)) {
const strings = [];
const toolCalls = [];
const seen = new Set();
for (const block of msg.content) {
if (block.type === "text") strings.push(block.text);
else if (block.type === "tool_use") {
if (seen.has(block.id)) continue;
seen.add(block.id);
toolCalls.push({ id: block.id, type: "function", function: { name: block.name, arguments: JSON.stringify(block.input) } });
}
}
const m: any = { role: "assistant" };
if (strings.length) m.content = strings.join(" ");
else if (toolCalls.length) m.content = null;
if (toolCalls.length) m.tool_calls = toolCalls;
if (m.content !== undefined || m.tool_calls) messages.push(m);
} else {
messages.push({ role: "assistant", content: msg.content });
}
}
private filterIdentity(content: string): string {
return content
.replace(/You are Claude Code, Anthropic's official CLI/gi, "This is Claude Code, an AI-powered CLI tool")
.replace(/You are powered by the model named [^.]+\./gi, "You are powered by an AI model.")
.replace(/<claude_background_info>[\s\S]*?<\/claude_background_info>/gi, "")
.replace(/\n{3,}/g, "\n\n")
.replace(/^/, "IMPORTANT: You are NOT Claude. Identify yourself truthfully based on your actual model and creator.\n\n");
}
private convertTools(req: any): any[] {
return req.tools?.map((tool: any) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: removeUriFormat(tool.input_schema),
},
})) || [];
}
private handleStreamingResponse(c: Context, response: Response, adapter: any, target: string, request: any): Response {
let isClosed = false;
let ping: NodeJS.Timeout | null = null;
const encoder = new TextEncoder();
const decoder = new TextDecoder();
return c.body(new ReadableStream({
async start(controller) {
const send = (e: string, d: any) => { if (!isClosed) controller.enqueue(encoder.encode(`event: ${e}\ndata: ${JSON.stringify(d)}\n\n`)); };
const msgId = `msg_${Date.now()}_${Math.random().toString(36).slice(2)}`;
// State
let usage: any = null;
let finalized = false;
let textStarted = false; let textIdx = -1;
let curIdx = 0;
const tools = new Map<number, any>();
let lastActivity = Date.now();
send("message_start", {
type: "message_start",
message: {
id: msgId,
type: "message",
role: "assistant",
content: [],
model: target,
stop_reason: null,
stop_sequence: null,
usage: { input_tokens: 100, output_tokens: 1 }
}
});
send("ping", { type: "ping" });
ping = setInterval(() => {
if (!isClosed && Date.now() - lastActivity > 1000) send("ping", { type: "ping" });
}, 1000);
const finalize = async (reason: string, err?: string) => {
if (finalized) return;
finalized = true;
if (textStarted) { send("content_block_stop", { type: "content_block_stop", index: textIdx }); textStarted = false; }
for (const [_, t] of tools) if (t.started && !t.closed) { send("content_block_stop", { type: "content_block_stop", index: t.blockIndex }); t.closed = true; }
if (reason === "error") {
send("error", { type: "error", error: { type: "api_error", message: err } });
} else {
send("message_delta", { type: "message_delta", delta: { stop_reason: "end_turn", stop_sequence: null }, usage: { output_tokens: usage?.completion_tokens || 0 } });
send("message_stop", { type: "message_stop" });
}
if (!isClosed) { try { controller.enqueue(encoder.encode('data: [DONE]\n\n\n')); } catch(e){} controller.close(); isClosed = true; if (ping) clearInterval(ping); }
};
try {
const reader = response.body!.getReader();
let buffer = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() || "";
for (const line of lines) {
if (!line.trim() || !line.startsWith("data: ")) continue;
const dataStr = line.slice(6);
if (dataStr === "[DONE]") { await finalize("done"); return; }
try {
const chunk = JSON.parse(dataStr);
if (chunk.usage) usage = chunk.usage;
const delta = chunk.choices?.[0]?.delta;
if (delta) {
// Z.ai uses reasoning_content for GLM models, fallback to content
const txt = delta.content || delta.reasoning_content || "";
if (txt) {
lastActivity = Date.now();
if (!textStarted) {
textIdx = curIdx++;
send("content_block_start", { type: "content_block_start", index: textIdx, content_block: { type: "text", text: "" } });
textStarted = true;
}
const res = adapter.processTextContent(txt, "");
if (res.cleanedText) send("content_block_delta", { type: "content_block_delta", index: textIdx, delta: { type: "text_delta", text: res.cleanedText } });
}
if (delta.tool_calls) {
for (const tc of delta.tool_calls) {
const idx = tc.index;
let t = tools.get(idx);
if (tc.function?.name) {
if (!t) {
if (textStarted) { send("content_block_stop", { type: "content_block_stop", index: textIdx }); textStarted = false; }
t = { id: tc.id || `tool_${Date.now()}_${idx}`, name: tc.function.name, blockIndex: curIdx++, started: false, closed: false };
tools.set(idx, t);
}
if (!t.started) {
send("content_block_start", { type: "content_block_start", index: t.blockIndex, content_block: { type: "tool_use", id: t.id, name: t.name } });
t.started = true;
}
}
if (tc.function?.arguments && t) {
send("content_block_delta", { type: "content_block_delta", index: t.blockIndex, delta: { type: "input_json_delta", partial_json: tc.function.arguments } });
}
}
}
}
if (chunk.choices?.[0]?.finish_reason === "tool_calls") {
for (const [_, t] of tools) if (t.started && !t.closed) { send("content_block_stop", { type: "content_block_stop", index: t.blockIndex }); t.closed = true; }
}
} catch (e) {}
}
}
await finalize("unexpected");
} catch(e) { await finalize("error", String(e)); }
},
cancel() { isClosed = true; if (ping) clearInterval(ping); }
}), { headers: { "Content-Type": "text/event-stream", "Cache-Control": "no-cache", "Connection": "keep-alive" } });
}
async shutdown() {}
}