go-dev-zprox-0.01/internal/auth/auth.go
2025-03-22 08:57:23 +00:00

205 lines
5.1 KiB
Go

package auth
import (
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/gob"
"encoding/json"
"errors"
"fmt"
"log"
"math/big"
"net/http"
"strings"
"github.com/golang-jwt/jwt"
"github.com/gookit/goutil/dump"
"github.com/zeevdiukman/zprox/internal/config"
"github.com/zeevdiukman/zprox/internal/session"
"golang.org/x/oauth2"
)
func init() {
gob.RegisterName("oauth2_token_pointer", &oauth2.Token{})
gob.RegisterName("rsa_public_key_pointer", &rsa.PublicKey{})
}
type TokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
RefreshExpiresIn int `json:"refresh_expires_in"`
TokenType string `json:"token_type"`
NotBeforePolicy int `json:"not-before-policy"`
SessionState string `json:"session_state"`
Scope string `json:"scope"`
}
type JWKS struct {
Keys []JSONWebKeys `json:"keys"`
}
type JSONWebKeys struct {
Kty string `json:"kty"`
Kid string `json:"kid"`
Use string `json:"use"`
N string `json:"n"`
E string `json:"e"`
X5c []string `json:"x5c"`
}
func GetPublicKey(kid string, jwksURL string) (*rsa.PublicKey, error) {
resp, err := http.Get(jwksURL)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var jwks JWKS
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
dump.P("json.NewDecoder")
return nil, err
}
key, err := ValidateJWTissuerSignatrue(kid, jwks)
if err != nil {
return nil, err
}
nBytes, err := base64.RawURLEncoding.DecodeString(key.N)
if err != nil {
return nil, err
}
eBytes, err := base64.RawURLEncoding.DecodeString(key.E)
if err != nil {
return nil, err
}
n := new(big.Int).SetBytes(nBytes)
e := new(big.Int).SetBytes(eBytes)
publicKey := &rsa.PublicKey{
N: n,
E: int(e.Int64()),
}
return publicKey, nil
}
func ValidateJWTissuerSignatrue(kid string, jwks JWKS) (JSONWebKeys, error) {
for _, key := range jwks.Keys {
if key.Kid == kid {
return key, nil
}
}
err := errors.New("public key not found for kid: " + kid)
return JSONWebKeys{}, err
}
func FetchKeycloakPublicKey(oauth2Token *oauth2.Token, jwksURL string) (*rsa.PublicKey, error) {
jwtKID, err := GetKidFromJWT(oauth2Token.AccessToken)
if err != nil {
dump.Println(err.Error())
return nil, err
}
publicKey, err := GetPublicKey(jwtKID, jwksURL)
if err != nil {
log.Println("Error fetching public key:", err)
return nil, err
}
return publicKey, nil
}
func GetJwtClaims(r *http.Request, publicKey *rsa.PublicKey, oauth2Token *oauth2.Token) (jwt.MapClaims, error) {
// authProvider := routerData.Auth.Provider
// jwksURL := config.Data.AuthMap[authProvider].OpenID.JwksURI
// Extract the username from the token
tokenString := oauth2Token.AccessToken
claims := jwt.MapClaims{}
_, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
return publicKey, nil
})
if err != nil {
err = errors.New("error parsing token: " + err.Error())
return nil, err
}
return claims, nil
}
func GetKidFromJWT(tokenString string) (string, error) {
parts := strings.Split(tokenString, ".")
if len(parts) < 2 {
return "", fmt.Errorf("invalid JWT format")
}
headerBase64 := parts[0]
headerJSON, err := base64.RawURLEncoding.DecodeString(headerBase64)
if err != nil {
return "", fmt.Errorf("failed to decode header: %v", err)
}
var header map[string]interface{}
if err := json.Unmarshal(headerJSON, &header); err != nil {
return "", fmt.Errorf("failed to unmarshal header: %v", err)
}
kid, ok := header["kid"].(string)
if !ok {
return "", fmt.Errorf("kid not found or not a string")
}
return kid, nil
}
func RedirectToLogin(authConfig *config.Auth, w http.ResponseWriter, r *http.Request) {
session.Manager.Clear(r.Context())
session.Manager.RenewToken(r.Context())
u := GetAuthCodeURL(authConfig, r)
// sessToken := session.Manager.Token(r.Context())
// sessCtx, err := session.Manager.Load(r.Context(), sessToken)
// if err != nil {
// log.Println(err.Error())
// }
// session.Manager.Put(sessCtx, "original_path", r.URL.Path)
session.Manager.Put(r.Context(), "original_path", r.URL.Path)
http.Redirect(w, r, u, http.StatusTemporaryRedirect)
}
func GetAuthCodeURL(authConfig *config.Auth, r *http.Request) string {
conf := &oauth2.Config{
ClientID: authConfig.OpenID.ClientID,
ClientSecret: authConfig.OpenID.ClientSecret,
RedirectURL: authConfig.OpenID.EndPoints.RedirectURI,
Scopes: []string{"openid", "email", "profile"},
Endpoint: oauth2.Endpoint{
AuthURL: authConfig.OpenID.EndPoints.AuthURL,
TokenURL: authConfig.OpenID.EndPoints.TokenURL,
},
}
state, err := generateState()
if err != nil {
return "/"
}
nonce, err := generateState()
if err != nil {
return "/"
}
// Adding options to the AuthCodeURL method
session.Manager.Put(r.Context(), "nonce", nonce)
session.Manager.Put(r.Context(), "state", state)
return conf.AuthCodeURL(state, oauth2.SetAuthURLParam("nonce", nonce))
}
func generateState() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}