Skip to content

Commit bb346b4

Browse files
authored
Merge pull request #715 from huanghaoyuanhhy/fix-az-ep
storage: fix azure list iter
2 parents 835e021 + d4a1040 commit bb346b4

File tree

4 files changed

+174
-38
lines changed

4 files changed

+174
-38
lines changed

core/storage/azure.go

Lines changed: 117 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,19 @@ func (a *AzureReader) Close() error { return nil }
5555
var _ Client = (*AzureClient)(nil)
5656

5757
func newAzureClient(cfg Config) (*AzureClient, error) {
58+
// backwards compatible, don't know why we kept the "blob" in the code instead of letting it be input externally.
59+
ep := fmt.Sprintf("https://%s.blob.%s", cfg.Credential.AzureAccountName, cfg.Endpoint)
5860
switch cfg.Credential.Type {
5961
case IAM:
6062
cred, err := azidentity.NewDefaultAzureCredential(nil)
6163
if err != nil {
6264
return nil, fmt.Errorf("storage: new azure default azure credential %w", err)
6365
}
64-
cli, err := azblob.NewClient(cfg.Endpoint, cred, nil)
66+
cli, err := azblob.NewClient(ep, cred, nil)
6567
if err != nil {
6668
return nil, fmt.Errorf("storage: new azure client %w", err)
6769
}
68-
sasCli, err := service.NewClient(cfg.Endpoint, cred, nil)
70+
sasCli, err := service.NewClient(ep, cred, nil)
6971
if err != nil {
7072
return nil, fmt.Errorf("storage: new azure service client %w", err)
7173
}
@@ -76,11 +78,11 @@ func newAzureClient(cfg Config) (*AzureClient, error) {
7678
if err != nil {
7779
return nil, fmt.Errorf("storage: new azure shared key credential %w", err)
7880
}
79-
cli, err := azblob.NewClientWithSharedKeyCredential(cfg.Endpoint, cred, nil)
81+
cli, err := azblob.NewClientWithSharedKeyCredential(ep, cred, nil)
8082
if err != nil {
8183
return nil, fmt.Errorf("storage: new azure client %w", err)
8284
}
83-
sasCli, err := service.NewClientWithSharedKeyCredential(cfg.Endpoint, cred, nil)
85+
sasCli, err := service.NewClientWithSharedKeyCredential(ep, cred, nil)
8486
return &AzureClient{cfg: cfg, cli: cli, sasCli: sasCli}, nil
8587
default:
8688
return nil, fmt.Errorf("storage: azure unsupported credential type: %s", cfg.Credential.Type.String())
@@ -219,46 +221,139 @@ func (a *AzureClient) UploadObject(ctx context.Context, i UploadObjectInput) err
219221
return nil
220222
}
221223

222-
type AzureObjectIterator struct {
224+
type AzureObjectFlatIterator struct {
223225
cli *AzureClient
224226

225-
pager *runtime.Pager[azblob.ListBlobsFlatResponse]
226-
currPage azblob.ListBlobsFlatResponse
227-
currIndex int
227+
pager *runtime.Pager[azblob.ListBlobsFlatResponse]
228+
229+
currPage []ObjectAttr
230+
nextIdx int
231+
}
232+
233+
func (flatIter *AzureObjectFlatIterator) HasNext() bool {
234+
// current page has more entries
235+
if flatIter.nextIdx < len(flatIter.currPage) {
236+
return true
237+
}
238+
239+
// current page is the last page
240+
if !flatIter.pager.More() {
241+
return false
242+
}
243+
244+
// try to get next page
245+
page, err := flatIter.pager.NextPage(context.Background())
246+
if err != nil {
247+
log.Warn("failed to get next page", zap.Error(err))
248+
return false
249+
}
250+
flatIter.currPage = flatIter.currPage[:0]
251+
for _, blob := range page.Segment.BlobItems {
252+
attr := ObjectAttr{Key: *blob.Name, Length: *blob.Properties.ContentLength}
253+
flatIter.currPage = append(flatIter.currPage, attr)
254+
}
255+
flatIter.nextIdx = 0
256+
return true
257+
}
258+
259+
func (flatIter *AzureObjectFlatIterator) Next() (ObjectAttr, error) {
260+
attr := flatIter.currPage[flatIter.nextIdx]
261+
flatIter.nextIdx += 1
262+
263+
return attr, nil
264+
}
265+
266+
type AzureObjectHierarchyIterator struct {
267+
cli *AzureClient
268+
269+
pager *runtime.Pager[container.ListBlobsHierarchyResponse]
270+
271+
currPage []ObjectAttr
272+
nextIdx int
228273
}
229274

230-
func (a *AzureObjectIterator) HasNext() bool {
275+
func (hierIter *AzureObjectHierarchyIterator) HasNext() bool {
231276
// current page has more entries
232-
if a.currIndex < len(a.currPage.Segment.BlobItems) {
277+
if hierIter.nextIdx < len(hierIter.currPage) {
233278
return true
234279
}
235280

236-
// current page is the last page, try to get next page
237-
if !a.pager.More() {
281+
// no more page
282+
if !hierIter.pager.More() {
238283
return false
239284
}
240285

241-
page, err := a.pager.NextPage(context.Background())
286+
// try to get next page
287+
page, err := hierIter.pager.NextPage(context.Background())
242288
if err != nil {
243289
log.Warn("failed to get next page", zap.Error(err))
244290
return false
245291
}
246-
a.currPage = page
247-
a.currIndex = 0
292+
hierIter.currPage = hierIter.currPage[:0]
293+
for _, blob := range page.Segment.BlobItems {
294+
attr := ObjectAttr{Key: *blob.Name, Length: *blob.Properties.ContentLength}
295+
hierIter.currPage = append(hierIter.currPage, attr)
296+
}
297+
for _, prefix := range page.Segment.BlobPrefixes {
298+
hierIter.currPage = append(hierIter.currPage, ObjectAttr{Key: *prefix.Name})
299+
}
300+
hierIter.nextIdx = 0
248301
return true
249302
}
250303

251-
func (a *AzureObjectIterator) Next() (ObjectAttr, error) {
252-
blob := a.currPage.Segment.BlobItems[a.currIndex]
253-
a.currIndex++
304+
func (hierIter *AzureObjectHierarchyIterator) Next() (ObjectAttr, error) {
305+
attr := hierIter.currPage[hierIter.nextIdx]
306+
hierIter.nextIdx += 1
254307

255-
return ObjectAttr{Key: *blob.Name, Length: *blob.Properties.ContentLength}, nil
308+
return attr, nil
256309
}
257310

258-
func (a *AzureClient) ListPrefix(_ context.Context, prefix string, _ bool) (ObjectIterator, error) {
259-
// currently only support list prefix recursively
311+
func (a *AzureClient) ListPrefix(_ context.Context, prefix string, recursive bool) (ObjectIterator, error) {
312+
if recursive {
313+
return a.listPrefixRecursive(prefix)
314+
}
315+
return a.listPrefixNonRecursive(prefix)
316+
}
317+
318+
func (a *AzureClient) listPrefixRecursive(prefix string) (*AzureObjectFlatIterator, error) {
260319
pager := a.cli.NewListBlobsFlatPager(a.cfg.Bucket, &azblob.ListBlobsFlatOptions{Prefix: to.Ptr(prefix)})
261-
return &AzureObjectIterator{cli: a, pager: pager}, nil
320+
page, err := pager.NextPage(context.Background())
321+
if err != nil {
322+
return nil, fmt.Errorf("storage: azure list prefix %w", err)
323+
}
324+
325+
var currPage []ObjectAttr
326+
if page.Segment != nil {
327+
for _, blob := range page.Segment.BlobItems {
328+
attr := ObjectAttr{Key: *blob.Name, Length: *blob.Properties.ContentLength}
329+
currPage = append(currPage, attr)
330+
}
331+
}
332+
333+
return &AzureObjectFlatIterator{cli: a, pager: pager, currPage: currPage}, nil
334+
}
335+
336+
func (a *AzureClient) listPrefixNonRecursive(prefix string) (*AzureObjectHierarchyIterator, error) {
337+
pager := a.cli.ServiceClient().
338+
NewContainerClient(a.cfg.Bucket).
339+
NewListBlobsHierarchyPager("/", &container.ListBlobsHierarchyOptions{Prefix: to.Ptr(prefix)})
340+
page, err := pager.NextPage(context.Background())
341+
if err != nil {
342+
return nil, fmt.Errorf("storage: azure list prefix %w", err)
343+
}
344+
345+
var currPage []ObjectAttr
346+
if page.Segment != nil {
347+
for _, blob := range page.Segment.BlobItems {
348+
attr := ObjectAttr{Key: *blob.Name, Length: *blob.Properties.ContentLength}
349+
currPage = append(currPage, attr)
350+
}
351+
for _, pre := range page.Segment.BlobPrefixes {
352+
currPage = append(currPage, ObjectAttr{Key: *pre.Name})
353+
}
354+
}
355+
356+
return &AzureObjectHierarchyIterator{cli: a, pager: pager, currPage: currPage}, nil
262357
}
263358

264359
func (a *AzureClient) DeleteObject(ctx context.Context, prefix string) error {

core/storage/client.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ type Credential struct {
6262

6363
// MinioCredential
6464
MinioCredProvider minioCred.Provider
65+
66+
// Azure Specific
67+
AzureAccountName string
6568
}
6669

6770
type CredentialType uint8

core/storage/factory.go

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,29 @@ import (
1212
)
1313

1414
func newBackupCredential(params *paramtable.BackupParams) Credential {
15+
var cred Credential
16+
if params.MinioCfg.BackupStorageType == paramtable.CloudProviderAzure {
17+
cred.AzureAccountName = params.MinioCfg.BackupAccessKeyID
18+
}
19+
1520
if params.MinioCfg.BackupUseIAM {
16-
return Credential{Type: IAM, IAMEndpoint: params.MinioCfg.BackupIAMEndpoint}
21+
cred.Type = IAM
22+
cred.IAMEndpoint = params.MinioCfg.BackupIAMEndpoint
23+
return cred
1724
}
1825

1926
if params.MinioCfg.BackupStorageType == paramtable.CloudProviderGCPNative &&
2027
params.MinioCfg.BackupGcpCredentialJSON != "" {
21-
return Credential{Type: GCPCredJSON, GCPCredJSON: params.MinioCfg.BackupGcpCredentialJSON}
28+
cred.Type = GCPCredJSON
29+
cred.GCPCredJSON = params.MinioCfg.BackupGcpCredentialJSON
30+
return cred
2231
}
2332

24-
return Credential{
25-
Type: Static,
26-
AK: params.MinioCfg.BackupAccessKeyID,
27-
SK: params.MinioCfg.BackupSecretAccessKey,
28-
Token: params.MinioCfg.BackupToken,
29-
}
33+
cred.Type = Static
34+
cred.AK = params.MinioCfg.BackupAccessKeyID
35+
cred.SK = params.MinioCfg.BackupSecretAccessKey
36+
cred.Token = params.MinioCfg.BackupToken
37+
return cred
3038
}
3139

3240
func NewBackupStorage(ctx context.Context, params *paramtable.BackupParams) (Client, error) {
@@ -57,21 +65,29 @@ func NewBackupStorage(ctx context.Context, params *paramtable.BackupParams) (Cli
5765
}
5866

5967
func newMilvusCredential(params *paramtable.BackupParams) Credential {
68+
var cred Credential
69+
if params.MinioCfg.StorageType == paramtable.CloudProviderAzure {
70+
cred.AzureAccountName = params.MinioCfg.AccessKeyID
71+
}
72+
6073
if params.MinioCfg.UseIAM {
61-
return Credential{Type: IAM, IAMEndpoint: params.MinioCfg.IAMEndpoint}
74+
cred.Type = IAM
75+
cred.IAMEndpoint = params.MinioCfg.IAMEndpoint
76+
return cred
6277
}
6378

6479
if params.MinioCfg.StorageType == paramtable.CloudProviderGCPNative &&
6580
params.MinioCfg.GcpCredentialJSON != "" {
66-
return Credential{Type: GCPCredJSON, GCPCredJSON: params.MinioCfg.GcpCredentialJSON}
81+
cred.Type = GCPCredJSON
82+
cred.GCPCredJSON = params.MinioCfg.GcpCredentialJSON
83+
return cred
6784
}
6885

69-
return Credential{
70-
Type: Static,
71-
AK: params.MinioCfg.AccessKeyID,
72-
SK: params.MinioCfg.SecretAccessKey,
73-
Token: params.MinioCfg.Token,
74-
}
86+
cred.Type = Static
87+
cred.AK = params.MinioCfg.AccessKeyID
88+
cred.SK = params.MinioCfg.SecretAccessKey
89+
cred.Token = params.MinioCfg.Token
90+
return cred
7591
}
7692

7793
func NewMilvusStorage(ctx context.Context, params *paramtable.BackupParams) (Client, error) {

core/storage/factory_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@ import (
99
)
1010

1111
func TestNewBackupCredential(t *testing.T) {
12+
t.Run("Azure", func(t *testing.T) {
13+
params := &paramtable.BackupParams{MinioCfg: paramtable.MinioConfig{
14+
BackupStorageType: paramtable.CloudProviderAzure,
15+
BackupAccessKeyID: "accountName",
16+
BackupUseIAM: true,
17+
}}
18+
cred := newBackupCredential(params)
19+
assert.Equal(t, IAM, cred.Type)
20+
assert.Equal(t, "accountName", cred.AzureAccountName)
21+
})
22+
1223
t.Run("IAM", func(t *testing.T) {
1324
params := &paramtable.BackupParams{MinioCfg: paramtable.MinioConfig{
1425
BackupUseIAM: true,
@@ -44,6 +55,17 @@ func TestNewBackupCredential(t *testing.T) {
4455
}
4556

4657
func TestNewMilvusCredential(t *testing.T) {
58+
t.Run("Azure", func(t *testing.T) {
59+
params := &paramtable.BackupParams{MinioCfg: paramtable.MinioConfig{
60+
StorageType: paramtable.CloudProviderAzure,
61+
AccessKeyID: "accountName",
62+
UseIAM: true,
63+
}}
64+
cred := newMilvusCredential(params)
65+
assert.Equal(t, IAM, cred.Type)
66+
assert.Equal(t, "accountName", cred.AzureAccountName)
67+
})
68+
4769
t.Run("IAM", func(t *testing.T) {
4870
params := &paramtable.BackupParams{MinioCfg: paramtable.MinioConfig{
4971
UseIAM: true,

0 commit comments

Comments
 (0)