summaryrefslogtreecommitdiff
path: root/app/internal/utils/middleware.go
blob: 5133916b0f99c4c22932dd6a98461dc8fa332f50 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
package utils

import (
	"log"
	"net/http"
	"os"
	"time"

	"github.com/golang-jwt/jwt/v5"
)

func LogMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		start := time.Now()
		next.ServeHTTP(w, r)
		duration := time.Since(start)

		log.Printf("Request: %s %s took %f seconds", r.Method, r.URL.Path, duration.Seconds())
	})
}

/**
 * AuthMiddleware is a placeholder for authentication middleware.
 * In a real application, this would check for valid authentication tokens or sessions.
 * For now, it just calls the next handler.
 */
func AuthMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		authHeader := r.Header.Get("Authorization")
		if authHeader == "" {
			RespondWithError(w, http.StatusUnauthorized, "Missing authorization header")
			return
		}
		tokenString := authHeader[len("Bearer "):]
		token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
			// Validate the algorithm
			if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
				return nil, http.ErrNotSupported
			}
			// Return the secret key for validation
			return []byte(os.Getenv("JWT_SECRET")), nil
		})
		if err != nil {
			RespondWithError(w, http.StatusUnauthorized, "Invalid token")
			return
		}
		if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
			// You can access claims here
			log.Printf("User ID: %v", claims["user_id"])
		} else {
			RespondWithError(w, http.StatusUnauthorized, "Invalid token claims")
			return
		}

		next.ServeHTTP(w, r)
	})
}

func Chain(middlewares ...func(http.Handler) http.Handler) func(http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		for _, middleware := range middlewares {
			next = middleware(next)
		}
		return next
	}
}