From c3a9a9179f9720729be7203062ecd67c5e99f5d0 Mon Sep 17 00:00:00 2001 From: Vassilis Bekiaris Date: Fri, 6 Jun 2025 13:33:25 +0100 Subject: [PATCH 1/2] feat: add IO registry Introduces IO registry --- catalog/glue/glue.go | 4 +- catalog/internal/utils.go | 6 +- catalog/rest/rest.go | 2 +- catalog/sql/sql.go | 2 +- io/azure.go | 27 +++++++ io/gcs.go | 23 ++++++ io/io.go | 66 ----------------- io/local.go | 10 +++ io/mem.go | 32 ++++++++ io/registry.go | 139 +++++++++++++++++++++++++++++++++++ io/registry_test.go | 151 ++++++++++++++++++++++++++++++++++++++ io/s3.go | 26 +++++++ 12 files changed, 415 insertions(+), 73 deletions(-) create mode 100644 io/mem.go create mode 100644 io/registry.go create mode 100644 io/registry_test.go diff --git a/catalog/glue/glue.go b/catalog/glue/glue.go index ab15c2d04..6e9af4675 100644 --- a/catalog/glue/glue.go +++ b/catalog/glue/glue.go @@ -220,7 +220,7 @@ func (c *Catalog) LoadTable(ctx context.Context, identifier table.Identifier, pr ctx = utils.WithAwsConfig(ctx, c.awsCfg) // TODO: consider providing a way to directly access the S3 iofs to enable testing of the catalog. - iofs, err := io.LoadFS(ctx, props, location) + iofs, err := io.Load(ctx, props, location) if err != nil { return nil, fmt.Errorf("failed to load table %s.%s: %w", database, tableName, err) } @@ -313,7 +313,7 @@ func (c *Catalog) RegisterTable(ctx context.Context, identifier table.Identifier } // Load the metadata file to get table properties ctx = utils.WithAwsConfig(ctx, c.awsCfg) - iofs, err := io.LoadFS(ctx, nil, metadataLocation) + iofs, err := io.Load(ctx, nil, metadataLocation) if err != nil { return nil, fmt.Errorf("failed to load metadata file at %s: %w", metadataLocation, err) } diff --git a/catalog/internal/utils.go b/catalog/internal/utils.go index 7e4c8ca5b..bf9a05520 100644 --- a/catalog/internal/utils.go +++ b/catalog/internal/utils.go @@ -52,7 +52,7 @@ func WriteTableMetadata(metadata table.Metadata, fs io.WriteFileIO, loc string) } func WriteMetadata(ctx context.Context, metadata table.Metadata, loc string, props iceberg.Properties) error { - fs, err := io.LoadFS(ctx, props, loc) + fs, err := io.Load(ctx, props, loc) if err != nil { return err } @@ -136,7 +136,7 @@ func CreateStagedTable(ctx context.Context, catprops iceberg.Properties, nsprops ioProps := maps.Clone(catprops) maps.Copy(ioProps, cfg.Properties) - fs, err := io.LoadFS(ctx, ioProps, metadataLoc) + fs, err := io.Load(ctx, ioProps, metadataLoc) if err != nil { return table.StagedTable{}, err } @@ -239,7 +239,7 @@ func UpdateAndStageTable(ctx context.Context, current *table.Table, ident table. return nil, err } - fs, err := io.LoadFS(ctx, updated.Properties(), newLocation) + fs, err := io.Load(ctx, updated.Properties(), newLocation) if err != nil { return nil, err } diff --git a/catalog/rest/rest.go b/catalog/rest/rest.go index 949164972..5f5d19f15 100644 --- a/catalog/rest/rest.go +++ b/catalog/rest/rest.go @@ -653,7 +653,7 @@ func checkValidNamespace(ident table.Identifier) error { } func (r *Catalog) tableFromResponse(ctx context.Context, identifier []string, metadata table.Metadata, loc string, config iceberg.Properties) (*table.Table, error) { - iofs, err := iceio.LoadFS(ctx, config, loc) + iofs, err := iceio.Load(ctx, config, loc) if err != nil { return nil, err } diff --git a/catalog/sql/sql.go b/catalog/sql/sql.go index bb8a083f3..c967a3d1d 100644 --- a/catalog/sql/sql.go +++ b/catalog/sql/sql.go @@ -416,7 +416,7 @@ func (c *Catalog) LoadTable(ctx context.Context, identifier table.Identifier, pr tblProps := maps.Clone(c.props) maps.Copy(props, tblProps) - iofs, err := io.LoadFS(ctx, tblProps, result.MetadataLocation.String) + iofs, err := io.Load(ctx, tblProps, result.MetadataLocation.String) if err != nil { return nil, err } diff --git a/io/azure.go b/io/azure.go index 43ba54a6b..57103ea78 100644 --- a/io/azure.go +++ b/io/azure.go @@ -111,3 +111,30 @@ func createAzureBucket(ctx context.Context, parsed *url.URL, props map[string]st return azureblob.OpenBucket(ctx, client, nil) } + +func init() { + azureRegistrar := RegistrarFunc(func(ctx context.Context, props map[string]string) (IO, error) { + // We need a warehouse location to extract the bucket name + location := props["warehouse"] + if location == "" { + return nil, fmt.Errorf("warehouse location required for Azure IO") + } + + parsed, err := url.Parse(location) + if err != nil { + return nil, fmt.Errorf("failed to parse Azure location: %w", err) + } + + bucket, err := createAzureBucket(ctx, parsed, props) + if err != nil { + return nil, err + } + + return createBlobFS(ctx, bucket, parsed.Host), nil + }) + + Register("abfs", azureRegistrar) + Register("abfss", azureRegistrar) + Register("wasb", azureRegistrar) + Register("wasbs", azureRegistrar) +} diff --git a/io/gcs.go b/io/gcs.go index 8f462d1d6..abc6b51a5 100644 --- a/io/gcs.go +++ b/io/gcs.go @@ -19,6 +19,7 @@ package io import ( "context" + "fmt" "net/url" "gocloud.dev/blob" @@ -63,3 +64,25 @@ func createGCSBucket(ctx context.Context, parsed *url.URL, props map[string]stri return bucket, nil } + +func init() { + Register("gs", RegistrarFunc(func(ctx context.Context, props map[string]string) (IO, error) { + // We need a warehouse location to extract the bucket name + location := props["warehouse"] + if location == "" { + return nil, fmt.Errorf("warehouse location required for GCS IO") + } + + parsed, err := url.Parse(location) + if err != nil { + return nil, fmt.Errorf("failed to parse GCS location: %w", err) + } + + bucket, err := createGCSBucket(ctx, parsed, props) + if err != nil { + return nil, err + } + + return createBlobFS(ctx, bucket, parsed.Host), nil + })) +} diff --git a/io/io.go b/io/io.go index 523f7c575..2cdc619f8 100644 --- a/io/io.go +++ b/io/io.go @@ -18,16 +18,10 @@ package io import ( - "context" "errors" - "fmt" "io" "io/fs" - "net/url" "strings" - - "gocloud.dev/blob" - "gocloud.dev/blob/memblob" ) // IO is an interface to a hierarchical file system. @@ -234,63 +228,3 @@ func (f ioFile) ReadDir(count int) ([]fs.DirEntry, error) { return d.ReadDir(count) } - -func inferFileIOFromSchema(ctx context.Context, path string, props map[string]string) (IO, error) { - parsed, err := url.Parse(path) - if err != nil { - return nil, err - } - var bucket *blob.Bucket - - switch parsed.Scheme { - case "s3", "s3a", "s3n": - bucket, err = createS3Bucket(ctx, parsed, props) - if err != nil { - return nil, err - } - case "gs": - bucket, err = createGCSBucket(ctx, parsed, props) - if err != nil { - return nil, err - } - case "mem": - // memblob doesn't use the URL host or path - bucket = memblob.OpenBucket(nil) - case "file", "": - return LocalFS{}, nil - case "abfs", "abfss", "wasb", "wasbs": - bucket, err = createAzureBucket(ctx, parsed, props) - if err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("IO for file '%s' not implemented", path) - } - - return createBlobFS(ctx, bucket, parsed.Host), nil -} - -// LoadFS takes a map of properties and an optional URI location -// and attempts to infer an IO object from it. -// -// A schema of "file://" or an empty string will result in a LocalFS -// implementation. Otherwise this will return an error if the schema -// does not yet have an implementation here. -// -// Currently local, S3, GCS, and In-Memory FSs are implemented. -func LoadFS(ctx context.Context, props map[string]string, location string) (IO, error) { - if location == "" { - location = props["warehouse"] - } - - iofs, err := inferFileIOFromSchema(ctx, location, props) - if err != nil { - return nil, err - } - - if iofs == nil { - iofs = LocalFS{} - } - - return iofs, nil -} diff --git a/io/local.go b/io/local.go index ff4e1cb4f..8f95ad85c 100644 --- a/io/local.go +++ b/io/local.go @@ -18,6 +18,7 @@ package io import ( + "context" "os" "path/filepath" "strings" @@ -47,3 +48,12 @@ func (LocalFS) WriteFile(name string, content []byte) error { func (LocalFS) Remove(name string) error { return os.Remove(strings.TrimPrefix(name, "file://")) } + +func init() { + Register("file", RegistrarFunc(func(ctx context.Context, props map[string]string) (IO, error) { + return LocalFS{}, nil + })) + Register("", RegistrarFunc(func(ctx context.Context, props map[string]string) (IO, error) { + return LocalFS{}, nil + })) +} diff --git a/io/mem.go b/io/mem.go new file mode 100644 index 000000000..3e31e1f01 --- /dev/null +++ b/io/mem.go @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package io + +import ( + "context" + + "gocloud.dev/blob/memblob" +) + +func init() { + Register("mem", RegistrarFunc(func(ctx context.Context, props map[string]string) (IO, error) { + // memblob doesn't use the URL host or path + bucket := memblob.OpenBucket(nil) + return createBlobFS(ctx, bucket, ""), nil + })) +} diff --git a/io/registry.go b/io/registry.go new file mode 100644 index 000000000..e8de26588 --- /dev/null +++ b/io/registry.go @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package io + +import ( + "context" + "errors" + "fmt" + "maps" + "net/url" + "slices" + "strings" + "sync" +) + +type registry map[string]Registrar + +func (r registry) getKeys() []string { + regMutex.Lock() + defer regMutex.Unlock() + + return slices.Collect(maps.Keys(r)) +} + +func (r registry) set(ioType string, reg Registrar) { + regMutex.Lock() + defer regMutex.Unlock() + r[ioType] = reg +} + +func (r registry) get(ioType string) (Registrar, bool) { + regMutex.Lock() + defer regMutex.Unlock() + reg, ok := r[ioType] + + return reg, ok +} + +func (r registry) remove(ioType string) { + regMutex.Lock() + defer regMutex.Unlock() + delete(r, ioType) +} + +var ( + regMutex sync.Mutex + defaultRegistry = registry{} + ErrIONotFound = errors.New("IO type not registered") +) + +// Registrar is a factory for creating IO instances, used for registering to use +// with Load. +type Registrar interface { + GetIO(ctx context.Context, props map[string]string) (IO, error) +} + +type RegistrarFunc func(context.Context, map[string]string) (IO, error) + +func (f RegistrarFunc) GetIO(ctx context.Context, props map[string]string) (IO, error) { + return f(ctx, props) +} + +// Register adds the new IO type to the registry. If the IO type is already registered, it will be replaced. +func Register(ioType string, reg Registrar) { + if reg == nil { + panic("io: RegisterIO factory is nil") + } + defaultRegistry.set(ioType, reg) +} + +// Unregister removes the requested IO factory from the registry. +func Unregister(ioType string) { + defaultRegistry.remove(ioType) +} + +// GetRegisteredIOs returns the list of registered IO names that can +// be looked up via Load. +func GetRegisteredIOs() []string { + return defaultRegistry.getKeys() +} + +// Load allows loading a specific IO implementation by scheme and properties. +// +// This is utilized alongside Register/Unregister to not only allow +// easier IO loading but also to allow for custom IO implementations to +// be registered and loaded external to this module. +// +// The scheme parameter is extracted from the URI to determine which IO +// implementation to use. For example, "s3://bucket/path" would use the +// "s3" IO implementation. +// +// Currently, the following IO types are supported by default: +// +// - "file" or "" for local filesystem +// - "s3", "s3a", "s3n" for Amazon S3 +// - "gs" for Google Cloud Storage +// - "abfs", "abfss", "wasb", "wasbs" for Azure Blob Storage +// - "mem" for in-memory storage +func Load(ctx context.Context, props map[string]string, location string) (IO, error) { + if location == "" { + location = props["warehouse"] + } + + scheme := "" + if strings.Contains(location, "://") { + parsed, err := url.Parse(location) + if err != nil { + return nil, fmt.Errorf("failed to parse IO location: %w", err) + } + scheme = parsed.Scheme + } + + // Default to local filesystem if no scheme + if scheme == "" { + scheme = "file" + } + + reg, ok := defaultRegistry.get(scheme) + if !ok { + return nil, fmt.Errorf("%w: %s", ErrIONotFound, location) + } + + return reg.GetIO(ctx, props) +} diff --git a/io/registry_test.go b/io/registry_test.go new file mode 100644 index 000000000..0a536eda1 --- /dev/null +++ b/io/registry_test.go @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package io + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRegister(t *testing.T) { + // Save original registry state + original := make(registry) + for k, v := range defaultRegistry { + original[k] = v + } + defer func() { + // Restore original registry + defaultRegistry = original + }() + + // Clear registry for clean test + defaultRegistry = make(registry) + + testRegistrar := RegistrarFunc(func(ctx context.Context, props map[string]string) (IO, error) { + return LocalFS{}, nil + }) + + Register("test", testRegistrar) + + registeredIOs := GetRegisteredIOs() + assert.Contains(t, registeredIOs, "test") +} + +func TestUnregister(t *testing.T) { + // Save original registry state + original := make(registry) + for k, v := range defaultRegistry { + original[k] = v + } + defer func() { + // Restore original registry + defaultRegistry = original + }() + + testRegistrar := RegistrarFunc(func(ctx context.Context, props map[string]string) (IO, error) { + return LocalFS{}, nil + }) + + Register("test", testRegistrar) + assert.Contains(t, GetRegisteredIOs(), "test") + + Unregister("test") + assert.NotContains(t, GetRegisteredIOs(), "test") +} + +func TestDefaultRegisteredIOs(t *testing.T) { + registeredIOs := GetRegisteredIOs() + + // Check that default IO implementations are registered + expectedIOs := []string{"file", "", "s3", "s3a", "s3n", "gs", "abfs", "abfss", "wasb", "wasbs", "mem"} + for _, expected := range expectedIOs { + assert.Contains(t, registeredIOs, expected, "IO type %s should be registered", expected) + } +} + +func TestLoadWithRegisteredIO(t *testing.T) { + ctx := context.Background() + + // Test loading local filesystem + io, err := Load(ctx, map[string]string{}, "file:///tmp/test") + require.NoError(t, err) + assert.IsType(t, LocalFS{}, io) + + // Test loading with empty scheme (defaults to local) + io, err = Load(ctx, map[string]string{}, "/tmp/test") + require.NoError(t, err) + assert.IsType(t, LocalFS{}, io) + + // Test loading memory filesystem + io, err = Load(ctx, map[string]string{}, "mem://bucket/path") + require.NoError(t, err) + // Should return a blobFileIO (which implements IO) + assert.NotNil(t, io) +} + +func TestLoadWithWarehouseFromProps(t *testing.T) { + ctx := context.Background() + + // Test loading from warehouse property + io, err := Load(ctx, map[string]string{"warehouse": "file:///tmp/warehouse"}, "") + require.NoError(t, err) + assert.IsType(t, LocalFS{}, io) +} + +func TestLoadWhenUnknownScheme(t *testing.T) { + ctx := context.Background() + + _, err := Load(ctx, map[string]string{}, "unknown://bucket/path") + assert.Error(t, err) + assert.ErrorIs(t, err, ErrIONotFound) +} + +func TestRegisterPanic(t *testing.T) { + assert.Panics(t, func() { + Register("test", nil) + }) +} + +func TestConcurrentAccess(t *testing.T) { + // Test that concurrent access to registry doesn't cause data races + ctx := context.Background() + + done := make(chan bool, 100) + + // Start multiple goroutines doing registry operations + for i := 0; i < 10; i++ { + go func(i int) { + defer func() { done <- true }() + + // Get registered IOs + GetRegisteredIOs() + + // Try to load IO + _, err := Load(ctx, map[string]string{}, "file:///tmp") + require.NoError(t, err) + }(i) + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } +} diff --git a/io/s3.go b/io/s3.go index 98ad4c634..a8464197a 100644 --- a/io/s3.go +++ b/io/s3.go @@ -149,3 +149,29 @@ func createS3Bucket(ctx context.Context, parsed *url.URL, props map[string]strin return bucket, nil } + +func init() { + s3Registrar := RegistrarFunc(func(ctx context.Context, props map[string]string) (IO, error) { + // We need a warehouse location to extract the bucket name + location := props["warehouse"] + if location == "" { + return nil, fmt.Errorf("warehouse location required for S3 IO") + } + + parsed, err := url.Parse(location) + if err != nil { + return nil, fmt.Errorf("failed to parse S3 location: %w", err) + } + + bucket, err := createS3Bucket(ctx, parsed, props) + if err != nil { + return nil, err + } + + return createBlobFS(ctx, bucket, parsed.Host), nil + }) + + Register("s3", s3Registrar) + Register("s3a", s3Registrar) + Register("s3n", s3Registrar) +} From f538fab9a9937c44f2c1c24786278de05399f7e5 Mon Sep 17 00:00:00 2001 From: Vassilis Bekiaris Date: Fri, 6 Jun 2025 15:06:12 +0100 Subject: [PATCH 2/2] add test for overriding default IO mappings --- io/registry_test.go | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/io/registry_test.go b/io/registry_test.go index 0a536eda1..05f92ffea 100644 --- a/io/registry_test.go +++ b/io/registry_test.go @@ -82,7 +82,7 @@ func TestDefaultRegisteredIOs(t *testing.T) { } func TestLoadWithRegisteredIO(t *testing.T) { - ctx := context.Background() + ctx := t.Context() // Test loading local filesystem io, err := Load(ctx, map[string]string{}, "file:///tmp/test") @@ -101,8 +101,21 @@ func TestLoadWithRegisteredIO(t *testing.T) { assert.NotNil(t, io) } +func TestDefaultsCanBeOverridden(t *testing.T) { + ctx := t.Context() + registrar := RegistrarFunc(func(ctx context.Context, props map[string]string) (IO, error) { + return LocalFS{}, nil + }) + + // override default registration for mem scheme + Register("mem", registrar) + io, err := Load(ctx, map[string]string{}, "mem://bucket/path") + require.NoError(t, err) + assert.IsType(t, LocalFS{}, io) +} + func TestLoadWithWarehouseFromProps(t *testing.T) { - ctx := context.Background() + ctx := t.Context() // Test loading from warehouse property io, err := Load(ctx, map[string]string{"warehouse": "file:///tmp/warehouse"}, "") @@ -111,7 +124,7 @@ func TestLoadWithWarehouseFromProps(t *testing.T) { } func TestLoadWhenUnknownScheme(t *testing.T) { - ctx := context.Background() + ctx := t.Context() _, err := Load(ctx, map[string]string{}, "unknown://bucket/path") assert.Error(t, err) @@ -126,7 +139,7 @@ func TestRegisterPanic(t *testing.T) { func TestConcurrentAccess(t *testing.T) { // Test that concurrent access to registry doesn't cause data races - ctx := context.Background() + ctx := t.Context() done := make(chan bool, 100)