2
This commit is contained in:
parent
fbaf9393ea
commit
5a6eed8c57
5 changed files with 162 additions and 88 deletions
96
cmd/server/handlers.go
Normal file
96
cmd/server/handlers.go
Normal file
|
|
@ -0,0 +1,96 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/gookit/goutil/dump"
|
||||||
|
"zeevdiukman.com/zprox/internal/config"
|
||||||
|
"zeevdiukman.com/zprox/pkg/helper"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
config.Wrapper(func(c *config.Config) {
|
||||||
|
// ctx := context.Background()
|
||||||
|
query := r.URL.Query()
|
||||||
|
|
||||||
|
code := query.Get("code")
|
||||||
|
state := query.Get("state")
|
||||||
|
|
||||||
|
verifier := app.SessionManager.GetString(r.Context(), "code_verifier")
|
||||||
|
if verifier == "" {
|
||||||
|
http.Error(w, "Code verifier not found in session", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
expectedState := app.SessionManager.GetString(r.Context(), "state")
|
||||||
|
if state != expectedState {
|
||||||
|
http.Error(w, "Invalid state parameter", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// originalURL, err := decodeState(state)
|
||||||
|
// if err != nil {
|
||||||
|
// dump.P(err.Error())
|
||||||
|
// http.Error(w, "Invalid state", http.StatusBadRequest)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
originalPath := app.SessionManager.GetString(r.Context(), "original_path")
|
||||||
|
|
||||||
|
authName := c.GetAuthNameByDomain(r.Host)
|
||||||
|
token, fullResponse, e := exchangeCode(code, verifier, c, authName)
|
||||||
|
if e != nil {
|
||||||
|
dump.Println("exchangeCode: " + e.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
app.SessionManager.Put(r.Context(), "access_token", token.AccessToken)
|
||||||
|
app.SessionManager.Put(r.Context(), "full_token", fullResponse)
|
||||||
|
|
||||||
|
// SetAuthHeader(w, token.AccessToken)
|
||||||
|
http.Redirect(w, r, originalPath, http.StatusFound)
|
||||||
|
// http.Redirect(w, r, originalURL, http.StatusFound)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
config.Wrapper(func(c *config.Config) {
|
||||||
|
|
||||||
|
//TODO: only after returninig, delete the session!
|
||||||
|
app.SessionManager.Remove(r.Context(), "access_token")
|
||||||
|
app.SessionManager.Remove(r.Context(), "full_token")
|
||||||
|
|
||||||
|
authName := c.DataMaps.DomainToAuth[r.Host]
|
||||||
|
a := c.Auth[authName]
|
||||||
|
u := a.OpenID.EndPoints.Logout
|
||||||
|
http.Redirect(w, r, u, http.StatusFound)
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
config.Wrapper(func(c *config.Config) {
|
||||||
|
|
||||||
|
authName := c.DataMaps.DomainToAuth[r.Host]
|
||||||
|
|
||||||
|
// state := helper.RandStringByBits(64)
|
||||||
|
nonce := helper.RandStringByBits(64)
|
||||||
|
authURL, _ := url.Parse(c.Auth[authName].OpenID.EndPoints.Auth)
|
||||||
|
query := authURL.Query()
|
||||||
|
|
||||||
|
codeVerifier, _ := generateCodeVerifier()
|
||||||
|
codeChallenge := generateCodeChallenge(codeVerifier)
|
||||||
|
|
||||||
|
originalPath := app.SessionManager.GetString(r.Context(), "original_path")
|
||||||
|
state := generateState(url.QueryEscape(originalPath))
|
||||||
|
query.Set("client_id", c.Auth[authName].OpenID.ClientID)
|
||||||
|
query.Set("response_type", "code")
|
||||||
|
query.Set("scope", "openid")
|
||||||
|
query.Set("redirect_uri", c.Auth[authName].OpenID.RedirectURI)
|
||||||
|
query.Set("code_challenge", codeChallenge)
|
||||||
|
query.Set("code_challenge_method", "S256")
|
||||||
|
query.Set("state", state)
|
||||||
|
query.Set("nonce", nonce)
|
||||||
|
authURL.RawQuery = query.Encode()
|
||||||
|
app.SessionManager.Put(r.Context(), "state", state)
|
||||||
|
app.SessionManager.Put(r.Context(), "code_verifier", codeVerifier)
|
||||||
|
http.Redirect(w, r, authURL.String(), http.StatusFound)
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -45,7 +45,7 @@ func main() {
|
||||||
mainRouter := router.New()
|
mainRouter := router.New()
|
||||||
groups.ForEach(func(k string, g *logic.Group) {
|
groups.ForEach(func(k string, g *logic.Group) {
|
||||||
groupSubRouter := mainRouter.Mux.NewRoute().Subrouter()
|
groupSubRouter := mainRouter.Mux.NewRoute().Subrouter()
|
||||||
// groupSubRouter.Use(Domain)
|
groupSubRouter.Use(Domain_Middleware)
|
||||||
for k := range g.ReverseProxies {
|
for k := range g.ReverseProxies {
|
||||||
rpConfig := c.ReverseProxies[k]
|
rpConfig := c.ReverseProxies[k]
|
||||||
domain := rpConfig.Domain
|
domain := rpConfig.Domain
|
||||||
|
|
@ -53,7 +53,6 @@ func main() {
|
||||||
proxy.Name = domain
|
proxy.Name = domain
|
||||||
newRoute := groupSubRouter.NewRoute()
|
newRoute := groupSubRouter.NewRoute()
|
||||||
subRouter := newRoute.Host(domain).Subrouter()
|
subRouter := newRoute.Host(domain).Subrouter()
|
||||||
|
|
||||||
if rpConfig.Auth != "" {
|
if rpConfig.Auth != "" {
|
||||||
if _, ok := c.Auth[rpConfig.Auth]; !ok {
|
if _, ok := c.Auth[rpConfig.Auth]; !ok {
|
||||||
err := errors.New("Error: Auth " + rpConfig.Auth + " not exist!")
|
err := errors.New("Error: Auth " + rpConfig.Auth + " not exist!")
|
||||||
|
|
@ -70,6 +69,8 @@ func main() {
|
||||||
}
|
}
|
||||||
subRouter.PathPrefix("/").Handler(proxy.Httputil)
|
subRouter.PathPrefix("/").Handler(proxy.Httputil)
|
||||||
|
|
||||||
|
// Filter out static file requests first
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(g.ReverseProxies) > 0 {
|
if len(g.ReverseProxies) > 0 {
|
||||||
|
|
@ -129,18 +130,14 @@ func main() {
|
||||||
|
|
||||||
func authMiddleware(next http.Handler) http.Handler {
|
func authMiddleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
config.Wrapper(func(c *config.Config) {
|
config.Wrapper(func(c *config.Config) {
|
||||||
currentPath := r.URL.Path
|
currentPath := r.URL.Path
|
||||||
authName := c.GetAuthNameByDomain(r.Host)
|
authName := c.GetAuthNameByDomain(r.Host)
|
||||||
// authName := c.DataMaps.DomainToAuth[r.Host]
|
|
||||||
loginPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Login
|
loginPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Login
|
||||||
logoutPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Logout
|
logoutPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Logout
|
||||||
callbackPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Callback
|
callbackPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Callback
|
||||||
// loginPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Login
|
|
||||||
// logoutPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Logout
|
|
||||||
// callbackPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Callback
|
|
||||||
// TODO: mark auth reverse proxy in yaml
|
// TODO: mark auth reverse proxy in yaml
|
||||||
// AuthHostUrl, _ := url.Parse(c.Auth.Default.OpenID.Host)
|
|
||||||
|
|
||||||
if r.Host == "keycloak.z.com" {
|
if r.Host == "keycloak.z.com" {
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
|
|
@ -159,11 +156,13 @@ func authMiddleware(next http.Handler) http.Handler {
|
||||||
}
|
}
|
||||||
case callbackPath:
|
case callbackPath:
|
||||||
{
|
{
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
// return
|
// return
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
|
|
||||||
accessToken := app.SessionManager.GetString(r.Context(), "access_token")
|
accessToken := app.SessionManager.GetString(r.Context(), "access_token")
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
authName := c.DataMaps.DomainToAuth[r.Host]
|
authName := c.DataMaps.DomainToAuth[r.Host]
|
||||||
|
|
@ -183,12 +182,11 @@ func authMiddleware(next http.Handler) http.Handler {
|
||||||
// return
|
// return
|
||||||
}
|
}
|
||||||
|
|
||||||
// tokenOk := IsAuthorizedJWT(accessToken, c, "default")
|
tokenOk := IsAuthorizedJWT(accessToken, c, "default")
|
||||||
// if tokenOk {
|
if !tokenOk {
|
||||||
// } else {
|
http.Redirect(w, r, loginPath, http.StatusFound)
|
||||||
// // p := a.OpenID
|
return
|
||||||
// // Redirect to login
|
}
|
||||||
// }
|
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -197,21 +195,25 @@ func authMiddleware(next http.Handler) http.Handler {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Domain(next http.Handler) http.Handler {
|
func Domain_Middleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// c := config.New().Auth.Default.Paths
|
// c := config.Get()
|
||||||
|
|
||||||
// requestedPath := r.URL.Path
|
// requestedPath := r.URL.Path
|
||||||
// a := c
|
// authName := c.GetAuthNameByDomain(r.Host)
|
||||||
|
// auth := c.Auth[authName]
|
||||||
// excludedPaths := []string{
|
// excludedPaths := []string{
|
||||||
// a.Prefix + a.Login,
|
// auth.Paths.Prefix + auth.Paths.Login,
|
||||||
// a.Prefix + a.Callback,
|
// auth.Paths.Prefix + auth.Paths.Callback,
|
||||||
// a.Prefix + a.Logout,
|
// auth.Paths.Prefix + auth.Paths.Logout,
|
||||||
// }
|
// }
|
||||||
// contains := helper.IsSliceContains(excludedPaths, requestedPath)
|
// contains := helper.IsSliceContains(excludedPaths, requestedPath)
|
||||||
|
// contains := slices.Contains(excludedPaths, requestedPath)
|
||||||
// if !contains {
|
// if !contains {
|
||||||
// app.SessionManager.Put(r.Context(), "original_path", requestedPath)
|
// app.SessionManager.Put(r.Context(), "original_path", requestedPath)
|
||||||
// }
|
// }
|
||||||
|
// dump.P(requestedPath)
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -276,75 +278,31 @@ type HandlerFuncConfigWrapper func(*config.Config, http.ResponseWriter, *http.Re
|
||||||
|
|
||||||
// HANDLERS
|
// HANDLERS
|
||||||
// ////////////
|
// ////////////
|
||||||
func CallbackHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
config.Wrapper(func(c *config.Config) {
|
|
||||||
query := r.URL.Query()
|
|
||||||
code := query.Get("code")
|
|
||||||
state := query.Get("state")
|
|
||||||
verifier := app.SessionManager.GetString(r.Context(), "code_verifier")
|
|
||||||
if verifier == "" {
|
|
||||||
http.Error(w, "Code verifier not found in session", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
expectedState := app.SessionManager.GetString(r.Context(), "state")
|
|
||||||
if state != expectedState {
|
|
||||||
http.Error(w, "Invalid state parameter", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
originalPath := app.SessionManager.GetString(r.Context(), "original_path")
|
|
||||||
authName := c.GetAuthNameByDomain(r.Host)
|
|
||||||
token, fullResponse, e := exchangeCode(code, verifier, c, authName)
|
|
||||||
if e != nil {
|
|
||||||
dump.Println("exchangeCode: " + e.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
app.SessionManager.Put(r.Context(), "access_token", token.AccessToken)
|
|
||||||
app.SessionManager.Put(r.Context(), "full_token", fullResponse)
|
|
||||||
|
|
||||||
// SetAuthHeader(w, token.AccessToken)
|
|
||||||
http.Redirect(w, r, originalPath, http.StatusFound)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
config.Wrapper(func(c *config.Config) {
|
|
||||||
app.SessionManager.Remove(r.Context(), "access_token")
|
|
||||||
app.SessionManager.Remove(r.Context(), "full_token")
|
|
||||||
authName := c.DataMaps.DomainToAuth[r.Host]
|
|
||||||
a := c.Auth[authName]
|
|
||||||
u := a.OpenID.EndPoints.Logout
|
|
||||||
http.Redirect(w, r, u, http.StatusFound)
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoginHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
config.Wrapper(func(c *config.Config) {
|
|
||||||
authName := c.DataMaps.DomainToAuth[r.Host]
|
|
||||||
codeVerifier, _ := generateCodeVerifier()
|
|
||||||
codeChallenge := generateCodeChallenge(codeVerifier)
|
|
||||||
state := helper.RandStringByBits(128)
|
|
||||||
nonce := helper.RandStringByBits(128)
|
|
||||||
authURL, _ := url.Parse(c.Auth[authName].OpenID.EndPoints.Auth)
|
|
||||||
query := authURL.Query()
|
|
||||||
query.Set("client_id", c.Auth[authName].OpenID.ClientID)
|
|
||||||
query.Set("response_type", "code")
|
|
||||||
query.Set("scope", "openid")
|
|
||||||
query.Set("redirect_uri", c.Auth[authName].OpenID.RedirectURI)
|
|
||||||
query.Set("code_challenge", codeChallenge)
|
|
||||||
query.Set("code_challenge_method", "S256")
|
|
||||||
query.Set("state", state)
|
|
||||||
query.Set("nonce", nonce)
|
|
||||||
authURL.RawQuery = query.Encode()
|
|
||||||
app.SessionManager.Put(r.Context(), "state", state)
|
|
||||||
app.SessionManager.Put(r.Context(), "code_verifier", codeVerifier)
|
|
||||||
http.Redirect(w, r, authURL.String(), http.StatusFound)
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// AUTH FUNCTIONS
|
// AUTH FUNCTIONS
|
||||||
////////////////////
|
// //////////////////
|
||||||
|
func generateState(redirectURL string) string {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
rand.Read(b)
|
||||||
|
stateData := map[string]string{"redirect_uri": redirectURL, "random": base64.StdEncoding.EncodeToString(b)}
|
||||||
|
jsonData, _ := json.Marshal(stateData)
|
||||||
|
return base64.StdEncoding.EncodeToString(jsonData)
|
||||||
|
}
|
||||||
|
func decodeState(encodedState string) (string, error) {
|
||||||
|
input := encodedState
|
||||||
|
b64data := input[strings.IndexByte(input, ',')+1:]
|
||||||
|
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(b64data)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
var stateData map[string]string
|
||||||
|
err = json.Unmarshal(decoded, &stateData)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return stateData["redirect_uri"], nil
|
||||||
|
}
|
||||||
func exchangeCode(code string, verifier string, c *config.Config, authName string) (*TokenResponse, string, error) {
|
func exchangeCode(code string, verifier string, c *config.Config, authName string) (*TokenResponse, string, error) {
|
||||||
data := url.Values{}
|
data := url.Values{}
|
||||||
data.Set("grant_type", "authorization_code")
|
data.Set("grant_type", "authorization_code")
|
||||||
|
|
@ -435,3 +393,23 @@ func IsAuthorizedJWT(rawAccessToken string, c *config.Config, authName string) b
|
||||||
|
|
||||||
///////////////////////////////////
|
///////////////////////////////////
|
||||||
///////////////////////////////////
|
///////////////////////////////////
|
||||||
|
|
||||||
|
func isStaticFileRequest(path string) bool {
|
||||||
|
// Check for common static file prefixes
|
||||||
|
if strings.HasPrefix(path, "/static/") || strings.HasPrefix(path, "/assets/") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for common static file extensions
|
||||||
|
staticExtensions := []string{
|
||||||
|
".css", ".js", ".jpg", ".jpeg", ".png", ".gif", ".svg", ".ico",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ext := range staticExtensions {
|
||||||
|
if strings.HasSuffix(path, ext) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ reverse_proxies:
|
||||||
tls:
|
tls:
|
||||||
enabled: true
|
enabled: true
|
||||||
certs: default
|
certs: default
|
||||||
|
auth_server: true
|
||||||
|
|
||||||
app:
|
app:
|
||||||
domain: app.z.com
|
domain: app.z.com
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
BIN
tmp/main
BIN
tmp/main
Binary file not shown.
Loading…
Reference in a new issue