test/app/dns/dns.go

185 lines
3.8 KiB
Go

package dns
import (
"context"
"fmt"
"net"
"path/filepath"
"strconv"
"strings"
"time"
miekgDNS "github.com/miekg/dns"
"github.com/spf13/viper"
"github.com/zeevdiukman/test/app/helper"
)
type ResponseWriter = miekgDNS.ResponseWriter
type Msg = miekgDNS.Msg
type HandlerFunc = miekgDNS.HandlerFunc
// func(miekgDNS.ResponseWriter, *miekgDNS.Msg)
type DNS struct {
Config Config
Server Server
Mux Mux
Records Records
Resolver Resolver
}
type Config struct {
*viper.Viper
}
type Server struct {
*miekgDNS.Server
}
type Mux struct {
*miekgDNS.ServeMux
}
type Records struct {
TypeA map[string]string
}
type Resolver struct {
*net.Resolver
}
func New(filePath string) *DNS {
app := &DNS{}
app.ConfigInit(filePath)
app.Records.TypeA = app.Config.GetStringMapString("records.type_a")
alternate_resolver_ip := app.Config.GetString("alternate_resolver.ip")
alternate_resolver_port := strconv.Itoa(app.Config.GetInt("alternate_resolver.port"))
app.ServerInit()
app.MuxInit()
app.Resolver = NewResolver(alternate_resolver_ip + ":" + alternate_resolver_port)
app.Server.Handler = app.Mux
return app
}
func (a *DNS) Run() {
err := a.Server.ListenAndServe()
if err != nil {
fmt.Println(err.Error())
}
}
func (a *DNS) ConfigInit(filePath string) {
dir, file := filepath.Split(filePath)
fe := filepath.Ext(file)
ext, _ := strings.CutPrefix(fe, ".")
name, _ := strings.CutSuffix(file, fe)
a.Config.Viper = viper.New()
a.Config.AddConfigPath(dir)
a.Config.SetConfigName(name)
a.Config.SetConfigType(ext)
a.Config.ReadInConfig()
}
func (a *DNS) ServerInit() {
port := a.Config.GetInt("port")
a.Server.Server = &miekgDNS.Server{
Addr: ":" + strconv.Itoa(port),
Net: a.Config.GetString("network"),
Handler: nil,
}
helper.P("DNS server started at port ", port)
}
func (a *DNS) MuxInit() {
a.Mux.ServeMux = miekgDNS.NewServeMux()
a.Mux.HandleFunc(".", a.HandleTypeA)
}
func NewResolver(DNSserverAddr string) Resolver {
return Resolver{
&net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{Timeout: 5 * time.Second}
return d.DialContext(ctx, "udp", DNSserverAddr)
},
},
}
}
func (a *DNS) Handler(f func(a *DNS, w miekgDNS.ResponseWriter, r *miekgDNS.Msg)) miekgDNS.HandlerFunc {
return func(w miekgDNS.ResponseWriter, r *miekgDNS.Msg) {
f(a, w, r)
}
}
func (r *Resolver) Lookup(lookupAddr string) string {
var resp []string
var err error
ctx := context.Background()
resp, err = r.LookupHost(ctx, lookupAddr)
if err != nil {
fmt.Println(err.Error())
}
return resp[0]
}
func (a *DNS) HandleTypeA(w miekgDNS.ResponseWriter, r *miekgDNS.Msg) {
useAlternateResolver := false
t := time.Now()
msg := &miekgDNS.Msg{}
msg.SetReply(r)
q := r.Question[0]
domainName := helper.FtoD(q.Name)
ip := ""
if ipValue, ok := a.Records.TypeA[domainName]; ok {
helper.P("FOUND => ", domainName)
ip = ipValue
} else {
dSlices := strings.Split(domainName, ".")
//check if wild card
if len(dSlices) > 2 {
name := dSlices[len(dSlices)-2]
tld := dSlices[len(dSlices)-1]
cname := name + "." + tld
wildCard := "*." + cname
if ipValue, ok := a.Records.TypeA[wildCard]; ok {
ip = ipValue
} else {
useAlternateResolver = true
}
} else {
useAlternateResolver = true
}
if useAlternateResolver {
ip = a.Resolver.Lookup(domainName)
}
}
RR_Header := miekgDNS.RR_Header{
Name: miekgDNS.Fqdn(domainName),
Rrtype: miekgDNS.TypeA,
Class: miekgDNS.ClassINET,
Ttl: 3600,
}
answer_typeA := &miekgDNS.A{
Hdr: RR_Header,
A: net.ParseIP(ip).To4(),
}
msg.Authoritative = true
msg.RecursionDesired = false
msg.SetRcode(r, miekgDNS.RcodeSuccess)
msg.Answer = append(msg.Answer, answer_typeA)
tt := time.Since(t)
helper.P(tt, " => ", domainName)
w.WriteMsg(msg)
}