205 lines
5.1 KiB
Go
205 lines
5.1 KiB
Go
package auth
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"encoding/gob"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"math/big"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/golang-jwt/jwt"
|
|
"github.com/gookit/goutil/dump"
|
|
"github.com/zeevdiukman/zprox/internal/config"
|
|
"github.com/zeevdiukman/zprox/internal/session"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
func init() {
|
|
gob.RegisterName("oauth2_token_pointer", &oauth2.Token{})
|
|
gob.RegisterName("rsa_public_key_pointer", &rsa.PublicKey{})
|
|
}
|
|
|
|
type TokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
RefreshExpiresIn int `json:"refresh_expires_in"`
|
|
TokenType string `json:"token_type"`
|
|
NotBeforePolicy int `json:"not-before-policy"`
|
|
SessionState string `json:"session_state"`
|
|
Scope string `json:"scope"`
|
|
}
|
|
|
|
type JWKS struct {
|
|
Keys []JSONWebKeys `json:"keys"`
|
|
}
|
|
|
|
type JSONWebKeys struct {
|
|
Kty string `json:"kty"`
|
|
Kid string `json:"kid"`
|
|
Use string `json:"use"`
|
|
N string `json:"n"`
|
|
E string `json:"e"`
|
|
X5c []string `json:"x5c"`
|
|
}
|
|
|
|
func GetPublicKey(kid string, jwksURL string) (*rsa.PublicKey, error) {
|
|
|
|
resp, err := http.Get(jwksURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var jwks JWKS
|
|
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
|
dump.P("json.NewDecoder")
|
|
|
|
return nil, err
|
|
}
|
|
|
|
key, err := ValidateJWTissuerSignatrue(kid, jwks)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
nBytes, err := base64.RawURLEncoding.DecodeString(key.N)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
eBytes, err := base64.RawURLEncoding.DecodeString(key.E)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
n := new(big.Int).SetBytes(nBytes)
|
|
e := new(big.Int).SetBytes(eBytes)
|
|
|
|
publicKey := &rsa.PublicKey{
|
|
N: n,
|
|
E: int(e.Int64()),
|
|
}
|
|
return publicKey, nil
|
|
|
|
}
|
|
func ValidateJWTissuerSignatrue(kid string, jwks JWKS) (JSONWebKeys, error) {
|
|
for _, key := range jwks.Keys {
|
|
if key.Kid == kid {
|
|
return key, nil
|
|
}
|
|
}
|
|
err := errors.New("public key not found for kid: " + kid)
|
|
return JSONWebKeys{}, err
|
|
}
|
|
func FetchKeycloakPublicKey(oauth2Token *oauth2.Token, jwksURL string) (*rsa.PublicKey, error) {
|
|
jwtKID, err := GetKidFromJWT(oauth2Token.AccessToken)
|
|
if err != nil {
|
|
dump.Println(err.Error())
|
|
return nil, err
|
|
}
|
|
|
|
publicKey, err := GetPublicKey(jwtKID, jwksURL)
|
|
if err != nil {
|
|
log.Println("Error fetching public key:", err)
|
|
return nil, err
|
|
}
|
|
return publicKey, nil
|
|
}
|
|
func GetJwtClaims(r *http.Request, publicKey *rsa.PublicKey, oauth2Token *oauth2.Token) (jwt.MapClaims, error) {
|
|
|
|
// authProvider := routerData.Auth.Provider
|
|
// jwksURL := config.Data.AuthMap[authProvider].OpenID.JwksURI
|
|
|
|
// Extract the username from the token
|
|
tokenString := oauth2Token.AccessToken
|
|
claims := jwt.MapClaims{}
|
|
_, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
|
|
return publicKey, nil
|
|
})
|
|
if err != nil {
|
|
err = errors.New("error parsing token: " + err.Error())
|
|
return nil, err
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
func GetKidFromJWT(tokenString string) (string, error) {
|
|
parts := strings.Split(tokenString, ".")
|
|
if len(parts) < 2 {
|
|
return "", fmt.Errorf("invalid JWT format")
|
|
}
|
|
|
|
headerBase64 := parts[0]
|
|
|
|
headerJSON, err := base64.RawURLEncoding.DecodeString(headerBase64)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to decode header: %v", err)
|
|
}
|
|
|
|
var header map[string]interface{}
|
|
if err := json.Unmarshal(headerJSON, &header); err != nil {
|
|
return "", fmt.Errorf("failed to unmarshal header: %v", err)
|
|
}
|
|
|
|
kid, ok := header["kid"].(string)
|
|
if !ok {
|
|
return "", fmt.Errorf("kid not found or not a string")
|
|
}
|
|
|
|
return kid, nil
|
|
}
|
|
func RedirectToLogin(authConfig *config.Auth, w http.ResponseWriter, r *http.Request) {
|
|
session.Manager.Clear(r.Context())
|
|
session.Manager.RenewToken(r.Context())
|
|
u := GetAuthCodeURL(authConfig, r)
|
|
|
|
// sessToken := session.Manager.Token(r.Context())
|
|
// sessCtx, err := session.Manager.Load(r.Context(), sessToken)
|
|
// if err != nil {
|
|
// log.Println(err.Error())
|
|
// }
|
|
// session.Manager.Put(sessCtx, "original_path", r.URL.Path)
|
|
session.Manager.Put(r.Context(), "original_path", r.URL.Path)
|
|
|
|
http.Redirect(w, r, u, http.StatusTemporaryRedirect)
|
|
}
|
|
func GetAuthCodeURL(authConfig *config.Auth, r *http.Request) string {
|
|
conf := &oauth2.Config{
|
|
ClientID: authConfig.OpenID.ClientID,
|
|
ClientSecret: authConfig.OpenID.ClientSecret,
|
|
RedirectURL: authConfig.OpenID.EndPoints.RedirectURI,
|
|
Scopes: []string{"openid", "email", "profile"},
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: authConfig.OpenID.EndPoints.AuthURL,
|
|
TokenURL: authConfig.OpenID.EndPoints.TokenURL,
|
|
},
|
|
}
|
|
state, err := generateState()
|
|
if err != nil {
|
|
return "/"
|
|
}
|
|
nonce, err := generateState()
|
|
if err != nil {
|
|
return "/"
|
|
}
|
|
// Adding options to the AuthCodeURL method
|
|
session.Manager.Put(r.Context(), "nonce", nonce)
|
|
session.Manager.Put(r.Context(), "state", state)
|
|
|
|
return conf.AuthCodeURL(state, oauth2.SetAuthURLParam("nonce", nonce))
|
|
}
|
|
|
|
func generateState() (string, error) {
|
|
b := make([]byte, 32)
|
|
_, err := rand.Read(b)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return base64.URLEncoding.EncodeToString(b), nil
|
|
}
|