package token

import (
	"crypto/rand"
	"encoding/base64"
	"errors"
	"net/http"
	"strings"
	"time"

	"github.com/gin-gonic/gin"
	"github.com/golang-jwt/jwt/v5"
)

var accessKey string
var refreshKey string
var accessTokenLiveHours int
var refreshTokenLiveHours int

type Claims struct {
	Username string `json:"username"`
	Version  string `json:"version"`
	jwt.RegisteredClaims
}

func InitJWT(accessTokenHours int, refreshTokenHours int) {
	accessKey, _ = generateRandomKey(32)
	refreshKey, _ = generateRandomKey(32)
	accessTokenLiveHours = accessTokenHours
	refreshTokenLiveHours = refreshTokenHours
}

func generateRandomKey(length int) (string, error) {
	bytes := make([]byte, length)
	if _, err := rand.Read(bytes); err != nil {
		return "", err
	}
	return base64.URLEncoding.EncodeToString(bytes), nil
}

func GenerateAccessToken(username string, version string) (string, error) {
	expirationTime := time.Now().Add(time.Duration(accessTokenLiveHours) * time.Hour)
	claims := &Claims{
		Username: username,
		Version:  version,
		RegisteredClaims: jwt.RegisteredClaims{
			ExpiresAt: jwt.NewNumericDate(expirationTime),
		},
	}

	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
	tokenString, err := token.SignedString([]byte(accessKey))
	if err != nil {
		return "", err
	}
	return tokenString, nil
}

func GenerateRefreshToken(username string, version string) (string, error) {
	expirationTime := time.Now().Add(time.Duration(refreshTokenLiveHours) * time.Hour)
	claims := &Claims{
		Username: username,
		Version:  version,
		RegisteredClaims: jwt.RegisteredClaims{
			ExpiresAt: jwt.NewNumericDate(expirationTime),
		},
	}

	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
	tokenString, err := token.SignedString([]byte(refreshKey))
	if err != nil {
		return "", err
	}
	return tokenString, nil
}

func ValidateAccessToken() gin.HandlerFunc {
	return func(ctx *gin.Context) {
		tokenString := ctx.GetHeader("Authorization")
		if tokenString == "" {
			_ = ctx.Error(errors.New("authorization token required"))
			return
		}

		if strings.HasPrefix(tokenString, "Bearer ") {
			tokenString = tokenString[7:]
		} else {
			_ = ctx.Error(errors.New("authorization token required"))
			return
		}

		claims := &Claims{}
		jwtToken, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
			return []byte(accessKey), nil
		})
		if err != nil || !jwtToken.Valid {
			_ = ctx.Error(errors.New("invalid token"))
			return
		}

		ctx.Set("username", claims.Username)
		ctx.Set("version", claims.Version)
		ctx.Next()
	}
}

func RefreshTokenHandler(ctx *gin.Context) {

	refreshToken := ctx.GetHeader("Authorization")
	if refreshToken == "" {
		_ = ctx.Error(errors.New("refresh token required"))
		return
	}

	if strings.HasPrefix(refreshToken, "Bearer ") {
		refreshToken = refreshToken[7:]
	} else {
		_ = ctx.Error(errors.New("refresh token required"))
		return
	}

	claims := &Claims{}
	token, err := jwt.ParseWithClaims(refreshToken, claims, func(token *jwt.Token) (interface{}, error) {
		return []byte(refreshKey), nil
	})

	if err != nil || !token.Valid {
		_ = ctx.Error(errors.New("invalid refresh token"))
		return
	}

	newAccessToken, err := GenerateAccessToken(claims.Username, claims.Version)
	if err != nil {
		_ = ctx.Error(errors.New("could not generate access token"))
		return
	}

	ctx.JSON(http.StatusOK, gin.H{"access_token": newAccessToken})
}
