437 lines
14 KiB
Go
437 lines
14 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"crypto/tls"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"github.com/gookit/goutil/dump"
|
|
"zeevdiukman.com/zprox/internal/config"
|
|
"zeevdiukman.com/zprox/internal/logic"
|
|
"zeevdiukman.com/zprox/internal/reverse_proxy"
|
|
"zeevdiukman.com/zprox/internal/router"
|
|
"zeevdiukman.com/zprox/pkg/helper"
|
|
)
|
|
|
|
const DEVELOPMENT bool = true
|
|
|
|
type EntryPoints map[string]EntryPoint
|
|
type EntryPoint struct {
|
|
Name string
|
|
Group string
|
|
*http.Server
|
|
}
|
|
type ReverseProxies map[string]ReverseProxy
|
|
type ReverseProxy *httputil.ReverseProxy
|
|
|
|
var app = logic.NewApp()
|
|
|
|
func main() {
|
|
helper.AppRunner(func() {
|
|
config.Wrapper(func(c *config.Config) {
|
|
groups := logic.NewGroups()
|
|
mainRouter := router.New()
|
|
groups.ForEach(func(k string, g *logic.Group) {
|
|
groupSubRouter := mainRouter.Mux.NewRoute().Subrouter()
|
|
// groupSubRouter.Use(Domain)
|
|
for k := range g.ReverseProxies {
|
|
rpConfig := c.ReverseProxies[k]
|
|
domain := rpConfig.Domain
|
|
proxy := reverse_proxy.New(rpConfig.Host)
|
|
proxy.Name = domain
|
|
newRoute := groupSubRouter.NewRoute()
|
|
subRouter := newRoute.Host(domain).Subrouter()
|
|
|
|
if rpConfig.Auth != "" {
|
|
if _, ok := c.Auth[rpConfig.Auth]; !ok {
|
|
err := errors.New("Error: Auth " + rpConfig.Auth + " not exist!")
|
|
panic(err.Error())
|
|
}
|
|
pths := c.Auth[rpConfig.Auth].Paths
|
|
authRoute := subRouter.NewRoute()
|
|
subRouter.Use(Middleware_SetHeaders)
|
|
authSubRouter := authRoute.PathPrefix(pths.Prefix).Subrouter()
|
|
authSubRouter.Path(pths.Login).Handler(http.HandlerFunc(LoginHandler))
|
|
authSubRouter.Path(pths.Logout).Handler(http.HandlerFunc(LogoutHandler))
|
|
authSubRouter.Path(pths.Callback).Handler(http.HandlerFunc(CallbackHandler))
|
|
subRouter.Use(authMiddleware)
|
|
}
|
|
subRouter.PathPrefix("/").Handler(proxy.Httputil)
|
|
|
|
}
|
|
|
|
if len(g.ReverseProxies) > 0 {
|
|
tlsConfig := &tls.Config{
|
|
|
|
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
// crt, key := "", ""
|
|
|
|
crt, key := c.GetCertsPairByDomain(info.ServerName)
|
|
|
|
if crt == "" && key == "" {
|
|
// crt = c.TLS.Certs["default"].Cert
|
|
// key = c.TLS.Certs["default"].Key
|
|
// panic("Error: TLS cert and key not found!")
|
|
|
|
}
|
|
cert, err := loadCertificate(crt, key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &cert, nil
|
|
},
|
|
}
|
|
server := &http.Server{
|
|
Addr: ":" + g.Port,
|
|
Handler: app.SessionManager.LoadAndSave(groupSubRouter),
|
|
TLSConfig: tlsConfig,
|
|
}
|
|
var err error
|
|
go func() {
|
|
ipAddr := helper.GetIP()
|
|
log.Println("Test server is running at http://" + ipAddr + ":" + g.Port)
|
|
if g.TLS {
|
|
err = server.ListenAndServeTLS("", "")
|
|
} else {
|
|
err = server.ListenAndServe()
|
|
}
|
|
if err != nil {
|
|
log.Println(err.Error())
|
|
}
|
|
}()
|
|
}
|
|
|
|
})
|
|
helper.StartTestHTTPServer(3000)
|
|
})
|
|
})
|
|
}
|
|
|
|
// //////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// //////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// //////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// //////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// //////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// //////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// //////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
func authMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
config.Wrapper(func(c *config.Config) {
|
|
currentPath := r.URL.Path
|
|
authName := c.GetAuthNameByDomain(r.Host)
|
|
// authName := c.DataMaps.DomainToAuth[r.Host]
|
|
loginPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Login
|
|
logoutPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Logout
|
|
callbackPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Callback
|
|
// loginPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Login
|
|
// logoutPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Logout
|
|
// callbackPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Callback
|
|
// TODO: mark auth reverse proxy in yaml
|
|
// AuthHostUrl, _ := url.Parse(c.Auth.Default.OpenID.Host)
|
|
|
|
if r.Host == "keycloak.z.com" {
|
|
next.ServeHTTP(w, r)
|
|
}
|
|
switch currentPath {
|
|
case loginPath:
|
|
{
|
|
// fmt.Fprintln(w, "LOGIN")
|
|
next.ServeHTTP(w, r)
|
|
|
|
}
|
|
case logoutPath:
|
|
{
|
|
next.ServeHTTP(w, r)
|
|
// return
|
|
}
|
|
case callbackPath:
|
|
{
|
|
next.ServeHTTP(w, r)
|
|
// return
|
|
}
|
|
default:
|
|
{
|
|
accessToken := app.SessionManager.GetString(r.Context(), "access_token")
|
|
if accessToken == "" {
|
|
authName := c.DataMaps.DomainToAuth[r.Host]
|
|
http.Redirect(w, r, c.Auth[authName].Paths.Prefix+c.Auth[authName].Paths.Login, http.StatusFound)
|
|
return
|
|
}
|
|
// auth.SetAuthHeader(w, accessToken)
|
|
a := c.Auth[authName]
|
|
pths := a.Paths
|
|
prefix := pths.Prefix
|
|
login := pths.Login
|
|
logout := pths.Logout
|
|
loginPath := prefix + login
|
|
logoutPath := prefix + logout
|
|
if loginPath == r.URL.Path || logoutPath == r.URL.Path {
|
|
next.ServeHTTP(w, r)
|
|
// return
|
|
}
|
|
|
|
// tokenOk := IsAuthorizedJWT(accessToken, c, "default")
|
|
// if tokenOk {
|
|
// } else {
|
|
// // p := a.OpenID
|
|
// // Redirect to login
|
|
// }
|
|
next.ServeHTTP(w, r)
|
|
}
|
|
}
|
|
})
|
|
})
|
|
|
|
}
|
|
|
|
func Domain(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// c := config.New().Auth.Default.Paths
|
|
|
|
// requestedPath := r.URL.Path
|
|
// a := c
|
|
// excludedPaths := []string{
|
|
// a.Prefix + a.Login,
|
|
// a.Prefix + a.Callback,
|
|
// a.Prefix + a.Logout,
|
|
// }
|
|
// contains := helper.IsSliceContains(excludedPaths, requestedPath)
|
|
// if !contains {
|
|
// app.SessionManager.Put(r.Context(), "original_path", requestedPath)
|
|
// }
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func Middleware_SetHeaders(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
r.Header.Set("X-Forwarded-Proto", getProto(r))
|
|
r.Header.Set("X-Forwarded-For", r.RemoteAddr)
|
|
r.Header.Set("X-Forwarded-Host", r.Host)
|
|
r.Header.Set("X-Real-IP", r.RemoteAddr)
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
func getProto(req *http.Request) string {
|
|
if req.TLS != nil {
|
|
return "https"
|
|
} else {
|
|
return "http"
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
type TokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
RefreshExpiresIn int `json:"refresh_expires_in"`
|
|
TokenType string `json:"token_type"`
|
|
NotBeforePolicy int `json:"not-before-policy"`
|
|
SessionState string `json:"session_state"`
|
|
Scope string `json:"scope"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
Error string `json:"error"`
|
|
ErrorDescription string `json:"error_description"`
|
|
}
|
|
|
|
type Res401Struct struct {
|
|
Status string `json:"status" example:"FAILED"`
|
|
HTTPCode int `json:"httpCode" example:"401"`
|
|
Message string `json:"message" example:"authorisation failed"`
|
|
}
|
|
|
|
type Claims struct {
|
|
ResourceAccess client `json:"resource_access,omitempty"`
|
|
JTI string `json:"jti,omitempty"`
|
|
}
|
|
|
|
type client struct {
|
|
DemoServiceClient clientRoles `json:"DemoServiceClient,omitempty"`
|
|
}
|
|
|
|
type clientRoles struct {
|
|
Roles []string `json:"roles,omitempty"`
|
|
}
|
|
type HandlerFuncConfigWrapper func(*config.Config, http.ResponseWriter, *http.Request) *http.Handler
|
|
|
|
// HANDLERS
|
|
// ////////////
|
|
func CallbackHandler(w http.ResponseWriter, r *http.Request) {
|
|
config.Wrapper(func(c *config.Config) {
|
|
query := r.URL.Query()
|
|
code := query.Get("code")
|
|
state := query.Get("state")
|
|
verifier := app.SessionManager.GetString(r.Context(), "code_verifier")
|
|
if verifier == "" {
|
|
http.Error(w, "Code verifier not found in session", http.StatusBadRequest)
|
|
return
|
|
}
|
|
expectedState := app.SessionManager.GetString(r.Context(), "state")
|
|
if state != expectedState {
|
|
http.Error(w, "Invalid state parameter", http.StatusBadRequest)
|
|
return
|
|
}
|
|
originalPath := app.SessionManager.GetString(r.Context(), "original_path")
|
|
authName := c.GetAuthNameByDomain(r.Host)
|
|
token, fullResponse, e := exchangeCode(code, verifier, c, authName)
|
|
if e != nil {
|
|
dump.Println("exchangeCode: " + e.Error())
|
|
}
|
|
|
|
app.SessionManager.Put(r.Context(), "access_token", token.AccessToken)
|
|
app.SessionManager.Put(r.Context(), "full_token", fullResponse)
|
|
|
|
// SetAuthHeader(w, token.AccessToken)
|
|
http.Redirect(w, r, originalPath, http.StatusFound)
|
|
})
|
|
}
|
|
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
|
config.Wrapper(func(c *config.Config) {
|
|
app.SessionManager.Remove(r.Context(), "access_token")
|
|
app.SessionManager.Remove(r.Context(), "full_token")
|
|
authName := c.DataMaps.DomainToAuth[r.Host]
|
|
a := c.Auth[authName]
|
|
u := a.OpenID.EndPoints.Logout
|
|
http.Redirect(w, r, u, http.StatusFound)
|
|
})
|
|
|
|
}
|
|
|
|
func LoginHandler(w http.ResponseWriter, r *http.Request) {
|
|
config.Wrapper(func(c *config.Config) {
|
|
authName := c.DataMaps.DomainToAuth[r.Host]
|
|
codeVerifier, _ := generateCodeVerifier()
|
|
codeChallenge := generateCodeChallenge(codeVerifier)
|
|
state := helper.RandStringByBits(128)
|
|
nonce := helper.RandStringByBits(128)
|
|
authURL, _ := url.Parse(c.Auth[authName].OpenID.EndPoints.Auth)
|
|
query := authURL.Query()
|
|
query.Set("client_id", c.Auth[authName].OpenID.ClientID)
|
|
query.Set("response_type", "code")
|
|
query.Set("scope", "openid")
|
|
query.Set("redirect_uri", c.Auth[authName].OpenID.RedirectURI)
|
|
query.Set("code_challenge", codeChallenge)
|
|
query.Set("code_challenge_method", "S256")
|
|
query.Set("state", state)
|
|
query.Set("nonce", nonce)
|
|
authURL.RawQuery = query.Encode()
|
|
app.SessionManager.Put(r.Context(), "state", state)
|
|
app.SessionManager.Put(r.Context(), "code_verifier", codeVerifier)
|
|
http.Redirect(w, r, authURL.String(), http.StatusFound)
|
|
})
|
|
|
|
}
|
|
|
|
// AUTH FUNCTIONS
|
|
////////////////////
|
|
|
|
func exchangeCode(code string, verifier string, c *config.Config, authName string) (*TokenResponse, string, error) {
|
|
data := url.Values{}
|
|
data.Set("grant_type", "authorization_code")
|
|
data.Set("client_id", c.Auth[authName].OpenID.ClientID)
|
|
data.Set("client_secret", c.Auth[authName].OpenID.ClientSecert)
|
|
data.Set("redirect_uri", c.Auth[authName].OpenID.RedirectURI)
|
|
data.Set("code", code)
|
|
data.Set("scope", "openid zapp")
|
|
if verifier != "" {
|
|
data.Set("code_verifier", verifier)
|
|
}
|
|
tr := &http.Transport{
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: DEVELOPMENT},
|
|
}
|
|
client := &http.Client{Transport: tr}
|
|
u := c.Auth[authName].OpenID.EndPoints.Token
|
|
r, _ := http.NewRequest(http.MethodPost, u, strings.NewReader(data.Encode()))
|
|
r.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
|
resp, err := client.Do(r)
|
|
|
|
if err != nil {
|
|
dump.Println("ERROR exchange code: " + err.Error())
|
|
return nil, "", err
|
|
}
|
|
respBytes, err := io.ReadAll(resp.Body)
|
|
tokenResponse := &TokenResponse{}
|
|
|
|
json.Unmarshal(respBytes, &tokenResponse)
|
|
|
|
if err != nil {
|
|
dump.Println("ERROR exchange code Unmarshal: " + err.Error())
|
|
return nil, "", err
|
|
}
|
|
if tokenResponse.Error != "" {
|
|
dump.Println(tokenResponse.Error + ": " + tokenResponse.ErrorDescription)
|
|
}
|
|
fullResponse := string(respBytes)
|
|
|
|
return tokenResponse, fullResponse, nil
|
|
}
|
|
|
|
func generateCodeVerifier() (string, error) {
|
|
verifier := make([]byte, 32)
|
|
_, err := rand.Read(verifier)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(verifier), nil
|
|
}
|
|
|
|
func generateCodeChallenge(verifier string) string {
|
|
hash := sha256.Sum256([]byte(verifier))
|
|
return base64.RawURLEncoding.EncodeToString(hash[:])
|
|
}
|
|
|
|
func IsAuthorizedJWT(rawAccessToken string, c *config.Config, authName string) bool {
|
|
tr := &http.Transport{
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
}
|
|
client := &http.Client{
|
|
Timeout: time.Duration(6000) * time.Second,
|
|
Transport: tr,
|
|
}
|
|
ctx := oidc.ClientContext(context.Background(), client)
|
|
provider, err := oidc.NewProvider(ctx, c.Auth[authName].OpenID.EndPoints.Issuer)
|
|
if err != nil {
|
|
dump.Println("authorisation failed while getting the provider: " + err.Error())
|
|
return false
|
|
}
|
|
|
|
oidcConfig := &oidc.Config{
|
|
ClientID: c.Auth[authName].OpenID.ClientID,
|
|
}
|
|
verifier := provider.Verifier(oidcConfig)
|
|
idToken, err := verifier.Verify(ctx, rawAccessToken)
|
|
if err != nil {
|
|
dump.Println("authorisation failed while verifying the token: " + err.Error())
|
|
return false
|
|
}
|
|
|
|
var IDTokenClaims Claims // ID Token payload is just JSON.
|
|
if err := idToken.Claims(&IDTokenClaims); err != nil {
|
|
dump.Println("claims: " + err.Error())
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
///////////////////////////////////
|
|
///////////////////////////////////
|