@@ -37,18 +37,17 @@ const (
3737)
3838
3939type routes struct {
40- upstream * url.URL
41- handler http.Handler
42- label string
43- labelValue string
40+ upstream * url.URL
41+ handler http.Handler
42+ label string
43+ el ExtractLabeler
4444
4545 mux http.Handler
4646 modifiers map [string ]func (* http.Response ) error
4747 errorOnReplace bool
4848}
4949
5050type options struct {
51- labelValue string
5251 enableLabelAPIs bool
5352 passthroughPaths []string
5453 errorOnReplace bool
@@ -96,14 +95,6 @@ func WithErrorOnReplace() Option {
9695 })
9796}
9897
99- // WithLabelValue enforces a specific value for the multi-tenancy label.
100- // If not specified, the value has to be provided as a URL parameter.
101- func WithLabelValue (value string ) Option {
102- return optionFunc (func (o * options ) {
103- o .labelValue = value
104- })
105- }
106-
10798// mux abstracts away the behavior we expect from the http.ServeMux type in this package.
10899type mux interface {
109100 http.Handler
@@ -170,11 +161,110 @@ func (i *instrumentedMux) Handle(pattern string, handler http.Handler) {
170161 i .mux .Handle (pattern , i .i .NewHandler (prometheus.Labels {"handler" : pattern }, handler ))
171162}
172163
173- func NewRoutes (upstream * url.URL , label string , opts ... Option ) (* routes , error ) {
164+ // ExtractLabeler is an HTTP handler that extract the label value to be
165+ // enforced from the HTTP request. If a valid label value is found, it should
166+ // store it in the request's context. Otherwise it should return an error in
167+ // the HTTP response (usually 400 or 500).
168+ type ExtractLabeler interface {
169+ ExtractLabel (next http.HandlerFunc ) http.Handler
170+ }
171+
172+ // HTTPFormEnforcer enforces a label value extracted from the HTTP form and query parameters.
173+ type HTTPFormEnforcer struct {
174+ ParameterName string
175+ }
176+
177+ // ExtractLabel implements the ExtractLabeler interface.
178+ func (hff HTTPFormEnforcer ) ExtractLabel (next http.HandlerFunc ) http.Handler {
179+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
180+ labelValue , err := hff .getLabelValue (r )
181+ if err != nil {
182+ prometheusAPIError (w , humanFriendlyErrorMessage (err ), http .StatusBadRequest )
183+ return
184+ }
185+
186+ // Remove the proxy label from the query parameters.
187+ q := r .URL .Query ()
188+ q .Del (hff .ParameterName )
189+ r .URL .RawQuery = q .Encode ()
190+
191+ // Remove the param from the PostForm.
192+ if r .Method == http .MethodPost {
193+ if err := r .ParseForm (); err != nil {
194+ prometheusAPIError (w , fmt .Sprintf ("Failed to parse the PostForm: %v" , err ), http .StatusInternalServerError )
195+ return
196+ }
197+ if r .PostForm .Get (hff .ParameterName ) != "" {
198+ r .PostForm .Del (hff .ParameterName )
199+ newBody := r .PostForm .Encode ()
200+ // We are replacing request body, close previous one (r.FormValue ensures it is read fully and not nil).
201+ _ = r .Body .Close ()
202+ r .Body = io .NopCloser (strings .NewReader (newBody ))
203+ r .ContentLength = int64 (len (newBody ))
204+ }
205+ }
206+
207+ next .ServeHTTP (w , r .WithContext (WithLabelValue (r .Context (), labelValue )))
208+ })
209+ }
210+
211+ func (hff HTTPFormEnforcer ) getLabelValue (r * http.Request ) (string , error ) {
212+ formValue := r .FormValue (hff .ParameterName )
213+ if formValue == "" {
214+ return "" , fmt .Errorf ("the %q query parameter must be provided" , hff .ParameterName )
215+ }
216+
217+ return formValue , nil
218+ }
219+
220+ // HTTPHeaderEnforcer enforces a label value extracted from the HTTP headers.
221+ type HTTPHeaderEnforcer struct {
222+ Name string
223+ }
224+
225+ // ExtractLabel implements the ExtractLabeler interface.
226+ func (hhe HTTPHeaderEnforcer ) ExtractLabel (next http.HandlerFunc ) http.Handler {
227+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
228+ labelValue , err := hhe .getLabelValue (r )
229+ if err != nil {
230+ prometheusAPIError (w , humanFriendlyErrorMessage (err ), http .StatusBadRequest )
231+ return
232+ }
233+
234+ next .ServeHTTP (w , r .WithContext (WithLabelValue (r .Context (), labelValue )))
235+ })
236+ }
237+
238+ func (hhe HTTPHeaderEnforcer ) getLabelValue (r * http.Request ) (string , error ) {
239+ headerValues := r .Header [hhe .Name ]
240+
241+ if len (headerValues ) == 0 {
242+ return "" , fmt .Errorf ("missing HTTP header %q" , hhe .Name )
243+ }
244+
245+ if len (headerValues ) > 1 {
246+ return "" , fmt .Errorf ("multiple values for the http header %q" , hhe .Name )
247+ }
248+
249+ return headerValues [0 ], nil
250+ }
251+
252+ // StaticLabelEnforcer enforces a static label value.
253+ type StaticLabelEnforcer string
254+
255+ // ExtractLabel implements the ExtractLabeler interface.
256+ func (sle StaticLabelEnforcer ) ExtractLabel (next http.HandlerFunc ) http.Handler {
257+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
258+ next (w , r .WithContext (WithLabelValue (r .Context (), string (sle ))))
259+ })
260+ }
261+
262+ func NewRoutes (upstream * url.URL , label string , extractLabeler ExtractLabeler , opts ... Option ) (* routes , error ) {
174263 opt := options {}
175264 for _ , o := range opts {
176265 o .apply (& opt )
177266 }
267+
178268 if opt .registerer == nil {
179269 opt .registerer = prometheus .NewRegistry ()
180270 }
@@ -185,35 +275,35 @@ func NewRoutes(upstream *url.URL, label string, opts ...Option) (*routes, error)
185275 upstream : upstream ,
186276 handler : proxy ,
187277 label : label ,
188- labelValue : opt . labelValue ,
278+ el : extractLabeler ,
189279 errorOnReplace : opt .errorOnReplace ,
190280 }
191281 mux := newStrictMux (newInstrumentedMux (http .NewServeMux (), opt .registerer ))
192282
193283 errs := merrors .New (
194- mux .Handle ("/federate" , r .enforceLabel (enforceMethods (r .matcher , "GET" ))),
195- mux .Handle ("/api/v1/query" , r .enforceLabel (enforceMethods (r .query , "GET" , "POST" ))),
196- mux .Handle ("/api/v1/query_range" , r .enforceLabel (enforceMethods (r .query , "GET" , "POST" ))),
197- mux .Handle ("/api/v1/alerts" , r .enforceLabel (enforceMethods (r .passthrough , "GET" ))),
198- mux .Handle ("/api/v1/rules" , r .enforceLabel (enforceMethods (r .passthrough , "GET" ))),
199- mux .Handle ("/api/v1/series" , r .enforceLabel (enforceMethods (r .matcher , "GET" , "POST" ))),
200- mux .Handle ("/api/v1/query_exemplars" , r .enforceLabel (enforceMethods (r .query , "GET" , "POST" ))),
284+ mux .Handle ("/federate" , r .el . ExtractLabel (enforceMethods (r .matcher , "GET" ))),
285+ mux .Handle ("/api/v1/query" , r .el . ExtractLabel (enforceMethods (r .query , "GET" , "POST" ))),
286+ mux .Handle ("/api/v1/query_range" , r .el . ExtractLabel (enforceMethods (r .query , "GET" , "POST" ))),
287+ mux .Handle ("/api/v1/alerts" , r .el . ExtractLabel (enforceMethods (r .passthrough , "GET" ))),
288+ mux .Handle ("/api/v1/rules" , r .el . ExtractLabel (enforceMethods (r .passthrough , "GET" ))),
289+ mux .Handle ("/api/v1/series" , r .el . ExtractLabel (enforceMethods (r .matcher , "GET" , "POST" ))),
290+ mux .Handle ("/api/v1/query_exemplars" , r .el . ExtractLabel (enforceMethods (r .query , "GET" , "POST" ))),
201291 )
202292
203293 if opt .enableLabelAPIs {
204294 errs .Add (
205- mux .Handle ("/api/v1/labels" , r .enforceLabel (enforceMethods (r .matcher , "GET" , "POST" ))),
295+ mux .Handle ("/api/v1/labels" , r .el . ExtractLabel (enforceMethods (r .matcher , "GET" , "POST" ))),
206296 // Full path is /api/v1/label/<label_name>/values but http mux does not support patterns.
207297 // This is fine though as we don't care about name for matcher injector.
208- mux .Handle ("/api/v1/label/" , r .enforceLabel (enforceMethods (r .matcher , "GET" ))),
298+ mux .Handle ("/api/v1/label/" , r .el . ExtractLabel (enforceMethods (r .matcher , "GET" ))),
209299 )
210300 }
211301
212302 errs .Add (
213- mux .Handle ("/api/v2/silences" , r .enforceLabel (enforceMethods (r .silences , "GET" , "POST" ))),
214- mux .Handle ("/api/v2/silence/" , r .enforceLabel (enforceMethods (r .deleteSilence , "DELETE" ))),
215- mux .Handle ("/api/v2/alerts/groups" , r .enforceLabel (enforceMethods (r .enforceFilterParameter , "GET" ))),
216- mux .Handle ("/api/v2/alerts" , r .enforceLabel (enforceMethods (r .alerts , "GET" ))),
303+ mux .Handle ("/api/v2/silences" , r .el . ExtractLabel (enforceMethods (r .silences , "GET" , "POST" ))),
304+ mux .Handle ("/api/v2/silence/" , r .el . ExtractLabel (enforceMethods (r .deleteSilence , "DELETE" ))),
305+ mux .Handle ("/api/v2/alerts/groups" , r .el . ExtractLabel (enforceMethods (r .enforceFilterParameter , "GET" ))),
306+ mux .Handle ("/api/v2/alerts" , r .el . ExtractLabel (enforceMethods (r .alerts , "GET" ))),
217307 )
218308
219309 errs .Add (
@@ -256,62 +346,6 @@ func NewRoutes(upstream *url.URL, label string, opts ...Option) (*routes, error)
256346 return r , nil
257347}
258348
259- func (r * routes ) enforceLabel (h http.HandlerFunc ) http.Handler {
260- return http .HandlerFunc (func (w http.ResponseWriter , req * http.Request ) {
261- lvalue , err := r .getLabelValue (req )
262- if err != nil {
263- prometheusAPIError (w , humanFriendlyErrorMessage (err ), http .StatusBadRequest )
264- return
265- }
266-
267- req = req .WithContext (withLabelValue (req .Context (), lvalue ))
268-
269- // Remove the proxy label from the query parameters.
270- q := req .URL .Query ()
271- if q .Get (r .label ) != "" {
272- q .Del (r .label )
273- }
274- req .URL .RawQuery = q .Encode ()
275- // Remove the proxy label from the PostForm.
276- if req .Method == http .MethodPost {
277- if err := req .ParseForm (); err != nil {
278- prometheusAPIError (w , fmt .Sprintf ("Failed to parse the PostForm: %v" , err ), http .StatusInternalServerError )
279- return
280- }
281- if req .PostForm .Get (r .label ) != "" {
282- req .PostForm .Del (r .label )
283- newBody := req .PostForm .Encode ()
284- // We are replacing request body, close previous one (req.FormValue ensures it is read fully and not nil).
285- _ = req .Body .Close ()
286- req .Body = io .NopCloser (strings .NewReader (newBody ))
287- req .ContentLength = int64 (len (newBody ))
288- }
289- }
290-
291- h .ServeHTTP (w , req )
292- })
293- }
294-
295- // getLabelValue returns the statically set label value, or the label value
296- // sent through a URL parameter.
297- // It returns an error when either the value is found in both places, or is not found at all.
298- func (r * routes ) getLabelValue (req * http.Request ) (string , error ) {
299- formValue := req .FormValue (r .label )
300- if r .labelValue != "" && formValue != "" {
301- return "" , fmt .Errorf ("a static value for the %s label has already been specified" , r .label )
302- }
303-
304- if r .labelValue == "" && formValue == "" {
305- return "" , fmt .Errorf ("the %q query parameter must be provided" , r .label )
306- }
307-
308- if r .labelValue != "" {
309- return r .labelValue , nil
310- }
311-
312- return formValue , nil
313- }
314-
315349func (r * routes ) ServeHTTP (w http.ResponseWriter , req * http.Request ) {
316350 r .mux .ServeHTTP (w , req )
317351}
@@ -341,7 +375,10 @@ type ctxKey int
341375
342376const keyLabel ctxKey = iota
343377
344- func mustLabelValue (ctx context.Context ) string {
378+ // MustLabelValue returns a label (previously stored using WithLabelValue())
379+ // from the given context.
380+ // It will panic if no label is found or the value is empty.
381+ func MustLabelValue (ctx context.Context ) string {
345382 label , ok := ctx .Value (keyLabel ).(string )
346383 if ! ok {
347384 panic (fmt .Sprintf ("can't find the %q value in the context" , keyLabel ))
@@ -352,7 +389,8 @@ func mustLabelValue(ctx context.Context) string {
352389 return label
353390}
354391
355- func withLabelValue (ctx context.Context , label string ) context.Context {
392+ // WithLabelValue stores a label in the given context.
393+ func WithLabelValue (ctx context.Context , label string ) context.Context {
356394 return context .WithValue (ctx , keyLabel , label )
357395}
358396
@@ -365,7 +403,7 @@ func (r *routes) query(w http.ResponseWriter, req *http.Request) {
365403 []* labels.Matcher {{
366404 Name : r .label ,
367405 Type : labels .MatchEqual ,
368- Value : mustLabelValue (req .Context ()),
406+ Value : MustLabelValue (req .Context ()),
369407 }}... )
370408
371409 // The `query` can come in the URL query string and/or the POST body.
@@ -451,7 +489,7 @@ func (r *routes) matcher(w http.ResponseWriter, req *http.Request) {
451489 matcher := & labels.Matcher {
452490 Name : r .label ,
453491 Type : labels .MatchEqual ,
454- Value : mustLabelValue (req .Context ()),
492+ Value : MustLabelValue (req .Context ()),
455493 }
456494 q := req .URL .Query ()
457495
0 commit comments