diff --git a/arrow/avro/loader.go b/arrow/avro/loader.go index a7199e661..fa97c426b 100644 --- a/arrow/avro/loader.go +++ b/arrow/avro/loader.go @@ -24,7 +24,7 @@ import ( func (r *OCFReader) decodeOCFToChan() { defer close(r.avroChan) - for r.r.HasNext() { + for { select { case <-r.readerCtx.Done(): r.err = fmt.Errorf("avro decoding cancelled, %d records read", r.avroDatumCount) @@ -34,7 +34,6 @@ func (r *OCFReader) decodeOCFToChan() { err := r.r.Decode(&datum) if err != nil { if errors.Is(err, io.EOF) { - r.err = nil return } r.err = err diff --git a/arrow/avro/reader.go b/arrow/avro/reader.go index db6de6275..a731a0621 100644 --- a/arrow/avro/reader.go +++ b/arrow/avro/reader.go @@ -27,10 +27,9 @@ import ( "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/internal/debug" "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/hamba/avro/v2/ocf" "github.com/tidwall/sjson" - - avro "github.com/hamba/avro/v2" + "github.com/twmb/avro" + "github.com/twmb/avro/ocf" ) var ErrMismatchFields = errors.New("arrow/avro: number of records mismatch") @@ -47,9 +46,9 @@ type schemaEdit struct { value any } -// Reader wraps goavro/OCFReader and creates array.RecordBatches from a schema. +// OCFReader reads Avro OCF files and exposes them as array.RecordBatches. type OCFReader struct { - r *ocf.Decoder + r *ocf.Reader avroSchema string avroSchemaEdits []schemaEdit schema *arrow.Schema @@ -82,7 +81,7 @@ type OCFReader struct { // NewReader returns a reader that reads from an Avro OCF file and creates // arrow.RecordBatches from the converted avro data. func NewOCFReader(r io.Reader, opts ...Option) (*OCFReader, error) { - ocfr, err := ocf.NewDecoder(r) + ocfr, err := ocf.NewReader(r) if err != nil { return nil, fmt.Errorf("%w: could not create avro ocfreader", arrow.ErrInvalid) } @@ -108,22 +107,20 @@ func NewOCFReader(r io.Reader, opts ...Option) (*OCFReader, error) { } rr.avroSchema = schema.String() if len(rr.avroSchemaEdits) > 0 { - // execute schema edits for _, e := range rr.avroSchemaEdits { err := rr.editAvroSchema(e) if err != nil { return nil, fmt.Errorf("%w: could not edit avro schema", arrow.ErrInvalid) } } - // validate edited schema - schema, err = avro.Parse(rr.avroSchema) - if err != nil { - return nil, fmt.Errorf("%w: could not parse modified avro schema", arrow.ErrInvalid) - } } - rr.schema, err = ArrowSchemaFromAvro(schema) + rr.schema, err = ArrowSchemaFromAvroJSON(rr.avroSchema) if err != nil { - return nil, fmt.Errorf("%w: could not convert avro schema", arrow.ErrInvalid) + msg := "could not convert avro schema" + if len(rr.avroSchemaEdits) > 0 { + msg = "could not parse modified avro schema" + } + return nil, fmt.Errorf("%w: %s: %w", arrow.ErrInvalid, msg, err) } if rr.mem == nil { rr.mem = memory.DefaultAllocator @@ -147,7 +144,7 @@ func NewOCFReader(r io.Reader, opts ...Option) (*OCFReader, error) { func (rr *OCFReader) Reuse(r io.Reader, opts ...Option) error { rr.Close() rr.err = nil - ocfr, err := ocf.NewDecoder(r) + ocfr, err := ocf.NewReader(r) if err != nil { return fmt.Errorf("%w: could not create avro ocfreader", arrow.ErrInvalid) } diff --git a/arrow/avro/reader_test.go b/arrow/avro/reader_test.go index 4aaac675e..d920e22a8 100644 --- a/arrow/avro/reader_test.go +++ b/arrow/avro/reader_test.go @@ -19,14 +19,13 @@ package avro import ( "bytes" "encoding/json" - "fmt" "os" "path/filepath" "testing" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/avro/testdata" - hamba "github.com/hamba/avro/v2" + "github.com/apache/arrow-go/v18/arrow/extensions" "github.com/stretchr/testify/assert" ) @@ -127,6 +126,10 @@ func TestReader(t *testing.T) { Name: "uuidField", Type: arrow.BinaryTypes.String, }, + { + Name: "fixedUuidField", + Type: extensions.NewUUIDType(), + }, { Name: "timemillis", Type: arrow.FixedWidthTypes.Time32ms, @@ -167,20 +170,13 @@ func TestReader(t *testing.T) { t.Fatal(err) } r := new(OCFReader) - r.avroSchema = schema.String() + r.avroSchema = schema r.editAvroSchema(schemaEdit{method: "delete", path: "fields.0"}) - schema, err = hamba.Parse(r.avroSchema) + got, err := ArrowSchemaFromAvroJSON(r.avroSchema) if err != nil { t.Fatalf("%v: could not parse modified avro schema", arrow.ErrInvalid) } - got, err := ArrowSchemaFromAvro(schema) - if err != nil { - t.Fatalf("%v", err) - } assert.Equal(t, want.String(), got.String()) - if fmt.Sprintf("%+v", want.String()) != fmt.Sprintf("%+v", got.String()) { - t.Fatalf("got=%v,\n want=%v", got.String(), want.String()) - } }) t.Run("ShouldLoadExpectedRecords", func(t *testing.T) { @@ -200,7 +196,7 @@ func TestReader(t *testing.T) { exists := ar.Next() if ar.Err() != nil { - t.Error("failed to read next record: %w", ar.Err()) + t.Errorf("failed to read next record: %v", ar.Err()) } if !exists { t.Error("no record exists") diff --git a/arrow/avro/reader_types.go b/arrow/avro/reader_types.go index ff21b5aa0..da13b03d6 100644 --- a/arrow/avro/reader_types.go +++ b/arrow/avro/reader_types.go @@ -17,8 +17,6 @@ package avro import ( - "bytes" - "encoding/binary" "errors" "fmt" "math/big" @@ -31,7 +29,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/decimal256" "github.com/apache/arrow-go/v18/arrow/extensions" "github.com/apache/arrow-go/v18/arrow/memory" - hamba "github.com/hamba/avro/v2" + avro "github.com/twmb/avro" ) type dataLoader struct { @@ -397,14 +395,12 @@ func mapFieldBuilders(b array.Builder, field arrow.Field, parent *fieldPos) { switch bt := b.(type) { case *array.BinaryBuilder: f.appendFunc = func(data interface{}) error { - appendBinaryData(bt, data) - return nil + return appendBinaryData(bt, data) } case *array.BinaryDictionaryBuilder: // has metadata for Avro enum symbols f.appendFunc = func(data interface{}) error { - appendBinaryDictData(bt, data) - return nil + return appendBinaryDictData(bt, data) } // add Avro enum symbols to builder sb := array.NewStringBuilder(memory.DefaultAllocator) @@ -415,13 +411,11 @@ func mapFieldBuilders(b array.Builder, field arrow.Field, parent *fieldPos) { bt.InsertStringDictValues(sa) case *array.BooleanBuilder: f.appendFunc = func(data interface{}) error { - appendBoolData(bt, data) - return nil + return appendBoolData(bt, data) } case *array.Date32Builder: f.appendFunc = func(data interface{}) error { - appendDate32Data(bt, data) - return nil + return appendDate32Data(bt, data) } case *array.Decimal128Builder: f.appendFunc = func(data interface{}) error { @@ -429,11 +423,7 @@ func mapFieldBuilders(b array.Builder, field arrow.Field, parent *fieldPos) { if !ok { return nil } - err := appendDecimal128Data(bt, data, typ) - if err != nil { - return err - } - return nil + return appendDecimal128Data(bt, data, typ) } case *array.Decimal256Builder: f.appendFunc = func(data interface{}) error { @@ -441,54 +431,31 @@ func mapFieldBuilders(b array.Builder, field arrow.Field, parent *fieldPos) { if !ok { return nil } - err := appendDecimal256Data(bt, data, typ) - if err != nil { - return err - } - return nil + return appendDecimal256Data(bt, data, typ) } case *extensions.UUIDBuilder: f.appendFunc = func(data interface{}) error { - switch dt := data.(type) { - case nil: - bt.AppendNull() - case string: - err := bt.AppendValueFromString(dt) - if err != nil { - return err - } - case []byte: - err := bt.AppendValueFromString(string(dt)) - if err != nil { - return err - } - } - return nil + return appendUUIDData(bt, data, field.Name) } case *array.FixedSizeBinaryBuilder: f.appendFunc = func(data interface{}) error { - appendFixedSizeBinaryData(bt, data) - return nil + return appendFixedSizeBinaryData(bt, data) } case *array.Float32Builder: f.appendFunc = func(data interface{}) error { - appendFloat32Data(bt, data) - return nil + return appendFloat32Data(bt, data) } case *array.Float64Builder: f.appendFunc = func(data interface{}) error { - appendFloat64Data(bt, data) - return nil + return appendFloat64Data(bt, data) } case *array.Int32Builder: f.appendFunc = func(data interface{}) error { - appendInt32Data(bt, data) - return nil + return appendInt32Data(bt, data) } case *array.Int64Builder: f.appendFunc = func(data interface{}) error { - appendInt64Data(bt, data) - return nil + return appendInt64Data(bt, data) } case *array.LargeListBuilder: vb := bt.ValueBuilder() @@ -546,13 +513,11 @@ func mapFieldBuilders(b array.Builder, field arrow.Field, parent *fieldPos) { } case *array.MonthDayNanoIntervalBuilder: f.appendFunc = func(data interface{}) error { - appendDurationData(bt, data) - return nil + return appendDurationData(bt, data) } case *array.StringBuilder: f.appendFunc = func(data interface{}) error { - appendStringData(bt, data) - return nil + return appendStringData(bt, data) } case *array.StructBuilder: // has metadata for Avro Union named types @@ -574,126 +539,108 @@ func mapFieldBuilders(b array.Builder, field arrow.Field, parent *fieldPos) { } case *array.Time32Builder: f.appendFunc = func(data interface{}) error { - appendTime32Data(bt, data) - return nil + return appendTime32Data(bt, data) } case *array.Time64Builder: f.appendFunc = func(data interface{}) error { - appendTime64Data(bt, data) - return nil + return appendTime64Data(bt, data) } case *array.TimestampBuilder: f.appendFunc = func(data interface{}) error { - appendTimestampData(bt, data) - return nil + return appendTimestampData(bt, data) } } } -func appendBinaryData(b *array.BinaryBuilder, data interface{}) { +// appendUUIDData accepts the two shapes a UUID may arrive as: a [16]byte +// (fixed(16)+uuid) or a hex-dash string (string+uuid). Other byte lengths +// are rejected rather than re-interpreted. +func appendUUIDData(b *extensions.UUIDBuilder, data any, fieldName string) error { switch dt := data.(type) { case nil: b.AppendNull() - case map[string]any: - switch ct := dt["bytes"].(type) { - case nil: - b.AppendNull() + case string: + return b.AppendValueFromString(dt) + case [16]byte: + b.AppendBytes(dt) + case []byte: + switch len(dt) { + case 16: + b.AppendBytes([16]byte(dt)) + case 36: + return b.AppendValueFromString(string(dt)) default: - b.Append(ct.([]byte)) + return fmt.Errorf("avro: %d-byte value cannot be a UUID for column %q", len(dt), fieldName) } default: - b.Append(fmt.Append([]byte{}, data)) + return fmt.Errorf("avro: unsupported value of type %T for UUID column %q", data, fieldName) } + return nil } -func appendBinaryDictData(b *array.BinaryDictionaryBuilder, data interface{}) { +func appendBinaryData(b *array.BinaryBuilder, data interface{}) error { + switch dt := data.(type) { + case nil: + b.AppendNull() + case []byte: + b.Append(dt) + default: + return fmt.Errorf("avro: unsupported value of type %T for Binary column", data) + } + return nil +} + +func appendBinaryDictData(b *array.BinaryDictionaryBuilder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() case string: - b.AppendString(dt) - case map[string]any: - switch v := dt["string"].(type) { - case nil: - b.AppendNull() - case string: - b.AppendString(v) + if err := b.AppendString(dt); err != nil { + return fmt.Errorf("avro: enum symbol %q is not in the dictionary (schema/data mismatch?): %w", dt, err) } + default: + return fmt.Errorf("avro: unsupported value of type %T for Dictionary column", data) } + return nil } -func appendBoolData(b *array.BooleanBuilder, data interface{}) { +func appendBoolData(b *array.BooleanBuilder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() case bool: b.Append(dt) - case map[string]any: - switch v := dt["boolean"].(type) { - case nil: - b.AppendNull() - case bool: - b.Append(v) - } + default: + return fmt.Errorf("avro: unsupported value of type %T for Boolean column", data) } + return nil } -func appendDate32Data(b *array.Date32Builder, data interface{}) { +func appendDate32Data(b *array.Date32Builder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() - case int32: - b.Append(arrow.Date32(dt)) - case map[string]any: - switch v := dt["int"].(type) { - case nil: - b.AppendNull() - case int32: - b.Append(arrow.Date32(v)) - } case time.Time: b.Append(arrow.Date32FromTime(dt)) + default: + return fmt.Errorf("avro: unsupported value of type %T for Date32 column", data) } + return nil } func appendDecimal128Data(b *array.Decimal128Builder, data interface{}, typ arrow.DecimalType) error { switch dt := data.(type) { case nil: b.AppendNull() - case []byte: - buf := bytes.NewBuffer(dt) - if len(dt) <= 38 { - var intData int64 - err := binary.Read(buf, binary.BigEndian, &intData) - if err != nil { - return err - } - b.Append(decimal128.FromI64(intData)) - } else { - var bigIntData big.Int - b.Append(decimal128.FromBigInt(bigIntData.SetBytes(buf.Bytes()))) - } - case map[string]any: - buf := bytes.NewBuffer(dt["bytes"].([]byte)) - if len(dt["bytes"].([]byte)) <= 38 { - var intData int64 - err := binary.Read(buf, binary.BigEndian, &intData) - if err != nil { - return err - } - b.Append(decimal128.FromI64(intData)) - } else { - var bigIntData big.Int - b.Append(decimal128.FromBigInt(bigIntData.SetBytes(buf.Bytes()))) - } case *big.Rat: v := bigRatToBigInt(dt, typ) - if v.IsInt64() { b.Append(decimal128.FromI64(v.Int64())) } else { b.Append(decimal128.FromBigInt(v)) } + default: + return fmt.Errorf("avro: unsupported value of type %T for Decimal128 column", data) } return nil } @@ -702,16 +649,10 @@ func appendDecimal256Data(b *array.Decimal256Builder, data interface{}, typ arro switch dt := data.(type) { case nil: b.AppendNull() - case []byte: - var bigIntData big.Int - buf := bytes.NewBuffer(dt) - b.Append(decimal256.FromBigInt(bigIntData.SetBytes(buf.Bytes()))) - case map[string]any: - var bigIntData big.Int - buf := bytes.NewBuffer(dt["bytes"].([]byte)) - b.Append(decimal256.FromBigInt(bigIntData.SetBytes(buf.Bytes()))) case *big.Rat: b.Append(decimal256.FromBigInt(bigRatToBigInt(dt, typ))) + default: + return fmt.Errorf("avro: unsupported value of type %T for Decimal256 column", data) } return nil } @@ -728,203 +669,138 @@ func bigRatToBigInt(dt *big.Rat, typ arrow.DecimalType) *big.Int { // Avro duration logical type annotates Avro fixed type of size 12, which stores three little-endian // unsigned integers that represent durations at different granularities of time. The first stores // a number in months, the second stores a number in days, and the third stores a number in milliseconds. -func appendDurationData(b *array.MonthDayNanoIntervalBuilder, data interface{}) { +func appendDurationData(b *array.MonthDayNanoIntervalBuilder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() - case []byte: - dur := new(arrow.MonthDayNanoInterval) - dur.Months = int32(binary.LittleEndian.Uint16(dt[:3])) - dur.Days = int32(binary.LittleEndian.Uint16(dt[4:7])) - dur.Nanoseconds = int64(binary.LittleEndian.Uint32(dt[8:]) * 1000000) - b.Append(*dur) - case map[string]any: - switch dtb := dt["bytes"].(type) { - case nil: - b.AppendNull() - case []byte: - dur := new(arrow.MonthDayNanoInterval) - dur.Months = int32(binary.LittleEndian.Uint16(dtb[:3])) - dur.Days = int32(binary.LittleEndian.Uint16(dtb[4:7])) - dur.Nanoseconds = int64(binary.LittleEndian.Uint32(dtb[8:]) * 1000000) - b.Append(*dur) - } - case hamba.LogicalDuration: + case avro.Duration: b.Append(arrow.MonthDayNanoInterval{ Months: int32(dt.Months), Days: int32(dt.Days), Nanoseconds: int64(dt.Milliseconds) * int64(time.Millisecond), }) + default: + return fmt.Errorf("avro: unsupported value of type %T for Duration column", data) } + return nil } -func appendFixedSizeBinaryData(b *array.FixedSizeBinaryBuilder, data interface{}) { +func appendFixedSizeBinaryData(b *array.FixedSizeBinaryBuilder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() case []byte: b.Append(dt) - case map[string]any: - switch v := dt["bytes"].(type) { - case nil: - b.AppendNull() - case []byte: - b.Append(v) - } default: + // fixed(N) may arrive as a Go [N]byte; accept any byte-array via reflection. v := reflect.ValueOf(data) if v.Kind() == reflect.Array && v.Type().Elem().Kind() == reflect.Uint8 { - bytes := make([]byte, v.Len()) - reflect.Copy(reflect.ValueOf(bytes), v) - b.Append(bytes) + buf := make([]byte, v.Len()) + reflect.Copy(reflect.ValueOf(buf), v) + b.Append(buf) + return nil } + return fmt.Errorf("avro: unsupported value of type %T for FixedSizeBinary column", data) } + return nil } -func appendFloat32Data(b *array.Float32Builder, data interface{}) { +func appendFloat32Data(b *array.Float32Builder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() case float32: b.Append(dt) - case map[string]any: - switch v := dt["float"].(type) { - case nil: - b.AppendNull() - case float32: - b.Append(v) - } + default: + return fmt.Errorf("avro: unsupported value of type %T for Float32 column", data) } + return nil } -func appendFloat64Data(b *array.Float64Builder, data interface{}) { +func appendFloat64Data(b *array.Float64Builder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() case float64: b.Append(dt) - case map[string]any: - switch v := dt["double"].(type) { - case nil: - b.AppendNull() - case float64: - b.Append(v) - } + default: + return fmt.Errorf("avro: unsupported value of type %T for Float64 column", data) } + return nil } -func appendInt32Data(b *array.Int32Builder, data interface{}) { +func appendInt32Data(b *array.Int32Builder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() - case int: - b.Append(int32(dt)) case int32: b.Append(dt) - case map[string]any: - switch v := dt["int"].(type) { - case nil: - b.AppendNull() - case int: - b.Append(int32(v)) - case int32: - b.Append(v) - } + default: + return fmt.Errorf("avro: unsupported value of type %T for Int32 column", data) } + return nil } -func appendInt64Data(b *array.Int64Builder, data interface{}) { +func appendInt64Data(b *array.Int64Builder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() - case int: - b.Append(int64(dt)) case int64: b.Append(dt) - case map[string]any: - switch v := dt["long"].(type) { - case nil: - b.AppendNull() - case int: - b.Append(int64(v)) - case int64: - b.Append(v) - } + default: + return fmt.Errorf("avro: unsupported value of type %T for Int64 column", data) } + return nil } -func appendStringData(b *array.StringBuilder, data interface{}) { +func appendStringData(b *array.StringBuilder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() case string: b.Append(dt) - case map[string]any: - switch v := dt["string"].(type) { - case nil: - b.AppendNull() - case string: - b.Append(v) - } default: - b.Append(fmt.Sprint(data)) + return fmt.Errorf("avro: unsupported value of type %T for String column", data) } + return nil } -func appendTime32Data(b *array.Time32Builder, data interface{}) { +func appendTime32Data(b *array.Time32Builder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() - case int32: - b.Append(arrow.Time32(dt)) - case map[string]any: - switch v := dt["int"].(type) { - case nil: - b.AppendNull() - case int32: - b.Append(arrow.Time32(v)) - } case time.Duration: b.Append(arrow.Time32(dt.Milliseconds())) + default: + return fmt.Errorf("avro: unsupported value of type %T for Time32 column", data) } + return nil } -func appendTime64Data(b *array.Time64Builder, data interface{}) { +func appendTime64Data(b *array.Time64Builder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() - case int64: - b.Append(arrow.Time64(dt)) - case map[string]any: - switch v := dt["long"].(type) { - case nil: - b.AppendNull() - case int64: - b.Append(arrow.Time64(v)) - } case time.Duration: b.Append(arrow.Time64(dt.Microseconds())) + default: + return fmt.Errorf("avro: unsupported value of type %T for Time64 column", data) } + return nil } -func appendTimestampData(b *array.TimestampBuilder, data interface{}) { +func appendTimestampData(b *array.TimestampBuilder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() - case int64: - b.Append(arrow.Timestamp(dt)) - case map[string]any: - switch v := dt["long"].(type) { - case nil: - b.AppendNull() - case int64: - b.Append(arrow.Timestamp(v)) - } case time.Time: v, err := arrow.TimestampFromTime(dt, b.Type().(*arrow.TimestampType).Unit) if err != nil { - panic(err) + return err } b.Append(v) + default: + return fmt.Errorf("avro: unsupported value of type %T for Timestamp column", data) } + return nil } diff --git a/arrow/avro/schema.go b/arrow/avro/schema.go index 4d9e76707..13214ca23 100644 --- a/arrow/avro/schema.go +++ b/arrow/avro/schema.go @@ -18,6 +18,7 @@ package avro import ( + "encoding/json" "fmt" "math" "strconv" @@ -26,24 +27,34 @@ import ( "github.com/apache/arrow-go/v18/arrow/decimal128" "github.com/apache/arrow-go/v18/arrow/extensions" "github.com/apache/arrow-go/v18/internal/utils" - avro "github.com/hamba/avro/v2" + hambaAvro "github.com/hamba/avro/v2" + avro "github.com/twmb/avro" ) +// builtinAvroTypes is the set of Type field values that mean "this SchemaNode +// is the inline definition of an Avro type." Anything else in node.Type is +// treated as a named-type reference to a previously-seen record/enum/fixed. +var builtinAvroTypes = map[string]struct{}{ + "null": {}, "boolean": {}, "int": {}, "long": {}, + "float": {}, "double": {}, "bytes": {}, "string": {}, + "record": {}, "enum": {}, "array": {}, "map": {}, + "fixed": {}, "union": {}, +} + type schemaNode struct { - name string - parent *schemaNode - schema avro.Schema - union bool - nullable bool - childrens []*schemaNode - arrowField arrow.Field - schemaCache *avro.SchemaCache - index, depth int32 + name string + parent *schemaNode + node avro.SchemaNode + union bool + nullable bool + childrens []*schemaNode + arrowField arrow.Field + namedCache map[string]avro.SchemaNode + index int32 } func newSchemaNode() *schemaNode { - var schemaCache avro.SchemaCache - return &schemaNode{name: "", index: -1, schemaCache: &schemaCache} + return &schemaNode{index: -1, namedCache: map[string]avro.SchemaNode{}} } func (node *schemaNode) schemaPath() string { @@ -56,33 +67,84 @@ func (node *schemaNode) schemaPath() string { return path } -func (node *schemaNode) newChild(n string, s avro.Schema) *schemaNode { +func (node *schemaNode) newChild(n string, s avro.SchemaNode) *schemaNode { child := &schemaNode{ - name: n, - parent: node, - schema: s, - schemaCache: node.schemaCache, - index: int32(len(node.childrens)), - depth: node.depth + 1, + name: n, + parent: node, + node: s, + namedCache: node.namedCache, + index: int32(len(node.childrens)), } node.childrens = append(node.childrens, child) return child } func (node *schemaNode) children() []*schemaNode { return node.childrens } -// func (node *schemaNode) nodeName() string { return node.name } +// rememberNamed adds a record/enum/fixed SchemaNode to the named-type cache +// under both its short name and (if a namespace is present) its full name, +// so later references like {"type": "Address"} or {"type": "ns.Address"} +// resolve back to the original definition. +func (node *schemaNode) rememberNamed(s avro.SchemaNode) { + if s.Name == "" { + return + } + node.namedCache[s.Name] = s + if s.Namespace != "" { + node.namedCache[s.Namespace+"."+s.Name] = s + } +} + +// resolveRef replaces s with its inline definition if s.Type is a named-type +// reference rather than a builtin Avro type. atField, when non-empty, names +// the field this reference appears in and is included in the panic so the +// user can locate the offending entry. +func (node *schemaNode) resolveRef(s avro.SchemaNode, atField string) avro.SchemaNode { + if _, ok := builtinAvroTypes[s.Type]; ok { + return s + } + if def, ok := node.namedCache[s.Type]; ok { + return def + } + loc := node.schemaPath() + if atField != "" { + loc += "." + atField + } + panic(fmt.Errorf("unknown named type %q referenced at %s", s.Type, loc)) +} + +// ArrowSchemaFromAvroJSON parses an Avro schema given as JSON text and returns +// the equivalent Arrow schema. +func ArrowSchemaFromAvroJSON(schemaJSON string) (*arrow.Schema, error) { + schema, err := avro.Parse(schemaJSON) + if err != nil { + return nil, err + } + return arrowSchemaFromAvroInternal(schema) +} -// ArrowSchemaFromAvro returns a new Arrow schema from an Avro schema -func ArrowSchemaFromAvro(schema avro.Schema) (s *arrow.Schema, err error) { +// ArrowSchemaFromAvro returns a new Arrow schema from a parsed Avro schema. +// +// Deprecated: Use [ArrowSchemaFromAvroJSON] instead — it does not couple +// callers to a particular Avro library through its signature. +func ArrowSchemaFromAvro(schema hambaAvro.Schema) (*arrow.Schema, error) { + js, err := json.Marshal(schema) + if err != nil { + return nil, fmt.Errorf("%w: could not serialize hamba avro schema: %w", arrow.ErrInvalid, err) + } + return ArrowSchemaFromAvroJSON(string(js)) +} + +func arrowSchemaFromAvroInternal(schema *avro.Schema) (s *arrow.Schema, err error) { defer func() { if r := recover(); r != nil { s = nil err = utils.FormatRecoveredError("invalid avro schema", r) } }() + root := schema.Root() n := newSchemaNode() - n.schema = schema - c := n.newChild(n.schema.(avro.NamedSchema).Name(), n.schema) + n.node = root + c := n.newChild(root.Name, root) arrowSchemafromAvro(c) var fields []arrow.Field for _, g := range c.children() { @@ -93,16 +155,16 @@ func ArrowSchemaFromAvro(schema avro.Schema) (s *arrow.Schema, err error) { } func arrowSchemafromAvro(n *schemaNode) { - if ns, ok := n.schema.(avro.NamedSchema); ok { - n.schemaCache.Add(ns.Name(), ns) + n.node = n.resolveRef(n.node, "") + if n.node.Name != "" { + n.rememberNamed(n.node) } - switch st := n.schema.Type(); st { + switch st := n.node.Type; st { case "record": iterateFields(n) case "enum": - n.schemaCache.Add(n.schema.(avro.NamedSchema).Name(), n.schema.(*avro.EnumSchema)) symbols := make(map[string]string) - for index, symbol := range n.schema.(avro.PropertySchema).(*avro.EnumSchema).Symbols() { + for index, symbol := range n.node.Symbols { k := strconv.FormatInt(int64(index), 10) symbols[k] = symbol } @@ -118,9 +180,12 @@ func arrowSchemafromAvro(n *schemaNode) { } n.arrowField = buildArrowField(n, &dt, arrow.MetadataFrom(symbols)) case "array": - // logical items type - c := n.newChild(n.name, n.schema.(*avro.ArraySchema).Items()) - if isLogicalSchemaType(n.schema.(*avro.ArraySchema).Items()) { + if n.node.Items == nil { + panic(fmt.Errorf("avro array schema at %s has no 'items'", n.schemaPath())) + } + items := *n.node.Items + c := n.newChild(n.name, items) + if isLogicalSchemaType(items) { avroLogicalToArrowField(c) } else { arrowSchemafromAvro(c) @@ -134,59 +199,58 @@ func arrowSchemafromAvro(n *schemaNode) { } n.arrowField = buildArrowField(n, typ, c.arrowField.Metadata) case "map": - n.schemaCache.Add(n.schema.(*avro.MapSchema).Values().(avro.NamedSchema).Name(), n.schema.(*avro.MapSchema).Values()) - c := n.newChild(n.name, n.schema.(*avro.MapSchema).Values()) + if n.node.Values == nil { + panic(fmt.Errorf("avro map schema at %s has no 'values'", n.schemaPath())) + } + values := *n.node.Values + c := n.newChild(n.name, values) arrowSchemafromAvro(c) n.arrowField = buildArrowField(n, arrow.MapOf(arrow.BinaryTypes.String, c.arrowField.Type), c.arrowField.Metadata) case "union": - if n.schema.(*avro.UnionSchema).Nullable() { - if len(n.schema.(*avro.UnionSchema).Types()) > 1 { - n.schema = n.schema.(*avro.UnionSchema).Types()[1] - n.union = true - n.nullable = true - arrowSchemafromAvro(n) - } + branch, ok := nullableBranch(n.node) + if !ok { + panic(fmt.Errorf("unsupported avro union at %s: only ['null', T] unions with exactly one non-null branch are supported", n.schemaPath())) } + n.node = branch + n.union = true + n.nullable = true + arrowSchemafromAvro(n) // Avro "fixed" field type = Arrow FixedSize Primitive BinaryType case "fixed": - n.schemaCache.Add(n.schema.(avro.NamedSchema).Name(), n.schema.(*avro.FixedSchema)) - if isLogicalSchemaType(n.schema) { + if isLogicalSchemaType(n.node) { avroLogicalToArrowField(n) } else { - n.arrowField = buildArrowField(n, &arrow.FixedSizeBinaryType{ByteWidth: n.schema.(*avro.FixedSchema).Size()}, arrow.Metadata{}) + n.arrowField = buildArrowField(n, &arrow.FixedSizeBinaryType{ByteWidth: n.node.Size}, arrow.Metadata{}) } case "string", "bytes", "int", "long": - if isLogicalSchemaType(n.schema) { + if isLogicalSchemaType(n.node) { avroLogicalToArrowField(n) } else { n.arrowField = buildArrowField(n, avroPrimitiveToArrowType(string(st)), arrow.Metadata{}) } case "float", "double", "boolean": n.arrowField = buildArrowField(n, avroPrimitiveToArrowType(string(st)), arrow.Metadata{}) - case "": - refSchema := n.schemaCache.Get(string(n.schema.(*avro.RefSchema).Schema().Name())) - if refSchema == nil { - panic(fmt.Errorf("could not find schema for '%v' in schema cache - %v", n.schemaPath(), n.schema.(*avro.RefSchema).Schema().Name())) - } - n.schema = refSchema - arrowSchemafromAvro(n) case "null": - n.schemaCache.Add(n.schema.(*avro.MapSchema).Values().(avro.NamedSchema).Name(), &avro.NullSchema{}) n.nullable = true n.arrowField = buildArrowField(n, arrow.Null, arrow.Metadata{}) + default: + panic(fmt.Errorf("unhandled avro type %q at %s", st, n.schemaPath())) } } -// iterate record Fields() +// iterate record Fields func iterateFields(n *schemaNode) { - for _, f := range n.schema.(*avro.RecordSchema).Fields() { - switch ft := f.Type().(type) { + for _, f := range n.node.Fields { + ft := n.resolveRef(f.Type, f.Name) + switch ft.Type { // Avro "array" field type - case *avro.ArraySchema: - n.schemaCache.Add(f.Name(), ft.Items()) - // logical items type - c := n.newChild(f.Name(), ft.Items()) - if isLogicalSchemaType(ft.Items()) { + case "array": + if ft.Items == nil { + panic(fmt.Errorf("avro array field %s.%s has no 'items'", n.schemaPath(), f.Name)) + } + items := *ft.Items + c := n.newChild(f.Name, items) + if isLogicalSchemaType(items) { avroLogicalToArrowField(c) } else { arrowSchemafromAvro(c) @@ -198,11 +262,11 @@ func iterateFields(n *schemaNode) { c.arrowField = arrow.Field{Name: c.name, Type: arrow.ListOfNonNullable(c.arrowField.Type), Metadata: c.arrowField.Metadata} } // Avro "enum" field type = Arrow dictionary type - case *avro.EnumSchema: - n.schemaCache.Add(f.Type().(*avro.EnumSchema).Name(), f.Type()) - c := n.newChild(f.Name(), f.Type()) + case "enum": + n.rememberNamed(ft) + c := n.newChild(f.Name, ft) symbols := make(map[string]string) - for index, symbol := range ft.Symbols() { + for index, symbol := range ft.Symbols { k := strconv.FormatInt(int64(index), 10) symbols[k] = symbol } @@ -218,44 +282,43 @@ func iterateFields(n *schemaNode) { } c.arrowField = buildArrowField(c, &dt, arrow.MetadataFrom(symbols)) // Avro "fixed" field type = Arrow FixedSize Primitive BinaryType - case *avro.FixedSchema: - n.schemaCache.Add(f.Name(), f.Type()) - c := n.newChild(f.Name(), f.Type()) - if isLogicalSchemaType(f.Type()) { + case "fixed": + n.rememberNamed(ft) + c := n.newChild(f.Name, ft) + if isLogicalSchemaType(ft) { avroLogicalToArrowField(c) } else { arrowSchemafromAvro(c) } - case *avro.RecordSchema: - n.schemaCache.Add(f.Name(), f.Type()) - c := n.newChild(f.Name(), f.Type()) + case "record": + n.rememberNamed(ft) + c := n.newChild(f.Name, ft) iterateFields(c) - // Avro "map" field type - KVP with value of one type - keys are strings - case *avro.MapSchema: - n.schemaCache.Add(f.Name(), ft.Values()) - c := n.newChild(f.Name(), ft.Values()) + // Avro "map" field type - KVP with value of one type - keys are strings + case "map": + if ft.Values == nil { + panic(fmt.Errorf("avro map field %s.%s has no 'values'", n.schemaPath(), f.Name)) + } + values := *ft.Values + c := n.newChild(f.Name, values) arrowSchemafromAvro(c) c.arrowField = buildArrowField(c, arrow.MapOf(arrow.BinaryTypes.String, c.arrowField.Type), c.arrowField.Metadata) - case *avro.UnionSchema: - if ft.Nullable() { - if len(ft.Types()) > 1 { - n.schemaCache.Add(f.Name(), ft.Types()[1]) - c := n.newChild(f.Name(), ft.Types()[1]) - c.union = true - c.nullable = true - arrowSchemafromAvro(c) - } + case "union": + branch, ok := nullableBranch(ft) + if !ok { + panic(fmt.Errorf("unsupported avro union at %s.%s: only ['null', T] unions with exactly one non-null branch are supported", n.schemaPath(), f.Name)) } + c := n.newChild(f.Name, branch) + c.union = true + c.nullable = true + arrowSchemafromAvro(c) default: - n.schemaCache.Add(f.Name(), f.Type()) - if isLogicalSchemaType(f.Type()) { - c := n.newChild(f.Name(), f.Type()) + c := n.newChild(f.Name, ft) + if isLogicalSchemaType(ft) { avroLogicalToArrowField(c) } else { - c := n.newChild(f.Name(), f.Type()) arrowSchemafromAvro(c) } - } } var fields []arrow.Field @@ -263,7 +326,7 @@ func iterateFields(n *schemaNode) { fields = append(fields, child.arrowField) } - namedSchema, ok := isNamedSchema(n.schema) + namedSchema, ok := isNamedSchema(n.node) var md arrow.Metadata if ok && namedSchema != n.name+"_data" && n.union { @@ -272,22 +335,46 @@ func iterateFields(n *schemaNode) { n.arrowField = buildArrowField(n, arrow.StructOf(fields...), md) } -func isLogicalSchemaType(s avro.Schema) bool { - lts, ok := s.(avro.LogicalTypeSchema) - if !ok { - return false +// nullableBranch returns the non-null branch of a two-element ["null", T] +// union, plus true if the union is in that nullable shape. If the union has +// more than two branches or no null branch, ok is false. +// +// Heterogeneous non-nullable unions (e.g. ["null", "int", "string"] or +// ["int", "string"]) are not supported and callers panic on them rather +// than silently picking one arm. +func nullableBranch(s avro.SchemaNode) (avro.SchemaNode, bool) { + if s.Type != "union" || len(s.Branches) < 2 { + return avro.SchemaNode{}, false + } + var nonNull *avro.SchemaNode + for i := range s.Branches { + b := s.Branches[i] + if b.Type == "null" { + continue + } + if nonNull != nil { + return avro.SchemaNode{}, false + } + nonNull = &b } - if lts.Logical() != nil { - return true + if nonNull == nil { + return avro.SchemaNode{}, false } - return false + return *nonNull, true } -func isNamedSchema(s avro.Schema) (string, bool) { - if ns, ok := s.(avro.NamedSchema); ok { - return ns.FullName(), ok +func isLogicalSchemaType(s avro.SchemaNode) bool { + return s.LogicalType != "" +} + +func isNamedSchema(s avro.SchemaNode) (string, bool) { + if s.Name == "" { + return "", false } - return "", false + if s.Namespace != "" { + return s.Namespace + "." + s.Name, true + } + return s.Name, true } func buildArrowField(n *schemaNode, t arrow.DataType, m arrow.Metadata) arrow.Field { @@ -332,7 +419,7 @@ func avroPrimitiveToArrowType(avroFieldType string) arrow.DataType { func avroLogicalToArrowField(n *schemaNode) { var dt arrow.DataType // Avro logical types - switch lt := n.schema.(avro.LogicalTypeSchema).Logical(); lt.Type() { + switch n.node.LogicalType { // The decimal logical type represents an arbitrary-precision signed decimal number of the form unscaled × 10-scale. // A decimal logical type annotates Avro bytes or fixed types. The byte array must contain the two’s-complement // representation of the unscaled integer value in big-endian byte order. The scale is fixed, and is specified @@ -343,13 +430,13 @@ func avroLogicalToArrowField(n *schemaNode) { // precision, a JSON integer representing the (maximum) precision of decimals stored in this type (required). case "decimal": id := arrow.DECIMAL128 - if lt.(*avro.DecimalLogicalSchema).Precision() > decimal128.MaxPrecision { + if n.node.Precision > decimal128.MaxPrecision { id = arrow.DECIMAL256 } - dt, _ = arrow.NewDecimalType(id, int32(lt.(*avro.DecimalLogicalSchema).Precision()), int32(lt.(*avro.DecimalLogicalSchema).Scale())) + dt, _ = arrow.NewDecimalType(id, int32(n.node.Precision), int32(n.node.Scale)) - // The uuid logical type represents a random generated universally unique identifier (UUID). - // A uuid logical type annotates an Avro string. The string has to conform with RFC-4122 + // The uuid logical type represents a random generated universally unique identifier (UUID). + // A uuid logical type annotates an Avro string. The string has to conform with RFC-4122 case "uuid": dt = extensions.NewUUIDType() @@ -394,21 +481,19 @@ func avroLogicalToArrowField(n *schemaNode) { case "timestamp-micros": dt = arrow.FixedWidthTypes.Timestamp_us - // The local-timestamp-millis logical type represents a timestamp in a local timezone, regardless of - // what specific time zone is considered local, with a precision of one millisecond. - // A local-timestamp-millis logical type annotates an Avro long, where the long stores the number of - // milliseconds, from 1 January 1970 00:00:00.000. - // Note: not implemented in hamba/avro - // case "local-timestamp-millis": - // dt = &arrow.TimestampType{Unit: arrow.Millisecond} - - // The local-timestamp-micros logical type represents a timestamp in a local timezone, regardless of - // what specific time zone is considered local, with a precision of one microsecond. - // A local-timestamp-micros logical type annotates an Avro long, where the long stores the number of - // microseconds, from 1 January 1970 00:00:00.000000. - // case "local-timestamp-micros": - // Note: not implemented in hamba/avro - // dt = &arrow.TimestampType{Unit: arrow.Microsecond} + // The timestamp-nanos logical type represents an instant on the global timeline with nanosecond + // precision. twmb/avro decodes it to time.Time (UTC). + case "timestamp-nanos": + dt = arrow.FixedWidthTypes.Timestamp_ns + + // The local-timestamp-millis/micros/nanos logical types represent a timestamp in a local timezone. + // Arrow models that as a TimestampType with no time zone set. + case "local-timestamp-millis": + dt = &arrow.TimestampType{Unit: arrow.Millisecond} + case "local-timestamp-micros": + dt = &arrow.TimestampType{Unit: arrow.Microsecond} + case "local-timestamp-nanos": + dt = &arrow.TimestampType{Unit: arrow.Nanosecond} // The duration logical type represents an amount of time defined by a number of months, days and milliseconds. // This is not equivalent to a number of milliseconds, because, depending on the moment in time from which the diff --git a/arrow/avro/schema_test.go b/arrow/avro/schema_test.go index 33b6d2a05..52a42809f 100644 --- a/arrow/avro/schema_test.go +++ b/arrow/avro/schema_test.go @@ -22,6 +22,8 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/avro/testdata" + "github.com/apache/arrow-go/v18/arrow/extensions" + hambaAvro "github.com/hamba/avro/v2" ) func TestSchemaStringEqual(t *testing.T) { @@ -126,6 +128,10 @@ func TestSchemaStringEqual(t *testing.T) { Name: "uuidField", Type: arrow.BinaryTypes.String, }, + { + Name: "fixedUuidField", + Type: extensions.NewUUIDType(), + }, { Name: "timemillis", Type: arrow.FixedWidthTypes.Time32ms, @@ -162,7 +168,7 @@ func TestSchemaStringEqual(t *testing.T) { if err != nil { t.Fatalf("%v", err) } - got, err := ArrowSchemaFromAvro(schema) + got, err := ArrowSchemaFromAvroJSON(schema) if err != nil { t.Fatalf("%v", err) } @@ -174,3 +180,37 @@ func TestSchemaStringEqual(t *testing.T) { }) } } + +// Remove together with [ArrowSchemaFromAvro] at the next major release. +func TestArrowSchemaFromAvro_Deprecated_PreservesLogicalTypesOnFixed(t *testing.T) { + const schemaJSON = `{ + "type": "record", + "name": "Sample", + "fields": [ + {"name": "id", "type": "int"}, + {"name": "name", "type": "string"}, + {"name": "nullable_double", "type": ["null", "double"]}, + {"name": "uuid_string", "type": {"type": "string", "logicalType": "uuid"}}, + {"name": "ts_millis", "type": {"type": "long", "logicalType": "timestamp-millis"}}, + {"name": "fixed_uuid", "type": {"type": "fixed", "name": "FUUID", "size": 16, "logicalType": "uuid"}}, + {"name": "fixed_decimal", "type": {"type": "fixed", "name": "FDec", "size": 16, "logicalType": "decimal", "precision": 20, "scale": 4}}, + {"name": "fixed_duration", "type": {"type": "fixed", "name": "FDur", "size": 12, "logicalType": "duration"}} + ] + }` + hambaSchema, err := hambaAvro.Parse(schemaJSON) + if err != nil { + t.Fatalf("hamba parse: %v", err) + } + + got, err := ArrowSchemaFromAvro(hambaSchema) + if err != nil { + t.Fatalf("ArrowSchemaFromAvro: %v", err) + } + want, err := ArrowSchemaFromAvroJSON(schemaJSON) + if err != nil { + t.Fatalf("ArrowSchemaFromAvroJSON: %v", err) + } + if got.String() != want.String() { + t.Fatalf("schema mismatch:\n got = %s\nwant = %s", got.String(), want.String()) + } +} diff --git a/arrow/avro/testdata/alltypes.avsc b/arrow/avro/testdata/alltypes.avsc index 29a72e56d..1c2f01647 100644 --- a/arrow/avro/testdata/alltypes.avsc +++ b/arrow/avro/testdata/alltypes.avsc @@ -164,6 +164,15 @@ "name": "uuidField", "type": "string" }, + { + "name": "fixedUuidField", + "type": { + "type": "fixed", + "name": "FixedUUID", + "size": 16, + "logicalType": "uuid" + } + }, { "name": "timemillis", "type": { diff --git a/arrow/avro/testdata/testdata.go b/arrow/avro/testdata/testdata.go index 235231dab..70c694627 100644 --- a/arrow/avro/testdata/testdata.go +++ b/arrow/avro/testdata/testdata.go @@ -17,8 +17,6 @@ package testdata import ( - "encoding/base64" - "encoding/binary" "encoding/json" "fmt" "log" @@ -28,8 +26,9 @@ import ( "strings" "time" - avro "github.com/hamba/avro/v2" - "github.com/hamba/avro/v2/ocf" + "github.com/google/uuid" + "github.com/twmb/avro" + "github.com/twmb/avro/ocf" ) const ( @@ -42,107 +41,169 @@ const ( type ByteArray []byte func (b ByteArray) MarshalJSON() ([]byte, error) { - s := fmt.Sprint(b) - encoded := base64.StdEncoding.EncodeToString([]byte(s)) - return json.Marshal(encoded) + return json.Marshal([]byte(b)) } -type TimestampMicros int64 +type TimestampJSON time.Time -func (t TimestampMicros) MarshalJSON() ([]byte, error) { - ts := time.Unix(0, int64(t)*int64(time.Microsecond)).UTC().Format(time.RFC3339Nano) - return json.Marshal(ts) +func (t TimestampJSON) MarshalJSON() ([]byte, error) { + return json.Marshal(time.Time(t).UTC().Format(time.RFC3339Nano)) } -type TimestampMillis int64 +type TimeMillisJSON time.Duration -func (t TimestampMillis) MarshalJSON() ([]byte, error) { - ts := time.Unix(0, int64(t)*int64(time.Millisecond)).UTC().Format(time.RFC3339Nano) - return json.Marshal(ts) -} - -type TimeMillis time.Duration - -func (t TimeMillis) MarshalJSON() ([]byte, error) { +func (t TimeMillisJSON) MarshalJSON() ([]byte, error) { ts := time.Unix(0, int64(t)).UTC().Format("15:04:05.000") return json.Marshal(strings.TrimRight(ts, "0.")) } -type TimeMicros time.Duration +type TimeMicrosJSON time.Duration -func (t TimeMicros) MarshalJSON() ([]byte, error) { +func (t TimeMicrosJSON) MarshalJSON() ([]byte, error) { ts := time.Unix(0, int64(t)).UTC().Format("15:04:05.000000") return json.Marshal(strings.TrimRight(ts, "0.")) } -type ExplicitNamespace [12]byte +type FixedJSON []byte -func (t ExplicitNamespace) MarshalJSON() ([]byte, error) { - return json.Marshal(t[:]) +func (t FixedJSON) MarshalJSON() ([]byte, error) { + return json.Marshal([]byte(t)) } -type MD5 [16]byte +type FixedUUIDJSON [16]byte -func (t MD5) MarshalJSON() ([]byte, error) { - return json.Marshal(t[:]) +func (t FixedUUIDJSON) MarshalJSON() ([]byte, error) { + return json.Marshal(uuid.UUID(t).String()) } -type DecimalType []byte +type DecimalJSON struct { + Rat *big.Rat +} -func (t DecimalType) MarshalJSON() ([]byte, error) { - v := new(big.Int).SetBytes(t) - s := fmt.Sprintf("%0*s", decimalTypeScale+1, v.String()) +func (t DecimalJSON) MarshalJSON() ([]byte, error) { + num := new(big.Int).Set(t.Rat.Num()) + den := new(big.Int).Set(t.Rat.Denom()) + scaleFactor := new(big.Int).Exp(big.NewInt(10), big.NewInt(decimalTypeScale), nil) + num.Mul(num, scaleFactor) + num.Quo(num, den) + s := fmt.Sprintf("%0*s", decimalTypeScale+1, num.String()) point := len(s) - decimalTypeScale return json.Marshal(s[:point] + "." + s[point:]) } -type Duration [12]byte - -func (t Duration) MarshalJSON() ([]byte, error) { - milliseconds := int32(binary.LittleEndian.Uint32(t[8:12])) +type DurationJSON avro.Duration - m := map[string]interface{}{ - "months": int32(binary.LittleEndian.Uint32(t[0:4])), - "days": int32(binary.LittleEndian.Uint32(t[4:8])), - "nanoseconds": int64(milliseconds) * int64(time.Millisecond), +func (t DurationJSON) MarshalJSON() ([]byte, error) { + m := map[string]any{ + "months": int32(t.Months), + "days": int32(t.Days), + "nanoseconds": int64(t.Milliseconds) * int64(time.Millisecond), } return json.Marshal(m) } -type Date int32 +type DateJSON time.Time -func (t Date) MarshalJSON() ([]byte, error) { - v := time.Unix(int64(t)*86400, 0).UTC().Format("2006-01-02") - return json.Marshal(v) +func (t DateJSON) MarshalJSON() ([]byte, error) { + return json.Marshal(time.Time(t).UTC().Format("2006-01-02")) } type Example struct { - InheritNull string `avro:"inheritNull" json:"inheritNull"` - ExplicitNamespace ExplicitNamespace `avro:"explicitNamespace" json:"explicitNamespace"` - FullName FullNameData `avro:"fullName" json:"fullName"` - ID int32 `avro:"id" json:"id"` - BigID int64 `avro:"bigId" json:"bigId"` - Temperature *float32 `avro:"temperature" json:"temperature"` - Fraction *float64 `avro:"fraction" json:"fraction"` - IsEmergency bool `avro:"is_emergency" json:"is_emergency"` - RemoteIP *ByteArray `avro:"remote_ip" json:"remote_ip"` - NullableRemoteIPS *[]ByteArray `avro:"nullable_remote_ips" json:"nullable_remote_ips"` - Person PersonData `avro:"person" json:"person"` - DecimalField DecimalType `avro:"decimalField" json:"decimalField"` - Decimal256Field DecimalType `avro:"decimal256Field" json:"decimal256Field"` - UUIDField string `avro:"uuidField" json:"uuidField"` - TimeMillis TimeMillis `avro:"timemillis" json:"timemillis"` - TimeMicros TimeMicros `avro:"timemicros" json:"timemicros"` - TimestampMillis TimestampMillis `avro:"timestampmillis" json:"timestampmillis"` - TimestampMicros TimestampMicros `avro:"timestampmicros" json:"timestampmicros"` - Duration Duration `avro:"duration" json:"duration"` - Date Date `avro:"date" json:"date"` + InheritNull string `avro:"inheritNull"` + ExplicitNamespace [12]byte `avro:"explicitNamespace"` + FullName FullNameData `avro:"fullName"` + ID int32 `avro:"id"` + BigID int64 `avro:"bigId"` + Temperature *float32 `avro:"temperature"` + Fraction *float64 `avro:"fraction"` + IsEmergency bool `avro:"is_emergency"` + RemoteIP *[]byte `avro:"remote_ip"` + NullableRemoteIPS *[][]byte `avro:"nullable_remote_ips"` + Person PersonData `avro:"person"` + DecimalField *big.Rat `avro:"decimalField"` + Decimal256Field *big.Rat `avro:"decimal256Field"` + UUIDField string `avro:"uuidField"` + FixedUUIDField [16]byte `avro:"fixedUuidField"` + TimeMillis time.Duration `avro:"timemillis"` + TimeMicros time.Duration `avro:"timemicros"` + TimestampMillis time.Time `avro:"timestampmillis"` + TimestampMicros time.Time `avro:"timestampmicros"` + Duration avro.Duration `avro:"duration"` + Date time.Time `avro:"date"` +} + +func (e Example) MarshalJSON() ([]byte, error) { + var remoteIP *ByteArray + if e.RemoteIP != nil { + v := ByteArray(*e.RemoteIP) + remoteIP = &v + } + var nullableRemoteIPs *[]ByteArray + if e.NullableRemoteIPS != nil { + arr := make([]ByteArray, len(*e.NullableRemoteIPS)) + for i, b := range *e.NullableRemoteIPS { + arr[i] = ByteArray(b) + } + nullableRemoteIPs = &arr + } + out := struct { + InheritNull string `json:"inheritNull"` + ExplicitNamespace FixedJSON `json:"explicitNamespace"` + FullName fullNameJSON `json:"fullName"` + ID int32 `json:"id"` + BigID int64 `json:"bigId"` + Temperature *float32 `json:"temperature"` + Fraction *float64 `json:"fraction"` + IsEmergency bool `json:"is_emergency"` + RemoteIP *ByteArray `json:"remote_ip"` + NullableRemoteIPS *[]ByteArray `json:"nullable_remote_ips"` + Person PersonData `json:"person"` + DecimalField DecimalJSON `json:"decimalField"` + Decimal256Field DecimalJSON `json:"decimal256Field"` + UUIDField string `json:"uuidField"` + FixedUUIDField FixedUUIDJSON `json:"fixedUuidField"` + TimeMillis TimeMillisJSON `json:"timemillis"` + TimeMicros TimeMicrosJSON `json:"timemicros"` + TimestampMillis TimestampJSON `json:"timestampmillis"` + TimestampMicros TimestampJSON `json:"timestampmicros"` + Duration DurationJSON `json:"duration"` + Date DateJSON `json:"date"` + }{ + InheritNull: e.InheritNull, + ExplicitNamespace: FixedJSON(e.ExplicitNamespace[:]), + FullName: fullNameJSON{InheritNamespace: e.FullName.InheritNamespace, Md5: FixedJSON(e.FullName.Md5[:])}, + ID: e.ID, + BigID: e.BigID, + Temperature: e.Temperature, + Fraction: e.Fraction, + IsEmergency: e.IsEmergency, + RemoteIP: remoteIP, + NullableRemoteIPS: nullableRemoteIPs, + Person: e.Person, + DecimalField: DecimalJSON{Rat: e.DecimalField}, + Decimal256Field: DecimalJSON{Rat: e.Decimal256Field}, + UUIDField: e.UUIDField, + FixedUUIDField: FixedUUIDJSON(e.FixedUUIDField), + TimeMillis: TimeMillisJSON(e.TimeMillis), + TimeMicros: TimeMicrosJSON(e.TimeMicros), + TimestampMillis: TimestampJSON(e.TimestampMillis), + TimestampMicros: TimestampJSON(e.TimestampMicros), + Duration: DurationJSON(e.Duration), + Date: DateJSON(e.Date), + } + return json.Marshal(out) } type FullNameData struct { - InheritNamespace string `avro:"inheritNamespace" json:"inheritNamespace"` - Md5 MD5 `avro:"md5" json:"md5"` + InheritNamespace string `avro:"inheritNamespace"` + Md5 [16]byte `avro:"md5"` } + +type fullNameJSON struct { + InheritNamespace string `json:"inheritNamespace"` + Md5 FixedJSON `json:"md5"` +} + type MapField map[string]int64 func (t MapField) MarshalJSON() ([]byte, error) { @@ -199,29 +260,43 @@ func TestdataDir() string { return "" } -func AllTypesAvroSchema() (avro.Schema, error) { +// AllTypesAvroSchema returns the raw JSON of the bundled `alltypes.avsc` +// testdata schema. +func AllTypesAvroSchema() (string, error) { sp := filepath.Join(TestdataDir(), SchemaFileName) avroSchemaBytes, err := os.ReadFile(sp) if err != nil { - return nil, err + return "", err } - return avro.ParseBytes(avroSchemaBytes) + return string(avroSchemaBytes), nil } func sampleData() Example { + now := time.Now().UTC() + // Truncate to micros so timestamp-millis/-micros round-trip exactly. + tsMillis := now.Truncate(time.Millisecond) + tsMicros := now.Truncate(time.Microsecond) + date := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) + + decimal := new(big.Rat).SetFrac(big.NewInt(9876), big.NewInt(100)) // 98.76 + decimal256, ok := new(big.Rat).SetString("12345678901234567890123456789012345678901234567890123456.78") + if !ok { + log.Fatal("bad decimal256 literal in sampleData") + } + return Example{ InheritNull: "a", - ExplicitNamespace: ExplicitNamespace{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + ExplicitNamespace: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, FullName: FullNameData{ InheritNamespace: "d", - Md5: MD5{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + Md5: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, }, - ID: 42, - BigID: 42000000000, - Temperature: func() *float32 { v := float32(36.6); return &v }(), - Fraction: func() *float64 { v := float64(0.75); return &v }(), - IsEmergency: true, - RemoteIP: func() *ByteArray { v := ByteArray{192, 168, 1, 1}; return &v }(), + ID: 42, + BigID: 42000000000, + Temperature: func() *float32 { v := float32(36.6); return &v }(), + Fraction: func() *float64 { v := float64(0.75); return &v }(), + IsEmergency: true, + RemoteIP: func() *[]byte { v := []byte{192, 168, 1, 1}; return &v }(), Person: PersonData{ Lastname: "Doe", Address: AddressUSRecord{ @@ -231,19 +306,16 @@ func sampleData() Example { Mapfield: MapField{"foo": 123}, ArrayField: []string{"one", "two"}, }, - DecimalField: DecimalType{0x00, 0x00, 0x00, 0x00, 0x00, 0x26, 0x94}, - Decimal256Field: DecimalType{ - 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, - 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, - 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x01, - }, + DecimalField: decimal, + Decimal256Field: decimal256, UUIDField: "123e4567-e89b-12d3-a456-426614174000", - TimeMillis: TimeMillis(50412345 * time.Millisecond), - TimeMicros: TimeMicros(50412345678 * time.Microsecond), - TimestampMillis: TimestampMillis(time.Now().UnixNano() / int64(time.Millisecond)), - TimestampMicros: TimestampMicros(time.Now().UnixNano() / int64(time.Microsecond)), - Duration: Duration{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - Date: Date(time.Now().Unix() / 86400), + FixedUUIDField: [16]byte{0x55, 0x0e, 0x84, 0x00, 0xe2, 0x9b, 0x41, 0xd4, 0xa7, 0x16, 0x44, 0x66, 0x55, 0x44, 0x00, 0x00}, + TimeMillis: 50412345 * time.Millisecond, + TimeMicros: 50412345678 * time.Microsecond, + TimestampMillis: tsMillis, + TimestampMicros: tsMicros, + Duration: avro.Duration{Months: 1, Days: 2, Milliseconds: 3}, + Date: date, } } @@ -254,11 +326,16 @@ func writeOCFSampleData(td string, data Example) string { log.Fatal(err) } defer ocfFile.Close() - schema, err := AllTypesAvroSchema() + schemaJSON, err := AllTypesAvroSchema() + if err != nil { + log.Fatal(err) + } + schema, err := avro.Parse(schemaJSON) if err != nil { log.Fatal(err) } - encoder, err := ocf.NewEncoder(schema.String(), ocfFile) + // Pass the original JSON so logical-type annotations survive in the OCF header. + encoder, err := ocf.NewWriter(ocfFile, schema, ocf.WithSchema(schemaJSON)) if err != nil { log.Fatal(err) } diff --git a/go.mod b/go.mod index 12b52ca34..e07c4980f 100644 --- a/go.mod +++ b/go.mod @@ -43,6 +43,7 @@ require ( github.com/substrait-io/substrait-go/v8 v8.1.0 github.com/substrait-io/substrait-protobuf/go v0.85.0 github.com/tidwall/sjson v1.2.5 + github.com/twmb/avro v1.5.0 github.com/zeebo/xxh3 v1.1.0 golang.org/x/exp v0.0.0-20260112195511-716be5621a96 golang.org/x/sync v0.20.0 @@ -67,7 +68,6 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/goccy/go-yaml v1.17.1 // indirect - github.com/golang/snappy v1.0.0 // indirect github.com/gookit/color v1.6.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/asmfmt v1.3.2 // indirect diff --git a/go.sum b/go.sum index 3732628f9..1bfdf54ca 100644 --- a/go.sum +++ b/go.sum @@ -54,8 +54,6 @@ github.com/goccy/go-yaml v1.17.1 h1:LI34wktB2xEE3ONG/2Ar54+/HJVBriAGJ55PHls4YuY= github.com/goccy/go-yaml v1.17.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= -github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/flatbuffers v25.12.19+incompatible h1:haMV2JRRJCe1998HeW/p0X9UaMTK6SDo0ffLn2+DbLs= github.com/google/flatbuffers v25.12.19+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -164,6 +162,8 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/twmb/avro v1.5.0 h1:9jmbvVQQBcyWHv/6zS+q5+nmASiR8/GwhKF/sU7u71c= +github.com/twmb/avro v1.5.0/go.mod h1:X0fT1dY2xcbV4YuCE4mYro+qljHl4kUF5uA/2z1rgSk= github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=