2024-12-25 17:18:23 +00:00

248 lines
6.1 KiB
Go

package server
import (
"context"
"encoding/json"
"fmt"
"os"
"path"
"time"
"codelab/ai-agent/internal/agent"
"codelab/ai-agent/internal/log"
"github.com/gin-gonic/gin"
"github.com/google/generative-ai-go/genai"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)
var (
GoogleAIKey string
GoogleAIModel string
CandidateCount = int32(1)
Temperature = float32(0.8)
)
type OpenAIMessage struct {
Role string `json:"role"`
Msg string `json:"content"`
}
type ChatRequest struct {
OpenAIMessages []*OpenAIMessage
Agent *agent.Agent
ClientIP string
ClientUserAgent string
ResWriter gin.ResponseWriter
LogPath string
}
type MessageLogEntry struct {
Timestamp time.Time `json:"timestamp"`
UserMsg string `json:"userMsg"`
ModelMsg string `json:"modelMsg"`
Feedback genai.PromptFeedback `json:"feedback"`
Usage genai.UsageMetadata `json:"usage"`
ClientIP string `json:"clientIP"`
ClientUserAgent string `json:"clientUserAgent"`
}
func openaiToGoogle(msgs []*OpenAIMessage) []*genai.Content {
g := []*genai.Content{}
for _, msg := range msgs {
if msg.Role == "system" {
continue
}
g = append(g, &genai.Content{
Role: msg.Role,
Parts: []genai.Part{
genai.Text(msg.Msg),
},
})
}
return g
}
func (r *ChatRequest) Chat() error {
ctx := context.Background()
client, err := genai.NewClient(ctx, option.WithAPIKey(GoogleAIKey))
if err != nil {
return err
}
defer client.Close()
model := client.GenerativeModel(GoogleAIModel)
model.SafetySettings = []*genai.SafetySetting{
{
Category: genai.HarmCategoryDangerousContent,
Threshold: genai.HarmBlockOnlyHigh,
},
{
Category: genai.HarmCategoryHateSpeech,
Threshold: genai.HarmBlockOnlyHigh,
},
{
Category: genai.HarmCategorySexuallyExplicit,
Threshold: genai.HarmBlockOnlyHigh,
},
{
Category: genai.HarmCategoryHarassment,
Threshold: genai.HarmBlockNone,
},
}
model.GenerationConfig = genai.GenerationConfig{
CandidateCount: &CandidateCount,
Temperature: &Temperature,
}
model.SystemInstruction = genai.NewUserContent(genai.Text(r.Agent.SystemPrompt))
model.GenerateContentStream(ctx)
chat := model.StartChat()
msgs := append([]*genai.Content{
{
Role: "user",
Parts: []genai.Part{
genai.Text(r.Agent.InitialPrompt),
},
},
}, openaiToGoogle(r.OpenAIMessages)...)
chat.History = msgs[:len(msgs)-1]
iter := chat.SendMessageStream(ctx, msgs[len(msgs)-1].Parts...)
iterConut := 0
for {
iterConut++
log.T("server/ai").Dbgf("AI response #%d", iterConut)
res, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
log.T("server/ai").Errf("Failed to get response: %v", err)
}
r.writeResp(res)
}
log.T("server/ai").Dbgf("AI Response took %d iterations", iterConut)
res := iter.MergedResponse()
answer := ""
if len(res.Candidates) > 0 {
for _, part := range res.Candidates[0].Content.Parts {
answer += fmt.Sprint(part)
}
} else {
answer = "<no response>"
}
feedback := res.PromptFeedback
if feedback == nil {
log.T("server/ai").Errf("Server response feedback is nil")
feedback = &genai.PromptFeedback{}
}
usage := res.UsageMetadata
if usage == nil {
log.T("server/ai").Errf("Server response usage is nil")
usage = &genai.UsageMetadata{}
}
log.T("server/ai").Dbgf("usage: %+v", usage)
usageJSON, err := json.Marshal(usage)
if err != nil {
log.T("server/ai").Errf("Failed to marshal usage data: %v", err)
} else {
_, err = r.ResWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", usageJSON)))
r.ResWriter.Flush()
if err != nil {
log.T("server/ai").Errf("Failed to write usage data: %v", err)
}
}
_, err = r.ResWriter.Write([]byte("data: [DONE]"))
r.ResWriter.Flush()
if err != nil {
log.T("server/ai").Errf("Failed to write done message: %v", err)
}
ok := r.EnsureLogPath()
if !ok {
log.T("server/ai").Errf("Failed to ensure log path exists")
return nil
}
msgLogEntry := MessageLogEntry{
Timestamp: time.Now(),
UserMsg: r.OpenAIMessages[len(r.OpenAIMessages)-1].Msg,
ModelMsg: answer,
Feedback: *feedback,
Usage: *usage,
ClientIP: r.ClientIP,
ClientUserAgent: r.ClientUserAgent,
}
logEntryJSON, err := json.Marshal(msgLogEntry)
if err != nil {
log.T("server/ai").Errf("Failed to marshal log entry: %v", err)
return nil
}
msgLogPath := path.Join(r.LogPath, fmt.Sprintf("%s.ndjson", r.Agent.AgentID))
msgLogFile, err := os.OpenFile(msgLogPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
log.T("server/ai").Errf("Failed to open log file: %v", err)
return nil
}
defer msgLogFile.Close()
_, err = msgLogFile.Write([]byte(fmt.Sprintf("%s\n", logEntryJSON)))
if err != nil {
log.T("server/ai").Errf("Failed to write log entry: %v", err)
}
return nil
}
func (req *ChatRequest) writeResp(res *genai.GenerateContentResponse) {
fullAnswerContent := ""
answer := res.Candidates[0]
if answer == nil {
return
}
for i, part := range answer.Content.Parts {
log.T("server/ai").Dbgf(" - AI response part #%d: %s", i+1, part)
content := fmt.Sprint(part)
fullAnswerContent += content
contentJSON, err := json.Marshal(gin.H{"content": content})
if err != nil {
log.T("server/ai").Errf("Failed to marshal response: %v", err)
continue
}
_, err = req.ResWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", contentJSON)))
req.ResWriter.Flush()
if err != nil {
log.T("server/ai").Errf("Failed to write response: %v", err)
}
}
}
func (r *ChatRequest) EnsureLogPath() bool {
if r.LogPath == "" {
return false
}
stat, err := os.Stat(r.LogPath)
if err != nil && os.IsNotExist(err) {
err = os.MkdirAll(r.LogPath, 0755)
if err != nil {
log.T("server/ai").Errf("Failed to create log path: %v", err)
return false
}
} else if err != nil {
log.T("server/ai").Errf("Failed to stat log path: %v", err)
return false
} else {
if !stat.IsDir() {
log.T("server/ai").Errf("Log path is not a directory: %s", r.LogPath)
return false
}
}
return true
}