129 lines
3.2 KiB
Go
129 lines
3.2 KiB
Go
|
|
package auth
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"crypto/hmac"
|
||
|
|
"crypto/sha256"
|
||
|
|
"encoding/base64"
|
||
|
|
"encoding/json"
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
||
|
|
"google.golang.org/api/idtoken"
|
||
|
|
)
|
||
|
|
|
||
|
|
type Service struct {
|
||
|
|
pool *pgxpool.Pool
|
||
|
|
googleClientID string
|
||
|
|
jwtSecret string
|
||
|
|
}
|
||
|
|
|
||
|
|
func NewService(pool *pgxpool.Pool, googleClientID, jwtSecret string) *Service {
|
||
|
|
return &Service{
|
||
|
|
pool: pool,
|
||
|
|
googleClientID: googleClientID,
|
||
|
|
jwtSecret: jwtSecret,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *Service) HandleGoogleLogin(ctx context.Context, req GoogleLoginRequest) (SessionResponse, error) {
|
||
|
|
if strings.TrimSpace(req.IDToken) == "" || s.googleClientID == "" {
|
||
|
|
return SessionResponse{}, errors.New("google id token is required")
|
||
|
|
}
|
||
|
|
|
||
|
|
payload, err := idtoken.Validate(ctx, req.IDToken, s.googleClientID)
|
||
|
|
if err != nil {
|
||
|
|
return SessionResponse{}, err
|
||
|
|
}
|
||
|
|
|
||
|
|
email := ""
|
||
|
|
if v, ok := payload.Claims["email"].(string); ok {
|
||
|
|
email = strings.ToLower(strings.TrimSpace(v))
|
||
|
|
}
|
||
|
|
name := ""
|
||
|
|
if v, ok := payload.Claims["name"].(string); ok {
|
||
|
|
name = strings.TrimSpace(v)
|
||
|
|
}
|
||
|
|
subject := payload.Subject
|
||
|
|
|
||
|
|
if email == "" {
|
||
|
|
return SessionResponse{}, errors.New("email not found in token")
|
||
|
|
}
|
||
|
|
if name == "" {
|
||
|
|
name = strings.Split(email, "@")[0]
|
||
|
|
}
|
||
|
|
|
||
|
|
user, err := s.upsertUser(ctx, email, name, subject)
|
||
|
|
if err != nil {
|
||
|
|
return SessionResponse{}, err
|
||
|
|
}
|
||
|
|
|
||
|
|
token, err := SignToken(s.jwtSecret, Claims{
|
||
|
|
UserID: user.ID,
|
||
|
|
Email: user.Email,
|
||
|
|
Exp: time.Now().Add(7 * 24 * time.Hour).Unix(),
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
return SessionResponse{}, err
|
||
|
|
}
|
||
|
|
|
||
|
|
return SessionResponse{Token: token, User: user}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *Service) upsertUser(ctx context.Context, email, displayName, subject string) (User, error) {
|
||
|
|
var user User
|
||
|
|
err := s.pool.QueryRow(ctx, `
|
||
|
|
INSERT INTO users (email, display_name, provider, provider_subject)
|
||
|
|
VALUES ($1, $2, 'google', $3)
|
||
|
|
ON CONFLICT (email) DO UPDATE SET
|
||
|
|
display_name = EXCLUDED.display_name,
|
||
|
|
provider_subject = EXCLUDED.provider_subject
|
||
|
|
RETURNING id, email, display_name
|
||
|
|
`, email, displayName, subject).Scan(&user.ID, &user.Email, &user.DisplayName)
|
||
|
|
if err != nil {
|
||
|
|
return User{}, fmt.Errorf("upsert user: %w", err)
|
||
|
|
}
|
||
|
|
return user, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func SignToken(secret string, claims Claims) (string, error) {
|
||
|
|
body, err := json.Marshal(claims)
|
||
|
|
if err != nil {
|
||
|
|
return "", err
|
||
|
|
}
|
||
|
|
payload := base64.RawURLEncoding.EncodeToString(body)
|
||
|
|
mac := hmac.New(sha256.New, []byte(secret))
|
||
|
|
_, _ = mac.Write([]byte(payload))
|
||
|
|
sig := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
||
|
|
return payload + "." + sig, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func VerifyToken(secret, token string) (Claims, error) {
|
||
|
|
parts := strings.Split(token, ".")
|
||
|
|
if len(parts) != 2 {
|
||
|
|
return Claims{}, errors.New("invalid token")
|
||
|
|
}
|
||
|
|
mac := hmac.New(sha256.New, []byte(secret))
|
||
|
|
_, _ = mac.Write([]byte(parts[0]))
|
||
|
|
expected := mac.Sum(nil)
|
||
|
|
actual, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||
|
|
if err != nil || !hmac.Equal(expected, actual) {
|
||
|
|
return Claims{}, errors.New("invalid signature")
|
||
|
|
}
|
||
|
|
body, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||
|
|
if err != nil {
|
||
|
|
return Claims{}, err
|
||
|
|
}
|
||
|
|
var claims Claims
|
||
|
|
if err := json.Unmarshal(body, &claims); err != nil {
|
||
|
|
return Claims{}, err
|
||
|
|
}
|
||
|
|
if claims.Exp < time.Now().Unix() {
|
||
|
|
return Claims{}, fmt.Errorf("token expired")
|
||
|
|
}
|
||
|
|
return claims, nil
|
||
|
|
}
|