package dns import ( "context" "fmt" "net" "path/filepath" "strconv" "strings" "time" "github.com/miekg/dns" "github.com/spf13/viper" "github.com/zeevdiukman/z/helper" ) type ResponseWriter = dns.ResponseWriter type Msg = dns.Msg type HandlerFunc = dns.HandlerFunc // func(dns.ResponseWriter, *dns.Msg) type DNS struct { Config Config Server Server Mux Mux Records Records Resolver Resolver } type Config struct { *viper.Viper } type Server struct { *dns.Server } type Mux struct { *dns.ServeMux } type Records struct { TypeA map[string]string } type Resolver struct { *net.Resolver } func New() *DNS { app := &DNS{} app.ConfigInit("./records.yaml") app.Records.TypeA = app.Config.GetStringMapString("records.type_a") app.ServerInit() app.MuxInit() app.Server.Handler = app.Mux return app } func (a *DNS) Run() { a.Server.ListenAndServe() } func (a *DNS) ConfigInit(filePath string) { dir, file := filepath.Split(filePath) fe := filepath.Ext(file) ext, _ := strings.CutPrefix(fe, ".") name, _ := strings.CutSuffix(file, fe) a.Config.Viper = viper.New() a.Config.AddConfigPath(dir) a.Config.SetConfigName(name) a.Config.SetConfigType(ext) a.Config.ReadInConfig() } func (a *DNS) ServerInit() { port := a.Config.GetInt("port") a.Server.Server = &dns.Server{ Addr: ":" + strconv.Itoa(port), Net: a.Config.GetString("network"), Handler: nil, } } func (a *DNS) MuxInit() { a.Mux.ServeMux = dns.NewServeMux() } func NewResolver(DNSserverAddr string) *Resolver { return &Resolver{ &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { d := net.Dialer{Timeout: 5 * time.Second} return d.DialContext(ctx, "udp", DNSserverAddr) }, }, } } func (a *DNS) Handler(f func(a *DNS, w dns.ResponseWriter, r *dns.Msg)) dns.HandlerFunc { return func(w dns.ResponseWriter, r *dns.Msg) { f(a, w, r) } } func (r *Resolver) Lookup(lookupAddr string, DNSserverAddr string) string { var resp []string var err error ctx := context.Background() resp, err = r.LookupHost(ctx, lookupAddr) if err != nil { fmt.Println(err.Error()) } return resp[0] } func (a *DNS) HandleTypeA(w dns.ResponseWriter, r *dns.Msg) { t := time.Now() msg := &dns.Msg{} msg.SetReply(r) q := r.Question[0] domainName := helper.FtoD(q.Name) ip := "" if ipValue, ok := a.Records.TypeA[domainName]; ok { ip = ipValue } else { dSlices := strings.Split(domainName, ".") if len(dSlices) > 2 { name := dSlices[len(dSlices)-2] tld := dSlices[len(dSlices)-1] cname := name + "." + tld wildCard := "*." + cname if ipValue, ok := a.Records.TypeA[wildCard]; ok { ip = ipValue // domainName = cname } } } RR_Header := dns.RR_Header{ Name: dns.Fqdn(domainName), Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 3600, } answer_typeA := &dns.A{ Hdr: RR_Header, A: net.ParseIP(ip).To4(), } msg.Authoritative = true msg.RecursionDesired = false msg.SetRcode(r, dns.RcodeSuccess) msg.Answer = append(msg.Answer, answer_typeA) tt := time.Since(t) helper.P(tt, " => ", domainName) w.WriteMsg(msg) }