go-dev-zprox-0.01/internal/auth/handlers.go
2025-03-22 08:57:23 +00:00

255 lines
7.7 KiB
Go

package auth
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gookit/goutil/dump"
"github.com/zeevdiukman/go-helper"
"github.com/zeevdiukman/zprox/internal/config"
"github.com/zeevdiukman/zprox/internal/session"
"golang.org/x/oauth2"
)
func CallbackHandler(authConfig *config.Auth, routerData *config.Router) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := config.Data.Context
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
insecureClient := &http.Client{Transport: tr}
ctx = context.WithValue(ctx, oauth2.HTTPClient, insecureClient)
provider, err := oidc.NewProvider(ctx, authConfig.OpenID.EndPoints.Issuer)
if err != nil {
panic(err)
}
oidcConf := &oidc.Config{
ClientID: authConfig.OpenID.ClientID,
}
verifier := provider.Verifier(oidcConf)
code := r.FormValue("code")
state := r.FormValue("state")
expectedState := session.Manager.GetString(r.Context(), "state")
if state != "" && state != expectedState {
dump.P("Wrong nonce")
return
}
// nonce := r.FormValue("nonce")
if code == "" {
dump.P("No code provided")
return
}
if state == "" {
dump.P("No state provided")
return
}
endPoint := oauth2.Endpoint{
AuthURL: authConfig.OpenID.EndPoints.AuthURL,
TokenURL: authConfig.OpenID.EndPoints.TokenURL,
}
conf := &oauth2.Config{
ClientID: authConfig.OpenID.ClientID,
ClientSecret: authConfig.OpenID.ClientSecret,
RedirectURL: authConfig.OpenID.EndPoints.RedirectURI,
Scopes: []string{"openid"},
Endpoint: endPoint,
}
oauth2Token, err := conf.Exchange(r.Context(), code)
if err != nil {
dump.Println(err.Error())
}
jwksURL := authConfig.OpenID.EndPoints.JwksURI
rsaPublicKey, err := FetchKeycloakPublicKey(oauth2Token, jwksURL)
if err != nil {
dump.Println("Error fetching public key:", err)
// http.Redirect(w, r, originalPath, http.StatusTemporaryRedirect)
status := http.StatusUnauthorized
statusText := http.StatusText(status)
w.WriteHeader(status)
w.Write([]byte(statusText))
return
}
originalPath := session.Manager.GetString(r.Context(), "original_path")
nonce := session.Manager.GetString(r.Context(), "nonce")
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
return
}
idToken, err := verifier.Verify(r.Context(), rawIDToken)
if err != nil {
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
return
}
if nonce != idToken.Nonce {
dump.P("ID token nonce is not as expected")
return
}
var claimsRaw json.RawMessage
if err := idToken.Claims(&claimsRaw); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var claims []byte
claimsRaw.UnmarshalJSON(claims)
dump.P(claims)
session.Manager.RenewToken(r.Context())
session.Manager.Put(r.Context(), "oauth2_token", oauth2Token)
session.Manager.Put(r.Context(), "jwks_public_key", rsaPublicKey)
//////////////////////////////////////////
// oauth2Token.SetAuthHeader(r) //?????????
//////////////////////////////////////////
http.Redirect(w, r, originalPath, http.StatusTemporaryRedirect)
return
}
}
func LoginHandler(authConfig *config.Auth, routerData *config.Router) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
w.Header().Add("Access-Control-Allow-Headers", "*")
w.Header().Add("Access-Control-Allow-Credentials", "true")
w.Header().Add("type", "application/json")
b, err := io.ReadAll(r.Body)
if err != nil {
log.Fatalf("Error reading body: %v", err)
}
defer r.Body.Close()
data := map[string]string{}
json.Unmarshal(b, &data)
username := data["username"]
password := data["password"]
kcData := url.Values{}
kcData.Set("grant_type", "password")
kcData.Set("client_id", authConfig.OpenID.ClientID)
kcData.Set("client_secret", authConfig.OpenID.ClientSecret)
kcData.Set("username", username)
kcData.Set("password", password)
client := helper.HttpClientWithSkipVerify()
clientReq, err := http.NewRequest("POST", authConfig.OpenID.EndPoints.TokenURL, strings.NewReader(kcData.Encode()))
if err != nil {
fmt.Println("Error creating request:", err)
return
}
clientReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := client.Do(clientReq)
if err != nil {
fmt.Println("Error sending request:", err)
return
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Println("Error reading response:", err)
return
}
if resp.StatusCode != http.StatusOK {
fmt.Printf("Error: Status code %d, Response: %s\n", resp.StatusCode, string(respBody))
return
}
var tokenResponse TokenResponse
err = json.Unmarshal(respBody, &tokenResponse)
if err != nil {
fmt.Println("Error unmarshaling JSON:", err)
return
}
}
}
func LogoutHandler(authConfig *config.Auth, routerData *config.Router) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if _, ok := session.Manager.Get(r.Context(), "oauth2_token").(*oauth2.Token); !ok {
// if !oauth2Token.Valid() {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(http.StatusText(http.StatusUnauthorized)))
return
// }
}
u, _ := url.Parse(authConfig.OpenID.EndPoints.LogoutURL)
// Set query parameters
// q := u.Query()
// q.Add("client_id", url.QueryEscape(authConfig.OpenID.ClientID))
// q.Add("redirect_uri", url.QueryEscape(authConfig.OpenID.PostLogoutRedirectUri))
// // q.Add("state", )
// // q.Set("redirect_uri", url.QueryEscape(authConfig.OpenID.PostLogoutRedirectUri))
// q.Add("post_logout_redirect_uri", url.QueryEscape("https://app.z.com/auth/logout"))
// u.RawQuery = q.Encode()
// // Redirect to the logout URL
// err = session.Manager.RenewToken(r.Context())
// if err != nil {
// log.Println("Error RenewToken session:", err)
// http.Error(w, "Internal Server Error", http.StatusInternalServerError)
// return
// }
// http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
// } else {
// w.WriteHeader(http.StatusUnauthorized)
// w.Write([]byte(http.StatusText(http.StatusUnauthorized)))
// return
// }
// err := session.Manager.Clear(r.Context())
// if err != nil {
// log.Println("Error Clearing session:", err)
// http.Error(w, "Internal Server Error", http.StatusInternalServerError)
// return
// }
// session.Manager.Remove(r.Context(), "oauth2_token")
session.Manager.Clear(r.Context())
session.Manager.RenewToken(r.Context())
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
}
}
type ContextKey string
func PostLogoutHandler(authConfig *config.Auth, routerData *config.Router) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// err := session.Manager.RenewToken(r.Context())
// if err != nil {
// log.Println("Error RenewToken session:", err)
// http.Error(w, "Internal Server Error", http.StatusInternalServerError)
// return
// }
// r = r.WithContext(config.Data.Context)
err := session.Manager.Clear(r.Context())
if err != nil {
log.Println("Error Clearing session:", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// dump.P(session.Manager.Commit(r.Context()))
// ctx := context.WithValue(r.Context(), ContextKey("auth"), true)
// r = r.WithContext(ctx)
http.Redirect(w, r, authConfig.OpenID.EndPoints.LogoutURL, http.StatusTemporaryRedirect)
}
}