package auth import ( "crypto/rand" "crypto/rsa" "encoding/base64" "encoding/gob" "encoding/json" "errors" "fmt" "log" "math/big" "net/http" "strings" "github.com/golang-jwt/jwt" "github.com/gookit/goutil/dump" "github.com/zeevdiukman/zprox/internal/config" "github.com/zeevdiukman/zprox/internal/session" "golang.org/x/oauth2" ) func init() { gob.RegisterName("oauth2_token_pointer", &oauth2.Token{}) gob.RegisterName("rsa_public_key_pointer", &rsa.PublicKey{}) } type TokenResponse struct { AccessToken string `json:"access_token"` ExpiresIn int `json:"expires_in"` RefreshToken string `json:"refresh_token"` RefreshExpiresIn int `json:"refresh_expires_in"` TokenType string `json:"token_type"` NotBeforePolicy int `json:"not-before-policy"` SessionState string `json:"session_state"` Scope string `json:"scope"` } type JWKS struct { Keys []JSONWebKeys `json:"keys"` } type JSONWebKeys struct { Kty string `json:"kty"` Kid string `json:"kid"` Use string `json:"use"` N string `json:"n"` E string `json:"e"` X5c []string `json:"x5c"` } func GetPublicKey(kid string, jwksURL string) (*rsa.PublicKey, error) { resp, err := http.Get(jwksURL) if err != nil { return nil, err } defer resp.Body.Close() var jwks JWKS if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { dump.P("json.NewDecoder") return nil, err } key, err := ValidateJWTissuerSignatrue(kid, jwks) if err != nil { return nil, err } nBytes, err := base64.RawURLEncoding.DecodeString(key.N) if err != nil { return nil, err } eBytes, err := base64.RawURLEncoding.DecodeString(key.E) if err != nil { return nil, err } n := new(big.Int).SetBytes(nBytes) e := new(big.Int).SetBytes(eBytes) publicKey := &rsa.PublicKey{ N: n, E: int(e.Int64()), } return publicKey, nil } func ValidateJWTissuerSignatrue(kid string, jwks JWKS) (JSONWebKeys, error) { for _, key := range jwks.Keys { if key.Kid == kid { return key, nil } } err := errors.New("public key not found for kid: " + kid) return JSONWebKeys{}, err } func FetchKeycloakPublicKey(oauth2Token *oauth2.Token, jwksURL string) (*rsa.PublicKey, error) { jwtKID, err := GetKidFromJWT(oauth2Token.AccessToken) if err != nil { dump.Println(err.Error()) return nil, err } publicKey, err := GetPublicKey(jwtKID, jwksURL) if err != nil { log.Println("Error fetching public key:", err) return nil, err } return publicKey, nil } func GetJwtClaims(r *http.Request, publicKey *rsa.PublicKey, oauth2Token *oauth2.Token) (jwt.MapClaims, error) { // authProvider := routerData.Auth.Provider // jwksURL := config.Data.AuthMap[authProvider].OpenID.JwksURI // Extract the username from the token tokenString := oauth2Token.AccessToken claims := jwt.MapClaims{} _, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { return publicKey, nil }) if err != nil { err = errors.New("error parsing token: " + err.Error()) return nil, err } return claims, nil } func GetKidFromJWT(tokenString string) (string, error) { parts := strings.Split(tokenString, ".") if len(parts) < 2 { return "", fmt.Errorf("invalid JWT format") } headerBase64 := parts[0] headerJSON, err := base64.RawURLEncoding.DecodeString(headerBase64) if err != nil { return "", fmt.Errorf("failed to decode header: %v", err) } var header map[string]interface{} if err := json.Unmarshal(headerJSON, &header); err != nil { return "", fmt.Errorf("failed to unmarshal header: %v", err) } kid, ok := header["kid"].(string) if !ok { return "", fmt.Errorf("kid not found or not a string") } return kid, nil } func RedirectToLogin(authConfig *config.Auth, w http.ResponseWriter, r *http.Request) { session.Manager.Clear(r.Context()) session.Manager.RenewToken(r.Context()) u := GetAuthCodeURL(authConfig, r) // sessToken := session.Manager.Token(r.Context()) // sessCtx, err := session.Manager.Load(r.Context(), sessToken) // if err != nil { // log.Println(err.Error()) // } // session.Manager.Put(sessCtx, "original_path", r.URL.Path) session.Manager.Put(r.Context(), "original_path", r.URL.Path) http.Redirect(w, r, u, http.StatusTemporaryRedirect) } func GetAuthCodeURL(authConfig *config.Auth, r *http.Request) string { conf := &oauth2.Config{ ClientID: authConfig.OpenID.ClientID, ClientSecret: authConfig.OpenID.ClientSecret, RedirectURL: authConfig.OpenID.EndPoints.RedirectURI, Scopes: []string{"openid", "email", "profile"}, Endpoint: oauth2.Endpoint{ AuthURL: authConfig.OpenID.EndPoints.AuthURL, TokenURL: authConfig.OpenID.EndPoints.TokenURL, }, } state, err := generateState() if err != nil { return "/" } nonce, err := generateState() if err != nil { return "/" } // Adding options to the AuthCodeURL method session.Manager.Put(r.Context(), "nonce", nonce) session.Manager.Put(r.Context(), "state", state) return conf.AuthCodeURL(state, oauth2.SetAuthURLParam("nonce", nonce)) } func generateState() (string, error) { b := make([]byte, 32) _, err := rand.Read(b) if err != nil { return "", err } return base64.URLEncoding.EncodeToString(b), nil }