134 lines
3.7 KiB
Go
134 lines
3.7 KiB
Go
package workflows
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"tutor/internal/llm"
|
|
)
|
|
|
|
type LLMRunner struct {
|
|
client *llm.Client
|
|
}
|
|
|
|
func NewLLMRunner(client *llm.Client) *LLMRunner {
|
|
return &LLMRunner{client: client}
|
|
}
|
|
|
|
func (r *LLMRunner) DiagnoseJobSeeker(ctx context.Context, input DiagnosticInput) (DiagnosticResult, error) {
|
|
raw, err := r.client.ChatJSON(ctx, diagnoseSystemPrompt(), diagnoseUserPrompt(input), true)
|
|
if err != nil {
|
|
return DiagnosticResult{}, fmt.Errorf("diagnose_job_seeker: %w", err)
|
|
}
|
|
|
|
var result DiagnosticResult
|
|
if err := extractJSON(raw, &result); err != nil {
|
|
return DiagnosticResult{}, fmt.Errorf("diagnose_job_seeker parse: %w", err)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (r *LLMRunner) GradeInterviewAnswer(ctx context.Context, input GradeAnswerInput) (GradedAnswer, error) {
|
|
raw, err := r.client.ChatJSON(ctx, gradeAnswerSystemPrompt(), gradeAnswerUserPrompt(input), true)
|
|
if err != nil {
|
|
return GradedAnswer{}, fmt.Errorf("grade_interview_answer: %w", err)
|
|
}
|
|
|
|
var result GradedAnswer
|
|
if err := extractJSON(raw, &result); err != nil {
|
|
return GradedAnswer{}, fmt.Errorf("grade_interview_answer parse: %w", err)
|
|
}
|
|
|
|
result.UserID = input.UserID
|
|
result.AnswerID = input.AnswerID
|
|
result.QuestionID = input.QuestionID
|
|
return result, nil
|
|
}
|
|
|
|
func (r *LLMRunner) ExtractLearningMemory(ctx context.Context, grade GradedAnswer) (MemoryUpdateCandidate, error) {
|
|
raw, err := r.client.ChatJSON(ctx, extractMemorySystemPrompt(), extractMemoryUserPrompt(grade), true)
|
|
if err != nil {
|
|
return MemoryUpdateCandidate{}, fmt.Errorf("extract_learning_memory: %w", err)
|
|
}
|
|
|
|
candidate := MemoryUpdateCandidate{
|
|
UserID: grade.UserID,
|
|
SourceAnswerID: grade.AnswerID,
|
|
}
|
|
if err := extractJSON(raw, &candidate); err != nil {
|
|
return MemoryUpdateCandidate{}, fmt.Errorf("extract_learning_memory parse: %w", err)
|
|
}
|
|
return candidate, nil
|
|
}
|
|
|
|
func (r *LLMRunner) SelectNextChallenge(ctx context.Context, input NextChallengeInput) (NextChallenge, error) {
|
|
raw, err := r.client.ChatJSON(ctx, nextChallengeSystemPrompt(), nextChallengeUserPrompt("", ""), true)
|
|
if err != nil {
|
|
return NextChallenge{}, fmt.Errorf("select_next_challenge: %w", err)
|
|
}
|
|
|
|
var next NextChallenge
|
|
if err := extractJSON(raw, &next); err != nil {
|
|
return NextChallenge{}, fmt.Errorf("select_next_challenge parse: %w", err)
|
|
}
|
|
next.UserID = input.UserID
|
|
next.Track = input.Track
|
|
return next, nil
|
|
}
|
|
|
|
func (r *LLMRunner) UpdateReadinessMap(ctx context.Context, input ReadinessUpdateInput) (ReadinessUpdate, error) {
|
|
raw, err := r.client.ChatJSON(ctx, readinessUpdateSystemPrompt(), readinessUpdateUserPrompt(input), true)
|
|
if err != nil {
|
|
return ReadinessUpdate{}, fmt.Errorf("update_readiness_map: %w", err)
|
|
}
|
|
|
|
var update ReadinessUpdate
|
|
if err := extractJSON(raw, &update); err != nil {
|
|
return ReadinessUpdate{}, fmt.Errorf("update_readiness_map parse: %w", err)
|
|
}
|
|
update.UserID = input.UserID
|
|
update.Track = input.Track
|
|
return update, nil
|
|
}
|
|
|
|
func extractJSON(raw string, target any) error {
|
|
clean := strings.TrimSpace(raw)
|
|
if strings.HasPrefix(clean, "```") {
|
|
clean = stripCodeFences(clean)
|
|
}
|
|
if err := json.Unmarshal([]byte(clean), target); err != nil {
|
|
return fmt.Errorf("%w: %s", err, firstBytes(clean, 200))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var errCodeFence = errors.New("code fence")
|
|
|
|
func stripCodeFences(input string) string {
|
|
lines := strings.Split(input, "\n")
|
|
start := 0
|
|
end := len(lines)
|
|
for i, line := range lines {
|
|
trimmed := strings.TrimSpace(line)
|
|
if strings.HasPrefix(trimmed, "```") {
|
|
if start == 0 {
|
|
start = i + 1
|
|
continue
|
|
}
|
|
end = i
|
|
break
|
|
}
|
|
}
|
|
return strings.Join(lines[start:end], "\n")
|
|
}
|
|
|
|
func firstBytes(input string, limit int) string {
|
|
if len(input) > limit {
|
|
return input[:limit] + "..."
|
|
}
|
|
return input
|
|
}
|