This commit is contained in:
Zeev Diukman 2025-03-02 18:16:57 +00:00
parent fbaf9393ea
commit 5a6eed8c57
5 changed files with 162 additions and 88 deletions

96
cmd/server/handlers.go Normal file
View file

@ -0,0 +1,96 @@
package main
import (
"net/http"
"net/url"
"github.com/gookit/goutil/dump"
"zeevdiukman.com/zprox/internal/config"
"zeevdiukman.com/zprox/pkg/helper"
)
func CallbackHandler(w http.ResponseWriter, r *http.Request) {
config.Wrapper(func(c *config.Config) {
// ctx := context.Background()
query := r.URL.Query()
code := query.Get("code")
state := query.Get("state")
verifier := app.SessionManager.GetString(r.Context(), "code_verifier")
if verifier == "" {
http.Error(w, "Code verifier not found in session", http.StatusBadRequest)
return
}
expectedState := app.SessionManager.GetString(r.Context(), "state")
if state != expectedState {
http.Error(w, "Invalid state parameter", http.StatusBadRequest)
return
}
// originalURL, err := decodeState(state)
// if err != nil {
// dump.P(err.Error())
// http.Error(w, "Invalid state", http.StatusBadRequest)
// return
// }
originalPath := app.SessionManager.GetString(r.Context(), "original_path")
authName := c.GetAuthNameByDomain(r.Host)
token, fullResponse, e := exchangeCode(code, verifier, c, authName)
if e != nil {
dump.Println("exchangeCode: " + e.Error())
}
app.SessionManager.Put(r.Context(), "access_token", token.AccessToken)
app.SessionManager.Put(r.Context(), "full_token", fullResponse)
// SetAuthHeader(w, token.AccessToken)
http.Redirect(w, r, originalPath, http.StatusFound)
// http.Redirect(w, r, originalURL, http.StatusFound)
})
}
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
config.Wrapper(func(c *config.Config) {
//TODO: only after returninig, delete the session!
app.SessionManager.Remove(r.Context(), "access_token")
app.SessionManager.Remove(r.Context(), "full_token")
authName := c.DataMaps.DomainToAuth[r.Host]
a := c.Auth[authName]
u := a.OpenID.EndPoints.Logout
http.Redirect(w, r, u, http.StatusFound)
})
}
func LoginHandler(w http.ResponseWriter, r *http.Request) {
config.Wrapper(func(c *config.Config) {
authName := c.DataMaps.DomainToAuth[r.Host]
// state := helper.RandStringByBits(64)
nonce := helper.RandStringByBits(64)
authURL, _ := url.Parse(c.Auth[authName].OpenID.EndPoints.Auth)
query := authURL.Query()
codeVerifier, _ := generateCodeVerifier()
codeChallenge := generateCodeChallenge(codeVerifier)
originalPath := app.SessionManager.GetString(r.Context(), "original_path")
state := generateState(url.QueryEscape(originalPath))
query.Set("client_id", c.Auth[authName].OpenID.ClientID)
query.Set("response_type", "code")
query.Set("scope", "openid")
query.Set("redirect_uri", c.Auth[authName].OpenID.RedirectURI)
query.Set("code_challenge", codeChallenge)
query.Set("code_challenge_method", "S256")
query.Set("state", state)
query.Set("nonce", nonce)
authURL.RawQuery = query.Encode()
app.SessionManager.Put(r.Context(), "state", state)
app.SessionManager.Put(r.Context(), "code_verifier", codeVerifier)
http.Redirect(w, r, authURL.String(), http.StatusFound)
})
}

View file

@ -45,7 +45,7 @@ func main() {
mainRouter := router.New()
groups.ForEach(func(k string, g *logic.Group) {
groupSubRouter := mainRouter.Mux.NewRoute().Subrouter()
// groupSubRouter.Use(Domain)
groupSubRouter.Use(Domain_Middleware)
for k := range g.ReverseProxies {
rpConfig := c.ReverseProxies[k]
domain := rpConfig.Domain
@ -53,7 +53,6 @@ func main() {
proxy.Name = domain
newRoute := groupSubRouter.NewRoute()
subRouter := newRoute.Host(domain).Subrouter()
if rpConfig.Auth != "" {
if _, ok := c.Auth[rpConfig.Auth]; !ok {
err := errors.New("Error: Auth " + rpConfig.Auth + " not exist!")
@ -70,6 +69,8 @@ func main() {
}
subRouter.PathPrefix("/").Handler(proxy.Httputil)
// Filter out static file requests first
}
if len(g.ReverseProxies) > 0 {
@ -129,18 +130,14 @@ func main() {
func authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
config.Wrapper(func(c *config.Config) {
currentPath := r.URL.Path
authName := c.GetAuthNameByDomain(r.Host)
// authName := c.DataMaps.DomainToAuth[r.Host]
loginPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Login
logoutPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Logout
callbackPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Callback
// loginPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Login
// logoutPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Logout
// callbackPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Callback
// TODO: mark auth reverse proxy in yaml
// AuthHostUrl, _ := url.Parse(c.Auth.Default.OpenID.Host)
if r.Host == "keycloak.z.com" {
next.ServeHTTP(w, r)
@ -159,11 +156,13 @@ func authMiddleware(next http.Handler) http.Handler {
}
case callbackPath:
{
next.ServeHTTP(w, r)
// return
}
default:
{
accessToken := app.SessionManager.GetString(r.Context(), "access_token")
if accessToken == "" {
authName := c.DataMaps.DomainToAuth[r.Host]
@ -183,12 +182,11 @@ func authMiddleware(next http.Handler) http.Handler {
// return
}
// tokenOk := IsAuthorizedJWT(accessToken, c, "default")
// if tokenOk {
// } else {
// // p := a.OpenID
// // Redirect to login
// }
tokenOk := IsAuthorizedJWT(accessToken, c, "default")
if !tokenOk {
http.Redirect(w, r, loginPath, http.StatusFound)
return
}
next.ServeHTTP(w, r)
}
}
@ -197,21 +195,25 @@ func authMiddleware(next http.Handler) http.Handler {
}
func Domain(next http.Handler) http.Handler {
func Domain_Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// c := config.New().Auth.Default.Paths
// c := config.Get()
// requestedPath := r.URL.Path
// a := c
// authName := c.GetAuthNameByDomain(r.Host)
// auth := c.Auth[authName]
// excludedPaths := []string{
// a.Prefix + a.Login,
// a.Prefix + a.Callback,
// a.Prefix + a.Logout,
// auth.Paths.Prefix + auth.Paths.Login,
// auth.Paths.Prefix + auth.Paths.Callback,
// auth.Paths.Prefix + auth.Paths.Logout,
// }
// contains := helper.IsSliceContains(excludedPaths, requestedPath)
// contains := slices.Contains(excludedPaths, requestedPath)
// if !contains {
// app.SessionManager.Put(r.Context(), "original_path", requestedPath)
// app.SessionManager.Put(r.Context(), "original_path", requestedPath)
// }
// dump.P(requestedPath)
next.ServeHTTP(w, r)
})
}
@ -276,75 +278,31 @@ type HandlerFuncConfigWrapper func(*config.Config, http.ResponseWriter, *http.Re
// HANDLERS
// ////////////
func CallbackHandler(w http.ResponseWriter, r *http.Request) {
config.Wrapper(func(c *config.Config) {
query := r.URL.Query()
code := query.Get("code")
state := query.Get("state")
verifier := app.SessionManager.GetString(r.Context(), "code_verifier")
if verifier == "" {
http.Error(w, "Code verifier not found in session", http.StatusBadRequest)
return
}
expectedState := app.SessionManager.GetString(r.Context(), "state")
if state != expectedState {
http.Error(w, "Invalid state parameter", http.StatusBadRequest)
return
}
originalPath := app.SessionManager.GetString(r.Context(), "original_path")
authName := c.GetAuthNameByDomain(r.Host)
token, fullResponse, e := exchangeCode(code, verifier, c, authName)
if e != nil {
dump.Println("exchangeCode: " + e.Error())
}
app.SessionManager.Put(r.Context(), "access_token", token.AccessToken)
app.SessionManager.Put(r.Context(), "full_token", fullResponse)
// SetAuthHeader(w, token.AccessToken)
http.Redirect(w, r, originalPath, http.StatusFound)
})
}
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
config.Wrapper(func(c *config.Config) {
app.SessionManager.Remove(r.Context(), "access_token")
app.SessionManager.Remove(r.Context(), "full_token")
authName := c.DataMaps.DomainToAuth[r.Host]
a := c.Auth[authName]
u := a.OpenID.EndPoints.Logout
http.Redirect(w, r, u, http.StatusFound)
})
}
func LoginHandler(w http.ResponseWriter, r *http.Request) {
config.Wrapper(func(c *config.Config) {
authName := c.DataMaps.DomainToAuth[r.Host]
codeVerifier, _ := generateCodeVerifier()
codeChallenge := generateCodeChallenge(codeVerifier)
state := helper.RandStringByBits(128)
nonce := helper.RandStringByBits(128)
authURL, _ := url.Parse(c.Auth[authName].OpenID.EndPoints.Auth)
query := authURL.Query()
query.Set("client_id", c.Auth[authName].OpenID.ClientID)
query.Set("response_type", "code")
query.Set("scope", "openid")
query.Set("redirect_uri", c.Auth[authName].OpenID.RedirectURI)
query.Set("code_challenge", codeChallenge)
query.Set("code_challenge_method", "S256")
query.Set("state", state)
query.Set("nonce", nonce)
authURL.RawQuery = query.Encode()
app.SessionManager.Put(r.Context(), "state", state)
app.SessionManager.Put(r.Context(), "code_verifier", codeVerifier)
http.Redirect(w, r, authURL.String(), http.StatusFound)
})
}
// AUTH FUNCTIONS
////////////////////
// //////////////////
func generateState(redirectURL string) string {
b := make([]byte, 32)
rand.Read(b)
stateData := map[string]string{"redirect_uri": redirectURL, "random": base64.StdEncoding.EncodeToString(b)}
jsonData, _ := json.Marshal(stateData)
return base64.StdEncoding.EncodeToString(jsonData)
}
func decodeState(encodedState string) (string, error) {
input := encodedState
b64data := input[strings.IndexByte(input, ',')+1:]
decoded, err := base64.StdEncoding.DecodeString(b64data)
if err != nil {
return "", err
}
var stateData map[string]string
err = json.Unmarshal(decoded, &stateData)
if err != nil {
return "", err
}
return stateData["redirect_uri"], nil
}
func exchangeCode(code string, verifier string, c *config.Config, authName string) (*TokenResponse, string, error) {
data := url.Values{}
data.Set("grant_type", "authorization_code")
@ -435,3 +393,23 @@ func IsAuthorizedJWT(rawAccessToken string, c *config.Config, authName string) b
///////////////////////////////////
///////////////////////////////////
func isStaticFileRequest(path string) bool {
// Check for common static file prefixes
if strings.HasPrefix(path, "/static/") || strings.HasPrefix(path, "/assets/") {
return true
}
// Check for common static file extensions
staticExtensions := []string{
".css", ".js", ".jpg", ".jpeg", ".png", ".gif", ".svg", ".ico",
}
for _, ext := range staticExtensions {
if strings.HasSuffix(path, ext) {
return true
}
}
return false
}

View file

@ -7,7 +7,7 @@ reverse_proxies:
tls:
enabled: true
certs: default
auth_server: true
app:
domain: app.z.com

File diff suppressed because one or more lines are too long

BIN
tmp/main

Binary file not shown.