auth/oauth2.go (115 lines of code) (raw):
package auth
import (
"context"
"crypto/tls"
"crypto/x509"
"ddm-admin-console/service/k8s"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/pkg/errors"
"golang.org/x/oauth2"
)
type OAuth2 struct {
clientID string
secret string
discoveryURL string
redirectURL string
httpClient *http.Client
providerInfo providerInfo
Config *oauth2.Config
}
type providerInfo struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
ScopesSupported []string `json:"scopes_supported"`
}
func InitOauth2(clientID, secret, discoveryURL, redirectURL string, httpClient *http.Client) (*OAuth2, error) {
if httpClient == nil {
httpClient = http.DefaultClient
}
oa2 := OAuth2{
clientID: clientID,
secret: secret,
discoveryURL: discoveryURL,
redirectURL: redirectURL,
httpClient: httpClient,
}
if !strings.Contains(discoveryURL, ".well-known") {
discoveryURL = strings.TrimSuffix(discoveryURL, "/") + "/.well-known/oauth-authorization-server"
}
rsp, err := httpClient.Get(discoveryURL)
if err != nil {
return nil, errors.Wrap(err, "unable to get discovery url")
}
defer rsp.Body.Close()
body, err := io.ReadAll(rsp.Body)
if err != nil {
return nil, errors.Errorf("unable to read response body: %v", err)
}
if rsp.StatusCode != http.StatusOK {
return nil, errors.Errorf("%s: %s", rsp.Status, string(body))
}
if err := json.Unmarshal(body, &oa2.providerInfo); err != nil {
return nil, errors.Wrap(err, "unable to unmarshal discovery body")
}
oa2.initConfig()
return &oa2, nil
}
func (o *OAuth2) initConfig() {
o.Config = &oauth2.Config{
ClientID: o.clientID,
ClientSecret: o.secret,
Scopes: o.providerInfo.ScopesSupported,
Endpoint: oauth2.Endpoint{
AuthURL: o.providerInfo.AuthURL,
TokenURL: o.providerInfo.TokenURL,
},
RedirectURL: o.redirectURL,
}
}
func (o *OAuth2) UseInternalTokenService(ctx context.Context, serviceHost string, k8sService k8s.ServiceInterface) error {
tokenUrl, err := url.Parse(o.Config.Endpoint.TokenURL)
if err != nil {
return errors.Wrap(err, "unable to parse token url")
}
tokenUrl.Host = serviceHost
o.Config.Endpoint.TokenURL = tokenUrl.String()
cm, err := k8sService.GetConfigMap(ctx, "openshift-service-ca.crt", "openshift-config-managed")
if err != nil {
return errors.Wrap(err, "unable to get openshift ca config map")
}
ca, ok := cm.Data["service-ca.crt"]
if !ok {
return errors.New("no service ca found in config map")
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM([]byte(ca)) {
return errors.New("unable to append certs from PEM")
}
o.httpClient.Transport = &http.Transport{TLSClientConfig: &tls.Config{
RootCAs: caCertPool,
}}
return nil
}
func (o *OAuth2) AuthCodeURL() string {
return o.Config.AuthCodeURL(fmt.Sprintf("state-%d", time.Now().Unix()))
}
func (o *OAuth2) GetTokenClient(ctx context.Context, code string) (token *oauth2.Token, oauthClient *http.Client,
err error) {
ctx = context.WithValue(ctx, oauth2.HTTPClient, o.httpClient)
token, err = o.Config.Exchange(ctx, code)
if err != nil {
return nil, nil, errors.Wrap(err, "unable to get access token")
}
oauthClient = o.Config.Client(ctx, token)
return
}
func (o *OAuth2) GetHTTPClient(ctx context.Context, token *oauth2.Token) *http.Client {
return o.Config.Client(ctx, token)
}