diff --git a/ocf/codec.go b/ocf/codec.go index 62bc64c..8029af3 100644 --- a/ocf/codec.go +++ b/ocf/codec.go @@ -32,6 +32,10 @@ type codecOptions struct { type zstdOptions struct { EOptions []zstd.EOption DOptions []zstd.DOption + // Encoder and Decoder allow sharing pre-created instances across multiple codecs. + // When set, EOptions/DOptions are ignored for that component. + Encoder *zstd.Encoder + Decoder *zstd.Decoder } func resolveCodec(name CodecName, codecOpts codecOptions) (Codec, error) { @@ -138,16 +142,36 @@ func (*SnappyCodec) Encode(b []byte) []byte { // ZStandardCodec is a zstandard compression codec. type ZStandardCodec struct { - decoder *zstd.Decoder - encoder *zstd.Encoder + decoder *zstd.Decoder + encoder *zstd.Encoder + sharedDecoder bool // true if decoder was provided externally and should not be closed + sharedEncoder bool // true if encoder was provided externally and should not be closed } func newZStandardCodec(opts zstdOptions) *ZStandardCodec { - decoder, _ := zstd.NewReader(nil, opts.DOptions...) - encoder, _ := zstd.NewWriter(nil, opts.EOptions...) + var decoder *zstd.Decoder + var encoder *zstd.Encoder + var sharedDecoder, sharedEncoder bool + + if opts.Decoder != nil { + decoder = opts.Decoder + sharedDecoder = true + } else { + decoder, _ = zstd.NewReader(nil, opts.DOptions...) + } + + if opts.Encoder != nil { + encoder = opts.Encoder + sharedEncoder = true + } else { + encoder, _ = zstd.NewWriter(nil, opts.EOptions...) + } + return &ZStandardCodec{ - decoder: decoder, - encoder: encoder, + decoder: decoder, + encoder: encoder, + sharedDecoder: sharedDecoder, + sharedEncoder: sharedEncoder, } } @@ -164,11 +188,12 @@ func (zstdCodec *ZStandardCodec) Encode(b []byte) []byte { } // Close closes the zstandard encoder and decoder, releasing resources. +// Shared instances (provided via WithZStandardEncoder/WithZStandardDecoder) are not closed. func (zstdCodec *ZStandardCodec) Close() error { - if zstdCodec.decoder != nil { + if zstdCodec.decoder != nil && !zstdCodec.sharedDecoder { zstdCodec.decoder.Close() } - if zstdCodec.encoder != nil { + if zstdCodec.encoder != nil && !zstdCodec.sharedEncoder { return zstdCodec.encoder.Close() } return nil diff --git a/ocf/ocf.go b/ocf/ocf.go index bc6d62a..f23034e 100644 --- a/ocf/ocf.go +++ b/ocf/ocf.go @@ -84,6 +84,15 @@ func WithZStandardDecoderOptions(opts ...zstd.DOption) DecoderFunc { } } +// WithZStandardDecoder sets a pre-created ZStandard decoder to be reused. +// This allows sharing a single decoder across multiple OCF decoders for efficiency. +// The caller is responsible for closing the decoder after all OCF decoders are done. +func WithZStandardDecoder(dec *zstd.Decoder) DecoderFunc { + return func(cfg *decoderConfig) { + cfg.CodecOptions.ZStandardOptions.Decoder = dec + } +} + // Decoder reads and decodes Avro values from a container file. type Decoder struct { reader *avro.Reader @@ -276,6 +285,15 @@ func WithZStandardEncoderOptions(opts ...zstd.EOption) EncoderFunc { } } +// WithZStandardEncoder sets a pre-created ZStandard encoder to be reused. +// This allows sharing a single encoder across multiple OCF encoders for efficiency. +// The caller is responsible for closing the encoder after all OCF encoders are done. +func WithZStandardEncoder(enc *zstd.Encoder) EncoderFunc { + return func(cfg *encoderConfig) { + cfg.CodecOptions.ZStandardOptions.Encoder = enc + } +} + // WithMetadata sets the metadata on the encoder header. func WithMetadata(meta map[string][]byte) EncoderFunc { return func(cfg *encoderConfig) { diff --git a/ocf/ocf_test.go b/ocf/ocf_test.go index 25566f9..090b7ec 100644 --- a/ocf/ocf_test.go +++ b/ocf/ocf_test.go @@ -1313,6 +1313,80 @@ func (*errorHeaderWriter) Write(p []byte) (int, error) { return 0, errors.New("test") } +func TestSharedZstdEncoder(t *testing.T) { + schema := `{"type": "string"}` + + // Create a shared zstd encoder + sharedEncoder, err := zstd.NewWriter(nil) + require.NoError(t, err) + defer sharedEncoder.Close() + + // Use the shared encoder with multiple OCF encoders + var buf1, buf2 bytes.Buffer + + enc1, err := ocf.NewEncoder(schema, &buf1, ocf.WithCodec(ocf.ZStandard), ocf.WithZStandardEncoder(sharedEncoder)) + require.NoError(t, err) + require.NoError(t, enc1.Encode("hello from encoder 1")) + require.NoError(t, enc1.Close()) + + enc2, err := ocf.NewEncoder(schema, &buf2, ocf.WithCodec(ocf.ZStandard), ocf.WithZStandardEncoder(sharedEncoder)) + require.NoError(t, err) + require.NoError(t, enc2.Encode("hello from encoder 2")) + require.NoError(t, enc2.Close()) + + // Verify both files can be read + dec1, err := ocf.NewDecoder(&buf1) + require.NoError(t, err) + require.True(t, dec1.HasNext()) + var result1 string + require.NoError(t, dec1.Decode(&result1)) + assert.Equal(t, "hello from encoder 1", result1) + + dec2, err := ocf.NewDecoder(&buf2) + require.NoError(t, err) + require.True(t, dec2.HasNext()) + var result2 string + require.NoError(t, dec2.Decode(&result2)) + assert.Equal(t, "hello from encoder 2", result2) +} + +func TestSharedZstdDecoder(t *testing.T) { + schema := `{"type": "string"}` + + // Create two OCF files + var buf1, buf2 bytes.Buffer + + enc1, err := ocf.NewEncoder(schema, &buf1, ocf.WithCodec(ocf.ZStandard)) + require.NoError(t, err) + require.NoError(t, enc1.Encode("data in file 1")) + require.NoError(t, enc1.Close()) + + enc2, err := ocf.NewEncoder(schema, &buf2, ocf.WithCodec(ocf.ZStandard)) + require.NoError(t, err) + require.NoError(t, enc2.Encode("data in file 2")) + require.NoError(t, enc2.Close()) + + // Create a shared zstd decoder + sharedDecoder, err := zstd.NewReader(nil) + require.NoError(t, err) + defer sharedDecoder.Close() + + // Use the shared decoder with multiple OCF decoders + dec1, err := ocf.NewDecoder(&buf1, ocf.WithZStandardDecoder(sharedDecoder)) + require.NoError(t, err) + require.True(t, dec1.HasNext()) + var result1 string + require.NoError(t, dec1.Decode(&result1)) + assert.Equal(t, "data in file 1", result1) + + dec2, err := ocf.NewDecoder(&buf2, ocf.WithZStandardDecoder(sharedDecoder)) + require.NoError(t, err) + require.True(t, dec2.HasNext()) + var result2 string + require.NoError(t, dec2.Decode(&result2)) + assert.Equal(t, "data in file 2", result2) +} + // TestEncoder_Reset tests that Reset allows reusing encoder for multiple files. func TestEncoder_Reset(t *testing.T) { record1 := FullRecord{