2
This commit is contained in:
parent
fbaf9393ea
commit
5a6eed8c57
5 changed files with 162 additions and 88 deletions
96
cmd/server/handlers.go
Normal file
96
cmd/server/handlers.go
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/gookit/goutil/dump"
|
||||
"zeevdiukman.com/zprox/internal/config"
|
||||
"zeevdiukman.com/zprox/pkg/helper"
|
||||
)
|
||||
|
||||
func CallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
config.Wrapper(func(c *config.Config) {
|
||||
// ctx := context.Background()
|
||||
query := r.URL.Query()
|
||||
|
||||
code := query.Get("code")
|
||||
state := query.Get("state")
|
||||
|
||||
verifier := app.SessionManager.GetString(r.Context(), "code_verifier")
|
||||
if verifier == "" {
|
||||
http.Error(w, "Code verifier not found in session", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
expectedState := app.SessionManager.GetString(r.Context(), "state")
|
||||
if state != expectedState {
|
||||
http.Error(w, "Invalid state parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
// originalURL, err := decodeState(state)
|
||||
// if err != nil {
|
||||
// dump.P(err.Error())
|
||||
// http.Error(w, "Invalid state", http.StatusBadRequest)
|
||||
// return
|
||||
// }
|
||||
originalPath := app.SessionManager.GetString(r.Context(), "original_path")
|
||||
|
||||
authName := c.GetAuthNameByDomain(r.Host)
|
||||
token, fullResponse, e := exchangeCode(code, verifier, c, authName)
|
||||
if e != nil {
|
||||
dump.Println("exchangeCode: " + e.Error())
|
||||
}
|
||||
|
||||
app.SessionManager.Put(r.Context(), "access_token", token.AccessToken)
|
||||
app.SessionManager.Put(r.Context(), "full_token", fullResponse)
|
||||
|
||||
// SetAuthHeader(w, token.AccessToken)
|
||||
http.Redirect(w, r, originalPath, http.StatusFound)
|
||||
// http.Redirect(w, r, originalURL, http.StatusFound)
|
||||
})
|
||||
}
|
||||
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
||||
config.Wrapper(func(c *config.Config) {
|
||||
|
||||
//TODO: only after returninig, delete the session!
|
||||
app.SessionManager.Remove(r.Context(), "access_token")
|
||||
app.SessionManager.Remove(r.Context(), "full_token")
|
||||
|
||||
authName := c.DataMaps.DomainToAuth[r.Host]
|
||||
a := c.Auth[authName]
|
||||
u := a.OpenID.EndPoints.Logout
|
||||
http.Redirect(w, r, u, http.StatusFound)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
config.Wrapper(func(c *config.Config) {
|
||||
|
||||
authName := c.DataMaps.DomainToAuth[r.Host]
|
||||
|
||||
// state := helper.RandStringByBits(64)
|
||||
nonce := helper.RandStringByBits(64)
|
||||
authURL, _ := url.Parse(c.Auth[authName].OpenID.EndPoints.Auth)
|
||||
query := authURL.Query()
|
||||
|
||||
codeVerifier, _ := generateCodeVerifier()
|
||||
codeChallenge := generateCodeChallenge(codeVerifier)
|
||||
|
||||
originalPath := app.SessionManager.GetString(r.Context(), "original_path")
|
||||
state := generateState(url.QueryEscape(originalPath))
|
||||
query.Set("client_id", c.Auth[authName].OpenID.ClientID)
|
||||
query.Set("response_type", "code")
|
||||
query.Set("scope", "openid")
|
||||
query.Set("redirect_uri", c.Auth[authName].OpenID.RedirectURI)
|
||||
query.Set("code_challenge", codeChallenge)
|
||||
query.Set("code_challenge_method", "S256")
|
||||
query.Set("state", state)
|
||||
query.Set("nonce", nonce)
|
||||
authURL.RawQuery = query.Encode()
|
||||
app.SessionManager.Put(r.Context(), "state", state)
|
||||
app.SessionManager.Put(r.Context(), "code_verifier", codeVerifier)
|
||||
http.Redirect(w, r, authURL.String(), http.StatusFound)
|
||||
})
|
||||
|
||||
}
|
||||
|
|
@ -45,7 +45,7 @@ func main() {
|
|||
mainRouter := router.New()
|
||||
groups.ForEach(func(k string, g *logic.Group) {
|
||||
groupSubRouter := mainRouter.Mux.NewRoute().Subrouter()
|
||||
// groupSubRouter.Use(Domain)
|
||||
groupSubRouter.Use(Domain_Middleware)
|
||||
for k := range g.ReverseProxies {
|
||||
rpConfig := c.ReverseProxies[k]
|
||||
domain := rpConfig.Domain
|
||||
|
|
@ -53,7 +53,6 @@ func main() {
|
|||
proxy.Name = domain
|
||||
newRoute := groupSubRouter.NewRoute()
|
||||
subRouter := newRoute.Host(domain).Subrouter()
|
||||
|
||||
if rpConfig.Auth != "" {
|
||||
if _, ok := c.Auth[rpConfig.Auth]; !ok {
|
||||
err := errors.New("Error: Auth " + rpConfig.Auth + " not exist!")
|
||||
|
|
@ -70,6 +69,8 @@ func main() {
|
|||
}
|
||||
subRouter.PathPrefix("/").Handler(proxy.Httputil)
|
||||
|
||||
// Filter out static file requests first
|
||||
|
||||
}
|
||||
|
||||
if len(g.ReverseProxies) > 0 {
|
||||
|
|
@ -129,18 +130,14 @@ func main() {
|
|||
|
||||
func authMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
config.Wrapper(func(c *config.Config) {
|
||||
currentPath := r.URL.Path
|
||||
authName := c.GetAuthNameByDomain(r.Host)
|
||||
// authName := c.DataMaps.DomainToAuth[r.Host]
|
||||
loginPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Login
|
||||
logoutPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Logout
|
||||
callbackPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Callback
|
||||
// loginPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Login
|
||||
// logoutPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Logout
|
||||
// callbackPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Callback
|
||||
// TODO: mark auth reverse proxy in yaml
|
||||
// AuthHostUrl, _ := url.Parse(c.Auth.Default.OpenID.Host)
|
||||
|
||||
if r.Host == "keycloak.z.com" {
|
||||
next.ServeHTTP(w, r)
|
||||
|
|
@ -159,11 +156,13 @@ func authMiddleware(next http.Handler) http.Handler {
|
|||
}
|
||||
case callbackPath:
|
||||
{
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
// return
|
||||
}
|
||||
default:
|
||||
{
|
||||
|
||||
accessToken := app.SessionManager.GetString(r.Context(), "access_token")
|
||||
if accessToken == "" {
|
||||
authName := c.DataMaps.DomainToAuth[r.Host]
|
||||
|
|
@ -183,12 +182,11 @@ func authMiddleware(next http.Handler) http.Handler {
|
|||
// return
|
||||
}
|
||||
|
||||
// tokenOk := IsAuthorizedJWT(accessToken, c, "default")
|
||||
// if tokenOk {
|
||||
// } else {
|
||||
// // p := a.OpenID
|
||||
// // Redirect to login
|
||||
// }
|
||||
tokenOk := IsAuthorizedJWT(accessToken, c, "default")
|
||||
if !tokenOk {
|
||||
http.Redirect(w, r, loginPath, http.StatusFound)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
|
@ -197,21 +195,25 @@ func authMiddleware(next http.Handler) http.Handler {
|
|||
|
||||
}
|
||||
|
||||
func Domain(next http.Handler) http.Handler {
|
||||
func Domain_Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// c := config.New().Auth.Default.Paths
|
||||
// c := config.Get()
|
||||
|
||||
// requestedPath := r.URL.Path
|
||||
// a := c
|
||||
// authName := c.GetAuthNameByDomain(r.Host)
|
||||
// auth := c.Auth[authName]
|
||||
// excludedPaths := []string{
|
||||
// a.Prefix + a.Login,
|
||||
// a.Prefix + a.Callback,
|
||||
// a.Prefix + a.Logout,
|
||||
// auth.Paths.Prefix + auth.Paths.Login,
|
||||
// auth.Paths.Prefix + auth.Paths.Callback,
|
||||
// auth.Paths.Prefix + auth.Paths.Logout,
|
||||
// }
|
||||
// contains := helper.IsSliceContains(excludedPaths, requestedPath)
|
||||
// contains := slices.Contains(excludedPaths, requestedPath)
|
||||
// if !contains {
|
||||
// app.SessionManager.Put(r.Context(), "original_path", requestedPath)
|
||||
// }
|
||||
// dump.P(requestedPath)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
|
@ -276,75 +278,31 @@ type HandlerFuncConfigWrapper func(*config.Config, http.ResponseWriter, *http.Re
|
|||
|
||||
// HANDLERS
|
||||
// ////////////
|
||||
func CallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
config.Wrapper(func(c *config.Config) {
|
||||
query := r.URL.Query()
|
||||
code := query.Get("code")
|
||||
state := query.Get("state")
|
||||
verifier := app.SessionManager.GetString(r.Context(), "code_verifier")
|
||||
if verifier == "" {
|
||||
http.Error(w, "Code verifier not found in session", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
expectedState := app.SessionManager.GetString(r.Context(), "state")
|
||||
if state != expectedState {
|
||||
http.Error(w, "Invalid state parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
originalPath := app.SessionManager.GetString(r.Context(), "original_path")
|
||||
authName := c.GetAuthNameByDomain(r.Host)
|
||||
token, fullResponse, e := exchangeCode(code, verifier, c, authName)
|
||||
if e != nil {
|
||||
dump.Println("exchangeCode: " + e.Error())
|
||||
}
|
||||
|
||||
app.SessionManager.Put(r.Context(), "access_token", token.AccessToken)
|
||||
app.SessionManager.Put(r.Context(), "full_token", fullResponse)
|
||||
|
||||
// SetAuthHeader(w, token.AccessToken)
|
||||
http.Redirect(w, r, originalPath, http.StatusFound)
|
||||
})
|
||||
}
|
||||
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
||||
config.Wrapper(func(c *config.Config) {
|
||||
app.SessionManager.Remove(r.Context(), "access_token")
|
||||
app.SessionManager.Remove(r.Context(), "full_token")
|
||||
authName := c.DataMaps.DomainToAuth[r.Host]
|
||||
a := c.Auth[authName]
|
||||
u := a.OpenID.EndPoints.Logout
|
||||
http.Redirect(w, r, u, http.StatusFound)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
config.Wrapper(func(c *config.Config) {
|
||||
authName := c.DataMaps.DomainToAuth[r.Host]
|
||||
codeVerifier, _ := generateCodeVerifier()
|
||||
codeChallenge := generateCodeChallenge(codeVerifier)
|
||||
state := helper.RandStringByBits(128)
|
||||
nonce := helper.RandStringByBits(128)
|
||||
authURL, _ := url.Parse(c.Auth[authName].OpenID.EndPoints.Auth)
|
||||
query := authURL.Query()
|
||||
query.Set("client_id", c.Auth[authName].OpenID.ClientID)
|
||||
query.Set("response_type", "code")
|
||||
query.Set("scope", "openid")
|
||||
query.Set("redirect_uri", c.Auth[authName].OpenID.RedirectURI)
|
||||
query.Set("code_challenge", codeChallenge)
|
||||
query.Set("code_challenge_method", "S256")
|
||||
query.Set("state", state)
|
||||
query.Set("nonce", nonce)
|
||||
authURL.RawQuery = query.Encode()
|
||||
app.SessionManager.Put(r.Context(), "state", state)
|
||||
app.SessionManager.Put(r.Context(), "code_verifier", codeVerifier)
|
||||
http.Redirect(w, r, authURL.String(), http.StatusFound)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// AUTH FUNCTIONS
|
||||
////////////////////
|
||||
// //////////////////
|
||||
func generateState(redirectURL string) string {
|
||||
b := make([]byte, 32)
|
||||
rand.Read(b)
|
||||
stateData := map[string]string{"redirect_uri": redirectURL, "random": base64.StdEncoding.EncodeToString(b)}
|
||||
jsonData, _ := json.Marshal(stateData)
|
||||
return base64.StdEncoding.EncodeToString(jsonData)
|
||||
}
|
||||
func decodeState(encodedState string) (string, error) {
|
||||
input := encodedState
|
||||
b64data := input[strings.IndexByte(input, ',')+1:]
|
||||
|
||||
decoded, err := base64.StdEncoding.DecodeString(b64data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var stateData map[string]string
|
||||
err = json.Unmarshal(decoded, &stateData)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return stateData["redirect_uri"], nil
|
||||
}
|
||||
func exchangeCode(code string, verifier string, c *config.Config, authName string) (*TokenResponse, string, error) {
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
|
|
@ -435,3 +393,23 @@ func IsAuthorizedJWT(rawAccessToken string, c *config.Config, authName string) b
|
|||
|
||||
///////////////////////////////////
|
||||
///////////////////////////////////
|
||||
|
||||
func isStaticFileRequest(path string) bool {
|
||||
// Check for common static file prefixes
|
||||
if strings.HasPrefix(path, "/static/") || strings.HasPrefix(path, "/assets/") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for common static file extensions
|
||||
staticExtensions := []string{
|
||||
".css", ".js", ".jpg", ".jpeg", ".png", ".gif", ".svg", ".ico",
|
||||
}
|
||||
|
||||
for _, ext := range staticExtensions {
|
||||
if strings.HasSuffix(path, ext) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ reverse_proxies:
|
|||
tls:
|
||||
enabled: true
|
||||
certs: default
|
||||
|
||||
auth_server: true
|
||||
|
||||
app:
|
||||
domain: app.z.com
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
BIN
tmp/main
BIN
tmp/main
Binary file not shown.
Loading…
Reference in a new issue