go-zprox/cmd/server/main.go
Zeev Diukman 5a6eed8c57 2
2025-03-02 18:16:57 +00:00

415 lines
12 KiB
Go

package main
import (
"context"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"io"
"log"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gookit/goutil/dump"
"zeevdiukman.com/zprox/internal/config"
"zeevdiukman.com/zprox/internal/logic"
"zeevdiukman.com/zprox/internal/reverse_proxy"
"zeevdiukman.com/zprox/internal/router"
"zeevdiukman.com/zprox/pkg/helper"
)
const DEVELOPMENT bool = true
type EntryPoints map[string]EntryPoint
type EntryPoint struct {
Name string
Group string
*http.Server
}
type ReverseProxies map[string]ReverseProxy
type ReverseProxy *httputil.ReverseProxy
var app = logic.NewApp()
func main() {
helper.AppRunner(func() {
config.Wrapper(func(c *config.Config) {
groups := logic.NewGroups()
mainRouter := router.New()
groups.ForEach(func(k string, g *logic.Group) {
groupSubRouter := mainRouter.Mux.NewRoute().Subrouter()
groupSubRouter.Use(Domain_Middleware)
for k := range g.ReverseProxies {
rpConfig := c.ReverseProxies[k]
domain := rpConfig.Domain
proxy := reverse_proxy.New(rpConfig.Host)
proxy.Name = domain
newRoute := groupSubRouter.NewRoute()
subRouter := newRoute.Host(domain).Subrouter()
if rpConfig.Auth != "" {
if _, ok := c.Auth[rpConfig.Auth]; !ok {
err := errors.New("Error: Auth " + rpConfig.Auth + " not exist!")
panic(err.Error())
}
pths := c.Auth[rpConfig.Auth].Paths
authRoute := subRouter.NewRoute()
subRouter.Use(Middleware_SetHeaders)
authSubRouter := authRoute.PathPrefix(pths.Prefix).Subrouter()
authSubRouter.Path(pths.Login).Handler(http.HandlerFunc(LoginHandler))
authSubRouter.Path(pths.Logout).Handler(http.HandlerFunc(LogoutHandler))
authSubRouter.Path(pths.Callback).Handler(http.HandlerFunc(CallbackHandler))
subRouter.Use(authMiddleware)
}
subRouter.PathPrefix("/").Handler(proxy.Httputil)
// Filter out static file requests first
}
if len(g.ReverseProxies) > 0 {
tlsConfig := &tls.Config{
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
// crt, key := "", ""
crt, key := c.GetCertsPairByDomain(info.ServerName)
if crt == "" && key == "" {
// crt = c.TLS.Certs["default"].Cert
// key = c.TLS.Certs["default"].Key
// panic("Error: TLS cert and key not found!")
}
cert, err := loadCertificate(crt, key)
if err != nil {
return nil, err
}
return &cert, nil
},
}
server := &http.Server{
Addr: ":" + g.Port,
Handler: app.SessionManager.LoadAndSave(groupSubRouter),
TLSConfig: tlsConfig,
}
var err error
go func() {
ipAddr := helper.GetIP()
log.Println("Test server is running at http://" + ipAddr + ":" + g.Port)
if g.TLS {
err = server.ListenAndServeTLS("", "")
} else {
err = server.ListenAndServe()
}
if err != nil {
log.Println(err.Error())
}
}()
}
})
helper.StartTestHTTPServer(3000)
})
})
}
// //////////////////////////////////////////////////////////////////////////////////////////////////////
// //////////////////////////////////////////////////////////////////////////////////////////////////////
// //////////////////////////////////////////////////////////////////////////////////////////////////////
// //////////////////////////////////////////////////////////////////////////////////////////////////////
// //////////////////////////////////////////////////////////////////////////////////////////////////////
// //////////////////////////////////////////////////////////////////////////////////////////////////////
// //////////////////////////////////////////////////////////////////////////////////////////////////////
func authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
config.Wrapper(func(c *config.Config) {
currentPath := r.URL.Path
authName := c.GetAuthNameByDomain(r.Host)
loginPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Login
logoutPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Logout
callbackPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Callback
// TODO: mark auth reverse proxy in yaml
if r.Host == "keycloak.z.com" {
next.ServeHTTP(w, r)
}
switch currentPath {
case loginPath:
{
// fmt.Fprintln(w, "LOGIN")
next.ServeHTTP(w, r)
}
case logoutPath:
{
next.ServeHTTP(w, r)
// return
}
case callbackPath:
{
next.ServeHTTP(w, r)
// return
}
default:
{
accessToken := app.SessionManager.GetString(r.Context(), "access_token")
if accessToken == "" {
authName := c.DataMaps.DomainToAuth[r.Host]
http.Redirect(w, r, c.Auth[authName].Paths.Prefix+c.Auth[authName].Paths.Login, http.StatusFound)
return
}
// auth.SetAuthHeader(w, accessToken)
a := c.Auth[authName]
pths := a.Paths
prefix := pths.Prefix
login := pths.Login
logout := pths.Logout
loginPath := prefix + login
logoutPath := prefix + logout
if loginPath == r.URL.Path || logoutPath == r.URL.Path {
next.ServeHTTP(w, r)
// return
}
tokenOk := IsAuthorizedJWT(accessToken, c, "default")
if !tokenOk {
http.Redirect(w, r, loginPath, http.StatusFound)
return
}
next.ServeHTTP(w, r)
}
}
})
})
}
func Domain_Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// c := config.Get()
// requestedPath := r.URL.Path
// authName := c.GetAuthNameByDomain(r.Host)
// auth := c.Auth[authName]
// excludedPaths := []string{
// auth.Paths.Prefix + auth.Paths.Login,
// auth.Paths.Prefix + auth.Paths.Callback,
// auth.Paths.Prefix + auth.Paths.Logout,
// }
// contains := helper.IsSliceContains(excludedPaths, requestedPath)
// contains := slices.Contains(excludedPaths, requestedPath)
// if !contains {
// app.SessionManager.Put(r.Context(), "original_path", requestedPath)
// }
// dump.P(requestedPath)
next.ServeHTTP(w, r)
})
}
func Middleware_SetHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Header.Set("X-Forwarded-Proto", getProto(r))
r.Header.Set("X-Forwarded-For", r.RemoteAddr)
r.Header.Set("X-Forwarded-Host", r.Host)
r.Header.Set("X-Real-IP", r.RemoteAddr)
next.ServeHTTP(w, r)
})
}
func getProto(req *http.Request) string {
if req.TLS != nil {
return "https"
} else {
return "http"
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////////
type TokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
RefreshExpiresIn int `json:"refresh_expires_in"`
TokenType string `json:"token_type"`
NotBeforePolicy int `json:"not-before-policy"`
SessionState string `json:"session_state"`
Scope string `json:"scope"`
RefreshToken string `json:"refresh_token"`
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}
type Res401Struct struct {
Status string `json:"status" example:"FAILED"`
HTTPCode int `json:"httpCode" example:"401"`
Message string `json:"message" example:"authorisation failed"`
}
type Claims struct {
ResourceAccess client `json:"resource_access,omitempty"`
JTI string `json:"jti,omitempty"`
}
type client struct {
DemoServiceClient clientRoles `json:"DemoServiceClient,omitempty"`
}
type clientRoles struct {
Roles []string `json:"roles,omitempty"`
}
type HandlerFuncConfigWrapper func(*config.Config, http.ResponseWriter, *http.Request) *http.Handler
// HANDLERS
// ////////////
// AUTH FUNCTIONS
// //////////////////
func generateState(redirectURL string) string {
b := make([]byte, 32)
rand.Read(b)
stateData := map[string]string{"redirect_uri": redirectURL, "random": base64.StdEncoding.EncodeToString(b)}
jsonData, _ := json.Marshal(stateData)
return base64.StdEncoding.EncodeToString(jsonData)
}
func decodeState(encodedState string) (string, error) {
input := encodedState
b64data := input[strings.IndexByte(input, ',')+1:]
decoded, err := base64.StdEncoding.DecodeString(b64data)
if err != nil {
return "", err
}
var stateData map[string]string
err = json.Unmarshal(decoded, &stateData)
if err != nil {
return "", err
}
return stateData["redirect_uri"], nil
}
func exchangeCode(code string, verifier string, c *config.Config, authName string) (*TokenResponse, string, error) {
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("client_id", c.Auth[authName].OpenID.ClientID)
data.Set("client_secret", c.Auth[authName].OpenID.ClientSecert)
data.Set("redirect_uri", c.Auth[authName].OpenID.RedirectURI)
data.Set("code", code)
data.Set("scope", "openid zapp")
if verifier != "" {
data.Set("code_verifier", verifier)
}
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: DEVELOPMENT},
}
client := &http.Client{Transport: tr}
u := c.Auth[authName].OpenID.EndPoints.Token
r, _ := http.NewRequest(http.MethodPost, u, strings.NewReader(data.Encode()))
r.Header.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err := client.Do(r)
if err != nil {
dump.Println("ERROR exchange code: " + err.Error())
return nil, "", err
}
respBytes, err := io.ReadAll(resp.Body)
tokenResponse := &TokenResponse{}
json.Unmarshal(respBytes, &tokenResponse)
if err != nil {
dump.Println("ERROR exchange code Unmarshal: " + err.Error())
return nil, "", err
}
if tokenResponse.Error != "" {
dump.Println(tokenResponse.Error + ": " + tokenResponse.ErrorDescription)
}
fullResponse := string(respBytes)
return tokenResponse, fullResponse, nil
}
func generateCodeVerifier() (string, error) {
verifier := make([]byte, 32)
_, err := rand.Read(verifier)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(verifier), nil
}
func generateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(hash[:])
}
func IsAuthorizedJWT(rawAccessToken string, c *config.Config, authName string) bool {
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client := &http.Client{
Timeout: time.Duration(6000) * time.Second,
Transport: tr,
}
ctx := oidc.ClientContext(context.Background(), client)
provider, err := oidc.NewProvider(ctx, c.Auth[authName].OpenID.EndPoints.Issuer)
if err != nil {
dump.Println("authorisation failed while getting the provider: " + err.Error())
return false
}
oidcConfig := &oidc.Config{
ClientID: c.Auth[authName].OpenID.ClientID,
}
verifier := provider.Verifier(oidcConfig)
idToken, err := verifier.Verify(ctx, rawAccessToken)
if err != nil {
dump.Println("authorisation failed while verifying the token: " + err.Error())
return false
}
var IDTokenClaims Claims // ID Token payload is just JSON.
if err := idToken.Claims(&IDTokenClaims); err != nil {
dump.Println("claims: " + err.Error())
return false
}
return true
}
///////////////////////////////////
///////////////////////////////////
func isStaticFileRequest(path string) bool {
// Check for common static file prefixes
if strings.HasPrefix(path, "/static/") || strings.HasPrefix(path, "/assets/") {
return true
}
// Check for common static file extensions
staticExtensions := []string{
".css", ".js", ".jpg", ".jpeg", ".png", ".gif", ".svg", ".ico",
}
for _, ext := range staticExtensions {
if strings.HasSuffix(path, ext) {
return true
}
}
return false
}