248 lines
6.1 KiB
Go
248 lines
6.1 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"path"
|
|
"time"
|
|
|
|
"codelab/ai-agent/internal/agent"
|
|
"codelab/ai-agent/internal/log"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/generative-ai-go/genai"
|
|
"google.golang.org/api/iterator"
|
|
"google.golang.org/api/option"
|
|
)
|
|
|
|
var (
|
|
GoogleAIKey string
|
|
GoogleAIModel string
|
|
CandidateCount = int32(1)
|
|
Temperature = float32(0.8)
|
|
)
|
|
|
|
type OpenAIMessage struct {
|
|
Role string `json:"role"`
|
|
Msg string `json:"content"`
|
|
}
|
|
|
|
type ChatRequest struct {
|
|
OpenAIMessages []*OpenAIMessage
|
|
Agent *agent.Agent
|
|
ClientIP string
|
|
ClientUserAgent string
|
|
ResWriter gin.ResponseWriter
|
|
LogPath string
|
|
}
|
|
|
|
type MessageLogEntry struct {
|
|
Timestamp time.Time `json:"timestamp"`
|
|
UserMsg string `json:"userMsg"`
|
|
ModelMsg string `json:"modelMsg"`
|
|
Feedback genai.PromptFeedback `json:"feedback"`
|
|
Usage genai.UsageMetadata `json:"usage"`
|
|
ClientIP string `json:"clientIP"`
|
|
ClientUserAgent string `json:"clientUserAgent"`
|
|
}
|
|
|
|
func openaiToGoogle(msgs []*OpenAIMessage) []*genai.Content {
|
|
g := []*genai.Content{}
|
|
for _, msg := range msgs {
|
|
if msg.Role == "system" {
|
|
continue
|
|
}
|
|
g = append(g, &genai.Content{
|
|
Role: msg.Role,
|
|
Parts: []genai.Part{
|
|
genai.Text(msg.Msg),
|
|
},
|
|
})
|
|
}
|
|
return g
|
|
}
|
|
|
|
func (r *ChatRequest) Chat() error {
|
|
ctx := context.Background()
|
|
client, err := genai.NewClient(ctx, option.WithAPIKey(GoogleAIKey))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
defer client.Close()
|
|
model := client.GenerativeModel(GoogleAIModel)
|
|
model.SafetySettings = []*genai.SafetySetting{
|
|
{
|
|
Category: genai.HarmCategoryDangerousContent,
|
|
Threshold: genai.HarmBlockOnlyHigh,
|
|
},
|
|
{
|
|
Category: genai.HarmCategoryHateSpeech,
|
|
Threshold: genai.HarmBlockOnlyHigh,
|
|
},
|
|
{
|
|
Category: genai.HarmCategorySexuallyExplicit,
|
|
Threshold: genai.HarmBlockOnlyHigh,
|
|
},
|
|
{
|
|
Category: genai.HarmCategoryHarassment,
|
|
Threshold: genai.HarmBlockNone,
|
|
},
|
|
}
|
|
model.GenerationConfig = genai.GenerationConfig{
|
|
CandidateCount: &CandidateCount,
|
|
Temperature: &Temperature,
|
|
}
|
|
model.SystemInstruction = genai.NewUserContent(genai.Text(r.Agent.SystemPrompt))
|
|
model.GenerateContentStream(ctx)
|
|
|
|
chat := model.StartChat()
|
|
msgs := append([]*genai.Content{
|
|
{
|
|
Role: "user",
|
|
Parts: []genai.Part{
|
|
genai.Text(r.Agent.InitialPrompt),
|
|
},
|
|
},
|
|
}, openaiToGoogle(r.OpenAIMessages)...)
|
|
chat.History = msgs[:len(msgs)-1]
|
|
|
|
iter := chat.SendMessageStream(ctx, msgs[len(msgs)-1].Parts...)
|
|
iterConut := 0
|
|
for {
|
|
iterConut++
|
|
log.T("server/ai").Dbgf("AI response #%d", iterConut)
|
|
res, err := iter.Next()
|
|
if err == iterator.Done {
|
|
break
|
|
}
|
|
if err != nil {
|
|
log.T("server/ai").Errf("Failed to get response: %v", err)
|
|
}
|
|
r.writeResp(res)
|
|
}
|
|
log.T("server/ai").Dbgf("AI Response took %d iterations", iterConut)
|
|
|
|
res := iter.MergedResponse()
|
|
answer := ""
|
|
if len(res.Candidates) > 0 {
|
|
for _, part := range res.Candidates[0].Content.Parts {
|
|
answer += fmt.Sprint(part)
|
|
}
|
|
} else {
|
|
answer = "<no response>"
|
|
}
|
|
|
|
feedback := res.PromptFeedback
|
|
if feedback == nil {
|
|
log.T("server/ai").Errf("Server response feedback is nil")
|
|
feedback = &genai.PromptFeedback{}
|
|
}
|
|
|
|
usage := res.UsageMetadata
|
|
if usage == nil {
|
|
log.T("server/ai").Errf("Server response usage is nil")
|
|
usage = &genai.UsageMetadata{}
|
|
}
|
|
|
|
log.T("server/ai").Dbgf("usage: %+v", usage)
|
|
usageJSON, err := json.Marshal(usage)
|
|
if err != nil {
|
|
log.T("server/ai").Errf("Failed to marshal usage data: %v", err)
|
|
} else {
|
|
_, err = r.ResWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", usageJSON)))
|
|
r.ResWriter.Flush()
|
|
if err != nil {
|
|
log.T("server/ai").Errf("Failed to write usage data: %v", err)
|
|
}
|
|
}
|
|
_, err = r.ResWriter.Write([]byte("data: [DONE]"))
|
|
r.ResWriter.Flush()
|
|
if err != nil {
|
|
log.T("server/ai").Errf("Failed to write done message: %v", err)
|
|
}
|
|
|
|
ok := r.EnsureLogPath()
|
|
if !ok {
|
|
log.T("server/ai").Errf("Failed to ensure log path exists")
|
|
return nil
|
|
}
|
|
|
|
msgLogEntry := MessageLogEntry{
|
|
Timestamp: time.Now(),
|
|
UserMsg: r.OpenAIMessages[len(r.OpenAIMessages)-1].Msg,
|
|
ModelMsg: answer,
|
|
Feedback: *feedback,
|
|
Usage: *usage,
|
|
ClientIP: r.ClientIP,
|
|
ClientUserAgent: r.ClientUserAgent,
|
|
}
|
|
logEntryJSON, err := json.Marshal(msgLogEntry)
|
|
if err != nil {
|
|
log.T("server/ai").Errf("Failed to marshal log entry: %v", err)
|
|
return nil
|
|
}
|
|
msgLogPath := path.Join(r.LogPath, fmt.Sprintf("%s.ndjson", r.Agent.AgentID))
|
|
msgLogFile, err := os.OpenFile(msgLogPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
|
|
if err != nil {
|
|
log.T("server/ai").Errf("Failed to open log file: %v", err)
|
|
return nil
|
|
}
|
|
defer msgLogFile.Close()
|
|
_, err = msgLogFile.Write([]byte(fmt.Sprintf("%s\n", logEntryJSON)))
|
|
if err != nil {
|
|
log.T("server/ai").Errf("Failed to write log entry: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (req *ChatRequest) writeResp(res *genai.GenerateContentResponse) {
|
|
fullAnswerContent := ""
|
|
answer := res.Candidates[0]
|
|
if answer == nil {
|
|
return
|
|
}
|
|
for i, part := range answer.Content.Parts {
|
|
log.T("server/ai").Dbgf(" - AI response part #%d: %s", i+1, part)
|
|
content := fmt.Sprint(part)
|
|
fullAnswerContent += content
|
|
contentJSON, err := json.Marshal(gin.H{"content": content})
|
|
if err != nil {
|
|
log.T("server/ai").Errf("Failed to marshal response: %v", err)
|
|
continue
|
|
}
|
|
_, err = req.ResWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", contentJSON)))
|
|
req.ResWriter.Flush()
|
|
if err != nil {
|
|
log.T("server/ai").Errf("Failed to write response: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *ChatRequest) EnsureLogPath() bool {
|
|
if r.LogPath == "" {
|
|
return false
|
|
}
|
|
stat, err := os.Stat(r.LogPath)
|
|
if err != nil && os.IsNotExist(err) {
|
|
err = os.MkdirAll(r.LogPath, 0755)
|
|
if err != nil {
|
|
log.T("server/ai").Errf("Failed to create log path: %v", err)
|
|
return false
|
|
}
|
|
} else if err != nil {
|
|
log.T("server/ai").Errf("Failed to stat log path: %v", err)
|
|
return false
|
|
} else {
|
|
if !stat.IsDir() {
|
|
log.T("server/ai").Errf("Log path is not a directory: %s", r.LogPath)
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|