diff --git a/CHANGELOG.md b/CHANGELOG.md index 9018890dc6..09881225e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,7 @@ a huge list of updates and fixes. - SN no longer accepts objects with invocation or verification script bigger than 1KiB (#3887) - Every object with non-zero payload is now paid, not only regular ones (#3856) - Separate policer placement state from replica shortage (#3901) +- SN now forwards remote SN's response to the client as is (#3877) ### Removed - `node.persistent_sessions.path` config option from SN config (#3846) diff --git a/internal/protobuf/api.go b/internal/protobuf/api.go index 3165e41053..f8c7a277da 100644 --- a/internal/protobuf/api.go +++ b/internal/protobuf/api.go @@ -34,6 +34,13 @@ const ( 1 + 3 + object.MaxHeaderLen ) +// Common response field numbers. +const ( + FieldResponseBody = 1 + FieldResponseMetaHeader = 2 + FieldResponseVerificationHeader = 3 +) + // ParseAPIVersionField parses version.Version from the next field with known // number and type at given offset. Also returns field length. func ParseAPIVersionField(buf []byte, fNum protowire.Number, fTyp protowire.Type) (version.Version, int, error) { @@ -254,3 +261,21 @@ func ParseAttribute(buf []byte, fNum protowire.Number, fTyp protowire.Type) ([]b return k, v, nf + lnf, nil } + +// // VerifyObjectSplitInfo checks whether buf is a valid object split info +// // protobuf. +// // +// // Absense of any fields is ignored. Unknown fields are allowed and checked. +// // Repeating fields is allowed. +// func VerifyObjectSplitInfo(buf []byte) error { +// return verifyMessage(buf, objectSplitInfoMessageScheme) +// } +// +// // VerifyObjectHeaderWithOrder checks whether buf is a valid object header +// // protobuf. If so, direct field order flag is returned. +// // +// // Absense of any fields is ignored. Unknown fields are allowed and checked. +// // Repeating fields is allowed. +// func VerifyObjectHeaderWithOrder(buf []byte) (bool, error) { +// return verifyMessageWithOrder(buf, objectHeaderScheme, true, interceptors{}) +// } diff --git a/internal/protobuf/buffers.go b/internal/protobuf/buffers.go index 6ed1dc53a7..ffabbe315c 100644 --- a/internal/protobuf/buffers.go +++ b/internal/protobuf/buffers.go @@ -1,6 +1,8 @@ package protobuf import ( + "hash" + "io" "sync" "sync/atomic" @@ -95,3 +97,188 @@ func (x *MemBufferPool) Get() *MemBuffer { item.refs.Store(1) return item } + +// TODO: docs. +type BuffersSlice struct { + buffers mem.BufferSlice + curOff int + lastTo int +} + +func (x *BuffersSlice) Reset(buffers mem.BufferSlice) { + x.buffers = buffers + x.curOff = 0 + x.lastTo = buffers[len(buffers)-1].Len() +} + +// TODO: docs. +func NewBuffersSlice(buffers mem.BufferSlice) BuffersSlice { + if len(buffers) == 0 { + return BuffersSlice{} + } + + return BuffersSlice{ + buffers: buffers, + curOff: 0, + lastTo: buffers[len(buffers)-1].Len(), + } +} + +func (x BuffersSlice) IsEmpty() bool { + return x.Len() == 0 +} + +func (x *BuffersSlice) buffersSeq(yield func([]byte) bool) { + if len(x.buffers) == 0 { + return + } + + if len(x.buffers) == 1 { + yield(x.buffers[0].ReadOnlyData()[x.curOff:x.lastTo]) + return + } + + if !yield(x.buffers[0].ReadOnlyData()[x.curOff:]) { + return + } + + for i := range len(x.buffers) - 2 { + if !yield(x.buffers[i+1].ReadOnlyData()) { + return + } + } + + yield(x.buffers[len(x.buffers)-1].ReadOnlyData()[:x.lastTo]) +} + +func (x *BuffersSlice) bytesSeq(yield func(byte) bool) { + var buf []byte + for { + buf = x.buffers[0].ReadOnlyData() + if len(x.buffers) == 1 { + buf = buf[:x.lastTo] + } + + for x.curOff < len(buf) { + cnt := yield(buf[x.curOff]) + x.curOff++ + if !cnt { + if x.curOff == len(buf) { + x.buffers = x.buffers[1:] + x.curOff = 0 + } + return + } + } + + x.buffers = x.buffers[1:] + x.curOff = 0 + + if x.IsEmpty() { + return + } + } +} + +func (x *BuffersSlice) MoveNext(n int) (BuffersSlice, error) { + if len(x.buffers) == 0 { + if n > 0 { + return BuffersSlice{}, io.ErrUnexpectedEOF + } + return BuffersSlice{}, nil + } + + sub := *x + var ln int + + for i := 0; ; i++ { + if i == 0 { + if len(x.buffers) == 1 { + ln = x.lastTo - x.curOff + } else { + ln = x.buffers[0].Len() - x.curOff + } + } else if i < len(x.buffers)-1 { + ln = x.buffers[i].Len() + } else { + ln = x.lastTo + } + + if n > ln { + if i == len(x.buffers)-1 { + break + } + n -= ln + continue + } + + if n < ln { + x.buffers = x.buffers[i:] + if i == 0 { + x.curOff += n + } else { + x.curOff = 0 + } + } else { + x.buffers = x.buffers[i+1:] + x.curOff = 0 + } + + sub.buffers = sub.buffers[:i+1] + sub.lastTo = n + if i == 0 { + sub.lastTo += sub.curOff + } + return sub, nil + } + + return BuffersSlice{}, io.ErrUnexpectedEOF +} + +// TODO: docs. +func (x *BuffersSlice) ReadOnlyData() []byte { + if len(x.buffers) == 0 { + return nil + } + + if len(x.buffers) == 1 { + return x.buffers[0].ReadOnlyData()[x.curOff:x.lastTo] + } + + buf := make([]byte, x.Len()) + return buf[:x.CopyTo(buf)] +} + +func (x BuffersSlice) Len() int { + if len(x.buffers) == 0 { + return 0 + } + + var ln int + for buf := range x.buffersSeq { + ln += len(buf) + } + + return ln +} + +func (x BuffersSlice) HashTo(h hash.Hash) { + if len(x.buffers) == 0 { + return + } + + for buf := range x.buffersSeq { + h.Write(buf) + } +} + +func (x BuffersSlice) CopyTo(dst []byte) int { + if len(x.buffers) == 0 { + return 0 + } + var n int + for buf := range x.buffersSeq { + n += copy(dst[n:], buf) + } + return n +} diff --git a/internal/protobuf/codecs.go b/internal/protobuf/codecs.go index e55b7a4cb6..99dba8804e 100644 --- a/internal/protobuf/codecs.go +++ b/internal/protobuf/codecs.go @@ -12,15 +12,26 @@ type BufferedCodec struct{} // Marshal implements [encoding.CodecV2]. func (BufferedCodec) Marshal(msg any) (mem.BufferSlice, error) { - if bs, ok := msg.(mem.Buffer); ok { - return mem.BufferSlice{bs}, nil + switch v := msg.(type) { + case mem.BufferSlice: + return v, nil + case mem.Buffer: + return mem.BufferSlice{v}, nil + default: + return encoding.GetCodecV2(proto.Name).Marshal(msg) } - return encoding.GetCodecV2(proto.Name).Marshal(msg) } // Unmarshal implements [encoding.CodecV2]. func (BufferedCodec) Unmarshal(data mem.BufferSlice, msg any) error { - return encoding.GetCodecV2(proto.Name).Unmarshal(data, msg) + switch v := msg.(type) { + case *mem.BufferSlice: + data.Ref() + *v = data + return nil + default: + return encoding.GetCodecV2(proto.Name).Unmarshal(data, msg) + } } // Name implements [encoding.CodecV2]. diff --git a/internal/protobuf/errors.go b/internal/protobuf/errors.go index 0c2341abf2..553a94a92d 100644 --- a/internal/protobuf/errors.go +++ b/internal/protobuf/errors.go @@ -23,3 +23,9 @@ func NewRepeatedFieldError(n protowire.Number) error { func NewUnsupportedFieldError(n protowire.Number, t protowire.Type) error { return fmt.Errorf("unsupported field #%d of type %v", n, t) } + +// NewInvalidUTF8Error returns common error for string field #n containing +// invalid UTF-8. +func NewInvalidUTF8Error(n protowire.Number) error { + return fmt.Errorf("string field #%d contains invalid UTF-8", n) +} diff --git a/internal/protobuf/parsers.go b/internal/protobuf/parsers.go index b11f6c57c2..1435f058ad 100644 --- a/internal/protobuf/parsers.go +++ b/internal/protobuf/parsers.go @@ -1,13 +1,19 @@ package protobuf import ( + "encoding/binary" "errors" "fmt" + "io" "math" + "unicode/utf8" "google.golang.org/protobuf/encoding/protowire" ) +// same as in [protowire] package. +var errVarintOverflow = errors.New("variable length integer overflow") + // ParseVarint parses varint-encoded uint64 from buf. Returns parsed value and // number of bytes read. func ParseVarint(buf []byte) (uint64, int, error) { @@ -20,39 +26,86 @@ func ParseVarint(buf []byte) (uint64, int, error) { return u, n, nil } -// ParseTag parses field tag from buf. Returns field number, type and number of -// bytes read. -func ParseTag(buf []byte) (protowire.Number, protowire.Type, int, error) { - u, n, err := ParseVarint(buf) +// [ParseVarint] analogue for scanning. +func (x *BuffersSlice) parseVarint() (uint64, error) { + var u uint64 + var s uint + var i int + for b := range x.bytesSeq { + if i == binary.MaxVarintLen64 { + // Catch byte reads past MaxVarintLen64. + // See issue https://golang.org/issues/41185 + return 0, errVarintOverflow // TODO: have something better + } + if b < 0x80 { + if i == binary.MaxVarintLen64-1 && b > 1 { + return 0, errVarintOverflow // overflow + } + return u | uint64(b)< math.MaxInt { + return 0, fmt.Errorf("value %d overflows int", u) + } + + return int(u), nil +} + // ParseLEN parses varint-encoded length from buf and check its overflow. Returns // parsed value and number of bytes read. func ParseLEN(buf []byte) (int, int, error) { - ln, n, err := ParseVarint(buf) + u, n, err := ParseVarint(buf) + ln, err := _checkLEN(u, err) if err != nil { - return 0, 0, fmt.Errorf("parse varint: %w", err) - } - - if ln > math.MaxInt { - return 0, 0, fmt.Errorf("value %d overflows int", ln) + return 0, 0, err } - if rem := len(buf) - n; int(ln) > rem { - return 0, 0, newTruncatedBufferError(int(ln), rem) + if rem := len(buf) - n; ln > rem { + return 0, 0, newTruncatedBufferError(ln, rem) } - return int(ln), n, nil + return ln, n, nil } // ParseLENField parses length of LEN field with preread number and type from @@ -73,6 +126,26 @@ func ParseLENField(buf []byte, num protowire.Number, typ protowire.Type) (int, i return ln, n, nil } +// [ParseLENField] analogue for scanning. +func (x *BuffersSlice) ParseLENField(num protowire.Number, typ protowire.Type) (BuffersSlice, error) { + err := checkFieldType(num, protowire.BytesType, typ) + if err != nil { + return BuffersSlice{}, err + } + + ln, err := _checkLEN(x.parseVarint()) + if err != nil { + return BuffersSlice{}, wrapParseFieldError(num, protowire.BytesType, err) + } + + sub, err := x.MoveNext(ln) + if err != nil { + return BuffersSlice{}, wrapParseFieldError(num, protowire.BytesType, err) + } + + return sub, nil +} + // ParseLENFieldBounds parses boundaries of LEN field with preread tag length, // number and type at given offset from buf. // @@ -91,19 +164,65 @@ func ParseLENFieldBounds(buf []byte, off int, tagLn int, num protowire.Number, t return f, nil } +// ParseStringField parses string field with preread number and type from buf. +// Returns parsed value and number of bytes read. +// +// If there is an error, its text contains num and typ. +func ParseStringField(buf []byte, num protowire.Number, typ protowire.Type) (int, int, error) { + ln, n, err := ParseLENField(buf, num, typ) + if err != nil { + return 0, 0, err + } + + if !utf8.Valid(buf[n:][:ln]) { + return 0, 0, NewInvalidUTF8Error(num) + } + + return ln, n, nil +} + +// [ParseStringField] analogue for scanning. +func (x *BuffersSlice) ParseStringField(num protowire.Number, typ protowire.Type) ([]byte, error) { + sub, err := x.ParseLENField(num, typ) + if err != nil { + return nil, err + } + + s := sub.ReadOnlyData() + if !utf8.Valid(s) { + return nil, NewInvalidUTF8Error(num) + } + + return s, nil +} + +func _checkEnum[T ~int32](u uint64, err error) (T, error) { + if err != nil { + return 0, fmt.Errorf("parse varint: %w", err) + } + + if u > math.MaxInt32 { + return 0, fmt.Errorf("value %d overflows int32", u) + } + + return T(u), nil +} + // ParseEnum parses enum value from buf. Returns parsed value and number of // bytes read. func ParseEnum[T ~int32](buf []byte) (T, int, error) { u, n, err := ParseVarint(buf) + t, err := _checkEnum[T](u, err) if err != nil { - return 0, 0, fmt.Errorf("parse varint: %w", err) + return 0, 0, err } - if u > math.MaxInt32 { - return 0, 0, fmt.Errorf("value %d overflows int32", u) - } + return t, n, nil +} - return T(u), n, nil +// [ParseEnum] analogue for scanning. +func (x *BuffersSlice) parseEnum() (int32, error) { + return _checkEnum[int32](x.parseVarint()) } // ParseEnumField parses value of enum field with preread number and type from @@ -124,19 +243,43 @@ func ParseEnumField[T ~int32](buf []byte, num protowire.Number, typ protowire.Ty return e, n, nil } +// [ParseEnumField] analogue for scanning. +func (x *BuffersSlice) ParseEnumField(num protowire.Number, typ protowire.Type) (int32, error) { + err := checkFieldType(num, protowire.VarintType, typ) + if err != nil { + return 0, err + } + + e, err := x.parseEnum() + if err != nil { + return 0, wrapParseFieldError(num, protowire.VarintType, err) + } + + return e, nil +} + +func _checkUint32(u uint64, err error) (uint32, error) { + if err != nil { + return 0, fmt.Errorf("parse varint: %w", err) + } + + if u > math.MaxUint32 { + return 0, fmt.Errorf("value %d overflows uint32", u) + } + + return uint32(u), nil +} + // ParseUint32 parses varint-encoded uint32 from buf. Returns parsed value and // number of bytes read. func ParseUint32(buf []byte) (uint32, int, error) { u, n, err := ParseVarint(buf) + u32, err := _checkUint32(u, err) if err != nil { - return 0, 0, fmt.Errorf("parse varint: %w", err) - } - - if u > math.MaxUint32 { - return 0, 0, fmt.Errorf("value %d overflows uint32", u) + return 0, 0, err } - return uint32(u), n, nil + return u32, n, nil } // ParseUint32Field parses value of uint32 field from buf. Returns parsed value @@ -157,6 +300,21 @@ func ParseUint32Field(buf []byte, num protowire.Number, typ protowire.Type) (uin return u, n, nil } +// [ParseUint32Field] analogue for scanning. +func (x *BuffersSlice) ParseUint32Field(num protowire.Number, typ protowire.Type) (uint32, error) { + err := checkFieldType(num, protowire.VarintType, typ) + if err != nil { + return 0, err + } + + u, err := _checkUint32(x.parseVarint()) + if err != nil { + return 0, wrapParseFieldError(num, protowire.VarintType, err) + } + + return u, nil +} + // ParseUint64Field parses value of uint64 field with preread number and type // from buf. Returns value and its length. // @@ -175,6 +333,67 @@ func ParseUint64Field(buf []byte, num protowire.Number, typ protowire.Type) (uin return u, n, nil } +// [ParseUint64Field] analogue for scanning. +func (x *BuffersSlice) ParseUint64Field(num protowire.Number, typ protowire.Type) (uint64, error) { + err := checkFieldType(num, protowire.VarintType, typ) + if err != nil { + return 0, err + } + + u, err := x.parseVarint() + if err != nil { + return 0, wrapParseFieldError(num, protowire.VarintType, fmt.Errorf("parse varint: %w", err)) + } + + return u, nil +} + +func _checkBool(u uint64, err error) (bool, error) { + if err != nil { + return false, err + } + + if u > 1 { + return false, fmt.Errorf("unexpected varint value for bool field %d", u) + } + + return u == 1, nil +} + +// ParseBoolField parses value of bool field with preread number and type from +// buf. Returns parsed value. +// +// If there is an error, its text contains num and typ. +func ParseBoolField(buf []byte, num protowire.Number, typ protowire.Type) (bool, error) { + err := checkFieldType(num, protowire.VarintType, typ) + if err != nil { + return false, err + } + + u, _, err := ParseVarint(buf) + b, err := _checkBool(u, err) + if err != nil { + return false, wrapParseFieldError(num, protowire.VarintType, err) + } + + return b, nil +} + +// [ParseBoolField] analogue for scanning. +func (x *BuffersSlice) ParseBoolField(num protowire.Number, typ protowire.Type) (bool, error) { + err := checkFieldType(num, protowire.VarintType, typ) + if err != nil { + return false, err + } + + b, err := _checkBool(x.parseVarint()) + if err != nil { + return false, wrapParseFieldError(num, protowire.VarintType, err) + } + + return b, nil +} + // SkipField parses length of skipped field with preread number and type from // buf and checks its overflow. Returns number of bytes read. // @@ -211,3 +430,45 @@ func SkipField(buf []byte, num protowire.Number, typ protowire.Type) (int, error return 0, wrapParseFieldError(num, typ, err) } + +// SkipRepeatedVarint parses repeated enum field with preread number and type +// from buf, checks its overflow and verifies each element is a non-negative +// int32. Returns number of bytes read. +// +// If verification of each element is not needed, use [SkipField]. +// +// If there is an error, its text contains num and typ. +func SkipRepeatedEnum(buf []byte, num protowire.Number, typ protowire.Type) (int, error) { + ln, n, err := ParseLENField(buf, num, typ) + if err != nil { + return n, err + } + + buf = buf[n:][:ln] + + var off int + for len(buf[off:]) > 0 { + if _, n, err = ParseEnum[int32](buf[off:]); err != nil { + return 0, fmt.Errorf("parse next element: %w", err) + } + off += n + } + + return n + ln, nil +} + +// [SkipRepeatedEnum] analogue for scanning. +func (x *BuffersSlice) SkipRepeatedEnum(num protowire.Number, typ protowire.Type) error { + sub, err := x.ParseLENField(num, typ) + if err != nil { + return err + } + + for !sub.IsEmpty() { + if _, err = sub.parseEnum(); err != nil { + return fmt.Errorf("parse next element: %w", err) + } + } + + return nil +} diff --git a/internal/protobuf/parsers_test.go b/internal/protobuf/parsers_test.go index 9af32cc41b..853f8c0aa8 100644 --- a/internal/protobuf/parsers_test.go +++ b/internal/protobuf/parsers_test.go @@ -17,6 +17,8 @@ var ( uint32OverflowVarint = []byte{128, 128, 128, 128, 16} // 4294967296 uint64OverflowVarint = []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 2} + + invalidUTF8 = []byte("\xF4\x90\x80\x80") ) var varintTestcases = []struct { @@ -253,6 +255,61 @@ func TestParseLENFieldBounds(t *testing.T) { } } +func TestParseStringField(t *testing.T) { + t.Run("wrong type", func(t *testing.T) { + _, _, err := iprotobuf.ParseStringField([]byte{}, 42, protowire.VarintType) + require.EqualError(t, err, "wrong type of field #42: expected LEN, got VARINT") + }) + + t.Run("invalid len", func(t *testing.T) { + for _, tc := range invalidVarintTestcases { + t.Run(tc.name, func(t *testing.T) { + _, _, err := iprotobuf.ParseStringField(tc.buf, 42, protowire.BytesType) + require.ErrorContains(t, err, "parse field #42 of LEN type: parse varint") + require.ErrorContains(t, err, tc.err) + }) + } + }) + + t.Run("buffer overflow", func(t *testing.T) { + t.Run("int overflow", func(t *testing.T) { + buf := make([]byte, binary.MaxVarintLen64) + n := binary.PutUvarint(buf, uint64(math.MaxInt+1)) + + _, _, err := iprotobuf.ParseStringField(buf[:n], 42, protowire.BytesType) + require.EqualError(t, err, "parse field #42 of LEN type: value "+strconv.FormatUint(math.MaxInt+1, 10)+" overflows int") + }) + + for i, tc := range varintTestcases { + if tc.val == 0 || tc.val > 1<<20 { + continue + } + + buf := slices.Concat(tc.buf, make([]byte, tc.val-1)) + _, _, err := iprotobuf.ParseStringField(buf, 42, protowire.BytesType) + require.EqualError(t, err, "parse field #42 of LEN type: unexpected EOF: need "+strconv.FormatUint(tc.val, 10)+" bytes, left "+strconv.FormatUint(tc.val-1, 10)+" in buffer", i) + } + }) + + t.Run("invalid UTF-8", func(t *testing.T) { + buf := protowire.AppendBytes(nil, invalidUTF8) + _, _, err := iprotobuf.ParseStringField(buf, 42, protowire.BytesType) + require.EqualError(t, err, "string field #42 contains invalid UTF-8") + }) + + for i, tc := range varintTestcases { + if tc.val > 1<<20 { + continue + } + + buf := make([]byte, tc.val) + u, n, err := iprotobuf.ParseStringField(slices.Concat(tc.buf, buf), 42, protowire.BytesType) + require.NoError(t, err, i) + require.EqualValues(t, tc.val, u, i) + require.EqualValues(t, tc.ln, n, i) + } +} + func TestParseEnum(t *testing.T) { t.Run("invalid varint", func(t *testing.T) { for _, tc := range invalidVarintTestcases { diff --git a/internal/protobuf/protoscan/messages.go b/internal/protobuf/protoscan/messages.go new file mode 100644 index 0000000000..58deabed94 --- /dev/null +++ b/internal/protobuf/protoscan/messages.go @@ -0,0 +1,362 @@ +package protoscan + +import ( + "fmt" + + "github.com/nspcc-dev/neofs-node/internal/protobuf" + protoobject "github.com/nspcc-dev/neofs-sdk-go/proto/object" + protorefs "github.com/nspcc-dev/neofs-sdk-go/proto/refs" + protosession "github.com/nspcc-dev/neofs-sdk-go/proto/session" + protostatus "github.com/nspcc-dev/neofs-sdk-go/proto/status" + "google.golang.org/protobuf/encoding/protowire" +) + +type simpleField struct { + num protowire.Number + namedField +} + +func newSimpleField(num protowire.Number, name string, typ fieldType) simpleField { + return simpleField{num: num, namedField: newNamedField(name, typ)} +} + +func newSimpleFieldsScheme(fs ...simpleField) MessageScheme { + m := make(map[protowire.Number]namedField, len(fs)) + + for _, f := range fs { + if _, ok := m[f.num]; ok { + panic(fmt.Sprintf("duplicated field with number %d", f.num)) + } + + m[f.num] = f.namedField + } + + return MessageScheme{fields: m} +} + +// Simple messages. +var ( + versionScheme = newSimpleFieldsScheme( + newSimpleField(protorefs.FieldVersionMajor, "major", fieldTypeUint32), + newSimpleField(protorefs.FieldVersionMinor, "minor", fieldTypeUint32), + ) + ChecksumScheme = newSimpleFieldsScheme( + newSimpleField(protorefs.FieldChecksumType, "type", fieldTypeEnum), + newSimpleField(protorefs.FieldChecksumValue, "value", fieldTypeBytes), + ) + signatureScheme = newSimpleFieldsScheme( + newSimpleField(protorefs.FieldSignatureKey, "key", fieldTypeBytes), + newSimpleField(protorefs.FieldSignatureValue, "value", fieldTypeBytes), + newSimpleField(protorefs.FieldSignatureScheme, "scheme", fieldTypeEnum), + ) + attributeScheme = newSimpleFieldsScheme( + newSimpleField(1, "key", fieldTypeString), + newSimpleField(2, "value", fieldTypeString), + ) + tokenLifetimeScheme = newSimpleFieldsScheme( + newSimpleField(1, "exp", fieldTypeUint64), + newSimpleField(2, "nbf", fieldTypeUint64), + newSimpleField(3, "iat", fieldTypeUint64), + ) +) + +func newIDScheme(kind binaryFieldKind) MessageScheme { + return MessageScheme{ + fields: map[protowire.Number]namedField{1: {name: "value", typ: fieldTypeBytes}}, + binaryKindFields: map[protowire.Number]binaryFieldKind{1: kind}, + } +} + +// IDs. +var ( + containerIDScheme = newIDScheme(binaryFieldSHA256) + ObjectIDScheme = newIDScheme(binaryFieldSHA256) + userIDScheme = newIDScheme(binaryFieldN3Address) +) + +// Session token. +var ( + sessionSubjectScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protosession.FieldTargetOwnerID: newNamedField("owner", fieldTypeNestedMessage), + protosession.FieldTargetNNSName: newNamedField("NNS name", fieldTypeString), + }, + nestedFields: map[protowire.Number]MessageScheme{ + protosession.FieldTargetOwnerID: userIDScheme, + }, + } + sessionContextScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protosession.FieldSessionContextV2Container: newNamedField("container", fieldTypeNestedMessage), + protosession.FieldSessionContextV2Verbs: newNamedField("verbs", fieldTypeRepeatedEnum), + }, + nestedFields: map[protowire.Number]MessageScheme{ + protosession.FieldSessionContextV2Container: containerIDScheme, + }, + } + sessionTokenBodyScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protosession.FieldSessionTokenV2BodyVersion: newNamedField("version", fieldTypeUint32), + protosession.FieldSessionTokenV2BodyAppdata: newNamedField("appdata", fieldTypeBytes), + protosession.FieldSessionTokenV2BodyIssuer: newNamedField("issuer", fieldTypeNestedMessage), + protosession.FieldSessionTokenV2BodySubjects: newNamedField("subject", fieldTypeNestedMessage), + protosession.FieldSessionTokenV2BodyLifetime: newNamedField("lifetime", fieldTypeNestedMessage), + protosession.FieldSessionTokenV2BodyContexts: newNamedField("contexts", fieldTypeNestedMessage), + protosession.FieldSessionTokenV2BodyFinal: newNamedField("final", fieldTypeBool), + }, + nestedFields: map[protowire.Number]MessageScheme{ + protosession.FieldSessionTokenV2BodyIssuer: userIDScheme, + protosession.FieldSessionTokenV2BodySubjects: sessionSubjectScheme, + protosession.FieldSessionTokenV2BodyLifetime: tokenLifetimeScheme, + protosession.FieldSessionTokenV2BodyContexts: sessionContextScheme, + }, + } + sessionTokenScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protosession.FieldSessionTokenV2Body: newNamedField("body", fieldTypeNestedMessage), + protosession.FieldSessionTokenV2Signature: newNamedField("signature", fieldTypeNestedMessage), + protosession.FieldSessionTokenV2Origin: newNamedField("origin", fieldTypeNestedMessage), + }, + nestedFields: map[protowire.Number]MessageScheme{ + protosession.FieldSessionTokenV2Body: sessionTokenBodyScheme, + protosession.FieldSessionTokenV2Signature: signatureScheme, + }, + recursionFields: []protowire.Number{protosession.FieldSessionTokenV2Origin}, + } +) + +// Session V1. +var ( + sessionV1ObjectTargetScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protosession.FieldObjectSessionContextTargetContainer: newNamedField("container", fieldTypeNestedMessage), + protosession.FieldObjectSessionContextTargetObjects: newNamedField("objects", fieldTypeNestedMessage), + }, + nestedFields: map[protowire.Number]MessageScheme{ + protosession.FieldObjectSessionContextTargetContainer: containerIDScheme, + protosession.FieldObjectSessionContextTargetObjects: ObjectIDScheme, + }, + } + sessionV1ObjectContextScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protosession.FieldObjectSessionContextVerb: newNamedField("verb", fieldTypeEnum), + protosession.FieldObjectSessionContextTarget: newNamedField("target", fieldTypeNestedMessage), + }, + nestedFields: map[protowire.Number]MessageScheme{ + protosession.FieldObjectSessionContextTargetObjects: sessionV1ObjectTargetScheme, + }, + } + sessionV1ContainerContextScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protosession.FieldContainerSessionContextVerb: newNamedField("verb", fieldTypeEnum), + protosession.FieldContainerSessionContextWildcard: newNamedField("wildcard", fieldTypeBool), + protosession.FieldContainerSessionContextContainerID: newNamedField("container", fieldTypeNestedMessage), + }, + nestedFields: map[protowire.Number]MessageScheme{ + protosession.FieldContainerSessionContextContainerID: containerIDScheme, + }, + } + sessionV1TokenBodyScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protosession.FieldSessionTokenBodyID: newNamedField("id", fieldTypeBytes), + protosession.FieldSessionTokenBodyOwnerID: newNamedField("owner", fieldTypeNestedMessage), + protosession.FieldSessionTokenBodyLifetime: newNamedField("lifetime", fieldTypeNestedMessage), + protosession.FieldSessionTokenBodySessionKey: newNamedField("session key", fieldTypeBytes), + protosession.FieldSessionTokenBodyObject: newNamedField("object context", fieldTypeNestedMessage), + protosession.FieldSessionTokenBodyContainer: newNamedField("container context", fieldTypeNestedMessage), + }, + binaryKindFields: map[protowire.Number]binaryFieldKind{ + protosession.FieldSessionTokenBodyID: binaryFieldKindUUIDV4, + }, + nestedFields: map[protowire.Number]MessageScheme{ + protosession.FieldSessionTokenBodyOwnerID: userIDScheme, + protosession.FieldSessionTokenBodyLifetime: tokenLifetimeScheme, + protosession.FieldSessionTokenBodyObject: sessionV1ObjectContextScheme, + protosession.FieldSessionTokenBodyContainer: sessionV1ContainerContextScheme, + }, + } + sessionV1TokenScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protosession.FieldSessionTokenBody: newNamedField("body", fieldTypeNestedMessage), + protosession.FieldSessionTokenSignature: newNamedField("signature", fieldTypeNestedMessage), + }, + nestedFields: map[protowire.Number]MessageScheme{ + protosession.FieldSessionTokenBody: sessionV1TokenBodyScheme, + protosession.FieldSessionTokenSignature: signatureScheme, + }, + } +) + +// Object. +var ( + objectSplitHeaderScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protoobject.FieldHeaderSplitParent: newNamedField("parent", fieldTypeNestedMessage), + protoobject.FieldHeaderSplitPrevious: newNamedField("previous", fieldTypeNestedMessage), + protoobject.FieldHeaderSplitParentSignature: newNamedField("parent signature", fieldTypeNestedMessage), + protoobject.FieldHeaderSplitParentHeader: newNamedField("parent header", fieldTypeNestedMessage), + protoobject.FieldHeaderSplitChildren: newNamedField("children", fieldTypeNestedMessage), + protoobject.FieldHeaderSplitSplitID: newNamedField("split ID", fieldTypeBytes), + protoobject.FieldHeaderSplitFirst: newNamedField("first", fieldTypeNestedMessage), + }, + binaryKindFields: map[protowire.Number]binaryFieldKind{ + protoobject.FieldHeaderSplitSplitID: binaryFieldKindUUIDV4, + }, + nestedFields: map[protowire.Number]MessageScheme{ + protoobject.FieldHeaderSplitParent: ObjectIDScheme, + protoobject.FieldHeaderSplitPrevious: ObjectIDScheme, + protoobject.FieldHeaderSplitParentSignature: signatureScheme, + protoobject.FieldHeaderSplitChildren: ObjectIDScheme, + protoobject.FieldHeaderSplitFirst: ObjectIDScheme, + }, + nestedAliases: map[protowire.Number]schemeAlias{ + protoobject.FieldHeaderSplitParentHeader: schemeAliasObjectHeader, + }, + } + ObjectHeaderScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protoobject.FieldHeaderVersion: newNamedField("version", fieldTypeNestedMessage), + protoobject.FieldHeaderContainerID: newNamedField("container ID", fieldTypeNestedMessage), + protoobject.FieldHeaderOwnerID: newNamedField("owner ID", fieldTypeNestedMessage), + protoobject.FieldHeaderCreationEpoch: newNamedField("creation epoch", fieldTypeUint64), + protoobject.FieldHeaderPayloadLength: newNamedField("payload length", fieldTypeUint64), + protoobject.FieldHeaderPayloadHash: newNamedField("payload hash", fieldTypeNestedMessage), + protoobject.FieldHeaderObjectType: newNamedField("type", fieldTypeEnum), + protoobject.FieldHeaderHomomorphicHash: newNamedField("homomorphic hash", fieldTypeNestedMessage), + protoobject.FieldHeaderSessionToken: newNamedField("session V1 token", fieldTypeNestedMessage), + protoobject.FieldHeaderAttributes: newNamedField("attribute", fieldTypeNestedMessage), + protoobject.FieldHeaderSplit: newNamedField("split", fieldTypeNestedMessage), + protoobject.FieldHeaderSessionV2: newNamedField("session token", fieldTypeNestedMessage), + }, + nestedFields: map[protowire.Number]MessageScheme{ + protoobject.FieldHeaderVersion: versionScheme, + protoobject.FieldHeaderContainerID: containerIDScheme, + protoobject.FieldHeaderOwnerID: userIDScheme, + protoobject.FieldHeaderPayloadHash: ChecksumScheme, + protoobject.FieldHeaderHomomorphicHash: ChecksumScheme, + protoobject.FieldHeaderAttributes: attributeScheme, + protoobject.FieldHeaderSessionToken: sessionV1TokenScheme, + protoobject.FieldHeaderSplit: objectSplitHeaderScheme, + protoobject.FieldHeaderSessionV2: sessionTokenScheme, + }, + } + ObjectHeaderWithSignatureScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protoobject.FieldHeaderWithSignatureHeader: newNamedField("header", fieldTypeNestedMessage), + protoobject.FieldHeaderWithSignatureSignature: newNamedField("signature", fieldTypeNestedMessage), + }, + nestedFields: map[protowire.Number]MessageScheme{ + protoobject.FieldHeaderWithSignatureHeader: ObjectHeaderScheme, + protoobject.FieldHeaderWithSignatureSignature: signatureScheme, + }, + } + ObjectSplitInfoScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protoobject.FieldSplitInfoSplitID: newNamedField("split ID", fieldTypeBytes), + protoobject.FieldSplitInfoLastPart: newNamedField("last part", fieldTypeNestedMessage), + protoobject.FieldSplitInfoLink: newNamedField("link", fieldTypeNestedMessage), + protoobject.FieldSplitInfoFirstPart: newNamedField("first part", fieldTypeNestedMessage), + }, + binaryKindFields: map[protowire.Number]binaryFieldKind{ + protoobject.FieldSplitInfoSplitID: binaryFieldKindUUIDV4, + }, + nestedFields: map[protowire.Number]MessageScheme{ + protoobject.FieldSplitInfoLastPart: ObjectIDScheme, + protoobject.FieldSplitInfoLink: ObjectIDScheme, + protoobject.FieldSplitInfoFirstPart: ObjectIDScheme, + }, + } +) + +// Responses. +var ( + responseStatusDetailScheme = newSimpleFieldsScheme( + newSimpleField(protostatus.FieldStatusDetailID, "ID", fieldTypeUint32), + newSimpleField(protostatus.FieldStatusDetailValue, "value", fieldTypeBytes), + ) + ResponseStatusScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protostatus.FieldStatusCode: newNamedField("code", fieldTypeUint32), + protostatus.FieldStatusMessage: newNamedField("message", fieldTypeString), + protostatus.FieldStatusDetails: newNamedField("details", fieldTypeNestedMessage), + }, + nestedFields: map[protowire.Number]MessageScheme{ + protostatus.FieldStatusDetails: responseStatusDetailScheme, + }, + } + ResponseMetaHeaderScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protosession.FieldResponseMetaHeaderVersion: newNamedField("version", fieldTypeNestedMessage), + protosession.FieldResponseMetaHeaderEpoch: newNamedField("epoch", fieldTypeUint64), + protosession.FieldResponseMetaHeaderTTL: newNamedField("TTL", fieldTypeUint32), + protosession.FieldResponseMetaHeaderXHeaders: newNamedField("X-headers", fieldTypeNestedMessage), + protosession.FieldResponseMetaHeaderOrigin: newNamedField("origin", fieldTypeNestedMessage), + protosession.FieldResponseMetaHeaderStatus: newNamedField("status", fieldTypeNestedMessage), + }, + nestedFields: map[protowire.Number]MessageScheme{ + protosession.FieldResponseMetaHeaderVersion: versionScheme, + protosession.FieldResponseMetaHeaderXHeaders: attributeScheme, + protosession.FieldResponseMetaHeaderStatus: ResponseStatusScheme, + }, + recursionFields: []protowire.Number{protosession.FieldResponseMetaHeaderOrigin}, + } + ResponseVerificationHeaderScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protosession.FieldResponseVerificationHeaderBodySignature: newNamedField("body signature", fieldTypeNestedMessage), + protosession.FieldResponseVerificationHeaderMetaSignature: newNamedField("meta signature", fieldTypeNestedMessage), + protosession.FieldResponseVerificationHeaderOriginSignature: newNamedField("origin signature", fieldTypeNestedMessage), + protosession.FieldResponseVerificationHeaderOrigin: newNamedField("origin", fieldTypeNestedMessage), + }, + nestedFields: map[protowire.Number]MessageScheme{ + protosession.FieldResponseVerificationHeaderBodySignature: signatureScheme, + protosession.FieldResponseVerificationHeaderMetaSignature: signatureScheme, + protosession.FieldResponseVerificationHeaderOriginSignature: signatureScheme, + }, + recursionFields: []protowire.Number{protosession.FieldResponseVerificationHeaderOrigin}, + } + ResponseScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protobuf.FieldResponseBody: newNamedField("body", fieldTypeNestedMessage), + protobuf.FieldResponseMetaHeader: newNamedField("meta header", fieldTypeNestedMessage), + protobuf.FieldResponseVerificationHeader: newNamedField("verification header", fieldTypeNestedMessage), + }, + nestedFields: map[protowire.Number]MessageScheme{ + protobuf.FieldResponseMetaHeader: ResponseMetaHeaderScheme, + protobuf.FieldResponseVerificationHeader: ResponseVerificationHeaderScheme, + }, + } + ObjectHeadResponseBodyScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protoobject.FieldHeadResponseBodyHeader: newNamedField("header", fieldTypeNestedMessage), + protoobject.FieldHeadResponseBodyShortHeader: newNamedField("short header", fieldTypeNestedMessage), + protoobject.FieldHeadResponseBodySplitInfo: newNamedField("split info", fieldTypeNestedMessage), + }, + nestedFields: map[protowire.Number]MessageScheme{ + protoobject.FieldHeadResponseBodyHeader: ObjectHeaderWithSignatureScheme, + protoobject.FieldHeadResponseBodySplitInfo: ObjectSplitInfoScheme, + }, + } + ObjectGetResponseInitScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protoobject.FieldGetResponseBodyInitObjectID: newNamedField("ID", fieldTypeNestedMessage), + protoobject.FieldGetResponseBodyInitSignature: newNamedField("signature", fieldTypeNestedMessage), + protoobject.FieldGetResponseBodyInitHeader: newNamedField("header", fieldTypeNestedMessage), + }, + nestedFields: map[protowire.Number]MessageScheme{ + protoobject.FieldGetResponseBodyInitObjectID: ObjectIDScheme, + protoobject.FieldGetResponseBodyInitSignature: signatureScheme, + protoobject.FieldGetResponseBodyInitHeader: ObjectHeaderScheme, + }, + } + ObjectGetResponseBodyScheme = MessageScheme{ + fields: map[protowire.Number]namedField{ + protoobject.FieldGetResponseBodyInit: newNamedField("init", fieldTypeNestedMessage), + protoobject.FieldGetResponseBodyChunk: newNamedField("chunk", fieldTypeBytes), + protoobject.FieldGetResponseBodySplitInfo: newNamedField("split info", fieldTypeNestedMessage), + }, + nestedFields: map[protowire.Number]MessageScheme{ + protoobject.FieldGetResponseBodyInit: ObjectGetResponseInitScheme, + protoobject.FieldGetResponseBodySplitInfo: ObjectSplitInfoScheme, + }, + } +) diff --git a/internal/protobuf/protoscan/scan.go b/internal/protobuf/protoscan/scan.go new file mode 100644 index 0000000000..ad9578c049 --- /dev/null +++ b/internal/protobuf/protoscan/scan.go @@ -0,0 +1,220 @@ +package protoscan + +import ( + "errors" + "fmt" + "slices" + + iprotobuf "github.com/nspcc-dev/neofs-node/internal/protobuf" + "google.golang.org/protobuf/encoding/protowire" +) + +// ErrContinue is a continuation error. +var ErrContinue = errors.New("continue") + +type commonOptions = struct { + InterceptUint32 func(protowire.Number, uint32) error + InterceptUint64 func(protowire.Number, uint64) error + InterceptUintEnum func(protowire.Number, int32) error + InterceptBool func(protowire.Number, bool) error + InterceptString func(protowire.Number, []byte) error + // If InterceptBytes returns nil, [ScanMessageOrdered] does not verify its + // specified format. In this case, the caller is responsible for processing the + // field itself. To continue scanning the field as usual, the function should + // return [ErrContinue]. + InterceptBytes func(protowire.Number, iprotobuf.BuffersSlice) error +} + +// ScanMessageOrderedOptions groups optional [ScanMessageOrdered] parameters. +// +// Interceptors allow to intercept typed field values of fields declared in +// message scheme to handle them specifically. If the function returns an error, +// [ScanMessageOrdered] immediately returns it as is generally. +type ScanMessageOrderedOptions struct { + commonOptions + // If InterceptNested returns nil, [ScanMessageOrdered] does not scan its + // argument. In this case, the caller is responsible for processing the field + // itself. To continue scanning the field as usual, the function should return + // [ErrContinue]. + InterceptNested func(protowire.Number, iprotobuf.BuffersSlice, bool) (bool, error) +} + +// ScanMessageOrdered checks whether buf contains a complete and valid Protocol +// Buffers V3 message according to the given scheme. It goes over each field +// one-by-one and checks that it is encoded correctly. If the field format is +// declared in the schema, ScanMessageOrdered additionally checks for compliance +// with it. If the field is unknown, ScanMessageOrdered simply checks the +// encoding and skips it. Field repetition is not checked. +// +// Package provides NeoFS API protocol schemes. With a zero scheme, +// ScanMessageOrdered verifies that buf is a valid Protocol Buffers V3 message +// overall. +// +// Boolean returns is a flag of direct field order: it states whether fields are +// arranged in ascending numerical order in all messages at all nesting levels. +func ScanMessageOrdered(buffers iprotobuf.BuffersSlice, scheme MessageScheme, opts ScanMessageOrderedOptions) (bool, error) { + return scanMessageOrdered(buffers, scheme, true, opts) +} + +// ScanMessageOptions groups optional [ScanMessage] parameters. +// +// Interceptors allow to intercept typed field values of fields declared in +// message scheme to handle them specifically. If the function returns an error, +// [ScanMessage] immediately returns it as is generally. +type ScanMessageOptions struct { + commonOptions + // If InterceptNested returns nil, [ScanMessage] does not scan its argument. In + // this case, the caller is responsible for processing the field itself. To + // continue scanning the field as usual, the function should return + // [ErrContinue]. + InterceptNested func(protowire.Number, iprotobuf.BuffersSlice) error +} + +// ScanMessage is an alternative for [ScanMessageOrdered] when field order does +// not matter. +func ScanMessage(buffers iprotobuf.BuffersSlice, scheme MessageScheme, opts ScanMessageOptions) error { + var orderedOpts ScanMessageOrderedOptions + orderedOpts.commonOptions = opts.commonOptions + if opts.InterceptNested != nil { + orderedOpts.InterceptNested = func(num protowire.Number, buffers iprotobuf.BuffersSlice, _ bool) (bool, error) { + return false, opts.InterceptNested(num, buffers) + } + } + + _, err := scanMessageOrdered(buffers, scheme, false, orderedOpts) + return err +} + +func scanMessageOrdered(buffers iprotobuf.BuffersSlice, scheme MessageScheme, checkOrder bool, opts ScanMessageOrderedOptions) (bool, error) { + var prevNum protowire.Number + + for !buffers.IsEmpty() { + num, wireTyp, err := buffers.ParseTag() + if err != nil { + return false, fmt.Errorf("parse next tag: %w", err) + } + + if checkOrder { + if num < prevNum { + checkOrder = false + } else { + prevNum = num + } + } + + switch f, _ := scheme.fields[num]; f.typ { + default: + return false, iprotobuf.NewUnsupportedFieldError(num, wireTyp) + case fieldTypeUint32: + v, err := buffers.ParseUint32Field(num, wireTyp) + if err != nil { + return false, newParseFieldError(f, err) + } + if opts.InterceptUint32 != nil { + if err = opts.InterceptUint32(num, v); err != nil { + return false, err + } + } + case fieldTypeUint64: + v, err := buffers.ParseUint64Field(num, wireTyp) + if err != nil { + return false, newParseFieldError(f, err) + } + if opts.InterceptUint64 != nil { + if err = opts.InterceptUint64(num, v); err != nil { + return false, err + } + } + case fieldTypeEnum: + e, err := buffers.ParseEnumField(num, wireTyp) + if err != nil { + return false, newParseFieldError(f, err) + } + if opts.InterceptUintEnum != nil { + if err = opts.InterceptUintEnum(num, e); err != nil { + return false, err + } + } + case fieldTypeBool: + v, err := buffers.ParseBoolField(num, wireTyp) + if err != nil { + return false, newParseFieldError(f, err) + } + if opts.InterceptBool != nil { + if err = opts.InterceptBool(num, v); err != nil { + return false, err + } + } + case fieldTypeRepeatedEnum: + err := buffers.SkipRepeatedEnum(num, wireTyp) + if err != nil { + return false, newParseFieldError(f, err) + } + case fieldTypeString: + v, err := buffers.ParseStringField(num, wireTyp) + if err != nil { + return false, newParseFieldError(f, err) + } + if opts.InterceptString != nil { + if err = opts.InterceptString(num, v); err != nil { + return false, err + } + } + case fieldTypeBytes: + v, err := buffers.ParseLENField(num, wireTyp) + if err != nil { + return false, newParseFieldError(f, err) + } + if opts.InterceptBytes != nil { + if err = opts.InterceptBytes(num, v); err == nil { + break + } else if !errors.Is(err, ErrContinue) { + return false, err + } + } + if kind, ok := scheme.binaryKindFields[num]; ok { + if err = verifyBinaryField(kind, v); err != nil { + return false, newParseFieldError(f, fmt.Errorf("invalid binary field of %s kind: %w", kind, err)) + } + } + case fieldTypeNestedMessage: + v, err := buffers.ParseLENField(num, wireTyp) + if err != nil { + return false, newParseFieldError(f, err) + } + + if opts.InterceptNested != nil { + if checkOrder, err = opts.InterceptNested(num, v, checkOrder); err == nil { + break + } else if !errors.Is(err, ErrContinue) { + return false, err + } + } + + if slices.Contains(scheme.recursionFields, num) { + if checkOrder, err = scanMessageOrdered(v, scheme, checkOrder, ScanMessageOrderedOptions{}); err != nil { + return false, newParseFieldError(f, err) + } + break + } + + if nestedScheme, ok := scheme.nestedFields[num]; ok { + if checkOrder, err = scanMessageOrdered(v, nestedScheme, checkOrder, ScanMessageOrderedOptions{}); err != nil { + return false, newParseFieldError(f, err) + } + break + } + + if alias, ok := scheme.nestedAliases[num]; ok { + if checkOrder, err = scanMessageOrdered(v, resolveScheme(alias), checkOrder, ScanMessageOrderedOptions{}); err != nil { + return false, newParseFieldError(f, err) + } + break + } + + panic(fmt.Sprintf("format of nested message field %s is not specified", f)) + } + } + + return checkOrder, nil +} diff --git a/internal/protobuf/protoscan/scheme.go b/internal/protobuf/protoscan/scheme.go new file mode 100644 index 0000000000..f6d80221fc --- /dev/null +++ b/internal/protobuf/protoscan/scheme.go @@ -0,0 +1,188 @@ +package protoscan + +import ( + "bytes" + "crypto/sha256" + "errors" + "fmt" + "strconv" + + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/encoding/address" + iprotobuf "github.com/nspcc-dev/neofs-node/internal/protobuf" + islices "github.com/nspcc-dev/neofs-node/internal/slices" + "github.com/nspcc-dev/neofs-sdk-go/user" + "google.golang.org/protobuf/encoding/protowire" +) + +const ( + uuidLen = 16 +) + +// fieldType is an enumeration of Protocol Buffers V3 field types used in NeoFS +// API protocol. +type fieldType uint8 + +// All [fieldType] values supported by this package. +const ( + _ = iota + fieldTypeUint32 + fieldTypeUint64 + fieldTypeEnum + fieldTypeBool + fieldTypeRepeatedEnum + fieldTypeString + fieldTypeBytes + fieldTypeNestedMessage +) + +// String implements [fmt.Stringer]. +func (x fieldType) String() string { + switch x { + case fieldTypeEnum: + return "enum" + case fieldTypeUint32: + return "uint32" + case fieldTypeUint64: + return "uint64" + case fieldTypeBool: + return "bool" + case fieldTypeRepeatedEnum: + return "repeated enum" + case fieldTypeString: + return "string" + case fieldTypeBytes: + return "bytes" + case fieldTypeNestedMessage: + return "nested message" + default: + return "unknown#" + strconv.Itoa(int(x)) + } +} + +// namedField pairs field name and number. +type namedField struct { + name string + typ fieldType +} + +// String implements [fmt.Stringer]. +func (x namedField) String() string { + return x.name + " (" + x.typ.String() + ")" +} + +// newNamedField constructs new namedField instance. +func newNamedField(name string, typ fieldType) namedField { + return namedField{name: name, typ: typ} +} + +// newParseFieldError returns common error for failed f's parsing. +func newParseFieldError(f namedField, cause error) error { + return fmt.Errorf("parse %s field: %w", f, cause) +} + +// binaryFieldKind enumerates binary fields of specific format. +type binaryFieldKind uint8 + +const ( + _ = iota + binaryFieldSHA256 + binaryFieldN3Address + binaryFieldKindUUIDV4 +) + +// String implements [fmt.Stringer]. +func (x binaryFieldKind) String() string { + switch x { + case binaryFieldSHA256: + return "SHA256" + case binaryFieldN3Address: + return "N3Address" + case binaryFieldKindUUIDV4: + return "UUIDV4" + default: + return "unknown#" + strconv.Itoa(int(x)) + } +} + +// verifySHA256 checks whether buffers contain a valid non-zero SHA-256 hash. +func verifySHA256(buffers iprotobuf.BuffersSlice) error { + b := buffers.ReadOnlyData() + if len(b) != sha256.Size { + return fmt.Errorf("len is %d, expected %d", len(b), sha256.Size) + } + if islices.AllZeros(b) { + return errors.New("all bytes are zero") + } + return nil +} + +// checks whether buffers contain a valid Neo3 address. +func verifyN3Address(buffers iprotobuf.BuffersSlice) error { + b := buffers.ReadOnlyData() + if len(b) != user.IDSize { + return fmt.Errorf("len is %d, expected %d", len(b), user.IDSize) + } + if b[0] != address.NEO3Prefix { + return fmt.Errorf("prefix byte is 0x%X, expected 0x%X", b[0], address.NEO3Prefix) + } + if !bytes.Equal(b[21:], hash.Checksum(b[:21])) { + return errors.New("checksum mismatch") + } + return nil +} + +// verifyUUIDV4 checks whether buffers contain a valid UUID V4. +func verifyUUIDV4(buffers iprotobuf.BuffersSlice) error { + b := buffers.ReadOnlyData() + if len(b) != uuidLen { + return fmt.Errorf("invalid len: %d instead of %d", len(b), uuidLen) + } + if ver := b[6] >> 4; ver != 4 { + return fmt.Errorf("wrong UUID version %d, expected 4", ver) + } + return nil +} + +// verifies b according to its kind. Panics if kind is unknown. +func verifyBinaryField(kind binaryFieldKind, buffers iprotobuf.BuffersSlice) error { + // TODO: consider optimization to use a byte stream as is instead of slicing + switch kind { + case binaryFieldSHA256: + return verifySHA256(buffers) + case binaryFieldN3Address: + return verifyN3Address(buffers) + case binaryFieldKindUUIDV4: + return verifyUUIDV4(buffers) + default: + panic(fmt.Sprintf("unexpected kind %d", kind)) + } +} + +// Scheme alias allows to resolve cross-dependency of messages. +type schemeAlias = uint8 + +const ( + _ = iota + schemeAliasObjectHeader +) + +// resolves scheme by its alias. Panics if alias is unknown. +func resolveScheme(alias schemeAlias) MessageScheme { + switch alias { + default: + panic(fmt.Sprintf("unexpected alias %d", alias)) + case schemeAliasObjectHeader: + return ObjectHeaderScheme + } +} + +// MessageScheme describes scheme of particular NeoFS API protocol message for +// proper scanning. +type MessageScheme struct { + fields map[protowire.Number]namedField + binaryKindFields map[protowire.Number]binaryFieldKind + nestedFields map[protowire.Number]MessageScheme + nestedAliases map[protowire.Number]schemeAlias + recursionFields []protowire.Number +} diff --git a/pkg/services/object/acl/eacl/v2/object.go b/pkg/services/object/acl/eacl/v2/object.go index 167b3f944e..4ffc02b40e 100644 --- a/pkg/services/object/acl/eacl/v2/object.go +++ b/pkg/services/object/acl/eacl/v2/object.go @@ -12,7 +12,6 @@ import ( oid "github.com/nspcc-dev/neofs-sdk-go/object/id" protoobject "github.com/nspcc-dev/neofs-sdk-go/proto/object" "github.com/nspcc-dev/neofs-sdk-go/version" - "google.golang.org/protobuf/encoding/protowire" ) type sysObjHdr struct { @@ -108,21 +107,12 @@ func headersFromBinaryObjectHeader(buf []byte, cnr cid.ID, id *oid.ID) ([]eaclSD res := make([]eaclSDK.Header, 0, 10) var off int - var prevNum protowire.Number for { num, typ, n, err := iprotobuf.ParseTag(buf[off:]) if err != nil { return nil, err } - if num < prevNum { - return nil, iprotobuf.NewUnorderedFieldsError(prevNum, num) - } - if num == prevNum && num != protoobject.FieldHeaderAttributes { - return nil, iprotobuf.NewRepeatedFieldError(num) - } - prevNum = num - off += n switch num { diff --git a/pkg/services/object/get.go b/pkg/services/object/get.go new file mode 100644 index 0000000000..693a2c28ae --- /dev/null +++ b/pkg/services/object/get.go @@ -0,0 +1,372 @@ +package object + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/binary" + "errors" + "fmt" + "io" + + iprotobuf "github.com/nspcc-dev/neofs-node/internal/protobuf" + "github.com/nspcc-dev/neofs-node/internal/protobuf/protoscan" + aclsvc "github.com/nspcc-dev/neofs-node/pkg/services/object/acl/v2" + apistatus "github.com/nspcc-dev/neofs-sdk-go/client/status" + "github.com/nspcc-dev/neofs-sdk-go/object" + oid "github.com/nspcc-dev/neofs-sdk-go/object/id" + protoobject "github.com/nspcc-dev/neofs-sdk-go/proto/object" + protorefs "github.com/nspcc-dev/neofs-sdk-go/proto/refs" + protostatus "github.com/nspcc-dev/neofs-sdk-go/proto/status" + "google.golang.org/grpc" + "google.golang.org/grpc/mem" + "google.golang.org/protobuf/encoding/protowire" +) + +type getStreamProgress struct { + headWas bool + readPayload int +} + +// returns: +// - nil on completed object transmission +// - [object.SplitInfoError]/nil on split info response and unset/set raw flag in request +// - [apistatus.ErrObjectNotFound] on 404 status +// - nil on other API statuses +// - any other transport/protocol error otherwise +func (x *getProxyContext) continueWithConn(ctx context.Context, conn *grpc.ClientConn) error { + stream, err := conn.NewStream(ctx, &protoobject.ObjectService_ServiceDesc.Streams[0], protoobject.ObjectService_Get_FullMethodName, + grpc.StaticMethod(), + grpc.ForceCodecV2(iprotobuf.BufferedCodec{}), + ) + if err != nil { + return fmt.Errorf("stream opening failed: %w", err) + } + if err = stream.SendMsg(x.req); err != nil { + return fmt.Errorf("send request: %w", err) + } + if err = stream.CloseSend(); err != nil { + return fmt.Errorf("close send: %w", err) + } + + // TODO: dont forget to free buffers when responses ignored + + var prog getStreamProgress + for { + var respBuf mem.BufferSlice + if err = stream.RecvMsg(&respBuf); err != nil { + if errors.Is(err, io.EOF) { + if !prog.headWas { + return io.ErrUnexpectedEOF + } + return nil + } + return fmt.Errorf("reading the response failed: %w", err) + } + + fin, sent, err := x.handleGetResponse(&prog, respBuf) + if !sent { + respBuf.Free() + } + if err != nil { + return fmt.Errorf("handle next stream message: %w", err) + } + if fin { + return nil + } + } +} + +func (x *getProxyContext) handleGetResponse(streamProg *getStreamProgress, respBuf mem.BufferSlice) (bool, bool, error) { + var code uint32 + var body iprotobuf.BuffersSlice + var opts protoscan.ScanMessageOptions + + opts.InterceptNested = func(num protowire.Number, buffers iprotobuf.BuffersSlice) error { + switch num { + case iprotobuf.FieldResponseBody: + body = buffers + return nil + case iprotobuf.FieldResponseMetaHeader: + var err error + code, err = getStatusCodeFromResponseMetaHeader(buffers) + if err != nil { + return fmt.Errorf("handle meta header: %w", err) + } + return nil + } + return protoscan.ErrContinue + } + + err := protoscan.ScanMessage(iprotobuf.NewBuffersSlice(respBuf), protoscan.ResponseScheme, opts) + if err != nil { + return false, false, err + } + + if code == protostatus.ObjectNotFound { + return false, false, apistatus.ErrObjectNotFound + } + + if code != protostatus.OK { + return true, true, x.respStream.base.SendMsg(respBuf) + } + + // TODO: forbid body if code != OK? + + sent, err := x.handleResponseBody(streamProg, respBuf, body) + if err != nil { + return false, sent, fmt.Errorf("handle body: %w", err) + } + + return false, sent, nil +} + +func (x *getProxyContext) handleResponseBody(streamProg *getStreamProgress, respBuf mem.BufferSlice, buffers iprotobuf.BuffersSlice) (bool, error) { + var oneofNum protowire.Number + var oneofFld iprotobuf.BuffersSlice + var opts protoscan.ScanMessageOptions + + opts.InterceptBytes = func(num protowire.Number, buffers iprotobuf.BuffersSlice) error { + if num == protoobject.FieldGetResponseBodyChunk { + if !streamProg.headWas { + return errors.New("incorrect message sequence") + } + oneofNum, oneofFld = num, buffers + } + return nil + } + opts.InterceptNested = func(num protowire.Number, buffers iprotobuf.BuffersSlice) error { + switch num { + case protoobject.FieldGetResponseBodyInit: + if streamProg.headWas { + return errors.New("incorrect message sequence") + } + streamProg.headWas = true + oneofNum, oneofFld = num, buffers + return nil + case protoobject.FieldGetResponseBodySplitInfo: + oneofNum, oneofFld = num, buffers + return nil + } + return protoscan.ErrContinue + } + + err := protoscan.ScanMessage(buffers, protoscan.ObjectGetResponseBodyScheme, opts) + if err != nil { + return false, err + } + + switch oneofNum { + default: + return false, errors.New("none of the supported oneof fields are specified") + case protoobject.FieldGetResponseBodyInit: + return x.handleInitResponse(respBuf, oneofFld) + case protoobject.FieldGetResponseBodyChunk: + return x.handleChunkResponse(streamProg, respBuf, oneofFld) + case protoobject.FieldGetResponseBodySplitInfo: + return x.handleSplitInfo(respBuf, oneofFld) + } +} + +func (x *getProxyContext) handleInitResponse(respBuf mem.BufferSlice, buffers iprotobuf.BuffersSlice) (bool, error) { + var hdr iprotobuf.BuffersSlice + var opts protoscan.ScanMessageOptions + + opts.InterceptNested = func(num protowire.Number, buffers iprotobuf.BuffersSlice) error { + if num != protoobject.FieldGetResponseBodyInitHeader { + return protoscan.ErrContinue + } + + var opts protoscan.ScanMessageOrderedOptions + opts.InterceptUint64 = func(num protowire.Number, u uint64) error { + if num == protoobject.FieldHeaderPayloadLength { + x.payloadLenCheck = u + } + return nil + } + opts.InterceptNested = func(num protowire.Number, buffers iprotobuf.BuffersSlice, checkOrder bool) (bool, error) { + if num != protoobject.FieldHeaderPayloadHash { + return checkOrder, protoscan.ErrContinue + } + var opts protoscan.ScanMessageOrderedOptions + opts.InterceptBytes = func(num protowire.Number, buffers iprotobuf.BuffersSlice) error { + if num == protorefs.FieldChecksumValue { + x.payloadHashCheck = buffers.ReadOnlyData() + } + return nil + } + return protoscan.ScanMessageOrdered(buffers, protoscan.ChecksumScheme, opts) + } + + hdrOrdered, err := protoscan.ScanMessageOrdered(buffers, protoscan.ObjectHeaderScheme, protoscan.ScanMessageOrderedOptions{}) + if err != nil { + return fmt.Errorf("handle header with signature field: %w", err) + } + + if err = checkHeaderProtobufAgainstID(buffers, x.reqOID, hdrOrdered); err != nil { + return err + } + + hdr = buffers + return nil + } + + err := protoscan.ScanMessage(buffers, protoscan.ObjectGetResponseInitScheme, opts) + if err != nil { + return false, err + } + + var sent bool + x.onceHdr.Do(func() { + if x.respStream.recheckEACL { + err = x.respStream.srv.aclChecker.CheckEACL(hdr.ReadOnlyData(), x.respStream.reqInfo) + if err != nil && !errors.Is(err, aclsvc.ErrNotMatched) { // Not matched -> follow basic ACL. + err = eACLErr(x.respStream.reqInfo, err) + return + } + } + err = x.respStream.base.SendMsg(respBuf) + sent = true + }) + + return sent, err +} + +func (x *getProxyContext) handleChunkResponse(streamProg *getStreamProgress, respBuf mem.BufferSlice, chunkBuffers iprotobuf.BuffersSlice) (bool, error) { + chunkLen := chunkBuffers.Len() + + from, to := chunkBoundsToSend(x.respondedPayload, streamProg.readPayload, chunkLen) + if from == to { + streamProg.readPayload += chunkLen + return false, nil + } + + _, err := chunkBuffers.MoveNext(from) + if err != nil { + return false, fmt.Errorf("seek chunk left bound in response buffers: %w", err) + } + + chunkBuffers, err = chunkBuffers.MoveNext(to - from) + if err != nil { + return false, fmt.Errorf("seek chunk right bound in response buffers: %w", err) + } + + if x.payloadHashGot == nil { + x.payloadHashGot = sha256.New() + } + chunkBuffers.HashTo(x.payloadHashGot) + + respChunkLen := to - from + + if uint64(x.respondedPayload+respChunkLen) == x.payloadLenCheck { + if !bytes.Equal(x.payloadHashGot.Sum(nil), x.payloadHashCheck) { // not merged via && for readability + return false, errors.New("received payload mismatches checksum from header") + } + } + + remoteSent := respChunkLen == chunkLen + if !remoteSent { + if respChunkLen <= maxGetResponseChunkLen { + localRespBuf, _ := getBufferForChunkGetResponse() + + chunkBuffers.CopyTo(localRespBuf.SliceBuffer[maxChunkOffsetInGetResponse:]) + + bodyf := shiftPayloadChunkInGetResponseBuffer(localRespBuf.SliceBuffer, maxChunkOffsetInGetResponse, respChunkLen) + + if x.respStream.signResponse { + n, err := x.respStream.srv.signResponse(localRespBuf.SliceBuffer[bodyf.To:], localRespBuf.SliceBuffer[bodyf.ValueFrom:bodyf.To], nil) + if err != nil { + return false, fmt.Errorf("sign chunk response: %w", err) + } + bodyf.To += n + } + + localRespBuf.SetBounds(bodyf.From, bodyf.To) + respBuf = mem.BufferSlice{localRespBuf} + } else { + // TODO: in this case we could make respBuf = mem.BufferSlice{prefix, chunkBuffers}, + // but then we'd have to provide mem.Buffer from iprotobuf.BuffersSlice + bodyFldLen := 1 + protowire.SizeBytes(respChunkLen) + fullLen := 1 + protowire.SizeBytes(bodyFldLen) + if x.respStream.signResponse { + fullLen += maxResponseVerificationHeaderLen + } + + b := make(mem.SliceBuffer, fullLen) + b[0] = iprotobuf.TagBytes1 // body field + off := 1 + binary.PutUvarint(b[1:], uint64(bodyFldLen)) + b[off] = iprotobuf.TagBytes2 // chunk field + off += 1 + binary.PutUvarint(b[off+1:], uint64(respChunkLen)) + off += chunkBuffers.CopyTo(b[off:]) + if x.respStream.signResponse { + n, err := x.respStream.srv.signResponse(b[off:], b[:off], nil) + if err != nil { + return false, fmt.Errorf("sign chunk response: %w", err) + } + b = b[:off+n] + } + + respBuf = mem.BufferSlice{b} + } + } + + if err := x.respStream.base.SendMsg(respBuf); err != nil { + return remoteSent, err + } + + streamProg.readPayload += chunkLen + x.respondedPayload += to - from + + return remoteSent, nil +} + +func (x *getProxyContext) handleSplitInfo(respBuf mem.BufferSlice, buffers iprotobuf.BuffersSlice) (bool, error) { + var si object.SplitInfo + var opts protoscan.ScanMessageOptions + + compose := !x.req.GetBody().GetRaw() + if compose { + opts.InterceptBytes = func(num protowire.Number, buffers iprotobuf.BuffersSlice) error { + if num == protoobject.FieldSplitInfoSplitID { + id := object.NewSplitIDFromV2(buffers.ReadOnlyData()) + if id == nil { + return errors.New("invalid split ID") + } + si.SetSplitID(id) + } + return nil + } + opts.InterceptNested = func(num protowire.Number, buffers iprotobuf.BuffersSlice) error { + if num != protoobject.FieldSplitInfoLastPart && num != protoobject.FieldSplitInfoLink && num != protoobject.FieldSplitInfoFirstPart { + return protoscan.ErrContinue + } + + var opts protoscan.ScanMessageOptions + opts.InterceptBytes = func(num2 protowire.Number, buffers iprotobuf.BuffersSlice) error { + if num2 == protorefs.FieldObjectIDValue { + switch num { + case protoobject.FieldSplitInfoLastPart: + si.SetLastPart(oid.ID(buffers.ReadOnlyData())) + case protoobject.FieldSplitInfoLink: + si.SetLink(oid.ID(buffers.ReadOnlyData())) + case protoobject.FieldSplitInfoFirstPart: + si.SetLastPart(oid.ID(buffers.ReadOnlyData())) + } + } + return nil + } + return protoscan.ScanMessage(buffers, protoscan.ObjectIDScheme, opts) + } + } + + err := protoscan.ScanMessage(buffers, protoscan.ObjectSplitInfoScheme, opts) + if err != nil { + return false, fmt.Errorf("handle split info field: %w", err) + } + + if compose { + return false, object.NewSplitInfoError(&si) + } + + return true, x.respStream.base.SendMsg(respBuf) +} diff --git a/pkg/services/object/get/exec.go b/pkg/services/object/get/exec.go index 36debaf679..29b29ebbc4 100644 --- a/pkg/services/object/get/exec.go +++ b/pkg/services/object/get/exec.go @@ -64,6 +64,12 @@ type execCtx struct { localGetBuffer []byte submitLocalGetStreamFn SubmitStreamFunc + + forwardHeadRequestFn ForwardHeadRequestFunc + + submitHeadResponseFn SubmitHeadResponseFunc + + forwardGetRequestFn ForwardGetRequestFunc } type execOption func(*execCtx) @@ -76,9 +82,11 @@ const ( statusNotFound ) -func headOnly() execOption { +func headOnly(forwardRequestFn ForwardHeadRequestFunc, submitResponseFn SubmitHeadResponseFunc) execOption { return func(c *execCtx) { c.head = true + c.forwardHeadRequestFn = forwardRequestFn + c.submitHeadResponseFn = submitResponseFn } } @@ -114,6 +122,12 @@ func withLocalGetBuffer(buf []byte, submitStreamFn SubmitStreamFunc) execOption } } +func withForwardGetRequestFunc(f ForwardGetRequestFunc) execOption { + return func(ctx *execCtx) { + ctx.forwardGetRequestFn = f + } +} + func (exec *execCtx) setLogger(l *zap.Logger) { if l.Level() != zap.DebugLevel { exec.log = l @@ -453,4 +467,5 @@ func (exec execCtx) isRangeHashForwardingEnabled() bool { func (exec *execCtx) disableForwarding() { exec.prm.SetRequestForwarder(nil) exec.prm.SetRangeHashRequestForwarder(nil) + exec.forwardGetRequestFn = nil } diff --git a/pkg/services/object/get/get.go b/pkg/services/object/get/get.go index 4564ef6aa9..141669389e 100644 --- a/pkg/services/object/get/get.go +++ b/pkg/services/object/get/get.go @@ -47,7 +47,8 @@ func (s *Service) Get(ctx context.Context, prm Prm) error { if len(repRules) > 0 { // REP format does not require encoding bufOpt := withLocalGetBuffer(prm.localGetBuffer, prm.submitLocalGetStreamFn) - err := s.get(ctx, prm.commonPrm, withPreSortedContainerNodes(nodeLists[:len(repRules)], repRules), bufOpt).err + forwardOpt := withForwardGetRequestFunc(prm.forwardRequestFn) + err := s.get(ctx, prm.commonPrm, withPreSortedContainerNodes(nodeLists[:len(repRules)], repRules), bufOpt, forwardOpt).err if len(ecRules) == 0 || !errors.Is(err, apistatus.ErrObjectNotFound) { return err } @@ -266,6 +267,7 @@ func (s *Service) Head(ctx context.Context, prm HeadPrm) error { } if prm.common.LocalOnly() { + if prm.buffer != nil { n, err := s.localObjects.ReadHeader(prm.addr, prm.raw, prm.buffer) if err == nil { @@ -278,7 +280,8 @@ func (s *Service) Head(ctx context.Context, prm HeadPrm) error { } if len(repRules) > 0 { - err := s.get(ctx, prm.commonPrm, headOnly(), withPreSortedContainerNodes(nodeLists[:len(repRules)], repRules)).err + headOpt := headOnly(prm.forwardHeadRequestFn, prm.submitHeadResponseFn) + err := s.get(ctx, prm.commonPrm, headOpt, withPreSortedContainerNodes(nodeLists[:len(repRules)], repRules)).err if len(ecRules) == 0 || !errors.Is(err, apistatus.ErrObjectNotFound) { return err } @@ -294,7 +297,8 @@ func (s *Service) Head(ctx context.Context, prm HeadPrm) error { for i := range ecRules { repRules[i] = uint(ecRules[i].DataPartNum + ecRules[i].ParityPartNum) } - return s.get(ctx, prm.commonPrm, headOnly(), withPreSortedContainerNodes(ecNodeLists, repRules)).err + headOpt := headOnly(prm.forwardHeadRequestFn, prm.submitHeadResponseFn) + return s.get(ctx, prm.commonPrm, headOpt, withPreSortedContainerNodes(ecNodeLists, repRules)).err } return s.copyECObjectHeader(ctx, prm.objWriter, prm.addr.Container(), prm.addr.Object(), prm.common.SessionToken(), diff --git a/pkg/services/object/get/prm.go b/pkg/services/object/get/prm.go index a100b1a0a8..e3f11e8666 100644 --- a/pkg/services/object/get/prm.go +++ b/pkg/services/object/get/prm.go @@ -6,11 +6,13 @@ import ( "hash" "io" + iprotobuf "github.com/nspcc-dev/neofs-node/internal/protobuf" coreclient "github.com/nspcc-dev/neofs-node/pkg/core/client" "github.com/nspcc-dev/neofs-node/pkg/services/object/internal" "github.com/nspcc-dev/neofs-node/pkg/services/object/util" "github.com/nspcc-dev/neofs-sdk-go/object" oid "github.com/nspcc-dev/neofs-sdk-go/object/id" + "google.golang.org/grpc/mem" ) // SubmitStreamFunc is a callback for partially read object stream. @@ -22,6 +24,8 @@ type Prm struct { localGetBuffer []byte submitLocalGetStreamFn SubmitStreamFunc + + forwardRequestFn ForwardGetRequestFunc } // RangePrm groups parameters of GetRange service call. @@ -47,12 +51,28 @@ type RangeHashPrm struct { type RequestForwarder func(context.Context, coreclient.MultiAddressClient) (*object.Object, error) type RangeRequestForwarder func(context.Context, coreclient.MultiAddressClient) ([][]byte, error) +// ForwardHeadRequestFunc sends currently served HEAD request to remote node +// through passed connection and returns buffered response with requested +// object's header binary in it. +type ForwardHeadRequestFunc = func(context.Context, coreclient.MultiAddressClient) (mem.BufferSlice, iprotobuf.BuffersSlice, error) + +// SubmitHeadResponseFunc accepts result of [ForwardHeadRequestFunc]. +type SubmitHeadResponseFunc = func(mem.BufferSlice, iprotobuf.BuffersSlice) + +// ForwardGetRequestFunc continues to serve current GET request from remote node +// through passed connection. +type ForwardGetRequestFunc = func(context.Context, coreclient.MultiAddressClient) error + // HeadPrm groups parameters of Head service call. type HeadPrm struct { commonPrm buffer []byte submitLenFn func(int) + + forwardHeadRequestFn ForwardHeadRequestFunc + + submitHeadResponseFn SubmitHeadResponseFunc } type commonPrm struct { @@ -174,3 +194,38 @@ func (p *Prm) WithBuffer(buffer []byte, submitStreamFn SubmitStreamFunc) { func (p Prm) GetBuffer() ([]byte, SubmitStreamFunc) { return p.localGetBuffer, p.submitLocalGetStreamFn } + +// SetRequestForwarder specifies request transport callback to use for receiving +// response from remote node. +// +// The f should return: +// - response buffer and object header protobuf without an error on OK +// - [object.SplitInfoError] on OK with corresponding body field +// - [apistatus.ErrObjectNotFound] on 404 status +// - (respBuf, iprotobuf.BuffersSlice{}, nil) on other API statuses +// - any transport error +// +// Once results successfully received, it is forwarded untouched to handler +// which must be set via [HeadPrm.SetSubmitHeadResponseFunc]. +func (p *HeadPrm) SetRequestForwarder(f ForwardHeadRequestFunc) { + p.forwardHeadRequestFn = f +} + +// SetSubmitHeadResponseFunc specifies handler to pass results of +// [HeadPrm.SetRequestForwarder] argument into. +func (p *HeadPrm) SetSubmitHeadResponseFunc(f SubmitHeadResponseFunc) { + p.submitHeadResponseFn = f +} + +// SetRequestForwarder specifies request transport callback to use for streaming +// responses from remote node. +// +// The f should return: +// - nil on completed object transmission +// - [object.SplitInfoError]/nil on split info response and unset/set raw flag in request +// - [apistatus.ErrObjectNotFound] on 404 status +// - nil on other API statuses +// - any other transport/protocol error otherwise +func (p *Prm) SetRequestForwarder(f ForwardGetRequestFunc) { + p.forwardRequestFn = f +} diff --git a/pkg/services/object/get/util.go b/pkg/services/object/get/util.go index b67b8a58c4..906699afcb 100644 --- a/pkg/services/object/get/util.go +++ b/pkg/services/object/get/util.go @@ -177,6 +177,18 @@ func (c *clientCacheWrapper) get(ctx context.Context, info coreclient.NodeInfo) } func (c *clientWrapper) getObject(exec *execCtx) (*object.Object, io.ReadCloser, error) { + if exec.forwardHeadRequestFn != nil { + respBuf, hdr, err := exec.forwardHeadRequestFn(exec.ctx, c.client) + if err == nil { + exec.submitHeadResponseFn(respBuf, hdr) + } + return nil, nil, err + } + + if exec.forwardGetRequestFn != nil { + return nil, nil, exec.forwardGetRequestFn(exec.ctx, c.client) + } + if exec.isForwardingEnabled() { obj, err := exec.prm.forwarder(exec.ctx, c.client) return obj, nil, err diff --git a/pkg/services/object/head.go b/pkg/services/object/head.go new file mode 100644 index 0000000000..21a99e8c98 --- /dev/null +++ b/pkg/services/object/head.go @@ -0,0 +1,170 @@ +package object + +import ( + "context" + "errors" + "fmt" + + iprotobuf "github.com/nspcc-dev/neofs-node/internal/protobuf" + "github.com/nspcc-dev/neofs-node/internal/protobuf/protoscan" + apistatus "github.com/nspcc-dev/neofs-sdk-go/client/status" + oid "github.com/nspcc-dev/neofs-sdk-go/object/id" + protoobject "github.com/nspcc-dev/neofs-sdk-go/proto/object" + protostatus "github.com/nspcc-dev/neofs-sdk-go/proto/status" + "google.golang.org/grpc" + "google.golang.org/grpc/mem" + "google.golang.org/protobuf/encoding/protowire" +) + +// returns: +// - response buffer and object header protobuf without an error on OK +// - (nil, nil, [object.SplitInfoError]) on OK with corresponding body field +// - (nil, nil, [apistatus.ErrObjectNotFound]) on 404 status +// - (respBuf, nil, nil) on other API statuses +// - (nil, nil, err) on any transport err +func getHeaderFromRemoteNode(ctx context.Context, conn *grpc.ClientConn, req *protoobject.HeadRequest, reqOID oid.ID) (mem.BufferSlice, iprotobuf.BuffersSlice, error) { + var respBuf mem.BufferSlice + + err := conn.Invoke(ctx, protoobject.ObjectService_Head_FullMethodName, req, &respBuf, + grpc.StaticMethod(), + grpc.ForceCodecV2(iprotobuf.BufferedCodec{}), + ) + if err != nil { + return nil, iprotobuf.BuffersSlice{}, fmt.Errorf("sending the request failed: %w", err) + } + + hdrBuffers, err := handleHeadResponse(respBuf, reqOID) + if err != nil { + respBuf.Free() + return nil, iprotobuf.BuffersSlice{}, fmt.Errorf("handle response: %w", err) + } + + return respBuf, hdrBuffers, nil +} + +func handleHeadResponse(respBuf mem.BufferSlice, reqOID oid.ID) (iprotobuf.BuffersSlice, error) { + var code uint32 + var body iprotobuf.BuffersSlice + var opts protoscan.ScanMessageOptions + + opts.InterceptNested = func(num protowire.Number, buffers iprotobuf.BuffersSlice) error { + switch num { + case iprotobuf.FieldResponseBody: + body = buffers + return nil + case iprotobuf.FieldResponseMetaHeader: + var err error + code, err = getStatusCodeFromResponseMetaHeader(buffers) + if err != nil { + return fmt.Errorf("handle meta header: %w", err) + } + return nil + } + return protoscan.ErrContinue + } + + err := protoscan.ScanMessage(iprotobuf.NewBuffersSlice(respBuf), protoscan.ResponseScheme, opts) + if err != nil { + return iprotobuf.BuffersSlice{}, err + } + + if code == protostatus.ObjectNotFound { + return iprotobuf.BuffersSlice{}, apistatus.ErrObjectNotFound + } + + if code != protostatus.OK { + return iprotobuf.BuffersSlice{}, nil + } + + // TODO: forbid body if code != OK? + + hdrBuffers, err := handleHeadResponseBody(body, reqOID) + if err != nil { + return iprotobuf.BuffersSlice{}, fmt.Errorf("handle body: %w", err) + } + + return hdrBuffers, nil +} + +func handleHeadResponseBody(buffers iprotobuf.BuffersSlice, reqOID oid.ID) (iprotobuf.BuffersSlice, error) { + var oneofNum protowire.Number + var oneofFld iprotobuf.BuffersSlice + var opts protoscan.ScanMessageOptions + + opts.InterceptNested = func(num protowire.Number, buffers iprotobuf.BuffersSlice) error { + switch num { + case protoobject.FieldHeadResponseBodyHeader, + protoobject.FieldHeadResponseBodyShortHeader, + protoobject.FieldHeadResponseBodySplitInfo: + oneofNum, oneofFld = num, buffers + return nil + } + return protoscan.ErrContinue + } + + err := protoscan.ScanMessage(buffers, protoscan.ObjectHeadResponseBodyScheme, opts) + if err != nil { + return iprotobuf.BuffersSlice{}, err + } + + switch oneofNum { + default: + return iprotobuf.BuffersSlice{}, errors.New("none of the supported oneof fields are specified") + case protoobject.FieldHeadResponseBodyHeader: + hdrBuffers, err := handleHeaderWithSignature(oneofFld, reqOID) + if err != nil { + return iprotobuf.BuffersSlice{}, fmt.Errorf("handle header with signature field: %w", err) + } + return hdrBuffers, nil + case protoobject.FieldHeadResponseBodyShortHeader: + return iprotobuf.BuffersSlice{}, errors.New("unsupported short header") + case protoobject.FieldHeadResponseBodySplitInfo: + err := protoscan.ScanMessage(oneofFld, protoscan.ObjectSplitInfoScheme, protoscan.ScanMessageOptions{}) + if err != nil { + return iprotobuf.BuffersSlice{}, fmt.Errorf("handle split info field: %w", err) + } + return iprotobuf.BuffersSlice{}, nil + } +} + +func handleHeaderWithSignature(buffers iprotobuf.BuffersSlice, reqOID oid.ID) (iprotobuf.BuffersSlice, error) { + var withHdr bool + var hdrBuffers iprotobuf.BuffersSlice + var hdrOrdered bool + var withSig bool + var opts protoscan.ScanMessageOptions + + opts.InterceptNested = func(num protowire.Number, buffers iprotobuf.BuffersSlice) error { + if num != protoobject.FieldHeaderWithSignatureHeader { + if num == protoobject.FieldHeaderWithSignatureSignature { + withSig = true + } + return protoscan.ErrContinue + } + + var err error + hdrOrdered, err = protoscan.ScanMessageOrdered(buffers, protoscan.ObjectHeaderScheme, protoscan.ScanMessageOrderedOptions{}) + if err != nil { + return fmt.Errorf("handle header with signature field: %w", err) + } + + hdrBuffers = buffers + withHdr = true + return nil + } + + err := protoscan.ScanMessage(buffers, protoscan.ObjectHeaderWithSignatureScheme, opts) + if err != nil { + return hdrBuffers, err + } + + if !withHdr { + return hdrBuffers, errors.New("missing header") + } + if !withSig { + // TODO(@cthulhu-rider): #1387 use "const" error + return hdrBuffers, errors.New("missing signature") + } + + return hdrBuffers, checkHeaderProtobufAgainstID(hdrBuffers, reqOID, hdrOrdered) +} diff --git a/pkg/services/object/head_internal_test.go b/pkg/services/object/head_internal_test.go new file mode 100644 index 0000000000..da7e82d7f1 --- /dev/null +++ b/pkg/services/object/head_internal_test.go @@ -0,0 +1,133 @@ +package object + +import ( + "bytes" + "testing" + + apistatus "github.com/nspcc-dev/neofs-sdk-go/client/status" + oid "github.com/nspcc-dev/neofs-sdk-go/object/id" + protoobject "github.com/nspcc-dev/neofs-sdk-go/proto/object" + protosession "github.com/nspcc-dev/neofs-sdk-go/proto/session" + protostatus "github.com/nspcc-dev/neofs-sdk-go/proto/status" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" +) + +func BenchmarkHandleHeadResponse(b *testing.B) { + b.Run("404", func(b *testing.B) { + metaHdr := newBlankMetaHeader() + metaHdr.Status = &protostatus.Status{ + Code: 2049, + Message: "object not found", + Details: []*protostatus.Status_Detail{ + {Id: 1869765515, Value: []byte("foo")}, + {Id: 2095463591, Value: []byte("bar")}, + }, + } + + bench := func(b *testing.B, metaHdr *protosession.ResponseMetaHeader) { + respBuf := messageToSingleMemBuffer(b, &protoobject.HeadResponse{ + MetaHeader: metaHdr, + VerifyHeader: newAnyVerificationHeader(), + }) + b.ReportAllocs() // FIXME: drop + for b.Loop() { + _, err := handleHeadResponse(respBuf, oid.ID{}) + require.ErrorIs(b, err, apistatus.ErrObjectNotFound) + // Originally it shows 2 allocs. Inspection showed that one is return + // apistatus.ErrObjectNotFound, the other require.ErrorIs. + } + } + + b.Run("root", func(b *testing.B) { + bench(b, metaHdr) + }) + + b.Run("nestedX5", func(b *testing.B) { + bench(b, nestMetaHeader(metaHdr, 5)) + }) + }) + + b.Run("non-404 failure", func(b *testing.B) { + metaHdr := newBlankMetaHeader() + metaHdr.Status = &protostatus.Status{ + Code: 2052, // already removed + Message: "object already removed", + Details: []*protostatus.Status_Detail{ + {Id: 1869765515, Value: []byte("foo")}, + {Id: 2095463591, Value: []byte("bar")}, + }, + } + + bench := func(b *testing.B, metaHdr *protosession.ResponseMetaHeader) { + respBuf := messageToSingleMemBuffer(b, &protoobject.HeadResponse{ + MetaHeader: metaHdr, + VerifyHeader: newAnyVerificationHeader(), + }) + b.ReportAllocs() // FIXME: drop + for b.Loop() { + hdr, err := handleHeadResponse(respBuf, oid.ID{}) + require.NoError(b, err) + require.True(b, hdr.IsEmpty()) + // Originally it shows 2 allocs. Inspection showed that one is return + // apistatus.ErrObjectNotFound, the other require.ErrorIs. + } + } + + b.Run("root", func(b *testing.B) { + bench(b, metaHdr) + }) + + b.Run("nestedX5", func(b *testing.B) { + bench(b, nestMetaHeader(metaHdr, 5)) + }) + }) + + b.Run("body", func(b *testing.B) { + b.Run("header", func(b *testing.B) { + obj := newTestObject(b) + + objMsg := obj.ProtoMessage() + + hdr, err := proto.Marshal(objMsg.Header) + require.NoError(b, err) + + respBuf := messageToSingleMemBuffer(b, &protoobject.HeadResponse{ + Body: &protoobject.HeadResponse_Body{Head: &protoobject.HeadResponse_Body_Header{ + Header: &protoobject.HeaderWithSignature{ + Header: objMsg.Header, + Signature: objMsg.Signature, + }, + }}, + MetaHeader: newBlankMetaHeader(), + VerifyHeader: newAnyVerificationHeader(), + }) + + id := obj.GetID() + + b.ReportAllocs() // FIXME: drop + + for b.Loop() { + gotHdr, err := handleHeadResponse(respBuf, id) + require.NoError(b, err) + require.True(b, bytes.Equal(hdr, gotHdr.ReadOnlyData())) + } + }) + + b.Run("split info", func(b *testing.B) { + respBuf := messageToSingleMemBuffer(b, &protoobject.HeadResponse{ + Body: &protoobject.HeadResponse_Body{Head: &protoobject.HeadResponse_Body_SplitInfo{ + SplitInfo: newTestSplitInfo(), + }}, + MetaHeader: newBlankMetaHeader(), + VerifyHeader: newAnyVerificationHeader(), + }) + b.ReportAllocs() // FIXME: drop + for b.Loop() { + hdr, err := handleHeadResponse(respBuf, oid.ID{}) + require.NoError(b, err) + require.True(b, hdr.IsEmpty()) + } + }) + }) +} diff --git a/pkg/services/object/head_test.go b/pkg/services/object/head_test.go index df110b58c7..930431ecfd 100644 --- a/pkg/services/object/head_test.go +++ b/pkg/services/object/head_test.go @@ -10,6 +10,7 @@ import ( iec "github.com/nspcc-dev/neofs-node/internal/ec" iprotobuf "github.com/nspcc-dev/neofs-node/internal/protobuf" "github.com/nspcc-dev/neofs-node/internal/testutil" + clientcore "github.com/nspcc-dev/neofs-node/pkg/core/client" corenetmap "github.com/nspcc-dev/neofs-node/pkg/core/netmap" . "github.com/nspcc-dev/neofs-node/pkg/services/object" getsvc "github.com/nspcc-dev/neofs-node/pkg/services/object/get" @@ -27,6 +28,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/test/bufconn" + "google.golang.org/protobuf/proto" ) func TestServer_Head_Local(t *testing.T) { @@ -163,6 +165,97 @@ func TestServer_Head_Remote(t *testing.T) { assertHeadRequestOK(t, srv, fsChain, req, *obj.CutPayload()) }) + + t.Run("REP forwarded", func(t *testing.T) { + var handlerFSChain mockHandlerFSChain + mockConns := newMockConnections() + + nodes := make([]netmap.NodeInfo, 2) + for i := range nodes { + nodes[i].SetPublicKey([]byte("pub_" + strconv.Itoa(i))) + nodes[i].SetNetworkEndpoints("localhost:" + strconv.Itoa(9090+i)) // any + } + + mockConns.setConn(nodes[0], emptyRemoteNode{}) + + handlerFSChain.repRules = []uint{uint(len(nodes))} + handlerFSChain.nodeLists = [][]netmap.NodeInfo{nodes} + + handler := getsvc.New(&handlerFSChain, + getsvc.WithLocalStorageEngine(newSimpleStorage(t, fsChain)), + getsvc.WithClientConstructor(mockConns), + getsvc.WithKeyStorage(keyStorage), + ) + handlers := headOnlyHandler{svc: handler} + + srv := New(handlers, 0, nil, fsChain, nil, nil, signer.ECDSAPrivateKey, mtrc, aclChecker, reqInfoExt, nil) + + t.Run("header", func(t *testing.T) { + const payloadLen = 100 << 10 + obj := object.New(cnr, signer.ID) + obj.SetAttributes( + object.NewAttribute("k1", "v1"), + object.NewAttribute("k2", "v2"), + ) + obj.SetPayloadSize(payloadLen) + obj.SetPayload(testutil.RandByteSlice(payloadLen)) + require.NoError(t, obj.SetVerificationFields(signer)) + + req := newUnsignedLocalHeadRequest(version.Current(), obj.Address()) + req.MetaHeader.Ttl = 2 + signHeadRequest(t, req, signer) + + objMsg := obj.ProtoMessage() + + metaHdr := newBlankMetaHeader() + nestMetaHeader(metaHdr, 5) + + resp := &protoobject.HeadResponse{ + Body: &protoobject.HeadResponse_Body{ + Head: &protoobject.HeadResponse_Body_Header{ + Header: &protoobject.HeaderWithSignature{ + Header: objMsg.Header, + Signature: objMsg.Signature, + }, + }, + }, + MetaHeader: metaHdr, + VerifyHeader: newAnyVerificationHeader(), + } + + mockConns.setConn(nodes[1], newFixedHeadResponseConn(t, resp)) + + gotResp, err := callHead(t, srv, req) + require.NoError(t, err) + require.True(t, proto.Equal(resp, gotResp)) + }) + + t.Run("split info", func(t *testing.T) { + req := newUnsignedLocalHeadRequest(version.Current(), obj.Address()) + req.MetaHeader.Ttl = 2 + signHeadRequest(t, req, signer) + + // TODO: share + metaHdr := newBlankMetaHeader() + nestMetaHeader(metaHdr, 5) + + resp := &protoobject.HeadResponse{ + Body: &protoobject.HeadResponse_Body{ + Head: &protoobject.HeadResponse_Body_SplitInfo{ + SplitInfo: newTestSplitInfo(), + }, + }, + MetaHeader: metaHdr, + VerifyHeader: newAnyVerificationHeader(), + } + + mockConns.setConn(nodes[1], newFixedHeadResponseConn(t, resp)) + + gotResp, err := callHead(t, srv, req) + require.NoError(t, err) + require.True(t, proto.Equal(resp, gotResp)) + }) + }) } type headOnlyHandler struct { @@ -261,3 +354,38 @@ func assertHeadRequestOK(t *testing.T, srv *Server, fsChain corenetmap.State, re return resp } + +func newFixedHeadResponseConn(t *testing.T, resp *protoobject.HeadResponse) clientcore.MultiAddressClient { + // TODO: try share code with + // simulating a full gRPC request lifecycle starting from the client + lis := bufconn.Listen(32 << 10) + + grpcSrv := grpc.NewServer( + grpc.ForceServerCodecV2(iprotobuf.BufferedCodec{}), + ) + t.Cleanup(grpcSrv.Stop) + + grpcSrv.RegisterService(&grpc.ServiceDesc{ + ServiceName: protoobject.ObjectService_ServiceDesc.ServiceName, + Methods: []grpc.MethodDesc{ + { + MethodName: "Head", + Handler: func(_ any, _ context.Context, _ func(any) error, _ grpc.UnaryServerInterceptor) (any, error) { + return resp, nil + }, + }, + }, + }, nil) + + go func() { _ = grpcSrv.Serve(lis) }() + + c, err := grpc.NewClient("localhost:8080", + grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { return lis.DialContext(ctx) }), + grpc.WithTransportCredentials(insecure.NewCredentials()), // error otherwise + ) + require.NoError(t, err) // lib misuse, not a request error + + return &mockGRPCConn{ + conn: c, + } +} diff --git a/pkg/services/object/server.go b/pkg/services/object/server.go index 9d4c221652..aeaa50c233 100644 --- a/pkg/services/object/server.go +++ b/pkg/services/object/server.go @@ -19,6 +19,7 @@ import ( icrypto "github.com/nspcc-dev/neofs-node/internal/crypto" iobject "github.com/nspcc-dev/neofs-node/internal/object" iprotobuf "github.com/nspcc-dev/neofs-node/internal/protobuf" + "github.com/nspcc-dev/neofs-node/internal/protobuf/protoscan" "github.com/nspcc-dev/neofs-node/pkg/core/client" "github.com/nspcc-dev/neofs-node/pkg/core/container" "github.com/nspcc-dev/neofs-node/pkg/core/netmap" @@ -51,6 +52,9 @@ import ( "github.com/nspcc-dev/tzhash/tz" "github.com/panjf2000/ants/v2" "google.golang.org/grpc" + "google.golang.org/grpc/mem" + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" ) // Handlers represents storage node's internal handler Object service op @@ -610,8 +614,8 @@ func (s *Server) Head(context.Context, *protoobject.HeadRequest) (*protoobject.H } // HeadBuffered serves req and returns response as either -// [*protoobject.HeadResponse] or [mem.Buffer]. The buffer must be freed -// eventually. +// [*protoobject.HeadResponse], [mem.BufferSlice] or [mem.Buffer]. All buffers +// must be freed eventually. func (s *Server) HeadBuffered(ctx context.Context, req *protoobject.HeadRequest) any { var ( err error @@ -670,15 +674,28 @@ func (s *Server) HeadBuffered(ctx context.Context, req *protoobject.HeadRequest) p.WithBuffer(hdrBuf, func(ln int) { hdrLen = ln }) + var proxyRespBuf mem.BufferSlice + var proxyHdrBuf iprotobuf.BuffersSlice + p.SetSubmitHeadResponseFunc(func(respBuf mem.BufferSlice, hdrBuf iprotobuf.BuffersSlice) { + proxyRespBuf, proxyHdrBuf = respBuf, hdrBuf + }) + err = s.handlers.Head(ctx, p) if err != nil { return s.makeStatusHeadResponse(err, needSignResp) } - buffered := hdrLen >= 0 - + var buffered bool var sigf, hdrf iprotobuf.FieldBounds - if buffered { + if proxyRespBuf != nil { + if !recheckEACL || proxyHdrBuf.IsEmpty() { + return proxyRespBuf + } + // TODO: this can be optimized by passing iprotobuf.BuffersSlice into eACL checker + hdrBuf = proxyHdrBuf.ReadOnlyData() + hdrf.To = len(hdrBuf) + buffered = true + } else if buffered = hdrLen >= 0; buffered { _, sigf, hdrf, err = iobject.GetNonPayloadFieldBounds(hdrBuf[:hdrLen]) if err != nil { return s.makeStatusHeadResponse(err, needSignResp) @@ -697,6 +714,10 @@ func (s *Server) HeadBuffered(ctx context.Context, req *protoobject.HeadRequest) err = eACLErr(reqInfo, err) // defer return s.makeStatusHeadResponse(err, needSignResp) } + + if proxyRespBuf != nil { + return proxyRespBuf + } } if !buffered { @@ -774,7 +795,7 @@ func convertHeadPrm(signer ecdsa.PrivateKey, req *protoobject.HeadRequest, resp if meta == nil { return getsvc.HeadPrm{}, errors.New("missing meta header") } - p.SetRequestForwarder(func(ctx context.Context, c client.MultiAddressClient) (*object.Object, error) { + p.SetRequestForwarder(func(ctx context.Context, c client.MultiAddressClient) (mem.BufferSlice, iprotobuf.BuffersSlice, error) { var err error onceResign.Do(func() { req.MetaHeader = &protosession.RequestMetaHeader{ @@ -785,77 +806,20 @@ func convertHeadPrm(signer ecdsa.PrivateKey, req *protoobject.HeadRequest, resp req.VerifyHeader, err = neofscrypto.SignRequestWithBuffer(neofsecdsa.Signer(signer), req, nil) }) if err != nil { - return nil, err + return nil, iprotobuf.BuffersSlice{}, err } - var hdr *object.Object - return hdr, c.ForEachGRPCConn(ctx, func(ctx context.Context, conn *grpc.ClientConn) error { + var respBuf mem.BufferSlice + var hdr iprotobuf.BuffersSlice + return respBuf, hdr, c.ForEachGRPCConn(ctx, func(ctx context.Context, conn *grpc.ClientConn) error { var err error - hdr, err = getHeaderFromRemoteNode(ctx, conn, req, addr.Object()) + respBuf, hdr, err = getHeaderFromRemoteNode(ctx, conn, req, addr.Object()) return err // TODO: log error }) }) return p, nil } -func getHeaderFromRemoteNode(ctx context.Context, conn *grpc.ClientConn, req *protoobject.HeadRequest, reqOID oid.ID) (*object.Object, error) { - resp, err := protoobject.NewObjectServiceClient(conn).Head(ctx, req) - if err != nil { - return nil, fmt.Errorf("sending the request failed: %w", err) - } - - if err := checkStatus(resp.GetMetaHeader().GetStatus()); err != nil { - return nil, err - } - - var hdr *protoobject.Header - var idSig *refs.Signature - switch v := resp.GetBody().GetHead().(type) { - case nil: - return nil, fmt.Errorf("unexpected header type %T", v) - case *protoobject.HeadResponse_Body_ShortHeader: - return nil, fmt.Errorf("unsupported short header") - case *protoobject.HeadResponse_Body_Header: - if v == nil || v.Header == nil { - return nil, errors.New("nil header oneof field") - } - if v.Header.Header == nil { - return nil, errors.New("missing header") - } - if v.Header.Signature == nil { - // TODO(@cthulhu-rider): #1387 use "const" error - return nil, errors.New("missing signature") - } - - if err := checkHeaderAgainstID(v.Header.Header, reqOID); err != nil { - return nil, err - } - - hdr = v.Header.Header - idSig = v.Header.Signature - case *protoobject.HeadResponse_Body_SplitInfo: - if v == nil || v.SplitInfo == nil { - return nil, errors.New("nil split info oneof field") - } - si := object.NewSplitInfo() - err := si.FromProtoMessage(v.SplitInfo) - if err != nil { - return nil, err - } - return nil, object.NewSplitInfoError(si) - } - - mObj := &protoobject.Object{ - Signature: idSig, - Header: hdr, - } - var obj = new(object.Object) - if err := obj.FromProtoMessage(mObj); err != nil { - return nil, err - } - return obj, nil -} - func (s *Server) signHashResponse(resp *protoobject.GetRangeHashResponse, req *protoobject.GetRangeHashRequest) *protoobject.GetRangeHashResponse { resp.VerifyHeader = util.SignResponseIfNeeded(&s.signer, resp, req) return resp @@ -1352,7 +1316,7 @@ func convertGetPrm(signer ecdsa.PrivateKey, req *protoobject.GetRequest, stream respStream: stream, } - p.SetRequestForwarder(func(ctx context.Context, c client.MultiAddressClient) (*object.Object, error) { + p.SetRequestForwarder(func(ctx context.Context, c client.MultiAddressClient) error { var err error onceResign.Do(func() { req.MetaHeader = &protosession.RequestMetaHeader{ @@ -1363,15 +1327,11 @@ func convertGetPrm(signer ecdsa.PrivateKey, req *protoobject.GetRequest, stream req.VerifyHeader, err = neofscrypto.SignRequestWithBuffer(neofsecdsa.Signer(signer), req, nil) }) if err != nil { - return nil, err + return err } - return nil, c.ForEachGRPCConn(ctx, func(ctx context.Context, conn *grpc.ClientConn) error { - err := proxyCtx.continueWithConn(ctx, conn) - if errors.Is(err, io.EOF) { - return nil - } - return err // TODO: log error + return c.ForEachGRPCConn(ctx, func(ctx context.Context, conn *grpc.ClientConn) error { + return proxyCtx.continueWithConn(ctx, conn) // TODO: log error }) }) return p, nil @@ -1391,110 +1351,6 @@ type getProxyContext struct { payloadHashGot hash.Hash } -func (x *getProxyContext) continueWithConn(ctx context.Context, conn *grpc.ClientConn) error { - getStream, err := protoobject.NewObjectServiceClient(conn).Get(ctx, x.req) - if err != nil { - return fmt.Errorf("stream opening failed: %w", err) - } - - var headWas bool - var readPayload int - for { - resp, err := getStream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - if !headWas { - return io.ErrUnexpectedEOF - } - return io.EOF - } - return fmt.Errorf("reading the response failed: %w", err) - } - - if err := checkStatus(resp.GetMetaHeader().GetStatus()); err != nil { - return err - } - - switch v := resp.GetBody().GetObjectPart().(type) { - default: - return fmt.Errorf("unexpected object part %T", v) - case *protoobject.GetResponse_Body_Init_: - if headWas { - return errors.New("incorrect message sequence") - } - headWas = true - if v == nil || v.Init == nil { - return errors.New("nil header oneof field") - } - - if v.Init.Header == nil { - return errors.New("invalid response: missing header") - } - if v.Init.Header.PayloadHash == nil { - return errors.New("invalid response: invalid header: missing payload hash") - } - if err := checkHeaderAgainstID(v.Init.Header, x.reqOID); err != nil { - return err - } - - mo := &protoobject.Object{ - ObjectId: v.Init.ObjectId, - Signature: v.Init.Signature, - Header: v.Init.Header, - } - obj := new(object.Object) - err := obj.FromProtoMessage(mo) - if err != nil { - return err - } - x.onceHdr.Do(func() { - err = x.respStream.WriteHeader(obj) - }) - if err != nil { - return fmt.Errorf("could not write object header in Get forwarder: %w", err) - } - - x.payloadLenCheck = v.Init.Header.PayloadLength - x.payloadHashCheck = v.Init.Header.PayloadHash.Sum - x.payloadHashGot = sha256.New() - case *protoobject.GetResponse_Body_Chunk: - if !headWas { - return errors.New("incorrect message sequence") - } - fullChunk := v.Chunk - respChunk := chunkToSend(x.respondedPayload, readPayload, fullChunk) - if len(respChunk) == 0 { - readPayload += len(fullChunk) - continue - } - - x.payloadHashGot.Write(respChunk) // never returns an error according to docs - - if uint64(x.respondedPayload+len(respChunk)) == x.payloadLenCheck { - if !bytes.Equal(x.payloadHashGot.Sum(nil), x.payloadHashCheck) { // not merged via && for readability - return errors.New("received payload mismatches checksum from header") - } - } - - if err := x.respStream.WriteChunk(respChunk); err != nil { - return fmt.Errorf("could not write object chunk in Get forwarder: %w", err) - } - readPayload += len(fullChunk) - x.respondedPayload += len(respChunk) - case *protoobject.GetResponse_Body_SplitInfo: - if v == nil || v.SplitInfo == nil { - return errors.New("nil split info oneof field") - } - si := object.NewSplitInfo() - err := si.FromProtoMessage(v.SplitInfo) - if err != nil { - return err - } - return object.NewSplitInfoError(si) - } - } -} - func (s *Server) sendRangeResponse(stream protoobject.ObjectService_GetRangeServer, resp *protoobject.GetRangeResponse, req *protoobject.GetRangeRequest) error { resp.VerifyHeader = util.SignResponseIfNeeded(&s.signer, resp, req) return stream.Send(resp) @@ -2523,14 +2379,19 @@ func checkStatus(st *protostatus.Status) error { } func chunkToSend(global, local int, chunk []byte) []byte { + from, to := chunkBoundsToSend(global, local, len(chunk)) + return chunk[from:to] +} + +func chunkBoundsToSend(global, local, chunkLen int) (int, int) { if global == local { - return chunk + return 0, chunkLen } - if local+len(chunk) <= global { + if local+chunkLen <= global { // chunk has already been sent - return nil + return 0, 0 } - return chunk[global-local:] + return global - local, chunkLen } func needSignGetResponse(req util.Request) bool { @@ -2540,10 +2401,74 @@ func needSignGetResponse(req util.Request) bool { func checkHeaderAgainstID(hdr *protoobject.Header, id oid.ID) error { b := make([]byte, hdr.MarshaledSize()) hdr.MarshalStable(b) + return checkOrderedHeaderProtobufAgainstID(b, id) +} + +func checkHeaderProtobufAgainstID(buffers iprotobuf.BuffersSlice, id oid.ID, ordered bool) error { + b := buffers.ReadOnlyData() + if !ordered { + // TODO: consider optimization + // Either require direct order in protocol (for example, current node does this) or use buffer from pool. + var hdr protoobject.Header + if err := proto.Unmarshal(buffers.ReadOnlyData(), &hdr); err != nil { + return fmt.Errorf("unmarshal: %w", err) + } + b = make([]byte, hdr.MarshaledSize()) + hdr.MarshalStable(b) + } + + return checkOrderedHeaderProtobufAgainstID(b, id) +} +func checkOrderedHeaderProtobufAgainstID(b []byte, id oid.ID) error { if oid.NewFromObjectHeaderBinary(b) != id { return errors.New("received header mismatches ID") } return nil } + +// getStatusCodeFromResponseMetaHeader checks whether buf is a valid response +// meta header. If so, status code field is returned. In case of nesting +// headers, code from the root is returned. +// +// Absense of any fields is ignored. Unknown fields are allowed and checked. +// Repeating fields is allowed: if status field is repeated (including nested), +// code from the last one is returned. +func getStatusCodeFromResponseMetaHeader(buffers iprotobuf.BuffersSlice) (uint32, error) { + var code uint32 + var gotOrigin bool + var opts protoscan.ScanMessageOptions + + opts.InterceptNested = func(num protowire.Number, buffers iprotobuf.BuffersSlice) error { + switch num { + case protosession.FieldResponseMetaHeaderOrigin: + var err error + code, err = getStatusCodeFromResponseMetaHeader(buffers) + if err != nil { + return fmt.Errorf("handle origin field: %w", err) + } + gotOrigin = true + return nil + case protosession.FieldResponseMetaHeaderStatus: + if gotOrigin { + break + } + var opts protoscan.ScanMessageOptions + opts.InterceptUint32 = func(num protowire.Number, u uint32) error { + if num == protostatus.FieldStatusCode { + code = u + } + return nil + } + err := protoscan.ScanMessage(buffers, protoscan.ResponseStatusScheme, opts) + if err != nil { + return fmt.Errorf("handle status field: %w", err) + } + return nil + } + return protoscan.ErrContinue + } + + return code, protoscan.ScanMessage(buffers, protoscan.ResponseMetaHeaderScheme, opts) +} diff --git a/pkg/services/object/server_test.go b/pkg/services/object/server_test.go index 13c1b5b8d0..99bd54e7e2 100644 --- a/pkg/services/object/server_test.go +++ b/pkg/services/object/server_test.go @@ -834,3 +834,16 @@ type emptyRemoteNode struct { func (emptyRemoteNode) ObjectHead(context.Context, cid.ID, oid.ID, user.Signer, client.PrmObjectHead) (*object.Object, error) { return nil, apistatus.ErrObjectNotFound } + +func (emptyRemoteNode) ForEachGRPCConn(context.Context, func(context.Context, *grpc.ClientConn) error) error { + return errors.New("any transport error") +} + +type mockGRPCConn struct { + unimplementedConn + conn *grpc.ClientConn +} + +func (x *mockGRPCConn) ForEachGRPCConn(ctx context.Context, f func(context.Context, *grpc.ClientConn) error) error { + return f(ctx, x.conn) +} diff --git a/pkg/services/object/util_internal_test.go b/pkg/services/object/util_internal_test.go new file mode 100644 index 0000000000..ed6fcfdf47 --- /dev/null +++ b/pkg/services/object/util_internal_test.go @@ -0,0 +1,185 @@ +package object + +import ( + "testing" + "time" + + "github.com/google/uuid" + "github.com/nspcc-dev/neofs-node/internal/testutil" + "github.com/nspcc-dev/neofs-sdk-go/checksum" + cidtest "github.com/nspcc-dev/neofs-sdk-go/container/id/test" + neofscryptotest "github.com/nspcc-dev/neofs-sdk-go/crypto/test" + "github.com/nspcc-dev/neofs-sdk-go/object" + oidtest "github.com/nspcc-dev/neofs-sdk-go/object/id/test" + protoobject "github.com/nspcc-dev/neofs-sdk-go/proto/object" + protorefs "github.com/nspcc-dev/neofs-sdk-go/proto/refs" + protosession "github.com/nspcc-dev/neofs-sdk-go/proto/session" + protostatus "github.com/nspcc-dev/neofs-sdk-go/proto/status" + "github.com/nspcc-dev/neofs-sdk-go/session" + sessionv2 "github.com/nspcc-dev/neofs-sdk-go/session/v2" + usertest "github.com/nspcc-dev/neofs-sdk-go/user/test" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/mem" + "google.golang.org/protobuf/proto" +) + +func messageToSingleMemBuffer(t testing.TB, m proto.Message) mem.BufferSlice { + buf, err := proto.Marshal(m) + require.NoError(t, err) + return mem.BufferSlice{mem.SliceBuffer(buf)} +} + +func nestMetaHeader(m *protosession.ResponseMetaHeader, n int) *protosession.ResponseMetaHeader { + for range n { + st := m.GetStatus() + m = &protosession.ResponseMetaHeader{ + Version: m.Version, + Epoch: m.Epoch, + Ttl: m.Ttl, + XHeaders: m.XHeaders, + Origin: m, + Status: &protostatus.Status{ + Code: st.GetCode() + 1, + Message: st.GetMessage(), + Details: st.GetDetails(), + }, + } + } + return m +} + +func newBlankMetaHeader() *protosession.ResponseMetaHeader { + return &protosession.ResponseMetaHeader{ + Version: &protorefs.Version{ + Major: 3651676384, + Minor: 2829345803, + }, + Epoch: 10699904184716558895, + Ttl: 1657590594, + XHeaders: []*protosession.XHeader{ + {Key: "key1", Value: ""}, + {Key: "", Value: "value2"}, + {Key: "key3", Value: "value3"}, + {Key: "", Value: ""}, + }, + } +} + +func newTestObject(t testing.TB) object.Object { + var tokV1 session.Object + tokV1.SetID(uuid.New()) + tokV1.SetIssuer(usertest.ID()) + tokV1.SetAuthKey(neofscryptotest.Signer().Public()) + tokV1.SetExp(17170611258521075862) + tokV1.SetNbf(738661496128559041) + tokV1.SetIat(4380256499670963239) + tokV1.ForVerb(session.VerbObjectHead) + tokV1.BindContainer(cidtest.ID()) + tokV1.LimitByObjects(oidtest.IDs(10)...) + require.NoError(t, tokV1.Sign(usertest.User())) + + originSessionToken := newTestUnsignedSessionToken(t) + require.NoError(t, originSessionToken.Sign(usertest.User())) + + sessionToken := newTestUnsignedSessionToken(t) + sessionToken.SetOrigin(&originSessionToken) + require.NoError(t, sessionToken.Sign(usertest.User())) + + par := *object.New(cidtest.ID(), usertest.ID()) + par.SetCreationEpoch(10738628592919807909) + par.SetPayloadSize(12653732379852397698) + par.SetPayloadHomomorphicHash(checksum.NewTillichZemor([64]byte(testutil.RandByteSlice(64)))) + par.SetType(object.TypeLink) + par.SetPayload(testutil.RandByteSlice(1 << 10)) + par.SetAttributes( + object.NewAttribute("pk1", "pv1"), + object.NewAttribute("pk2", "pv2"), + ) + par.SetSessionToken(&tokV1) + par.SetSessionTokenV2(&sessionToken) + require.NoError(t, par.SetVerificationFields(usertest.User())) + + obj := *object.New(cidtest.ID(), usertest.ID()) + obj.SetCreationEpoch(18093843418564698707) + obj.SetPayloadSize(895946572232418931) + obj.SetPayloadHomomorphicHash(checksum.NewTillichZemor([64]byte(testutil.RandByteSlice(64)))) + obj.SetType(object.TypeLink) + obj.SetPayload(testutil.RandByteSlice(1 << 10)) + obj.SetAttributes( + object.NewAttribute("k1", "v1"), + object.NewAttribute("k2", "v2"), + ) + obj.SetSessionToken(&tokV1) + obj.SetSessionTokenV2(&sessionToken) + obj.SetParent(&par) + require.NoError(t, obj.SetVerificationFields(usertest.User())) + + return obj +} + +func newTestUnsignedSessionToken(t testing.TB) sessionv2.Token { + var tok sessionv2.Token + tok.SetVersion(2357623054) + tok.SetAppData(testutil.RandByteSlice(1024)) + tok.SetIssuer(usertest.ID()) + tm := time.Unix(1774601168, 0) + tok.Lifetime = sessionv2.NewLifetime(tm, tm.Add(time.Minute), tm.Add(time.Hour)) + + ctx1, err := sessionv2.NewContext(cidtest.ID(), []sessionv2.Verb{sessionv2.VerbObjectHead, sessionv2.VerbObjectGet}) + require.NoError(t, err) + require.NoError(t, tok.AddContext(ctx1)) + ctx2, err := sessionv2.NewContext(cidtest.ID(), []sessionv2.Verb{sessionv2.VerbObjectPut, sessionv2.VerbObjectDelete}) + require.NoError(t, err) + require.NoError(t, tok.AddContext(ctx2)) + + require.NoError(t, tok.AddSubject(sessionv2.NewTargetUser(usertest.ID()))) + require.NoError(t, tok.AddSubject(sessionv2.NewTargetNamed("Bob"))) + + return tok +} + +func newTestSplitInfo() *protoobject.SplitInfo { + var res object.SplitInfo + res.SetSplitID(object.NewSplitID()) + res.SetLastPart(oidtest.ID()) + res.SetLink(oidtest.ID()) + res.SetFirstPart(oidtest.ID()) + return res.ProtoMessage() +} + +func newAnyVerificationHeader() *protosession.ResponseVerificationHeader { + return &protosession.ResponseVerificationHeader{ + BodySignature: &protorefs.Signature{ + Key: []byte("any_body_key"), + Sign: []byte("any_body_signature"), + Scheme: 123, + }, + MetaSignature: &protorefs.Signature{ + Key: []byte("any_meta_key"), + Sign: []byte("any_meta_signature"), + Scheme: 456, + }, + OriginSignature: &protorefs.Signature{ + Key: []byte("any_origin_key"), + Sign: []byte("any_origin_signature"), + Scheme: 789, + }, + Origin: &protosession.ResponseVerificationHeader{ + BodySignature: &protorefs.Signature{ + Key: []byte("any_origin_body_key"), + Sign: []byte("any_origin_body_signature"), + Scheme: 321, + }, + MetaSignature: &protorefs.Signature{ + Key: []byte("any_origin_meta_key"), + Sign: []byte("any_origin_meta_signature"), + Scheme: 654, + }, + OriginSignature: &protorefs.Signature{ + Key: []byte("any_origin_origin_key"), + Sign: []byte("any_origin_origin_signature"), + Scheme: 987, + }, + }, + } +} diff --git a/pkg/services/object/util_test.go b/pkg/services/object/util_test.go new file mode 100644 index 0000000000..19395aace2 --- /dev/null +++ b/pkg/services/object/util_test.go @@ -0,0 +1,92 @@ +package object_test + +import ( + "github.com/nspcc-dev/neofs-sdk-go/object" + oidtest "github.com/nspcc-dev/neofs-sdk-go/object/id/test" + protoobject "github.com/nspcc-dev/neofs-sdk-go/proto/object" + protorefs "github.com/nspcc-dev/neofs-sdk-go/proto/refs" + protosession "github.com/nspcc-dev/neofs-sdk-go/proto/session" + protostatus "github.com/nspcc-dev/neofs-sdk-go/proto/status" +) + +func newBlankMetaHeader() *protosession.ResponseMetaHeader { + return &protosession.ResponseMetaHeader{ + Version: &protorefs.Version{ + Major: 3651676384, + Minor: 2829345803, + }, + Epoch: 10699904184716558895, + Ttl: 1657590594, + XHeaders: []*protosession.XHeader{ + {Key: "key1", Value: ""}, + {Key: "", Value: "value2"}, + {Key: "key3", Value: "value3"}, + {Key: "", Value: ""}, + }, + } +} + +func nestMetaHeader(m *protosession.ResponseMetaHeader, n int) *protosession.ResponseMetaHeader { + for range n { + st := m.GetStatus() + m = &protosession.ResponseMetaHeader{ + Version: m.Version, + Epoch: m.Epoch, + Ttl: m.Ttl, + XHeaders: m.XHeaders, + Origin: m, + Status: &protostatus.Status{ + Code: st.GetCode() + 1, + Message: st.GetMessage(), + Details: st.GetDetails(), + }, + } + } + return m +} + +func newTestSplitInfo() *protoobject.SplitInfo { + var res object.SplitInfo + res.SetSplitID(object.NewSplitID()) + res.SetLastPart(oidtest.ID()) + res.SetLink(oidtest.ID()) + res.SetFirstPart(oidtest.ID()) + return res.ProtoMessage() +} + +func newAnyVerificationHeader() *protosession.ResponseVerificationHeader { + return &protosession.ResponseVerificationHeader{ + BodySignature: &protorefs.Signature{ + Key: []byte("any_body_key"), + Sign: []byte("any_body_signature"), + Scheme: 123, + }, + MetaSignature: &protorefs.Signature{ + Key: []byte("any_meta_key"), + Sign: []byte("any_meta_signature"), + Scheme: 456, + }, + OriginSignature: &protorefs.Signature{ + Key: []byte("any_origin_key"), + Sign: []byte("any_origin_signature"), + Scheme: 789, + }, + Origin: &protosession.ResponseVerificationHeader{ + BodySignature: &protorefs.Signature{ + Key: []byte("any_origin_body_key"), + Sign: []byte("any_origin_body_signature"), + Scheme: 321, + }, + MetaSignature: &protorefs.Signature{ + Key: []byte("any_origin_meta_key"), + Sign: []byte("any_origin_meta_signature"), + Scheme: 654, + }, + OriginSignature: &protorefs.Signature{ + Key: []byte("any_origin_origin_key"), + Sign: []byte("any_origin_origin_signature"), + Scheme: 987, + }, + }, + } +}