diff --git a/policyfile.go b/policyfile.go index 02f15a0..fda443e 100644 --- a/policyfile.go +++ b/policyfile.go @@ -221,6 +221,32 @@ func (pr *PolicyFileResource) Set(ctx context.Context, acl any, etag string) err return pr.do(req, nil) } +// Set sets the [ACL] for the tailnet and returns the resulting [ACL]. +// etag is an optional value that, if supplied, will be used in the "If-Match" HTTP request header. +func (pr *PolicyFileResource) SetAndGet(ctx context.Context, acl ACL, etag string) (*ACL, error) { + headers := make(map[string]string) + if etag != "" { + headers["If-Match"] = fmt.Sprintf("%q", etag) + } + + reqOpts := []requestOption{ + requestHeaders(headers), + requestBody(acl), + } + + req, err := pr.buildRequest(ctx, http.MethodPost, pr.buildTailnetURL("acl"), reqOpts...) + if err != nil { + return nil, err + } + + out, header, err := bodyWithResponseHeader[ACL](pr, req) + if err != nil { + return nil, err + } + out.ETag = header.Get("Etag") + return out, nil +} + // Validate validates the provided ACL via the API. acl can either be an [ACL], or a HuJSON string. func (pr *PolicyFileResource) Validate(ctx context.Context, acl any) error { reqOpts := []requestOption{ diff --git a/policyfile_test.go b/policyfile_test.go index 557a4ff..53c3703 100644 --- a/policyfile_test.go +++ b/policyfile_test.go @@ -283,6 +283,62 @@ func TestClient_SetACL(t *testing.T) { assert.EqualValues(t, expectedACL, actualACL) } +func TestClient_SetAndGetACL(t *testing.T) { + t.Parallel() + + client, server := NewTestHarness(t) + server.ResponseCode = http.StatusOK + server.ResponseHeader.Set("ETag", "abcdefg") + in := ACL{ + ACLs: []ACLEntry{ + { + Action: "accept", + Ports: []string{"*:*"}, + Users: []string{"*"}, + }, + }, + TagOwners: map[string][]string{ + "tag:example": {"group:example"}, + }, + Hosts: map[string]string{ + "example-host-1": "100.100.100.100", + "example-host-2": "100.100.101.100/24", + }, + Groups: map[string][]string{ + "group:example": { + "user1@example.com", + "user2@example.com", + }, + }, + Tests: []ACLTest{ + { + User: "user1@example.com", + Allow: []string{"example-host-1:22", "example-host-2:80"}, + Deny: []string{"exapmle-host-2:100"}, + }, + { + User: "user2@example.com", + Allow: []string{"100.60.3.4:22"}, + }, + }, + ETag: "abcdefg", + } + server.ResponseBody = in + + out, err := client.PolicyFile().SetAndGet(context.Background(), in, "abcdefg") + assert.NoError(t, err) + assert.Equal(t, http.MethodPost, server.Method) + assert.Equal(t, "/api/v2/tailnet/example.com/acl", server.Path) + assert.Equal(t, `"abcdefg"`, server.Header.Get("If-Match")) + assert.EqualValues(t, "application/json", server.Header.Get("Content-Type")) + assert.EqualValues(t, &in, out) + + var actualACL ACL + assert.NoError(t, json.Unmarshal(server.Body.Bytes(), &actualACL)) + in.ETag = "" + assert.EqualValues(t, in, actualACL) +} + func TestClient_SetACL_HuJSON(t *testing.T) { t.Parallel()