package server import ( "context" "encoding/json" "fmt" "os" "path" "time" "codelab/ai-agent/internal/agent" "codelab/ai-agent/internal/log" "github.com/gin-gonic/gin" "github.com/google/generative-ai-go/genai" "google.golang.org/api/iterator" "google.golang.org/api/option" ) var ( GoogleAIKey string GoogleAIModel string CandidateCount = int32(1) Temperature = float32(0.8) ) type OpenAIMessage struct { Role string `json:"role"` Msg string `json:"content"` } type ChatRequest struct { OpenAIMessages []*OpenAIMessage Agent *agent.Agent ClientIP string ClientUserAgent string ResWriter gin.ResponseWriter LogPath string } type MessageLogEntry struct { Timestamp time.Time `json:"timestamp"` UserMsg string `json:"userMsg"` ModelMsg string `json:"modelMsg"` Feedback genai.PromptFeedback `json:"feedback"` Usage genai.UsageMetadata `json:"usage"` ClientIP string `json:"clientIP"` ClientUserAgent string `json:"clientUserAgent"` } func openaiToGoogle(msgs []*OpenAIMessage) []*genai.Content { g := []*genai.Content{} for _, msg := range msgs { if msg.Role == "system" { continue } g = append(g, &genai.Content{ Role: msg.Role, Parts: []genai.Part{ genai.Text(msg.Msg), }, }) } return g } func (r *ChatRequest) Chat() error { ctx := context.Background() client, err := genai.NewClient(ctx, option.WithAPIKey(GoogleAIKey)) if err != nil { return err } defer client.Close() model := client.GenerativeModel(GoogleAIModel) model.SafetySettings = []*genai.SafetySetting{ { Category: genai.HarmCategoryDangerousContent, Threshold: genai.HarmBlockOnlyHigh, }, { Category: genai.HarmCategoryHateSpeech, Threshold: genai.HarmBlockOnlyHigh, }, { Category: genai.HarmCategorySexuallyExplicit, Threshold: genai.HarmBlockOnlyHigh, }, { Category: genai.HarmCategoryHarassment, Threshold: genai.HarmBlockNone, }, } model.GenerationConfig = genai.GenerationConfig{ CandidateCount: &CandidateCount, Temperature: &Temperature, } model.SystemInstruction = genai.NewUserContent(genai.Text(r.Agent.SystemPrompt)) model.GenerateContentStream(ctx) chat := model.StartChat() msgs := append([]*genai.Content{ { Role: "user", Parts: []genai.Part{ genai.Text(r.Agent.InitialPrompt), }, }, }, openaiToGoogle(r.OpenAIMessages)...) chat.History = msgs[:len(msgs)-1] iter := chat.SendMessageStream(ctx, msgs[len(msgs)-1].Parts...) iterConut := 0 for { iterConut++ log.T("server/ai").Dbgf("AI response #%d", iterConut) res, err := iter.Next() if err == iterator.Done { break } if err != nil { log.T("server/ai").Errf("Failed to get response: %v", err) } r.writeResp(res) } log.T("server/ai").Dbgf("AI Response took %d iterations", iterConut) res := iter.MergedResponse() answer := "" if len(res.Candidates) > 0 { for _, part := range res.Candidates[0].Content.Parts { answer += fmt.Sprint(part) } } else { answer = "" } feedback := res.PromptFeedback if feedback == nil { log.T("server/ai").Errf("Server response feedback is nil") feedback = &genai.PromptFeedback{} } usage := res.UsageMetadata if usage == nil { log.T("server/ai").Errf("Server response usage is nil") usage = &genai.UsageMetadata{} } log.T("server/ai").Dbgf("usage: %+v", usage) usageJSON, err := json.Marshal(usage) if err != nil { log.T("server/ai").Errf("Failed to marshal usage data: %v", err) } else { _, err = r.ResWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", usageJSON))) r.ResWriter.Flush() if err != nil { log.T("server/ai").Errf("Failed to write usage data: %v", err) } } _, err = r.ResWriter.Write([]byte("data: [DONE]")) r.ResWriter.Flush() if err != nil { log.T("server/ai").Errf("Failed to write done message: %v", err) } ok := r.EnsureLogPath() if !ok { log.T("server/ai").Errf("Failed to ensure log path exists") return nil } msgLogEntry := MessageLogEntry{ Timestamp: time.Now(), UserMsg: r.OpenAIMessages[len(r.OpenAIMessages)-1].Msg, ModelMsg: answer, Feedback: *feedback, Usage: *usage, ClientIP: r.ClientIP, ClientUserAgent: r.ClientUserAgent, } logEntryJSON, err := json.Marshal(msgLogEntry) if err != nil { log.T("server/ai").Errf("Failed to marshal log entry: %v", err) return nil } msgLogPath := path.Join(r.LogPath, fmt.Sprintf("%s.ndjson", r.Agent.AgentID)) msgLogFile, err := os.OpenFile(msgLogPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) if err != nil { log.T("server/ai").Errf("Failed to open log file: %v", err) return nil } defer msgLogFile.Close() _, err = msgLogFile.Write([]byte(fmt.Sprintf("%s\n", logEntryJSON))) if err != nil { log.T("server/ai").Errf("Failed to write log entry: %v", err) } return nil } func (req *ChatRequest) writeResp(res *genai.GenerateContentResponse) { fullAnswerContent := "" answer := res.Candidates[0] if answer == nil { return } for i, part := range answer.Content.Parts { log.T("server/ai").Dbgf(" - AI response part #%d: %s", i+1, part) content := fmt.Sprint(part) fullAnswerContent += content contentJSON, err := json.Marshal(gin.H{"content": content}) if err != nil { log.T("server/ai").Errf("Failed to marshal response: %v", err) continue } _, err = req.ResWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", contentJSON))) req.ResWriter.Flush() if err != nil { log.T("server/ai").Errf("Failed to write response: %v", err) } } } func (r *ChatRequest) EnsureLogPath() bool { if r.LogPath == "" { return false } stat, err := os.Stat(r.LogPath) if err != nil && os.IsNotExist(err) { err = os.MkdirAll(r.LogPath, 0755) if err != nil { log.T("server/ai").Errf("Failed to create log path: %v", err) return false } } else if err != nil { log.T("server/ai").Errf("Failed to stat log path: %v", err) return false } else { if !stat.IsDir() { log.T("server/ai").Errf("Log path is not a directory: %s", r.LogPath) return false } } return true }