package reverseproxy import ( "context" "net/http" "net/http/httputil" "net/url" "strings" ) type ReverseProxy struct { CtxKey CtxKey Context context.Context *httputil.ReverseProxy } type CtxKey string func NewSTD(ctx context.Context, host string) *ReverseProxy { ctxKey := CtxKey("host") ctx = context.WithValue(ctx, ctxKey, host) reverseProxySTDLIB := &httputil.ReverseProxy{} reverseProxy := &ReverseProxy{} reverseProxy.Context = ctx reverseProxy.ReverseProxy = reverseProxySTDLIB return reverseProxy } func New(ctx context.Context, host string) *ReverseProxy { rp := &ReverseProxy{} rp.CtxKey = CtxKey("host") rp.Context = context.WithValue(ctx, rp.CtxKey, host) target, _ := url.Parse(host) rp.ReverseProxy = &httputil.ReverseProxy{ Director: func(r *http.Request) { r = r.WithContext(rp.Context) // hostFromCtx := ctx.Value(rp.CtxKey).(string) targetQuery := target.RawQuery r.URL.Scheme = target.Scheme r.URL.Host = target.Host r.URL.Path, r.URL.RawPath = JoinURLPath(target, r.URL) if targetQuery == "" || r.URL.RawQuery == "" { r.URL.RawQuery = targetQuery + r.URL.RawQuery } else { r.URL.RawQuery = targetQuery + "&" + r.URL.RawQuery } }, } return rp } func (revereProxy *ReverseProxy) SetContext(ctx context.Context) { revereProxy.Context = ctx } func (revereProxy *ReverseProxy) DirectorFunc(df func(jup JoinURLPathFunc) DirectorFunc) { d := df(JoinURLPath) revereProxy.Director = d } func SingleJoiningSlash(a, b string) string { aslash := strings.HasSuffix(a, "/") bslash := strings.HasPrefix(b, "/") switch { case aslash && bslash: return a + b[1:] case !aslash && !bslash: return a + "/" + b } return a + b } func JoinURLPath(a, b *url.URL) (path, rawpath string) { if a.RawPath == "" && b.RawPath == "" { return SingleJoiningSlash(a.Path, b.Path), "" } // Same as singleJoiningSlash, but uses EscapedPath to determine // whether a slash should be added apath := a.EscapedPath() bpath := b.EscapedPath() aslash := strings.HasSuffix(apath, "/") bslash := strings.HasPrefix(bpath, "/") switch { case aslash && bslash: return a.Path + b.Path[1:], apath + bpath[1:] case !aslash && !bslash: return a.Path + "/" + b.Path, apath + "/" + bpath } return a.Path + b.Path, apath + bpath } type JoinURLPathFunc func(*url.URL, *url.URL) (string, string) // type SingleJoiningSlashFunc func(string, string) string type DirectorFunc func(*http.Request) // func StripPrefix(r *http.Request, prefixPath string) *http.Request { // if prefixPath == "/" { // newPath := strings.TrimPrefix(r.URL.Path, prefixPath) // if newPath == "" { // newPath = "/" // } // r.URL.Path = newPath // } // return r // } func (rp *ReverseProxy) DefaultDirectorFunc(ctx context.Context, ctxKey CtxKey, fn func(*http.Request) *http.Request) { rp.Director = func(r *http.Request) { r = fn(r) r = r.WithContext(rp.Context) hostFromCtx := ctx.Value(CtxKey("host")).(string) target, _ := url.Parse(hostFromCtx) targetQuery := target.RawQuery r.URL.Scheme = target.Scheme r.URL.Host = target.Host r.URL.Path, r.URL.RawPath = JoinURLPath(target, r.URL) if targetQuery == "" || r.URL.RawQuery == "" { r.URL.RawQuery = targetQuery + r.URL.RawQuery } else { r.URL.RawQuery = targetQuery + "&" + r.URL.RawQuery } } }