glapi/main.go

377 lines
9.3 KiB
Go

// Go Ldap Api
// Copyright (c) 2022 yo000 <johan@nosd.in>
//
package main
import (
"os"
"fmt"
"flag"
"time"
"strings"
"net/http"
"encoding/json"
"github.com/spf13/viper"
"github.com/gin-gonic/gin"
"github.com/go-ldap/ldap/v3"
log "github.com/sirupsen/logrus"
)
var (
gVersion = "0.5.1"
)
func marshalResultToText(res *ldap.SearchResult, delimiter string, showValueName, showDN bool) string {
var txtRes string
for _, e := range res.Entries {
if showDN {
if showValueName {
txtRes = fmt.Sprintf("%s%s\n", txtRes, e.DN)
} else {
txtRes = fmt.Sprintf("%sdn%s%s\n", txtRes, delimiter, e.DN)
}
}
for _, a := range e.Attributes {
for _, v := range a.Values {
if showValueName {
txtRes = fmt.Sprintf("%s%s\n", txtRes, v)
} else {
txtRes = fmt.Sprintf("%s%s%s%s\n", txtRes, a.Name, delimiter, v)
}
}
}
// No DN = No linefeed between entries
if showDN {
txtRes = fmt.Sprintf("%s\n", txtRes)
}
}
return txtRes
}
func sendResponse(c *gin.Context, res *ldap.SearchResult, format string) {
// 404 Not found
if len(res.Entries) == 0 {
if strings.EqualFold(format, "json") {
c.JSON(http.StatusNotFound, gin.H{"error": "No result"})
} else {
c.String(http.StatusNotFound, "No result")
}
return
}
log.Debugf("Got %d results", len(res.Entries))
if strings.EqualFold(format, "json") {
jsonRes, err := json.Marshal(res.Entries)
if err != nil {
log.Errorf("Error marshalling result to json: %v", err)
}
log.Debugf("%v\n", string(jsonRes))
c.String(http.StatusOK, string(jsonRes))
} else if strings.EqualFold(format, "text") {
txtRes := marshalResultToText(res, "=", false, true)
log.Debugf("%v\n", string(txtRes))
c.String(http.StatusOK, string(txtRes))
} else if strings.EqualFold(format, "ldif") {
txtRes := marshalResultToText(res, ": ", false, true)
log.Debugf("%v\n", string(txtRes))
c.String(http.StatusOK, string(txtRes))
} else if strings.EqualFold(format, "textvalue") {
txtRes := marshalResultToText(res, "", true, true)
log.Debugf("%v\n", string(txtRes))
c.String(http.StatusOK, string(txtRes))
} else if strings.EqualFold(format, "textvalue-nodn") {
txtRes := marshalResultToText(res, "", true, false)
log.Debugf("%v\n", string(txtRes))
c.String(http.StatusOK, string(txtRes))
}
}
func checkIfModifiedSince(c *gin.Context, myldap *MyLdap, baseDn, cn, class, attributes string) (bool, error) {
// FIXME: We need to cache the last result, because if an item is deleted from LDAP we won't see it and
// we will return 304. So deletions will never make their way to Rspamd
// For now, lets always return "Modified"
return true, nil
if len(c.Request.Header["If-Modified-Since"]) > 0 {
t := strings.Replace(c.Request.Header["If-Modified-Since"][0], "GMT", "+0000", 1)
ifModifiedSince, _ := time.Parse(time.RFC1123Z, t)
log.Debugf("ifModifiedSince: %s", ifModifiedSince)
res, err := doLdapSearch(myldap, baseDn, cn, class, "modifyTimestamp")
if err != nil {
log.Errorf("Error searching modifyTimestamp for %s in %s : %v", cn, baseDn, err)
return true, err
}
// modifyTimestamp format
mtFmt := "20060102150405Z"
// Compare each object timestamp
hasNewer := false
for _, e := range res.Entries {
for _, a := range e.Attributes {
if strings.EqualFold(a.Name, "modifyTimestamp") {
mt, _ := time.Parse(mtFmt, a.Values[0])
log.Debugf("%s modifyTimestamp: %s", e.DN, mt)
if mt.Unix() > ifModifiedSince.Unix() {
log.Debugf("%s is newer than %s: %s", e.DN, ifModifiedSince, mt)
hasNewer = true
break
}
}
}
if hasNewer {
break
}
}
if false == hasNewer {
return false, nil
}
}
return true, nil
}
// Basic Authentication handler
// TODO: Where to store accounts?
func basicAuth(c *gin.Context) {
user, password, hasAuth := c.Request.BasicAuth()
if hasAuth && user == "admin" && password == "admin" {
log.Infof("[%s]: User %s successfully authenticated", c.Request.RemoteAddr, user)
} else {
c.AbortWithStatus(http.StatusUnauthorized)
c.Writer.Header().Set("WWW-Authenticate", "Basic realm=Restricted")
return
}
}
func initRouter(r *gin.Engine, myldap *MyLdap) {
r.GET("/ping", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"message": "pong",
})
})
// All following routes need authentication
r.GET("/:ou/:cn/:class", basicAuth, func(c *gin.Context) {
ou := c.Param("ou")
cn := c.Param("cn")
class := c.Param("class")
// json format is the default
format := c.DefaultQuery("format", "json")
log.Printf("Format : %s", format)
res, err := doLdapSearch(myldap, ou, cn, class, "ALL")
// If OU does not exist, we'll get err='LDAP Result Code 32 "No Such Object"'
if err != nil {
log.Errorf("Error searching %s in %s : %v", cn, ou, err)
c.AbortWithError(http.StatusInternalServerError, err)
return
}
sendResponse(c, res, format)
return
})
r.HEAD("/:ou/:cn/:class", basicAuth, func(c *gin.Context) {
ou := c.Param("ou")
cn := c.Param("cn")
class := c.Param("class")
format := c.DefaultQuery("format", "json")
log.Printf("Format : %s", format)
modified, err := checkIfModifiedSince(c, myldap, ou, cn, class, "ALL")
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
if modified {
res, err := doLdapSearch(myldap, ou, cn, class, "ALL")
if err != nil {
log.Errorf("Error searching %s in %s : %v", cn, ou, err)
c.AbortWithError(http.StatusInternalServerError, err)
return
}
sendResponse(c, res, format)
} else {
c.String(http.StatusNotModified, "")
}
return
})
r.GET("/:ou/:cn/:class/:attribute", basicAuth, func(c *gin.Context) {
ou := c.Param("ou")
cn := c.Param("cn")
attr := c.Param("attribute")
class := c.Param("class")
format := c.DefaultQuery("format", "json")
log.Printf("Format : %s", format)
res, err := doLdapSearch(myldap, ou, cn, class, attr)
if err != nil {
log.Errorf("Error searching %s in %s : %v", cn, ou, err)
c.AbortWithError(http.StatusInternalServerError, err)
return
}
sendResponse(c, res, format)
return
})
r.HEAD("/:ou/:cn/:class/:attribute", basicAuth, func(c *gin.Context) {
ou := c.Param("ou")
cn := c.Param("cn")
attr := c.Param("attribute")
class := c.Param("class")
format := c.DefaultQuery("format", "json")
log.Printf("Format : %s", format)
modified, err := checkIfModifiedSince(c, myldap, ou, cn, class, attr)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
if modified {
res, err := doLdapSearch(myldap, ou, cn, class, attr)
if err != nil {
log.Errorf("Error searching %s in %s : %v", cn, ou, err)
c.AbortWithError(http.StatusInternalServerError, err)
return
}
sendResponse(c, res, format)
} else {
c.String(http.StatusNotModified, "")
}
return
})
}
func main() {
var confFile string
var listen string
var ldapHost string
var ldapUser string
var ldapPass string
var ldapBaseDN string
var tlsPrivKey string
var tlsCert string
var doTls bool
var debug bool
flag.StringVar(&confFile, "config", "", "Path to the config file (optional)")
flag.StringVar(&listen, "listen-addr", "0.0.0.0:8080", "listen address for server")
flag.StringVar(&ldapHost, "ldap-host", "", "ldap host to connect to")
flag.StringVar(&ldapUser, "ldap-user", "", "ldap username")
flag.StringVar(&ldapPass, "ldap-pass", "", "ldap password")
flag.StringVar(&ldapBaseDN, "ldap-base-dn", "", "ldap base DN")
flag.BoolVar(&doTls, "https", false, "Serve over TLS")
flag.StringVar(&tlsPrivKey, "ssl-private-key", "", "SSL Private key")
flag.StringVar(&tlsCert, "ssl-certificate", "", "SSL certificate (PEM format)")
flag.BoolVar(&debug, "debug", false, "Set log level to debug")
flag.Parse()
if len(confFile) > 0 {
viper.SetConfigFile(confFile)
if err := viper.ReadInConfig(); err != nil {
log.Fatalf("Could not open config file: %v", err)
os.Exit(1)
}
}
if strings.EqualFold(listen, "0.0.0.0:8080") && len(confFile) > 0 {
l := viper.GetString("LISTEN")
if len(l) > 0 {
listen = l
}
}
if len(ldapHost) == 0 {
l := viper.GetString("LDAP_HOST")
if len(l) > 0 {
ldapHost = l
} else {
log.Fatal("No ldap-host defined!")
}
}
if len(ldapUser) == 0 {
l := viper.GetString("LDAP_USER")
if len(l) > 0 {
ldapUser = l
} else {
log.Fatal("No ldap-user defined!")
}
}
if len(ldapPass) == 0 {
l := viper.GetString("LDAP_PASS")
if len(l) > 0 {
ldapPass = l
} else {
log.Fatal("No ldap-pass defined!")
}
}
if len(ldapBaseDN) == 0 {
l := viper.GetString("LDAP_BASE_DN")
if len(l) > 0 {
ldapBaseDN = l
} else {
log.Fatal("No ldap-base-dn defined!")
}
}
if false == doTls {
doTls = viper.GetBool("HTTPS")
}
if doTls && len(tlsCert) == 0 {
l := viper.GetString("SSL_CERTIFICATE")
if len(l) > 0 {
tlsCert = l
} else {
log.Fatal("SSL certificate must be set to use https!")
}
}
if doTls && len(tlsPrivKey) == 0 {
l := viper.GetString("SSL_PRIVATE_KEY")
if len(l) > 0 {
tlsPrivKey = l
} else {
log.Fatal("SSL private key must be set to use https!")
}
}
log.Println("Starting Go Ldap API v.", gVersion)
if debug {
log.SetLevel(log.DebugLevel)
}
r := gin.Default()
ldap := MyLdap{Host: ldapHost, User: ldapUser, Pass: ldapPass, BaseDN: ldapBaseDN}
initRouter(r, &ldap)
if doTls {
r.RunTLS(listen, tlsCert, tlsPrivKey)
} else {
r.Run(listen)
}
}