summaryrefslogtreecommitdiffstats
path: root/tools/chrome_proxy/testserver/server_test.go
blob: d215131194a231d3233c647a1f8e01b4c910913c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
package server

import (
	"encoding/base64"
	"encoding/json"
	"net/http"
	"net/http/httptest"
	"net/url"
	"reflect"
	"strconv"
	"testing"
)

func composeQuery(path string, code int, headers http.Header, body []byte) (string, error) {
	u, err := url.Parse(path)
	if err != nil {
		return "", err
	}
	q := u.Query()
	if code > 0 {
		q.Set("respStatus", strconv.Itoa(code))
	}
	if headers != nil {
		h, err := json.Marshal(headers)
		if err != nil {
			return "", err
		}
		q.Set("respHeader", base64.URLEncoding.EncodeToString(h))
	}
	if len(body) > 0 {
		q.Set("respBody", base64.URLEncoding.EncodeToString(body))
	}
	u.RawQuery = q.Encode()
	return u.String(), nil
}

func TestResponseOverride(t *testing.T) {
	tests := []struct {
		name    string
		code    int
		headers http.Header
		body    []byte
	}{
		{name: "code", code: 204},
		{name: "body", body: []byte("new body")},
		{
			name: "headers",
			headers: http.Header{
				"Via":          []string{"Via1", "Via2"},
				"Content-Type": []string{"random content"},
			},
		},
		{
			name: "everything",
			code: 204,
			body: []byte("new body"),
			headers: http.Header{
				"Via":          []string{"Via1", "Via2"},
				"Content-Type": []string{"random content"},
			},
		},
	}

	for _, test := range tests {
		u, err := composeQuery("http://test.com/override", test.code, test.headers, test.body)
		if err != nil {
			t.Errorf("%s: composeQuery: %v", test.name, err)
			return
		}
		req, err := http.NewRequest("GET", u, nil)
		if err != nil {
			t.Errorf("%s: http.NewRequest: %v", test.name, err)
			return
		}
		w := httptest.NewRecorder()
		defaultResponse(w, req)
		if test.code > 0 {
			if got, want := w.Code, test.code; got != want {
				t.Errorf("%s: response code: got %d want %d", test.name, got, want)
				return
			}
		}
		if test.headers != nil {
			for k, want := range test.headers {
				got, ok := w.HeaderMap[k]
				if !ok || !reflect.DeepEqual(got, want) {
					t.Errorf("%s: header %s: code: got %v want %v", test.name, k, got, want)
					return
				}
			}
		}
		if test.body != nil {
			if got, want := string(w.Body.Bytes()), string(test.body); got != want {
				t.Errorf("%s: body: got %s want %s", test.name, got, want)
				return
			}
		}
	}
}