426 lines
12 KiB
Go
426 lines
12 KiB
Go
// MyNetTCPTable is a postfix tcp_table service to check if a client IP is
|
|
// contained in LDAP, either in ipHostNumber or ipNetworkNumber form
|
|
//
|
|
// Copyright (c) 2022 yo000 <johan@nosd.in>
|
|
//
|
|
|
|
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
// "log"
|
|
"errors"
|
|
"flag"
|
|
"log/syslog"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/go-ldap/ldap/v3"
|
|
"github.com/peterbourgon/ff"
|
|
"github.com/sirupsen/logrus"
|
|
lSyslog "github.com/sirupsen/logrus/hooks/syslog"
|
|
"github.com/tabalt/pidfile"
|
|
)
|
|
|
|
const (
|
|
version = "1.0.2"
|
|
)
|
|
|
|
var (
|
|
logstream *logrus.Logger
|
|
conLdap *ldap.Conn
|
|
mutex sync.Mutex
|
|
logTo *string
|
|
logLevel *string
|
|
listen *string
|
|
ldapURL *string
|
|
ldapBaseDN *string
|
|
ldapUser *string
|
|
ldapPass *string
|
|
refreshInterval *int
|
|
pidFilePath *string
|
|
timeout *int
|
|
|
|
netCache []NoAuthNet
|
|
)
|
|
|
|
type NoAuthNet struct {
|
|
Net *net.IPNet
|
|
Present bool
|
|
}
|
|
|
|
// Test if a net is present in cache, and set "Present" flag to true
|
|
// Flag is necessary when reviewing cache content
|
|
func checkNetCacheContainsAndFlag(ipnet *net.IPNet) bool {
|
|
for _, n := range netCache {
|
|
if strings.EqualFold(ipnet.String(), n.Net.String()) {
|
|
n.Present = true
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Remove all items with Present = false
|
|
func clearNetCache() {
|
|
// Manually manage loop index so we won't pass over the item following a deleted one
|
|
i := 0
|
|
for {
|
|
if i >= len(netCache) {
|
|
break
|
|
}
|
|
n := netCache[i]
|
|
if n.Present == false {
|
|
logstream.Debugf("Delete %s from netCache\n", n.Net.String())
|
|
// Nil the pointer to avoid memory leak
|
|
n.Net = nil
|
|
if i+1 < len(netCache) {
|
|
netCache = append(netCache[:i], netCache[i+1:]...)
|
|
} else {
|
|
netCache = netCache[:i]
|
|
}
|
|
i--
|
|
}
|
|
i++
|
|
}
|
|
}
|
|
|
|
// Flag all items as "Present = false"
|
|
func unsetNetCachePresentFlag() {
|
|
for _, n := range netCache {
|
|
n.Present = false
|
|
}
|
|
}
|
|
|
|
func buildNetCacheFromIPNetwork(conLdap *ldap.Conn) error {
|
|
attribute := "ipNetworkNumber"
|
|
|
|
filter := "(objectClass=ipNetwork)"
|
|
searchReq := ldap.NewSearchRequest(*ldapBaseDN, ldap.ScopeWholeSubtree, 0, 0, 0,
|
|
false, filter, []string{attribute}, []ldap.Control{})
|
|
result, err := searchLdap(searchReq, 0)
|
|
if err != nil {
|
|
logstream.Errorf("Error searching %s into LDAP: %v\n", filter, err)
|
|
return err
|
|
}
|
|
logstream.Debugf("Received %d results to ipNetwork query\n", len(result.Entries))
|
|
|
|
// First flag off all elements of netCache
|
|
unsetNetCachePresentFlag()
|
|
|
|
for _, r := range result.Entries {
|
|
if len(r.Attributes) == 0 {
|
|
logstream.Info(fmt.Sprintf("Error searching into LDAP: Attribute %s not found for entry %s\n", attribute, r))
|
|
continue
|
|
} else {
|
|
_, ipnet, err := net.ParseCIDR(r.Attributes[0].Values[0])
|
|
if err != nil {
|
|
logstream.Info(err.Error())
|
|
continue
|
|
}
|
|
if false == checkNetCacheContainsAndFlag(ipnet) {
|
|
netCache = append(netCache, NoAuthNet{Net: ipnet, Present: true})
|
|
}
|
|
}
|
|
}
|
|
|
|
// Finally delete items not previously accessed
|
|
clearNetCache()
|
|
|
|
logstream.Debug("Dump netcache:")
|
|
for _, n := range netCache {
|
|
logstream.Debugf("%s\n", n.Net.String())
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func isIPContainedInNetCache(string_ip string) (bool, error) {
|
|
ip := net.ParseIP(string_ip)
|
|
if ip == nil {
|
|
return false, errors.New(fmt.Sprintf("Invalid IP: %s", string_ip))
|
|
}
|
|
|
|
for _, n := range netCache {
|
|
if n.Net.Contains(ip) {
|
|
return true, nil
|
|
}
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
// Handle incoming request
|
|
func handleConnection(connClt net.Conn, conLdap *ldap.Conn) {
|
|
buf := make([]byte, 1024)
|
|
|
|
// TODO : Maybe keep it open and process following requests? See "warning: read TCP map reply from srv-ldap:8080: unexpected EOF (Application error)"
|
|
// Close client connection when this function ends
|
|
defer func() {
|
|
logstream.Debug("Closing connection")
|
|
connClt.Close()
|
|
}()
|
|
|
|
timeoutDuration := time.Duration(*timeout) * time.Second
|
|
|
|
for {
|
|
// Set a deadline for reading. Read operation will fail if no data
|
|
// is received after deadline.
|
|
connClt.SetReadDeadline(time.Now().Add(timeoutDuration))
|
|
|
|
readlen, err := connClt.Read(buf)
|
|
if err != nil {
|
|
// 10/05/2022 : Drop this conn if client closed connection or timeout occured
|
|
// Dont notice if client closed connection
|
|
if err.Error() != "EOF" && !strings.HasSuffix(err.Error(), "i/o timeout") {
|
|
logstream.Errorf("Error reading connection: %v\n", err.Error())
|
|
}
|
|
return
|
|
}
|
|
|
|
logstream.Debug(fmt.Sprintf("Received: %s\n", string(buf[:readlen-1])))
|
|
|
|
if strings.EqualFold(string(buf[:readlen-1]), "quit") {
|
|
logstream.Infof("Received \"quit\" instruction from %s, closing connection to the client", connClt.RemoteAddr().String())
|
|
connClt.Close()
|
|
return
|
|
}
|
|
|
|
// "set refresh" sent on listening port will refresh netCache from LDAP
|
|
if strings.EqualFold(string(buf[:readlen-1]), "set refresh") {
|
|
logstream.Infof("Received \"set refresh\" instruction from %s, refreshing netCache", connClt.RemoteAddr().String())
|
|
buildNetCacheFromIPNetwork(conLdap)
|
|
sendResponse(connClt, "Refreshing cache", 200)
|
|
continue
|
|
}
|
|
|
|
// "get loglevel" sent on listening port will return current loglevel
|
|
if strings.EqualFold(string(buf[:readlen-1]), "get loglevel") {
|
|
logstream.Infof("Received \"get loglevel\" instruction from %s", connClt.RemoteAddr().String())
|
|
sendResponse(connClt, logstream.Level.String(), 200)
|
|
continue
|
|
}
|
|
|
|
// "set loglevel level" sent on listening port will set current loglevel
|
|
if readlen > 14 && strings.EqualFold(string(buf[:12]), "set loglevel") {
|
|
logstream.Infof("Received \"%s\" instruction from %s", string(buf[:readlen-1]), connClt.RemoteAddr().String())
|
|
level, err := logrus.ParseLevel(string(buf[13 : readlen-1]))
|
|
if err != nil {
|
|
sendResponse(connClt, fmt.Sprintf("Invalid log level: %s", string(buf[13:readlen-1])), 500)
|
|
} else {
|
|
logstream.Level = level
|
|
sendResponse(connClt, "loglevel set", 200)
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Get IP sent by postfix
|
|
// tcp_table request is formated like "get the_ip\n"
|
|
if false == strings.HasPrefix(string(buf[:readlen-1]), "get ") {
|
|
sendResponse(connClt, fmt.Sprintf("Invalid request: %s", buf[:readlen-1]), 500)
|
|
continue
|
|
}
|
|
ip := string(buf[4 : readlen-1])
|
|
|
|
// First query netCache built with ipNetworkNumber
|
|
res, err := isIPContainedInNetCache(ip)
|
|
if err != nil {
|
|
if strings.EqualFold(err.Error(), fmt.Sprintf("Invalid IP: %s", ip)) {
|
|
// We don't want those msg to pollute logs
|
|
logstream.Info(err.Error())
|
|
} else {
|
|
logstream.Error(err.Error())
|
|
}
|
|
sendResponse(connClt, err.Error(), 500)
|
|
continue
|
|
}
|
|
|
|
// IP is allowed, return the IP with code 200
|
|
if res == true {
|
|
sendResponse(connClt, ip, 200)
|
|
continue
|
|
}
|
|
|
|
// Then, if no result, query LDAP for exact IP
|
|
filter := fmt.Sprintf("(ipHostNumber=%s)", ldap.EscapeFilter(ip))
|
|
searchReq := ldap.NewSearchRequest(*ldapBaseDN, ldap.ScopeWholeSubtree, 0, 0, 0,
|
|
false, filter, []string{"ipHostNumber"}, []ldap.Control{})
|
|
result, err := searchLdap(searchReq, 0)
|
|
if err != nil {
|
|
logstream.Errorf("Error searching into LDAP: %v\n", err)
|
|
sendResponse(connClt, err.Error(), 500)
|
|
continue
|
|
}
|
|
if len(result.Entries) < 1 {
|
|
s := "IP not authorized"
|
|
sendResponse(connClt, s, 500)
|
|
continue
|
|
}
|
|
if len(result.Entries) > 1 {
|
|
logstream.Infof("More than one match for IP %s", ip)
|
|
}
|
|
|
|
sendResponse(connClt, ip, 200)
|
|
|
|
}
|
|
}
|
|
|
|
func sendResponse(con net.Conn, respMsg string, respCode int) error {
|
|
response := fmt.Sprintf("%d %s\n", respCode, strings.Replace(respMsg, " ", "%20", -1))
|
|
_, err := con.Write([]byte(response))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func searchLdap(searchReq *ldap.SearchRequest, attempt int) (*ldap.SearchResult, error) {
|
|
mutex.Lock()
|
|
result, err := conLdap.Search(searchReq)
|
|
mutex.Unlock()
|
|
// Let's just manage connection errors here
|
|
if (err != nil && (strings.HasSuffix(err.Error(), "ldap: connection closed")|| strings.HasSuffix(err.Error(), "ldap: conn is nil, expected net.Conn"))) {
|
|
logstream.Error("LDAP connection closed, retrying")
|
|
mutex.Lock()
|
|
// 16/01/2023: panic: runtime error: invalid memory address or nil pointer dereference
|
|
// probably bc connection is already closed
|
|
if conLdap != nil {
|
|
conLdap.Close()
|
|
}
|
|
conLdap, err = connectLdap()
|
|
mutex.Unlock()
|
|
if err != nil {
|
|
return result, err
|
|
} else {
|
|
attempt = attempt + 1
|
|
return searchLdap(searchReq, attempt)
|
|
}
|
|
}
|
|
return result, err
|
|
}
|
|
|
|
func connectLdap() (*ldap.Conn, error) {
|
|
var err error
|
|
conLdap, err = ldap.DialURL(*ldapURL)
|
|
if err != nil {
|
|
logstream.Errorf("Error dialing LDAP on %s: %v\n", *ldapURL, err)
|
|
return nil, err
|
|
}
|
|
err = conLdap.Bind(*ldapUser, *ldapPass)
|
|
if err != nil {
|
|
logstream.Errorf("Error binding LDAP: %s", err)
|
|
return nil, err
|
|
}
|
|
return conLdap, err
|
|
}
|
|
|
|
// TODO : buildNetCache should have its own LDAP connection
|
|
func periodicallyUpdateCache(conLdap *ldap.Conn) {
|
|
// On initialise immediatement le cache
|
|
buildNetCacheFromIPNetwork(conLdap)
|
|
|
|
for range time.Tick(time.Second * time.Duration(*refreshInterval)) {
|
|
buildNetCacheFromIPNetwork(conLdap)
|
|
}
|
|
}
|
|
|
|
func run() {
|
|
listener, err := net.Listen("tcp", *listen)
|
|
if err != nil {
|
|
logstream.Fatalf("Error listening on %s: %v\n", *listen, err)
|
|
}
|
|
conLdap, err := connectLdap()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
go periodicallyUpdateCache(conLdap)
|
|
|
|
// Spawn a go routine for incoming connection
|
|
for {
|
|
connClt, err := listener.Accept()
|
|
if err != nil {
|
|
logstream.Errorf("Error accepting: ", err)
|
|
}
|
|
go handleConnection(connClt, conLdap)
|
|
}
|
|
conLdap.Close()
|
|
}
|
|
|
|
func main() {
|
|
fs := flag.NewFlagSet("mynettcptable", flag.ExitOnError)
|
|
listen = fs.String("listen-addr", "127.0.0.1:8080", "listen address for server (also via LISTEN env var)")
|
|
logTo = fs.String("logTo", "syslog", "Where to output logs. Valid values are \"stdout\", \"syslog\"")
|
|
logLevel = fs.String("logLevel", "warn", "Log level. Valid values are \"fatal\", \"error\", \"warn\", \"info\", \"debug\"")
|
|
ldapURL = fs.String("ldap", "", "LDAP Server URL (also via LDAP env var)")
|
|
ldapBaseDN = fs.String("ldapDN", "", "LDAP base DN (also via LDAPDN env var)")
|
|
ldapUser = fs.String("ldapUser", "", "LDAP user DN (also via LDAPUSER env var)")
|
|
ldapPass = fs.String("ldapPass", "", "LDAP user password (also via LDAPPASS env var)")
|
|
pidFilePath = fs.String("pidfile", "", "PID File (also via PIDFILE env var). Creates pidfile only if defined")
|
|
refreshInterval = fs.Int("refresh", 300, "Net cache update interval in seconds")
|
|
timeout = fs.Int("timeout", 5, "timeout in seconds")
|
|
_ = fs.String("config", "", "config file (optional)")
|
|
// Surcharge de la fonction Usage()
|
|
fs.Usage = func() {
|
|
fmt.Fprintf(flag.CommandLine.Output(), "%s version %s\n", os.Args[0], version)
|
|
fmt.Fprintf(flag.CommandLine.Output(), "Usage:\n")
|
|
fmt.Fprintf(flag.CommandLine.Output(), "Each argument can be set in config file.\n")
|
|
fs.PrintDefaults()
|
|
}
|
|
ff.Parse(fs, os.Args[1:], ff.WithEnvVarNoPrefix(), ff.WithConfigFileFlag("config"), ff.WithConfigFileParser(ff.PlainParser))
|
|
if len(*ldapURL) == 0 || len(*ldapBaseDN) == 0 || len(*ldapUser) == 0 || len(*ldapPass) == 0 {
|
|
fs.Usage()
|
|
return
|
|
}
|
|
|
|
fmt.Printf("%s: MyNetTCPTable v.%s starting\n", time.Now().Format(time.RFC3339), version)
|
|
|
|
logstream = logrus.New()
|
|
level, err := logrus.ParseLevel(*logLevel)
|
|
if err != nil {
|
|
fmt.Printf("Invalid log level: %s\n", *logLevel)
|
|
os.Exit(-1)
|
|
} else {
|
|
logstream.Level = level
|
|
}
|
|
|
|
if strings.EqualFold(*logTo, "stdout") {
|
|
logstream.Out = os.Stdout
|
|
}
|
|
if strings.EqualFold(*logTo, "syslog") {
|
|
// level != priority
|
|
prio := syslog.LOG_MAIL
|
|
switch *logLevel {
|
|
case "fatal":
|
|
prio += syslog.LOG_CRIT
|
|
case "error":
|
|
prio += syslog.LOG_ERR
|
|
case "warn":
|
|
prio += syslog.LOG_WARNING
|
|
case "info":
|
|
prio += syslog.LOG_INFO
|
|
case "debug":
|
|
prio += syslog.LOG_DEBUG
|
|
}
|
|
hook, err := lSyslog.NewSyslogHook("", "", prio, "mynettcptable")
|
|
if err != nil {
|
|
fmt.Printf("Error opening syslog: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
logstream.Hooks.Add(hook)
|
|
}
|
|
|
|
if len(*pidFilePath) > 0 {
|
|
if pid, err := pidfile.Create(*pidFilePath); err != nil {
|
|
logstream.Fatal(err)
|
|
} else {
|
|
defer pid.Clear()
|
|
}
|
|
}
|
|
|
|
logstream.Infof("Start listening for incoming connections on %s\n", *listen)
|
|
run()
|
|
}
|