diff --git a/const.go b/const.go index 02321ef..2cece5d 100644 --- a/const.go +++ b/const.go @@ -38,24 +38,31 @@ const ( ) const ( - AMF3_UNDEFINED_MARKER = 0x00 - AMF3_NULL_MARKER = 0x01 - AMF3_FALSE_MARKER = 0x02 - AMF3_TRUE_MARKER = 0x03 - AMF3_INTEGER_MARKER = 0x04 - AMF3_DOUBLE_MARKER = 0x05 - AMF3_STRING_MARKER = 0x06 - AMF3_XMLDOC_MARKER = 0x07 - AMF3_DATE_MARKER = 0x08 - AMF3_ARRAY_MARKER = 0x09 - AMF3_OBJECT_MARKER = 0x0a - AMF3_XMLSTRING_MARKER = 0x0b - AMF3_BYTEARRAY_MARKER = 0x0c + AMF3_UNDEFINED_MARKER = 0x00 + AMF3_NULL_MARKER = 0x01 + AMF3_FALSE_MARKER = 0x02 + AMF3_TRUE_MARKER = 0x03 + AMF3_INTEGER_MARKER = 0x04 + AMF3_DOUBLE_MARKER = 0x05 + AMF3_STRING_MARKER = 0x06 + AMF3_XMLDOC_MARKER = 0x07 + AMF3_DATE_MARKER = 0x08 + AMF3_ARRAY_MARKER = 0x09 + AMF3_OBJECT_MARKER = 0x0a + AMF3_XMLSTRING_MARKER = 0x0b + AMF3_BYTEARRAY_MARKER = 0x0c + AMF3_VECTOR_INT_MARKER = 0x0d + AMF3_VECTOR_UINT_MARKER = 0x0e + AMF3_VECTOR_DOUBLE_MARKER = 0x0f + AMF3_VECTOR_OBJECT_MARKER = 0x10 + AMF3_DICTIONARY_MARKER = 0x11 ) type ExternalHandler func(*Decoder, io.Reader) (interface{}, error) type Decoder struct { + // If set to true, decoded doubles that are NaN or Inf will be stored as 0 + FilterNaNs bool refCache []interface{} stringRefs []string objectRefs []interface{} @@ -80,6 +87,11 @@ type Version uint8 type Array []interface{} type Object map[string]interface{} +type Dictionary map[interface{}]interface{} +type ObjectVector struct { + Type string + Data Array +} type TypedObject struct { Type string diff --git a/decoder_amf3.go b/decoder_amf3.go index 7c8260a..989406b 100644 --- a/decoder_amf3.go +++ b/decoder_amf3.go @@ -3,6 +3,7 @@ package amf import ( "encoding/binary" "io" + "math" "time" ) @@ -40,6 +41,16 @@ func (d *Decoder) DecodeAmf3(r io.Reader) (interface{}, error) { return d.DecodeAmf3Xml(r, false) case AMF3_BYTEARRAY_MARKER: return d.DecodeAmf3ByteArray(r, false) + case AMF3_VECTOR_INT_MARKER: + return d.DecodeAmf3VectorInt(r, false) + case AMF3_VECTOR_UINT_MARKER: + return d.DecodeAmf3VectorUint(r, false) + case AMF3_VECTOR_DOUBLE_MARKER: + return d.DecodeAmf3VectorDouble(r, false) + case AMF3_VECTOR_OBJECT_MARKER: + return d.DecodeAmf3VectorObject(r, false) + case AMF3_DICTIONARY_MARKER: + return d.DecodeAmf3Dictionary(r, false) } return nil, Error("decode amf3: unsupported type %d", marker) @@ -105,6 +116,11 @@ func (d *Decoder) DecodeAmf3Double(r io.Reader, decodeMarker bool) (result float if err != nil { return float64(0), Error("amf3 decode: unable to read double: %s", err) } + if d.FilterNaNs { + if math.IsNaN(result) || math.IsInf(result, 0) { + result = 0 + } + } return } @@ -124,7 +140,6 @@ func (d *Decoder) DecodeAmf3String(r io.Reader, decodeMarker bool) (result strin if err != nil { return "", Error("amf3 decode: unable to decode string reference and length: %s", err) } - if isRef { result = d.stringRefs[refVal] return @@ -459,6 +474,187 @@ func (d *Decoder) DecodeAmf3ByteArray(r io.Reader, decodeMarker bool) (result [] return } +func (d *Decoder) DecodeAmf3VectorInt(r io.Reader, decodeMarker bool) (result []int32, err error) { + if err = AssertMarker(r, decodeMarker, AMF3_VECTOR_INT_MARKER); err != nil { + return + } + + isRef, refVal, length, _, err := d.decodeVectorInfo(r) + if err != nil { + return nil, err + } + if isRef { + if result, ok := d.objectRefs[refVal].([]int32); ok { + return result, nil + } else { + return nil, Error("amf3 decode: unable to convert object ref to vector") + } + } + result = make([]int32, length) + for i := uint32(0); i < length; i++ { + value := int32(0) + err = binary.Read(r, binary.BigEndian, &value) + if err != nil { + return nil, err + } + result[i] = value + } + + d.objectRefs = append(d.objectRefs, result) + + return +} + +func (d *Decoder) DecodeAmf3VectorUint(r io.Reader, decodeMarker bool) (result []uint32, err error) { + if err = AssertMarker(r, decodeMarker, AMF3_VECTOR_UINT_MARKER); err != nil { + return + } + + isRef, refVal, length, _, err := d.decodeVectorInfo(r) + if err != nil { + return nil, err + } + if isRef { + if result, ok := d.objectRefs[refVal].([]uint32); ok { + return result, nil + } else { + return nil, Error("amf3 decode: unable to convert object ref to vector") + } + } + result = make([]uint32, length) + for i := uint32(0); i < length; i++ { + value := uint32(0) + err = binary.Read(r, binary.BigEndian, &value) + if err != nil { + return nil, err + } + result[i] = value + } + + d.objectRefs = append(d.objectRefs, result) + + return +} + +func (d *Decoder) DecodeAmf3VectorDouble(r io.Reader, decodeMarker bool) (result []float64, err error) { + if err = AssertMarker(r, decodeMarker, AMF3_VECTOR_DOUBLE_MARKER); err != nil { + return + } + + isRef, refVal, length, _, err := d.decodeVectorInfo(r) + if err != nil { + return nil, err + } + if isRef { + if result, ok := d.objectRefs[refVal].([]float64); ok { + return result, nil + } else { + return nil, Error("amf3 decode: unable to convert object ref to vector") + } + } + result = make([]float64, length) + for i := uint32(0); i < length; i++ { + result[i], err = d.DecodeAmf3Double(r, false) + if err != nil { + return nil, err + } + } + + d.objectRefs = append(d.objectRefs, result) + + return +} + +func (d *Decoder) DecodeAmf3VectorObject(r io.Reader, decodeMarker bool) (result *ObjectVector, err error) { + if err = AssertMarker(r, decodeMarker, AMF3_VECTOR_OBJECT_MARKER); err != nil { + return + } + + isRef, refVal, length, _, err := d.decodeVectorInfo(r) + if err != nil { + return nil, err + } + if isRef { + if result, ok := d.objectRefs[refVal].(*ObjectVector); ok { + return result, nil + } else { + return nil, Error("amf3 decode: unable to convert object ref to vector") + } + } + result = &ObjectVector{} + result.Data = make(Array, length) + typeName, err := d.DecodeAmf3String(r, false) + if err != nil { + return nil, Error("amf3 decode: unable to decode vector type name: %s", err) + } + result.Type = typeName + for i := uint32(0); i < length; i++ { + val, err := d.DecodeAmf3(r) + if err != nil { + return nil, err + } + result.Data[i] = val + } + + d.objectRefs = append(d.objectRefs, result) + + return +} + +func (d *Decoder) DecodeAmf3Dictionary(r io.Reader, decodeMarker bool) (result Dictionary, err error) { + if err = AssertMarker(r, decodeMarker, AMF3_DICTIONARY_MARKER); err != nil { + return + } + + isRef, refVal, err := d.decodeReferenceInt(r) + if err != nil { + return nil, Error("amf3 decode: unable to decode dictioanry reference and length: %s", err) + } + if isRef { + if result, ok := d.objectRefs[refVal].(Dictionary); ok { + return result, nil + } else { + return nil, Error("amf3 decode: unable to convert object ref to Dictionary") + } + } + _, err = ReadByte(r) + if err != nil { + return nil, Error("amf3 decode: unable to read weak-keys byte: %s", err) + } + result = make(Dictionary) + for i := uint32(0); i < refVal; i++ { + key, err := d.DecodeAmf3(r) + if err != nil { + return nil, Error("amf3 decode: unable to decode dictionary key: %s", err) + } + value, err := d.DecodeAmf3(r) + if err != nil { + return nil, Error("amf3 decode: unable to decode dictionary value: %s", err) + } + result[key] = value + } + + d.objectRefs = append(d.objectRefs, result) + + return +} + +func (d *Decoder) decodeVectorInfo(r io.Reader) (isRef bool, refVal uint32, length uint32, isFixed bool, err error) { + isRef, refVal, err = d.decodeReferenceInt(r) + if err != nil { + return isRef, refVal, 0, false, Error("amf3 decode: unable to decode vector reference and length: %s", err) + } + if isRef { + return isRef, refVal, 0, false, nil + } + fixedByte, err := ReadByte(r) + if err != nil { + return isRef, 0, refVal, false, Error("amf3 decode: unable to read vector fixed field: %s", err) + } + isFixed = fixedByte != 0 + return isRef, 0, refVal, isFixed, nil +} + func (d *Decoder) decodeU29(r io.Reader) (result uint32, err error) { var b byte diff --git a/util.go b/util.go index fd134a5..78c7c86 100644 --- a/util.go +++ b/util.go @@ -93,3 +93,13 @@ func AssertMarker(r io.Reader, checkMarker bool, m byte) error { return nil } + +func (d Dictionary) MarshalJSON() ([]byte, error) { + // This is inefficient and wrong, but works for most use cases + normalized := map[string]interface{}{} + for k, v := range d { + ks := fmt.Sprint(k) + normalized[ks] = v + } + return json.Marshal(normalized) +}