diff options
Diffstat (limited to 'vendor/google.golang.org/grpc/stream.go')
| -rw-r--r-- | vendor/google.golang.org/grpc/stream.go | 101 | 
1 files changed, 73 insertions, 28 deletions
diff --git a/vendor/google.golang.org/grpc/stream.go b/vendor/google.golang.org/grpc/stream.go index 0f8e6c0..d1226a4 100644 --- a/vendor/google.golang.org/grpc/stream.go +++ b/vendor/google.golang.org/grpc/stream.go @@ -168,10 +168,19 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth  }  func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { -	if md, _, ok := metadata.FromOutgoingContextRaw(ctx); ok { +	if md, added, ok := metadata.FromOutgoingContextRaw(ctx); ok { +		// validate md  		if err := imetadata.Validate(md); err != nil {  			return nil, status.Error(codes.Internal, err.Error())  		} +		// validate added +		for _, kvs := range added { +			for i := 0; i < len(kvs); i += 2 { +				if err := imetadata.ValidatePair(kvs[i], kvs[i+1]); err != nil { +					return nil, status.Error(codes.Internal, err.Error()) +				} +			} +		}  	}  	if channelz.IsOn() {  		cc.incrCallsStarted() @@ -352,7 +361,7 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client  			}  		}  		for _, binlog := range cs.binlogs { -			binlog.Log(logEntry) +			binlog.Log(cs.ctx, logEntry)  		}  	} @@ -438,7 +447,7 @@ func (a *csAttempt) getTransport() error {  	cs := a.cs  	var err error -	a.t, a.done, err = cs.cc.getTransport(a.ctx, cs.callInfo.failFast, cs.callHdr.Method) +	a.t, a.pickResult, err = cs.cc.getTransport(a.ctx, cs.callInfo.failFast, cs.callHdr.Method)  	if err != nil {  		if de, ok := err.(dropError); ok {  			err = de.error @@ -455,6 +464,25 @@ func (a *csAttempt) getTransport() error {  func (a *csAttempt) newStream() error {  	cs := a.cs  	cs.callHdr.PreviousAttempts = cs.numRetries + +	// Merge metadata stored in PickResult, if any, with existing call metadata. +	// It is safe to overwrite the csAttempt's context here, since all state +	// maintained in it are local to the attempt. When the attempt has to be +	// retried, a new instance of csAttempt will be created. +	if a.pickResult.Metatada != nil { +		// We currently do not have a function it the metadata package which +		// merges given metadata with existing metadata in a context. Existing +		// function `AppendToOutgoingContext()` takes a variadic argument of key +		// value pairs. +		// +		// TODO: Make it possible to retrieve key value pairs from metadata.MD +		// in a form passable to AppendToOutgoingContext(), or create a version +		// of AppendToOutgoingContext() that accepts a metadata.MD. +		md, _ := metadata.FromOutgoingContext(a.ctx) +		md = metadata.Join(md, a.pickResult.Metatada) +		a.ctx = metadata.NewOutgoingContext(a.ctx, md) +	} +  	s, err := a.t.NewStream(a.ctx, cs.callHdr)  	if err != nil {  		nse, ok := err.(*transport.NewStreamError) @@ -529,12 +557,12 @@ type clientStream struct {  // csAttempt implements a single transport stream attempt within a  // clientStream.  type csAttempt struct { -	ctx  context.Context -	cs   *clientStream -	t    transport.ClientTransport -	s    *transport.Stream -	p    *parser -	done func(balancer.DoneInfo) +	ctx        context.Context +	cs         *clientStream +	t          transport.ClientTransport +	s          *transport.Stream +	p          *parser +	pickResult balancer.PickResult  	finished  bool  	dc        Decompressor @@ -781,7 +809,7 @@ func (cs *clientStream) Header() (metadata.MD, error) {  		}  		cs.serverHeaderBinlogged = true  		for _, binlog := range cs.binlogs { -			binlog.Log(logEntry) +			binlog.Log(cs.ctx, logEntry)  		}  	}  	return m, nil @@ -862,7 +890,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {  			Message:      data,  		}  		for _, binlog := range cs.binlogs { -			binlog.Log(cm) +			binlog.Log(cs.ctx, cm)  		}  	}  	return err @@ -886,7 +914,7 @@ func (cs *clientStream) RecvMsg(m interface{}) error {  			Message:      recvInfo.uncompressedBytes,  		}  		for _, binlog := range cs.binlogs { -			binlog.Log(sm) +			binlog.Log(cs.ctx, sm)  		}  	}  	if err != nil || !cs.desc.ServerStreams { @@ -907,7 +935,7 @@ func (cs *clientStream) RecvMsg(m interface{}) error {  				logEntry.PeerAddr = peer.Addr  			}  			for _, binlog := range cs.binlogs { -				binlog.Log(logEntry) +				binlog.Log(cs.ctx, logEntry)  			}  		}  	} @@ -934,7 +962,7 @@ func (cs *clientStream) CloseSend() error {  			OnClientSide: true,  		}  		for _, binlog := range cs.binlogs { -			binlog.Log(chc) +			binlog.Log(cs.ctx, chc)  		}  	}  	// We never returned an error here for reasons. @@ -952,6 +980,9 @@ func (cs *clientStream) finish(err error) {  		return  	}  	cs.finished = true +	for _, onFinish := range cs.callInfo.onFinish { +		onFinish(err) +	}  	cs.commitAttemptLocked()  	if cs.attempt != nil {  		cs.attempt.finish(err) @@ -973,7 +1004,7 @@ func (cs *clientStream) finish(err error) {  			OnClientSide: true,  		}  		for _, binlog := range cs.binlogs { -			binlog.Log(c) +			binlog.Log(cs.ctx, c)  		}  	}  	if err == nil { @@ -1062,9 +1093,10 @@ func (a *csAttempt) recvMsg(m interface{}, payInfo *payloadInfo) (err error) {  			RecvTime: time.Now(),  			Payload:  m,  			// TODO truncate large payload. -			Data:       payInfo.uncompressedBytes, -			WireLength: payInfo.wireLength + headerLen, -			Length:     len(payInfo.uncompressedBytes), +			Data:             payInfo.uncompressedBytes, +			WireLength:       payInfo.compressedLength + headerLen, +			CompressedLength: payInfo.compressedLength, +			Length:           len(payInfo.uncompressedBytes),  		})  	}  	if channelz.IsOn() { @@ -1103,12 +1135,12 @@ func (a *csAttempt) finish(err error) {  		tr = a.s.Trailer()  	} -	if a.done != nil { +	if a.pickResult.Done != nil {  		br := false  		if a.s != nil {  			br = a.s.BytesReceived()  		} -		a.done(balancer.DoneInfo{ +		a.pickResult.Done(balancer.DoneInfo{  			Err:           err,  			Trailer:       tr,  			BytesSent:     a.s != nil, @@ -1464,6 +1496,9 @@ type ServerStream interface {  	// It is safe to have a goroutine calling SendMsg and another goroutine  	// calling RecvMsg on the same stream at the same time, but it is not safe  	// to call SendMsg on the same stream in different goroutines. +	// +	// It is not safe to modify the message after calling SendMsg. Tracing +	// libraries and stats handlers may use the message lazily.  	SendMsg(m interface{}) error  	// RecvMsg blocks until it receives a message into m or the stream is  	// done. It returns io.EOF when the client has performed a CloseSend. On @@ -1489,6 +1524,8 @@ type serverStream struct {  	comp   encoding.Compressor  	decomp encoding.Compressor +	sendCompressorName string +  	maxReceiveMessageSize int  	maxSendMessageSize    int  	trInfo                *traceInfo @@ -1536,7 +1573,7 @@ func (ss *serverStream) SendHeader(md metadata.MD) error {  		}  		ss.serverHeaderBinlogged = true  		for _, binlog := range ss.binlogs { -			binlog.Log(sh) +			binlog.Log(ss.ctx, sh)  		}  	}  	return err @@ -1581,6 +1618,13 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {  		}  	}() +	// Server handler could have set new compressor by calling SetSendCompressor. +	// In case it is set, we need to use it for compressing outbound message. +	if sendCompressorsName := ss.s.SendCompress(); sendCompressorsName != ss.sendCompressorName { +		ss.comp = encoding.GetCompressor(sendCompressorsName) +		ss.sendCompressorName = sendCompressorsName +	} +  	// load hdr, payload, data  	hdr, payload, data, err := prepareMsg(m, ss.codec, ss.cp, ss.comp)  	if err != nil { @@ -1602,14 +1646,14 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {  			}  			ss.serverHeaderBinlogged = true  			for _, binlog := range ss.binlogs { -				binlog.Log(sh) +				binlog.Log(ss.ctx, sh)  			}  		}  		sm := &binarylog.ServerMessage{  			Message: data,  		}  		for _, binlog := range ss.binlogs { -			binlog.Log(sm) +			binlog.Log(ss.ctx, sm)  		}  	}  	if len(ss.statsHandler) != 0 { @@ -1657,7 +1701,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {  			if len(ss.binlogs) != 0 {  				chc := &binarylog.ClientHalfClose{}  				for _, binlog := range ss.binlogs { -					binlog.Log(chc) +					binlog.Log(ss.ctx, chc)  				}  			}  			return err @@ -1673,9 +1717,10 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {  				RecvTime: time.Now(),  				Payload:  m,  				// TODO truncate large payload. -				Data:       payInfo.uncompressedBytes, -				WireLength: payInfo.wireLength + headerLen, -				Length:     len(payInfo.uncompressedBytes), +				Data:             payInfo.uncompressedBytes, +				Length:           len(payInfo.uncompressedBytes), +				WireLength:       payInfo.compressedLength + headerLen, +				CompressedLength: payInfo.compressedLength,  			})  		}  	} @@ -1684,7 +1729,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {  			Message: payInfo.uncompressedBytes,  		}  		for _, binlog := range ss.binlogs { -			binlog.Log(cm) +			binlog.Log(ss.ctx, cm)  		}  	}  	return nil  | 
