diff --git a/ocf/ocf.go b/ocf/ocf.go index bc6d62a..02442f7 100644 --- a/ocf/ocf.go +++ b/ocf/ocf.go @@ -59,6 +59,20 @@ type decoderConfig struct { CodecOptions codecOptions } +func newDecoderConfig(opts ...DecoderFunc) *decoderConfig { + cfg := decoderConfig{ + DecoderConfig: avro.DefaultConfig, + SchemaCache: avro.DefaultSchemaCache, + CodecOptions: codecOptions{ + DeflateCompressionLevel: flate.DefaultCompression, + }, + } + for _, opt := range opts { + opt(&cfg) + } + return &cfg +} + // DecoderFunc represents a configuration function for Decoder. type DecoderFunc func(cfg *decoderConfig) @@ -96,23 +110,16 @@ type Decoder struct { codec Codec count int64 + size int64 + n int64 } // NewDecoder returns a new decoder that reads from reader r. func NewDecoder(r io.Reader, opts ...DecoderFunc) (*Decoder, error) { - cfg := decoderConfig{ - DecoderConfig: avro.DefaultConfig, - SchemaCache: avro.DefaultSchemaCache, - CodecOptions: codecOptions{ - DeflateCompressionLevel: flate.DefaultCompression, - }, - } - for _, opt := range opts { - opt(&cfg) - } - reader := avro.NewReader(r, 1024) + cfg := newDecoderConfig(opts...) + h, err := readHeader(reader, cfg.SchemaCache, cfg.CodecOptions) if err != nil { return nil, fmt.Errorf("decoder: %w", err) @@ -131,6 +138,31 @@ func NewDecoder(r io.Reader, opts ...DecoderFunc) (*Decoder, error) { }, nil } +// NewDecoderWithHeader returns a new decoder that reads from reader r using the provided header. +func NewDecoderWithHeader(r *avro.Reader, h *OCFHeader, opts ...DecoderFunc) (*Decoder, error) { + cfg := newDecoderConfig(opts...) + decReader := bytesx.NewResetReader([]byte{}) + return &Decoder{ + reader: r, + resetReader: decReader, + decoder: cfg.DecoderConfig.NewDecoder(h.Schema, decReader), + meta: h.Meta, + sync: h.Sync, + codec: h.Codec, + schema: h.Schema, + }, nil +} + +// DecodeHeader reads and decodes the Avro container file header from r. +func DecodeHeader(r *avro.Reader, opts ...DecoderFunc) (*OCFHeader, error) { + cfg := newDecoderConfig(opts...) + h, err := readHeader(r, cfg.SchemaCache, cfg.CodecOptions) + if err != nil { + return nil, fmt.Errorf("decoder: %w", err) + } + return h, nil +} + // Metadata returns the header metadata. func (d *Decoder) Metadata() map[string][]byte { return d.meta @@ -145,8 +177,8 @@ func (d *Decoder) Schema() avro.Schema { // HasNext determines if there is another value to read. func (d *Decoder) HasNext() bool { if d.count <= 0 { - count := d.readBlock() - d.count = count + d.count, d.size = d.readBlock() + d.n = d.count } if d.reader.Error != nil { @@ -184,11 +216,29 @@ func (d *Decoder) Close() error { return nil } -func (d *Decoder) readBlock() int64 { +// BlockStatus represents the status of the current block. +type BlockStatus struct { + Current int64 + Count int64 + Size int64 + Offset int64 +} + +// BlockStatus returns the current block status. +func (d *Decoder) BlockStatus() *BlockStatus { + return &BlockStatus{ + Current: d.n - d.count + 1, + Count: d.n, + Size: d.size, + Offset: d.reader.InputOffset(), + } +} + +func (d *Decoder) readBlock() (int64, int64) { _ = d.reader.Peek() if errors.Is(d.reader.Error, io.EOF) { // There is no next block - return 0 + return 0, 0 } count := d.reader.ReadLong() @@ -220,7 +270,7 @@ func (d *Decoder) readBlock() int64 { d.reader.Error = errors.New("decoder: invalid block") } - return count + return count, size } type encoderConfig struct { @@ -379,7 +429,7 @@ func newEncoder(schema avro.Schema, w io.Writer, cfg encoderConfig) (*Encoder, e return nil, err } - writer := avro.NewWriter(w, 512, avro.WithWriterConfig(cfg.EncodingConfig)) + writer := avro.NewWriter(w, 512, avro.WithWriterConfig(avro.DefaultConfig)) buf := &bytes.Buffer{} e := &Encoder{ writer: writer, @@ -420,7 +470,7 @@ func newEncoder(schema avro.Schema, w io.Writer, cfg encoderConfig) (*Encoder, e return nil, err } - writer := avro.NewWriter(w, 512, avro.WithWriterConfig(cfg.EncodingConfig)) + writer := avro.NewWriter(w, 512, avro.WithWriterConfig(avro.DefaultConfig)) writer.WriteVal(HeaderSchema, header) if err = writer.Flush(); err != nil { return nil, err @@ -567,14 +617,15 @@ func (e *Encoder) writerBlock() error { return e.writer.Flush() } -type ocfHeader struct { +// OCFHeader represents the parsed header of an OCF file. +type OCFHeader struct { //nolint:revive Schema avro.Schema Codec Codec Meta map[string][]byte Sync [16]byte } -func readHeader(reader *avro.Reader, schemaCache *avro.SchemaCache, codecOpts codecOptions) (*ocfHeader, error) { +func readHeader(reader *avro.Reader, schemaCache *avro.SchemaCache, codecOpts codecOptions) (*OCFHeader, error) { var h Header reader.ReadVal(HeaderSchema, &h) if reader.Error != nil { @@ -594,7 +645,7 @@ func readHeader(reader *avro.Reader, schemaCache *avro.SchemaCache, codecOpts co return nil, err } - return &ocfHeader{ + return &OCFHeader{ Schema: schema, Codec: codec, Meta: h.Meta, diff --git a/ocf/ocf_test.go b/ocf/ocf_test.go index 25566f9..5d9c98c 100644 --- a/ocf/ocf_test.go +++ b/ocf/ocf_test.go @@ -6,9 +6,12 @@ import ( "encoding/json" "errors" "flag" + "fmt" "io" "os" + "sort" "strings" + "sync" "testing" "github.com/hamba/avro/v2" @@ -1274,6 +1277,7 @@ func TestWithSchemaMarshaler(t *testing.T) { want, err := os.ReadFile("testdata/full-schema.json") require.NoError(t, err) + want = bytes.ReplaceAll(want, []byte("\r\n"), []byte("\n")) assert.Equal(t, want, got) } @@ -1313,6 +1317,134 @@ func (*errorHeaderWriter) Write(p []byte) (int, error) { return 0, errors.New("test") } +func TestConcurrentDecode(t *testing.T) { + // build an in-memory OCF with many records + unionStr := "union value" + base := FullRecord{ + Strings: []string{"s1", "s2"}, + Longs: []int64{1, 2}, + Enum: "A", + Map: map[string]int{"k": 1}, + Nullable: &unionStr, + Fixed: [16]byte{0x01, 0x02, 0x03}, + Record: &TestRecord{Long: 1, String: "r", Int: 0, Float: 1.1, Double: 2.2, Bool: true}, + } + + const total = 200 + buf := &bytes.Buffer{} + enc, err := ocf.NewEncoder(schema, buf, ocf.WithBlockLength(10)) + require.NoError(t, err) + for i := 0; i < total; i++ { + base.Record.Int = int32(i) + require.NoError(t, enc.Encode(base)) + } + require.NoError(t, enc.Close()) + + // decode header once + data := buf.Bytes() + r0 := avro.NewReader(bytes.NewReader(data), 1024) + hdr, err := ocf.DecodeHeader(r0) + require.NoError(t, err) + + // concurrency values to test; caller requirement: configurable concurrency + concs := []int64{1, 2, 3, 5} + + // split file into parts by size and let workers align to sync using SkipTo + headerEnd := r0.InputOffset() + for _, conc := range concs { + t.Run(fmt.Sprintf("concurrency=%d", conc), func(t *testing.T) { + totalData := int64(len(data)) - headerEnd + partSize := totalData / int64(conc) + + recCh := make(chan FullRecord, total) + var wg sync.WaitGroup + + errCh := make(chan error, 1) + sendErr := func(err error) { + select { + case errCh <- err: + default: + } + } + + for i := int64(0); i < conc; i++ { + start := headerEnd + i*partSize + end := headerEnd + (i+1)*partSize + if i == conc-1 { + end = int64(len(data)) + } + + wg.Add(1) + go func(start, end int64) { + defer wg.Done() + r := avro.NewReader(bytes.NewReader(data[start:]), 1024) + skipped := int64(0) + // align to next sync marker unless starting at header end + if start != headerEnd { + n, err := r.SkipTo(hdr.Sync[:]) + if err != nil && !errors.Is(err, io.EOF) { + sendErr(err) + return + } + // if SkipTo advanced past our partition end, nothing to do + skipped = int64(n) + if start+skipped >= end { + return + } + } + + dec, err := ocf.NewDecoderWithHeader(r, hdr) + if err != nil { + sendErr(err) + return + } + defer dec.Close() + + for dec.HasNext() { + var rec FullRecord + if err := dec.Decode(&rec); err != nil { + sendErr(err) + return + } + recCh <- rec + bs := dec.BlockStatus() + if bs.Current > bs.Count { + if start+bs.Offset > end { + return + } + } + } + if err := dec.Error(); err != nil { + sendErr(err) + return + } + }(start, end) + } + + go func() { wg.Wait(); close(recCh) }() + + var got []int32 + for r := range recCh { + got = append(got, r.Record.Int) + } + + select { + case e := <-errCh: + if e != nil { + t.Fatalf("worker error: %v", e) + } + default: + } + + require.Equal(t, total, len(got), "unexpected number of records read") + sort.Slice(got, func(i, j int) bool { return got[i] < got[j] }) + for i := 0; i < total; i++ { + require.Equal(t, int32(i), got[i]) + } + }) + } +} + // TestEncoder_Reset tests that Reset allows reusing encoder for multiple files. func TestEncoder_Reset(t *testing.T) { record1 := FullRecord{ @@ -1575,3 +1707,58 @@ func TestEncoder_ResetPreservesCodec(t *testing.T) { require.NoError(t, err) assert.Equal(t, []byte("deflate"), dec2.Metadata()["avro.codec"]) } + +type CustomTagTestObject struct { + StringField string `json:"string_field"` + IntField int `json:"int_field"` +} + +func TestCustomTagKey(t *testing.T) { + // Define schema matching the json tags + schemaStr := `{ + "type": "record", + "name": "CustomTagTestObject", + "fields": [ + {"name": "string_field", "type": "string"}, + {"name": "int_field", "type": "int"} + ] + }` + + // Create a Config with TagKey set to "json" + config := avro.Config{ + TagKey: "json", + }.Freeze() + + // Create a buffer to write the OCF file to + var buf bytes.Buffer + + // Create OCF encoder with custom encoding config + enc, err := ocf.NewEncoder(schemaStr, &buf, ocf.WithEncodingConfig(config)) + require.NoError(t, err) + + // Data to encode + data := CustomTagTestObject{ + StringField: "hello", + IntField: 42, + } + + // Encode using the OCF encoder + err = enc.Encode(data) + require.NoError(t, err) + + // Close the encoder to flush data + err = enc.Close() + require.NoError(t, err) + + // Verify the output by decoding + dec, err := ocf.NewDecoder(&buf, ocf.WithDecoderConfig(config)) + require.NoError(t, err) + + var result CustomTagTestObject + require.True(t, dec.HasNext()) + err = dec.Decode(&result) + require.NoError(t, err) + + assert.Equal(t, data.StringField, result.StringField) + assert.Equal(t, data.IntField, result.IntField) +} diff --git a/reader.go b/reader.go index 3c11b18..a256bd7 100644 --- a/reader.go +++ b/reader.go @@ -31,6 +31,7 @@ type Reader struct { buf []byte head int tail int + offset int64 Error error } @@ -42,6 +43,7 @@ func NewReader(r io.Reader, bufSize int, opts ...ReaderFunc) *Reader { buf: make([]byte, bufSize), head: 0, tail: 0, + offset: 0, } for _, opt := range opts { @@ -57,6 +59,8 @@ func (r *Reader) Reset(b []byte) *Reader { r.buf = b r.head = 0 r.tail = len(b) + r.offset = 0 + r.Error = nil return r } @@ -90,6 +94,7 @@ func (r *Reader) loadMore() bool { continue } + r.offset += int64(r.tail) r.head = 0 r.tail = n return true @@ -322,3 +327,8 @@ func (r *Reader) ReadBlockHeader() (int64, int64) { return length, 0 } + +// InputOffset returns the current input offset of the Reader. +func (r *Reader) InputOffset() int64 { + return r.offset + int64(r.head) +} diff --git a/reader_skip.go b/reader_skip.go index 94288c8..4cfdc90 100644 --- a/reader_skip.go +++ b/reader_skip.go @@ -1,5 +1,10 @@ package avro +import ( + "bytes" + "fmt" +) + // SkipNBytes skips the given number of bytes in the reader. func (r *Reader) SkipNBytes(n int) { read := 0 @@ -77,3 +82,70 @@ func (r *Reader) SkipBytes() { } r.SkipNBytes(int(size)) } + +// SkipTo skips to the given token in the reader. +func (r *Reader) SkipTo(token []byte) (int, error) { + tokenLen := len(token) + if tokenLen == 0 { + return 0, nil + } + if tokenLen > len(r.buf) { + return 0, fmt.Errorf("token length %d exceeds buffer size %d", tokenLen, len(r.buf)) + } + + var skipped int + var stash []byte + + for { + // Check boundary if we have stash from previous read + if len(stash) > 0 { + need := min(r.tail-r.head, tokenLen-1) + + // Construct boundary window: stash + beginning of new buffer + boundary := make([]byte, len(stash)+need) + copy(boundary, stash) + copy(boundary[len(stash):], r.buf[r.head:r.head+need]) + + if idx := bytes.Index(boundary, token); idx >= 0 { + // Found in boundary + bytesToEndOfToken := idx + tokenLen + skipped += bytesToEndOfToken + + // Advance r.head by the number of bytes used from r.buf + bufferBytesConsumed := bytesToEndOfToken - len(stash) + r.head += bufferBytesConsumed + return skipped, nil + } + + // Not found in boundary, stash is definitely skipped + skipped += len(stash) + stash = nil + } + + // Search in current buffer + idx := bytes.Index(r.buf[r.head:r.tail], token) + if idx >= 0 { + advance := idx + tokenLen + r.head += advance + skipped += advance + return skipped, nil + } + + // Prepare stash for next iteration + available := r.tail - r.head + keep := min(tokenLen-1, available) + + // Bytes that are definitely skipped (not kept in stash) + consumed := available - keep + skipped += consumed + + if keep > 0 { + stash = make([]byte, keep) + copy(stash, r.buf[r.tail-keep:r.tail]) + } + + if !r.loadMore() { + return skipped, r.Error + } + } +} diff --git a/reader_skip_test.go b/reader_skip_test.go index a821880..7103ddc 100644 --- a/reader_skip_test.go +++ b/reader_skip_test.go @@ -48,6 +48,7 @@ func TestReader_SkipInt(t *testing.T) { data: []byte{0x38, 0x36}, }, { + name: "Overflow", data: []byte{0xE2, 0xA2, 0xF3, 0xAD, 0xAD, 0x36}, }, @@ -76,6 +77,7 @@ func TestReader_SkipLong(t *testing.T) { data: []byte{0x38, 0x36}, }, { + name: "Overflow", data: []byte{0xE2, 0xA2, 0xF3, 0xAD, 0xAD, 0xAD, 0xE2, 0xA2, 0xF3, 0xAD, 0x36}, }, @@ -153,3 +155,78 @@ func TestReader_SkipBytesEmpty(t *testing.T) { require.NoError(t, r.Error) assert.Equal(t, int32(27), r.ReadInt()) } + +func TestReader_SkipTo(t *testing.T) { + tests := []struct { + name string + data []byte + bufSize int + token []byte + wantSkipped int + wantErr require.ErrorAssertionFunc + }{ + // TokenInFirstBuffer + { + name: "TokenInFirstBuffer", + data: []byte("abcdefgTOKENhij"), + bufSize: 1024, + token: []byte("TOKEN"), + wantSkipped: 12, // "abcdefgTOKEN" length + wantErr: require.NoError, + }, + // TokenSplitAcrossBuffers + { + name: "TokenSplitAcrossBuffers", + data: append(append(make([]byte, 10), []byte("TO")...), []byte("KEN")...), + bufSize: 12, // 10 filler + "TO" = 12 bytes. Split happens exactly after "TO". + token: []byte("TOKEN"), + wantSkipped: 15, // 10 filler + TOKEN + wantErr: require.NoError, + }, + // FalsePositiveSplit: XXKEN should NOT match TOKEN + { + name: "FalsePositiveSplit", + data: []byte("XXKEN"), + bufSize: 2, // Split "XX", "KEN" + token: []byte("TOKEN"), + wantSkipped: 0, + wantErr: require.Error, // Should fail to find TOKEN + }, + // TokenNotFound + { + name: "TokenNotFound", + data: []byte("abcdefg"), + bufSize: 1024, + token: []byte("XYZ"), + wantSkipped: 7, + wantErr: require.Error, // EOF causes error in SkipTo + }, + { + name: "EmptyToken", + data: []byte("abc"), + bufSize: 1024, + token: []byte{}, + wantSkipped: 0, + wantErr: require.NoError, + }, + { + name: "PartialMatchAtEndButNotComplete", + data: []byte("abcTO"), + bufSize: 10, + token: []byte("TOKEN"), + wantSkipped: 5, + wantErr: require.Error, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := avro.NewReader(bytes.NewReader(tt.data), tt.bufSize) + skipped, err := r.SkipTo(tt.token) + tt.wantErr(t, err) + if err == nil { + assert.Equal(t, tt.wantSkipped, skipped) + } + }) + } +}