diff --git a/ldap.go b/ldap.go index 7933452..7952566 100644 --- a/ldap.go +++ b/ldap.go @@ -19,6 +19,7 @@ type MyLdap struct { User string Pass string BaseDN string + AuthBaseDN string } var ( @@ -101,7 +102,7 @@ func searchByCn(myldap *MyLdap, baseDn, cn, class, attributes string) (*ldap.Sea } else { filter = cn } - return doLdapSearch(myldap, baseDn, filter, class, attributes) + return doLdapSearch(myldap, baseDn, false, filter, class, attributes) } func searchByDn(myldap *MyLdap, dn, attributes string) (*ldap.SearchResult, error) { @@ -110,10 +111,10 @@ func searchByDn(myldap *MyLdap, dn, attributes string) (*ldap.SearchResult, erro rem := strings.Split(dn, ",")[1:] bdn := strings.Join(rem, ",") bdn = strings.Replace(bdn, fmt.Sprintf(",%s", myldap.BaseDN), "", 1) - return doLdapSearch(myldap, bdn, filter, "ALL", "ALL") + return doLdapSearch(myldap, bdn, false, filter, "ALL", "ALL") } -func doLdapSearch(myldap *MyLdap, baseDn, filter, class, attributes string) (*ldap.SearchResult, error) { +func doLdapSearch(myldap *MyLdap, baseDn string, baseDnIsAbsolute bool, filter, class, attributes string) (*ldap.SearchResult, error) { var fFilter string var realBaseDn string var realAttributes []string @@ -135,7 +136,11 @@ func doLdapSearch(myldap *MyLdap, baseDn, filter, class, attributes string) (*ld if strings.EqualFold(baseDn, "ALL") || len(baseDn) == 0 { realBaseDn = fmt.Sprintf("%s", myldap.BaseDN) } else { - realBaseDn = fmt.Sprintf("%s,%s", baseDn, myldap.BaseDN) + if len(baseDn) > 0 && baseDnIsAbsolute { + realBaseDn = fmt.Sprintf("%s", baseDn) + } else { + realBaseDn = fmt.Sprintf("%s,%s", baseDn, myldap.BaseDN) + } } log.Debugf("LDAP search base dn: %s", realBaseDn) @@ -162,7 +167,7 @@ func doLdapSearch(myldap *MyLdap, baseDn, filter, class, attributes string) (*ld func findUserFullDN(myldap *MyLdap, username string) (string, error) { filter := fmt.Sprintf("cn=%s", username) - sr, err := doLdapSearch(myldap, "", filter, "ALL", "") + sr, err := doLdapSearch(myldap, myldap.AuthBaseDN, true, filter, "ALL", "") if err != nil { return "", err } diff --git a/main.go b/main.go index b5027cd..0de238d 100644 --- a/main.go +++ b/main.go @@ -22,7 +22,7 @@ import ( ) var ( - gVersion = "0.5.4" + gVersion = "0.5.5" gRoLdap *MyLdap ) @@ -95,7 +95,6 @@ func sendResponse(c *gin.Context, res *ldap.SearchResult, format string) { txtRes := marshalResultToText(res, "", true, false) log.Debugf("%v\n", string(txtRes)) c.String(http.StatusOK, string(txtRes)) - } } @@ -566,6 +565,7 @@ func main() { var ldapUser string var ldapPass string var ldapBaseDN string + var ldapAuthBaseDN string var tlsPrivKey string var tlsCert string var doTls bool @@ -577,6 +577,7 @@ func main() { flag.StringVar(&ldapUser, "ldap-user", "", "ldap read-only username") flag.StringVar(&ldapPass, "ldap-pass", "", "ldap password") flag.StringVar(&ldapBaseDN, "ldap-base-dn", "", "ldap base DN") + flag.StringVar(&ldapAuthBaseDN, "ldap-auth-base-dn", "", "ldap base DN to find authenticating users") flag.BoolVar(&doTls, "https", false, "Serve over TLS") flag.StringVar(&tlsPrivKey, "ssl-private-key", "", "SSL Private key") flag.StringVar(&tlsCert, "ssl-certificate", "", "SSL certificate (PEM format)") @@ -635,6 +636,14 @@ func main() { log.Fatal("No ldap-base-dn defined!") } } + if len(ldapAuthBaseDN) == 0 { + l := viper.GetString("LDAP_AUTH_BASE_DN") + if len(l) > 0 { + ldapAuthBaseDN = l + } else { + log.Fatal("No ldap-auth-base-dn defined!") + } + } if false == doTls { doTls = viper.GetBool("HTTPS") } @@ -662,7 +671,7 @@ func main() { r := gin.Default() - gRoLdap = &MyLdap{Host: ldapHost, User: ldapUser, Pass: ldapPass, BaseDN: ldapBaseDN} + gRoLdap = &MyLdap{Host: ldapHost, User: ldapUser, Pass: ldapPass, BaseDN: ldapBaseDN, AuthBaseDN: ldapAuthBaseDN} _, err := connectLdap(gRoLdap) if err != nil { log.Fatalf("Cannot connect to ldap: %v", err)