LogCabin
RPC/MessageSocket.cc
Go to the documentation of this file.
00001 /* Copyright (c) 2010-2014 Stanford University
00002  * Copyright (c) 2015 Diego Ongaro
00003  *
00004  * Permission to use, copy, modify, and distribute this software for any
00005  * purpose with or without fee is hereby granted, provided that the above
00006  * copyright notice and this permission notice appear in all copies.
00007  *
00008  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR(S) DISCLAIM ALL WARRANTIES
00009  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
00010  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL AUTHORS BE LIABLE FOR
00011  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
00012  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
00013  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
00014  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
00015  */
00016 
00017 #include <cassert>
00018 #include <errno.h>
00019 #include <netinet/in.h>
00020 #include <netinet/tcp.h>
00021 #include <string.h>
00022 #include <sys/epoll.h>
00023 #include <sys/socket.h>
00024 #include <sys/types.h>
00025 #include <unistd.h>
00026 
00027 #include "Core/Debug.h"
00028 #include "Core/Endian.h"
00029 #include "Event/Loop.h"
00030 #include "RPC/MessageSocket.h"
00031 
00032 namespace LogCabin {
00033 namespace RPC {
00034 
00035 namespace {
00036 
00037 /// Wrapper for dup().
00038 int
00039 dupOrPanic(int oldfd)
00040 {
00041     int newfd = dup(oldfd);
00042     if (newfd < 0)
00043         PANIC("Failed to dup(%d): %s", oldfd, strerror(errno));
00044     return newfd;
00045 }
00046 
00047 } // anonymous namespace
00048 
00049 ////////// MessageSocket::SendSocket //////////
00050 
00051 MessageSocket::SendSocket::SendSocket(int fd,
00052                                       MessageSocket& messageSocket)
00053     : Event::File(fd)
00054     , messageSocket(messageSocket)
00055 {
00056     int flag = 1;
00057     int r = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag));
00058     if (r < 0) {
00059         // This should be a warning, but some unit tests pass weird types of
00060         // file descriptors in here. It's not very important, anyhow.
00061         NOTICE("Could not set TCP_NODELAY flag on sending socket %d: %s",
00062                fd, strerror(errno));
00063     }
00064 }
00065 
00066 MessageSocket::SendSocket::~SendSocket()
00067 {
00068 }
00069 
00070 void
00071 MessageSocket::SendSocket::handleFileEvent(uint32_t events)
00072 {
00073     messageSocket.writable();
00074 }
00075 
00076 ////////// MessageSocket::ReceiveSocket //////////
00077 
00078 MessageSocket::ReceiveSocket::ReceiveSocket(int fd,
00079                                             MessageSocket& messageSocket)
00080     : Event::File(fd)
00081     , messageSocket(messageSocket)
00082 {
00083     // I don't know that TCP_NODELAY has any effect if we're only reading from
00084     // this file descriptor, but I guess it can't hurt.
00085     int flag = 1;
00086     int r = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag));
00087     if (r < 0) {
00088         // This should be a warning, but some unit tests pass weird types of
00089         // file descriptors in here. It's not very important, anyhow.
00090         NOTICE("Could not set TCP_NODELAY flag on receiving socket %d: %s",
00091                 fd, strerror(errno));
00092     }
00093 }
00094 
00095 MessageSocket::ReceiveSocket::~ReceiveSocket()
00096 {
00097 }
00098 
00099 void
00100 MessageSocket::ReceiveSocket::handleFileEvent(uint32_t events)
00101 {
00102     messageSocket.readable();
00103 }
00104 
00105 ////////// MessageSocket::Header //////////
00106 
00107 void
00108 MessageSocket::Header::fromBigEndian()
00109 {
00110     fixed = be16toh(fixed);
00111     version = be16toh(version);
00112     payloadLength = be32toh(payloadLength);
00113     messageId = be64toh(messageId);
00114 }
00115 
00116 void
00117 MessageSocket::Header::toBigEndian()
00118 {
00119     fixed = htobe16(fixed);
00120     version = htobe16(version);
00121     payloadLength = htobe32(payloadLength);
00122     messageId = htobe64(messageId);
00123 }
00124 
00125 ////////// MessageSocket::Inbound //////////
00126 
00127 MessageSocket::Inbound::Inbound()
00128     : bytesRead(0)
00129     , header()
00130     , message()
00131 {
00132 }
00133 
00134 ////////// MessageSocket::Outbound //////////
00135 
00136 MessageSocket::Outbound::Outbound()
00137     : bytesSent(0)
00138     , header()
00139     , message()
00140 {
00141 }
00142 
00143 MessageSocket::Outbound::Outbound(Outbound&& other)
00144     : bytesSent(other.bytesSent)
00145     , header(other.header)
00146     , message(std::move(other.message))
00147 {
00148 }
00149 
00150 MessageSocket::Outbound::Outbound(MessageId messageId,
00151                                   Core::Buffer message)
00152     : bytesSent(0)
00153     , header()
00154     , message(std::move(message))
00155 {
00156     header.fixed = 0xdaf4;
00157     header.version = 1;
00158     header.payloadLength = uint32_t(this->message.getLength());
00159     header.messageId = messageId;
00160     header.toBigEndian();
00161 }
00162 
00163 MessageSocket::Outbound&
00164 MessageSocket::Outbound::operator=(Outbound&& other)
00165 {
00166     bytesSent = other.bytesSent;
00167     header = other.header;
00168     message = std::move(other.message);
00169     return *this;
00170 }
00171 
00172 ////////// MessageSocket //////////
00173 
00174 MessageSocket::MessageSocket(Handler& handler,
00175                              Event::Loop& eventLoop, int fd,
00176                              uint32_t maxMessageLength)
00177     : maxMessageLength(maxMessageLength)
00178     , handler(handler)
00179     , eventLoop(eventLoop)
00180     , inbound()
00181     , outboundQueueMutex()
00182     , outboundQueue()
00183     , receiveSocket(dupOrPanic(fd), *this)
00184     , sendSocket(fd, *this)
00185     , receiveSocketMonitor(eventLoop, receiveSocket, EPOLLIN)
00186     , sendSocketMonitor(eventLoop, sendSocket, 0)
00187 {
00188 }
00189 
00190 MessageSocket::~MessageSocket()
00191 {
00192 }
00193 
00194 void
00195 MessageSocket::close()
00196 {
00197     receiveSocketMonitor.disableForever();
00198     sendSocketMonitor.disableForever();
00199 
00200     // Take an Event::Loop::Lock in case the handler assumes it's being
00201     // executed on the event loop thread.
00202     Event::Loop::Lock lock(eventLoop);
00203     handler.handleDisconnect();
00204 }
00205 
00206 void
00207 MessageSocket::sendMessage(MessageId messageId, Core::Buffer contents)
00208 {
00209     // Check the message length.
00210     if (contents.getLength() > maxMessageLength) {
00211         PANIC("Message of length %lu bytes is too long to send "
00212               "(limit is %u bytes)",
00213               contents.getLength(), maxMessageLength);
00214     }
00215 
00216     bool kick;
00217     { // Place the message on the outbound queue.
00218         std::lock_guard<Core::Mutex> lock(outboundQueueMutex);
00219         kick = outboundQueue.empty();
00220         outboundQueue.emplace_back(messageId, std::move(contents));
00221     }
00222     // Make sure the SendSocket is set up to call writable().
00223     if (kick)
00224         sendSocketMonitor.setEvents(EPOLLOUT|EPOLLONESHOT);
00225 }
00226 
00227 void
00228 MessageSocket::disconnect()
00229 {
00230     receiveSocketMonitor.disableForever();
00231     sendSocketMonitor.disableForever();
00232     // TODO(ongaro): to make it safe for epoll_wait to return multiple events,
00233     // need to somehow queue the handleDisconnect for later.
00234     handler.handleDisconnect();
00235 }
00236 
00237 void
00238 MessageSocket::readable()
00239 {
00240     // Try to read data from the kernel until there is no more left.
00241     while (true) {
00242         if (inbound.bytesRead < sizeof(Header)) {
00243             // Receiving header
00244             ssize_t bytesRead = read(
00245                 reinterpret_cast<char*>(&inbound.header) + inbound.bytesRead,
00246                 sizeof(Header) - inbound.bytesRead);
00247             if (bytesRead == -1) {
00248                 disconnect();
00249                 return;
00250             }
00251             inbound.bytesRead += size_t(bytesRead);
00252             if (inbound.bytesRead < sizeof(Header))
00253                 return;
00254             // Transition to receiving data
00255             inbound.header.fromBigEndian();
00256             if (inbound.header.fixed != 0xdaf4) {
00257                 WARNING("Disconnecting since message doesn't start with magic "
00258                         "0xdaf4 (first two bytes are 0x%02x)",
00259                         inbound.header.fixed);
00260                 disconnect();
00261                 return;
00262             }
00263             if (inbound.header.version != 1) {
00264                 WARNING("Disconnecting since message uses version %u, but "
00265                         "this code only understands version 1",
00266                         inbound.header.version);
00267                 disconnect();
00268                 return;
00269             }
00270             if (inbound.header.payloadLength > maxMessageLength) {
00271                 WARNING("Disconnecting since message is too long to receive "
00272                         "(message is %u bytes, limit is %u bytes)",
00273                         inbound.header.payloadLength, maxMessageLength);
00274                 disconnect();
00275                 return;
00276             }
00277             inbound.message.setData(new char[inbound.header.payloadLength],
00278                                     inbound.header.payloadLength,
00279                                     Core::Buffer::deleteArrayFn<char>);
00280         }
00281         // Don't use 'else' here; we want to check this branch for two reasons:
00282         // First, if there is a header with a length of 0, the socket won't be
00283         // readable, but we still need to process the message. Second, most of
00284         // the time the header will arrive with at least some data. It makes
00285         // sense to go ahead and try a non-blocking read, rather than going
00286         // back to the event loop.
00287         if (inbound.bytesRead >= sizeof(Header)) {
00288             // Receiving data
00289             size_t payloadBytesRead = inbound.bytesRead - sizeof(Header);
00290             ssize_t bytesRead = read(
00291                 (static_cast<char*>(inbound.message.getData()) +
00292                  payloadBytesRead),
00293                 inbound.header.payloadLength - payloadBytesRead);
00294             if (bytesRead == -1) {
00295                 disconnect();
00296                 return;
00297             }
00298             inbound.bytesRead += size_t(bytesRead);
00299             if (inbound.bytesRead < (sizeof(Header) +
00300                                      inbound.header.payloadLength)) {
00301                 return;
00302             }
00303             handler.handleReceivedMessage(inbound.header.messageId,
00304                                           std::move(inbound.message));
00305             // Transition to receiving header
00306             inbound.bytesRead = 0;
00307         }
00308     }
00309 }
00310 
00311 ssize_t
00312 MessageSocket::read(void* buf, size_t maxBytes)
00313 {
00314     ssize_t actual = recv(receiveSocket.fd, buf, maxBytes, MSG_DONTWAIT);
00315     if (actual > 0)
00316         return actual;
00317     if (actual == 0 || // peer performed orderly shutdown.
00318         errno == ECONNRESET || errno == ETIMEDOUT || errno == EHOSTUNREACH) {
00319         return -1;
00320     }
00321     if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR)
00322         return 0;
00323     PANIC("Error while reading from socket: %s", strerror(errno));
00324 }
00325 
00326 void
00327 MessageSocket::writable()
00328 {
00329     // Each iteration of this loop tries to write one message
00330     // from outboundQueue.
00331     while (true) {
00332 
00333         // Get the next outbound message.
00334         Outbound outbound;
00335         int flags = MSG_DONTWAIT | MSG_NOSIGNAL;
00336         {
00337             std::lock_guard<Core::Mutex> lock(outboundQueueMutex);
00338             if (outboundQueue.empty())
00339                 return;
00340             outbound = std::move(outboundQueue.front());
00341             outboundQueue.pop_front();
00342             if (!outboundQueue.empty())
00343                 flags |= MSG_MORE;
00344         }
00345 
00346         // Use an iovec to send everything in one kernel call: one iov for the
00347         // header, another for the payload.
00348         enum { IOV_LEN = 2 };
00349         struct iovec iov[IOV_LEN];
00350         iov[0].iov_base = &outbound.header;
00351         iov[0].iov_len = sizeof(Header);
00352         iov[1].iov_base = outbound.message.getData();
00353         iov[1].iov_len = outbound.message.getLength();
00354 
00355         { // Skip the parts of the iovec that have already been sent.
00356             size_t bytesSent = outbound.bytesSent;
00357             for (uint32_t i = 0; i < IOV_LEN; ++i) {
00358                 iov[i].iov_base = (static_cast<char*>(iov[i].iov_base) +
00359                                    bytesSent);
00360                 if (bytesSent < iov[i].iov_len) {
00361                     iov[i].iov_len -= bytesSent;
00362                     break;
00363                 } else {
00364                     bytesSent -= iov[i].iov_len;
00365                     iov[i].iov_len = 0;
00366                 }
00367             }
00368         }
00369 
00370         struct msghdr msg;
00371         memset(&msg, 0, sizeof(msg));
00372         msg.msg_iov = iov;
00373         msg.msg_iovlen = IOV_LEN;
00374 
00375         // Do the actual send
00376         ssize_t bytesSent = sendmsg(sendSocket.fd, &msg, flags);
00377         if (bytesSent < 0) {
00378             if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
00379                 // Wasn't able to send, try again later.
00380                 bytesSent = 0;
00381             } else if (errno == ECONNRESET || errno == EPIPE) {
00382                 // Connection closed; disconnect this end.
00383                 // This must be the last line to touch this object, in case
00384                 // handleDisconnect() deletes this object.
00385                 disconnect();
00386                 return;
00387             } else {
00388                 // Unexpected error.
00389                 PANIC("Error while writing to socket %d: %s",
00390                       sendSocket.fd, strerror(errno));
00391             }
00392         }
00393 
00394         // Sent successfully.
00395         outbound.bytesSent += size_t(bytesSent);
00396         if (outbound.bytesSent != (sizeof(Header) +
00397                                    outbound.message.getLength())) {
00398             sendSocketMonitor.setEvents(EPOLLOUT|EPOLLONESHOT);
00399             std::lock_guard<Core::Mutex> lockGuard(outboundQueueMutex);
00400             outboundQueue.emplace_front(std::move(outbound));
00401             return;
00402         }
00403     }
00404 }
00405 
00406 } // namespace LogCabin::RPC
00407 } // namespace LogCabin
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines