summaryrefslogtreecommitdiff
path: root/vendor/google.golang.org/grpc/stream.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/google.golang.org/grpc/stream.go')
-rw-r--r--vendor/google.golang.org/grpc/stream.go101
1 files changed, 73 insertions, 28 deletions
diff --git a/vendor/google.golang.org/grpc/stream.go b/vendor/google.golang.org/grpc/stream.go
index 0f8e6c0..d1226a4 100644
--- a/vendor/google.golang.org/grpc/stream.go
+++ b/vendor/google.golang.org/grpc/stream.go
@@ -168,10 +168,19 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}
func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
- if md, _, ok := metadata.FromOutgoingContextRaw(ctx); ok {
+ if md, added, ok := metadata.FromOutgoingContextRaw(ctx); ok {
+ // validate md
if err := imetadata.Validate(md); err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
+ // validate added
+ for _, kvs := range added {
+ for i := 0; i < len(kvs); i += 2 {
+ if err := imetadata.ValidatePair(kvs[i], kvs[i+1]); err != nil {
+ return nil, status.Error(codes.Internal, err.Error())
+ }
+ }
+ }
}
if channelz.IsOn() {
cc.incrCallsStarted()
@@ -352,7 +361,7 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client
}
}
for _, binlog := range cs.binlogs {
- binlog.Log(logEntry)
+ binlog.Log(cs.ctx, logEntry)
}
}
@@ -438,7 +447,7 @@ func (a *csAttempt) getTransport() error {
cs := a.cs
var err error
- a.t, a.done, err = cs.cc.getTransport(a.ctx, cs.callInfo.failFast, cs.callHdr.Method)
+ a.t, a.pickResult, err = cs.cc.getTransport(a.ctx, cs.callInfo.failFast, cs.callHdr.Method)
if err != nil {
if de, ok := err.(dropError); ok {
err = de.error
@@ -455,6 +464,25 @@ func (a *csAttempt) getTransport() error {
func (a *csAttempt) newStream() error {
cs := a.cs
cs.callHdr.PreviousAttempts = cs.numRetries
+
+ // Merge metadata stored in PickResult, if any, with existing call metadata.
+ // It is safe to overwrite the csAttempt's context here, since all state
+ // maintained in it are local to the attempt. When the attempt has to be
+ // retried, a new instance of csAttempt will be created.
+ if a.pickResult.Metatada != nil {
+ // We currently do not have a function it the metadata package which
+ // merges given metadata with existing metadata in a context. Existing
+ // function `AppendToOutgoingContext()` takes a variadic argument of key
+ // value pairs.
+ //
+ // TODO: Make it possible to retrieve key value pairs from metadata.MD
+ // in a form passable to AppendToOutgoingContext(), or create a version
+ // of AppendToOutgoingContext() that accepts a metadata.MD.
+ md, _ := metadata.FromOutgoingContext(a.ctx)
+ md = metadata.Join(md, a.pickResult.Metatada)
+ a.ctx = metadata.NewOutgoingContext(a.ctx, md)
+ }
+
s, err := a.t.NewStream(a.ctx, cs.callHdr)
if err != nil {
nse, ok := err.(*transport.NewStreamError)
@@ -529,12 +557,12 @@ type clientStream struct {
// csAttempt implements a single transport stream attempt within a
// clientStream.
type csAttempt struct {
- ctx context.Context
- cs *clientStream
- t transport.ClientTransport
- s *transport.Stream
- p *parser
- done func(balancer.DoneInfo)
+ ctx context.Context
+ cs *clientStream
+ t transport.ClientTransport
+ s *transport.Stream
+ p *parser
+ pickResult balancer.PickResult
finished bool
dc Decompressor
@@ -781,7 +809,7 @@ func (cs *clientStream) Header() (metadata.MD, error) {
}
cs.serverHeaderBinlogged = true
for _, binlog := range cs.binlogs {
- binlog.Log(logEntry)
+ binlog.Log(cs.ctx, logEntry)
}
}
return m, nil
@@ -862,7 +890,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
Message: data,
}
for _, binlog := range cs.binlogs {
- binlog.Log(cm)
+ binlog.Log(cs.ctx, cm)
}
}
return err
@@ -886,7 +914,7 @@ func (cs *clientStream) RecvMsg(m interface{}) error {
Message: recvInfo.uncompressedBytes,
}
for _, binlog := range cs.binlogs {
- binlog.Log(sm)
+ binlog.Log(cs.ctx, sm)
}
}
if err != nil || !cs.desc.ServerStreams {
@@ -907,7 +935,7 @@ func (cs *clientStream) RecvMsg(m interface{}) error {
logEntry.PeerAddr = peer.Addr
}
for _, binlog := range cs.binlogs {
- binlog.Log(logEntry)
+ binlog.Log(cs.ctx, logEntry)
}
}
}
@@ -934,7 +962,7 @@ func (cs *clientStream) CloseSend() error {
OnClientSide: true,
}
for _, binlog := range cs.binlogs {
- binlog.Log(chc)
+ binlog.Log(cs.ctx, chc)
}
}
// We never returned an error here for reasons.
@@ -952,6 +980,9 @@ func (cs *clientStream) finish(err error) {
return
}
cs.finished = true
+ for _, onFinish := range cs.callInfo.onFinish {
+ onFinish(err)
+ }
cs.commitAttemptLocked()
if cs.attempt != nil {
cs.attempt.finish(err)
@@ -973,7 +1004,7 @@ func (cs *clientStream) finish(err error) {
OnClientSide: true,
}
for _, binlog := range cs.binlogs {
- binlog.Log(c)
+ binlog.Log(cs.ctx, c)
}
}
if err == nil {
@@ -1062,9 +1093,10 @@ func (a *csAttempt) recvMsg(m interface{}, payInfo *payloadInfo) (err error) {
RecvTime: time.Now(),
Payload: m,
// TODO truncate large payload.
- Data: payInfo.uncompressedBytes,
- WireLength: payInfo.wireLength + headerLen,
- Length: len(payInfo.uncompressedBytes),
+ Data: payInfo.uncompressedBytes,
+ WireLength: payInfo.compressedLength + headerLen,
+ CompressedLength: payInfo.compressedLength,
+ Length: len(payInfo.uncompressedBytes),
})
}
if channelz.IsOn() {
@@ -1103,12 +1135,12 @@ func (a *csAttempt) finish(err error) {
tr = a.s.Trailer()
}
- if a.done != nil {
+ if a.pickResult.Done != nil {
br := false
if a.s != nil {
br = a.s.BytesReceived()
}
- a.done(balancer.DoneInfo{
+ a.pickResult.Done(balancer.DoneInfo{
Err: err,
Trailer: tr,
BytesSent: a.s != nil,
@@ -1464,6 +1496,9 @@ type ServerStream interface {
// It is safe to have a goroutine calling SendMsg and another goroutine
// calling RecvMsg on the same stream at the same time, but it is not safe
// to call SendMsg on the same stream in different goroutines.
+ //
+ // It is not safe to modify the message after calling SendMsg. Tracing
+ // libraries and stats handlers may use the message lazily.
SendMsg(m interface{}) error
// RecvMsg blocks until it receives a message into m or the stream is
// done. It returns io.EOF when the client has performed a CloseSend. On
@@ -1489,6 +1524,8 @@ type serverStream struct {
comp encoding.Compressor
decomp encoding.Compressor
+ sendCompressorName string
+
maxReceiveMessageSize int
maxSendMessageSize int
trInfo *traceInfo
@@ -1536,7 +1573,7 @@ func (ss *serverStream) SendHeader(md metadata.MD) error {
}
ss.serverHeaderBinlogged = true
for _, binlog := range ss.binlogs {
- binlog.Log(sh)
+ binlog.Log(ss.ctx, sh)
}
}
return err
@@ -1581,6 +1618,13 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
}
}()
+ // Server handler could have set new compressor by calling SetSendCompressor.
+ // In case it is set, we need to use it for compressing outbound message.
+ if sendCompressorsName := ss.s.SendCompress(); sendCompressorsName != ss.sendCompressorName {
+ ss.comp = encoding.GetCompressor(sendCompressorsName)
+ ss.sendCompressorName = sendCompressorsName
+ }
+
// load hdr, payload, data
hdr, payload, data, err := prepareMsg(m, ss.codec, ss.cp, ss.comp)
if err != nil {
@@ -1602,14 +1646,14 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
}
ss.serverHeaderBinlogged = true
for _, binlog := range ss.binlogs {
- binlog.Log(sh)
+ binlog.Log(ss.ctx, sh)
}
}
sm := &binarylog.ServerMessage{
Message: data,
}
for _, binlog := range ss.binlogs {
- binlog.Log(sm)
+ binlog.Log(ss.ctx, sm)
}
}
if len(ss.statsHandler) != 0 {
@@ -1657,7 +1701,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
if len(ss.binlogs) != 0 {
chc := &binarylog.ClientHalfClose{}
for _, binlog := range ss.binlogs {
- binlog.Log(chc)
+ binlog.Log(ss.ctx, chc)
}
}
return err
@@ -1673,9 +1717,10 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
RecvTime: time.Now(),
Payload: m,
// TODO truncate large payload.
- Data: payInfo.uncompressedBytes,
- WireLength: payInfo.wireLength + headerLen,
- Length: len(payInfo.uncompressedBytes),
+ Data: payInfo.uncompressedBytes,
+ Length: len(payInfo.uncompressedBytes),
+ WireLength: payInfo.compressedLength + headerLen,
+ CompressedLength: payInfo.compressedLength,
})
}
}
@@ -1684,7 +1729,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
Message: payInfo.uncompressedBytes,
}
for _, binlog := range ss.binlogs {
- binlog.Log(cm)
+ binlog.Log(ss.ctx, cm)
}
}
return nil