package main import ( "context" "crypto/rand" "crypto/sha256" "crypto/tls" "encoding/base64" "encoding/json" "errors" "io" "log" "net/http" "net/http/httputil" "net/url" "strings" "time" "github.com/coreos/go-oidc/v3/oidc" "github.com/gookit/goutil/dump" "zeevdiukman.com/zprox/internal/config" "zeevdiukman.com/zprox/internal/logic" "zeevdiukman.com/zprox/internal/reverse_proxy" "zeevdiukman.com/zprox/internal/router" "zeevdiukman.com/zprox/pkg/helper" ) const DEVELOPMENT bool = true type EntryPoints map[string]EntryPoint type EntryPoint struct { Name string Group string *http.Server } type ReverseProxies map[string]ReverseProxy type ReverseProxy *httputil.ReverseProxy var app = logic.NewApp() func main() { helper.AppRunner(func() { config.Wrapper(func(c *config.Config) { groups := logic.NewGroups() mainRouter := router.New() groups.ForEach(func(k string, g *logic.Group) { groupSubRouter := mainRouter.Mux.NewRoute().Subrouter() groupSubRouter.Use(Domain_Middleware) for k := range g.ReverseProxies { rpConfig := c.ReverseProxies[k] domain := rpConfig.Domain proxy := reverse_proxy.New(rpConfig.Host) proxy.Name = domain newRoute := groupSubRouter.NewRoute() subRouter := newRoute.Host(domain).Subrouter() if rpConfig.Auth != "" { if _, ok := c.Auth[rpConfig.Auth]; !ok { err := errors.New("Error: Auth " + rpConfig.Auth + " not exist!") panic(err.Error()) } pths := c.Auth[rpConfig.Auth].Paths authRoute := subRouter.NewRoute() subRouter.Use(Middleware_SetHeaders) authSubRouter := authRoute.PathPrefix(pths.Prefix).Subrouter() authSubRouter.Path(pths.Login).Handler(http.HandlerFunc(LoginHandler)) authSubRouter.Path(pths.Logout).Handler(http.HandlerFunc(LogoutHandler)) authSubRouter.Path(pths.Callback).Handler(http.HandlerFunc(CallbackHandler)) subRouter.Use(authMiddleware) } subRouter.PathPrefix("/").Handler(proxy.Httputil) // Filter out static file requests first } if len(g.ReverseProxies) > 0 { tlsConfig := &tls.Config{ GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { // crt, key := "", "" crt, key := c.GetCertsPairByDomain(info.ServerName) if crt == "" && key == "" { // crt = c.TLS.Certs["default"].Cert // key = c.TLS.Certs["default"].Key // panic("Error: TLS cert and key not found!") } cert, err := loadCertificate(crt, key) if err != nil { return nil, err } return &cert, nil }, } server := &http.Server{ Addr: ":" + g.Port, Handler: app.SessionManager.LoadAndSave(groupSubRouter), TLSConfig: tlsConfig, } var err error go func() { ipAddr := helper.GetIP() log.Println("Test server is running at http://" + ipAddr + ":" + g.Port) if g.TLS { err = server.ListenAndServeTLS("", "") } else { err = server.ListenAndServe() } if err != nil { log.Println(err.Error()) } }() } }) helper.StartTestHTTPServer(3000) }) }) } // ////////////////////////////////////////////////////////////////////////////////////////////////////// // ////////////////////////////////////////////////////////////////////////////////////////////////////// // ////////////////////////////////////////////////////////////////////////////////////////////////////// // ////////////////////////////////////////////////////////////////////////////////////////////////////// // ////////////////////////////////////////////////////////////////////////////////////////////////////// // ////////////////////////////////////////////////////////////////////////////////////////////////////// // ////////////////////////////////////////////////////////////////////////////////////////////////////// func authMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { config.Wrapper(func(c *config.Config) { currentPath := r.URL.Path authName := c.GetAuthNameByDomain(r.Host) loginPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Login logoutPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Logout callbackPath := c.Auth[authName].Paths.Prefix + c.Auth[authName].Paths.Callback // TODO: mark auth reverse proxy in yaml if r.Host == "keycloak.z.com" { next.ServeHTTP(w, r) } switch currentPath { case loginPath: { // fmt.Fprintln(w, "LOGIN") next.ServeHTTP(w, r) } case logoutPath: { next.ServeHTTP(w, r) // return } case callbackPath: { next.ServeHTTP(w, r) // return } default: { accessToken := app.SessionManager.GetString(r.Context(), "access_token") if accessToken == "" { authName := c.DataMaps.DomainToAuth[r.Host] http.Redirect(w, r, c.Auth[authName].Paths.Prefix+c.Auth[authName].Paths.Login, http.StatusFound) return } // auth.SetAuthHeader(w, accessToken) a := c.Auth[authName] pths := a.Paths prefix := pths.Prefix login := pths.Login logout := pths.Logout loginPath := prefix + login logoutPath := prefix + logout if loginPath == r.URL.Path || logoutPath == r.URL.Path { next.ServeHTTP(w, r) // return } tokenOk := IsAuthorizedJWT(accessToken, c, "default") if !tokenOk { http.Redirect(w, r, loginPath, http.StatusFound) return } next.ServeHTTP(w, r) } } }) }) } func Domain_Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // c := config.Get() // requestedPath := r.URL.Path // authName := c.GetAuthNameByDomain(r.Host) // auth := c.Auth[authName] // excludedPaths := []string{ // auth.Paths.Prefix + auth.Paths.Login, // auth.Paths.Prefix + auth.Paths.Callback, // auth.Paths.Prefix + auth.Paths.Logout, // } // contains := helper.IsSliceContains(excludedPaths, requestedPath) // contains := slices.Contains(excludedPaths, requestedPath) // if !contains { // app.SessionManager.Put(r.Context(), "original_path", requestedPath) // } // dump.P(requestedPath) next.ServeHTTP(w, r) }) } func Middleware_SetHeaders(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.Header.Set("X-Forwarded-Proto", getProto(r)) r.Header.Set("X-Forwarded-For", r.RemoteAddr) r.Header.Set("X-Forwarded-Host", r.Host) r.Header.Set("X-Real-IP", r.RemoteAddr) next.ServeHTTP(w, r) }) } func getProto(req *http.Request) string { if req.TLS != nil { return "https" } else { return "http" } } //////////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////////// type TokenResponse struct { AccessToken string `json:"access_token"` ExpiresIn int `json:"expires_in"` RefreshExpiresIn int `json:"refresh_expires_in"` TokenType string `json:"token_type"` NotBeforePolicy int `json:"not-before-policy"` SessionState string `json:"session_state"` Scope string `json:"scope"` RefreshToken string `json:"refresh_token"` Error string `json:"error"` ErrorDescription string `json:"error_description"` } type Res401Struct struct { Status string `json:"status" example:"FAILED"` HTTPCode int `json:"httpCode" example:"401"` Message string `json:"message" example:"authorisation failed"` } type Claims struct { ResourceAccess client `json:"resource_access,omitempty"` JTI string `json:"jti,omitempty"` } type client struct { DemoServiceClient clientRoles `json:"DemoServiceClient,omitempty"` } type clientRoles struct { Roles []string `json:"roles,omitempty"` } type HandlerFuncConfigWrapper func(*config.Config, http.ResponseWriter, *http.Request) *http.Handler // HANDLERS // //////////// // AUTH FUNCTIONS // ////////////////// func generateState(redirectURL string) string { b := make([]byte, 32) rand.Read(b) stateData := map[string]string{"redirect_uri": redirectURL, "random": base64.StdEncoding.EncodeToString(b)} jsonData, _ := json.Marshal(stateData) return base64.StdEncoding.EncodeToString(jsonData) } func decodeState(encodedState string) (string, error) { input := encodedState b64data := input[strings.IndexByte(input, ',')+1:] decoded, err := base64.StdEncoding.DecodeString(b64data) if err != nil { return "", err } var stateData map[string]string err = json.Unmarshal(decoded, &stateData) if err != nil { return "", err } return stateData["redirect_uri"], nil } func exchangeCode(code string, verifier string, c *config.Config, authName string) (*TokenResponse, string, error) { data := url.Values{} data.Set("grant_type", "authorization_code") data.Set("client_id", c.Auth[authName].OpenID.ClientID) data.Set("client_secret", c.Auth[authName].OpenID.ClientSecert) data.Set("redirect_uri", c.Auth[authName].OpenID.RedirectURI) data.Set("code", code) data.Set("scope", "openid zapp") if verifier != "" { data.Set("code_verifier", verifier) } tr := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: DEVELOPMENT}, } client := &http.Client{Transport: tr} u := c.Auth[authName].OpenID.EndPoints.Token r, _ := http.NewRequest(http.MethodPost, u, strings.NewReader(data.Encode())) r.Header.Add("Content-Type", "application/x-www-form-urlencoded") resp, err := client.Do(r) if err != nil { dump.Println("ERROR exchange code: " + err.Error()) return nil, "", err } respBytes, err := io.ReadAll(resp.Body) tokenResponse := &TokenResponse{} json.Unmarshal(respBytes, &tokenResponse) if err != nil { dump.Println("ERROR exchange code Unmarshal: " + err.Error()) return nil, "", err } if tokenResponse.Error != "" { dump.Println(tokenResponse.Error + ": " + tokenResponse.ErrorDescription) } fullResponse := string(respBytes) return tokenResponse, fullResponse, nil } func generateCodeVerifier() (string, error) { verifier := make([]byte, 32) _, err := rand.Read(verifier) if err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(verifier), nil } func generateCodeChallenge(verifier string) string { hash := sha256.Sum256([]byte(verifier)) return base64.RawURLEncoding.EncodeToString(hash[:]) } func IsAuthorizedJWT(rawAccessToken string, c *config.Config, authName string) bool { tr := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } client := &http.Client{ Timeout: time.Duration(6000) * time.Second, Transport: tr, } ctx := oidc.ClientContext(context.Background(), client) provider, err := oidc.NewProvider(ctx, c.Auth[authName].OpenID.EndPoints.Issuer) if err != nil { dump.Println("authorisation failed while getting the provider: " + err.Error()) return false } oidcConfig := &oidc.Config{ ClientID: c.Auth[authName].OpenID.ClientID, } verifier := provider.Verifier(oidcConfig) idToken, err := verifier.Verify(ctx, rawAccessToken) if err != nil { dump.Println("authorisation failed while verifying the token: " + err.Error()) return false } var IDTokenClaims Claims // ID Token payload is just JSON. if err := idToken.Claims(&IDTokenClaims); err != nil { dump.Println("claims: " + err.Error()) return false } return true } /////////////////////////////////// /////////////////////////////////// func isStaticFileRequest(path string) bool { // Check for common static file prefixes if strings.HasPrefix(path, "/static/") || strings.HasPrefix(path, "/assets/") { return true } // Check for common static file extensions staticExtensions := []string{ ".css", ".js", ".jpg", ".jpeg", ".png", ".gif", ".svg", ".ico", } for _, ext := range staticExtensions { if strings.HasSuffix(path, ext) { return true } } return false }