LogCabin
|
00001 /* Copyright (c) 2010-2014 Stanford University 00002 * Copyright (c) 2015 Diego Ongaro 00003 * 00004 * Permission to use, copy, modify, and distribute this software for any 00005 * purpose with or without fee is hereby granted, provided that the above 00006 * copyright notice and this permission notice appear in all copies. 00007 * 00008 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR(S) DISCLAIM ALL WARRANTIES 00009 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 00010 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL AUTHORS BE LIABLE FOR 00011 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 00012 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 00013 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 00014 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 00015 */ 00016 00017 #include <cassert> 00018 #include <errno.h> 00019 #include <netinet/in.h> 00020 #include <netinet/tcp.h> 00021 #include <string.h> 00022 #include <sys/epoll.h> 00023 #include <sys/socket.h> 00024 #include <sys/types.h> 00025 #include <unistd.h> 00026 00027 #include "Core/Debug.h" 00028 #include "Core/Endian.h" 00029 #include "Event/Loop.h" 00030 #include "RPC/MessageSocket.h" 00031 00032 namespace LogCabin { 00033 namespace RPC { 00034 00035 namespace { 00036 00037 /// Wrapper for dup(). 00038 int 00039 dupOrPanic(int oldfd) 00040 { 00041 int newfd = dup(oldfd); 00042 if (newfd < 0) 00043 PANIC("Failed to dup(%d): %s", oldfd, strerror(errno)); 00044 return newfd; 00045 } 00046 00047 } // anonymous namespace 00048 00049 ////////// MessageSocket::SendSocket ////////// 00050 00051 MessageSocket::SendSocket::SendSocket(int fd, 00052 MessageSocket& messageSocket) 00053 : Event::File(fd) 00054 , messageSocket(messageSocket) 00055 { 00056 int flag = 1; 00057 int r = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag)); 00058 if (r < 0) { 00059 // This should be a warning, but some unit tests pass weird types of 00060 // file descriptors in here. It's not very important, anyhow. 00061 NOTICE("Could not set TCP_NODELAY flag on sending socket %d: %s", 00062 fd, strerror(errno)); 00063 } 00064 } 00065 00066 MessageSocket::SendSocket::~SendSocket() 00067 { 00068 } 00069 00070 void 00071 MessageSocket::SendSocket::handleFileEvent(uint32_t events) 00072 { 00073 messageSocket.writable(); 00074 } 00075 00076 ////////// MessageSocket::ReceiveSocket ////////// 00077 00078 MessageSocket::ReceiveSocket::ReceiveSocket(int fd, 00079 MessageSocket& messageSocket) 00080 : Event::File(fd) 00081 , messageSocket(messageSocket) 00082 { 00083 // I don't know that TCP_NODELAY has any effect if we're only reading from 00084 // this file descriptor, but I guess it can't hurt. 00085 int flag = 1; 00086 int r = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag)); 00087 if (r < 0) { 00088 // This should be a warning, but some unit tests pass weird types of 00089 // file descriptors in here. It's not very important, anyhow. 00090 NOTICE("Could not set TCP_NODELAY flag on receiving socket %d: %s", 00091 fd, strerror(errno)); 00092 } 00093 } 00094 00095 MessageSocket::ReceiveSocket::~ReceiveSocket() 00096 { 00097 } 00098 00099 void 00100 MessageSocket::ReceiveSocket::handleFileEvent(uint32_t events) 00101 { 00102 messageSocket.readable(); 00103 } 00104 00105 ////////// MessageSocket::Header ////////// 00106 00107 void 00108 MessageSocket::Header::fromBigEndian() 00109 { 00110 fixed = be16toh(fixed); 00111 version = be16toh(version); 00112 payloadLength = be32toh(payloadLength); 00113 messageId = be64toh(messageId); 00114 } 00115 00116 void 00117 MessageSocket::Header::toBigEndian() 00118 { 00119 fixed = htobe16(fixed); 00120 version = htobe16(version); 00121 payloadLength = htobe32(payloadLength); 00122 messageId = htobe64(messageId); 00123 } 00124 00125 ////////// MessageSocket::Inbound ////////// 00126 00127 MessageSocket::Inbound::Inbound() 00128 : bytesRead(0) 00129 , header() 00130 , message() 00131 { 00132 } 00133 00134 ////////// MessageSocket::Outbound ////////// 00135 00136 MessageSocket::Outbound::Outbound() 00137 : bytesSent(0) 00138 , header() 00139 , message() 00140 { 00141 } 00142 00143 MessageSocket::Outbound::Outbound(Outbound&& other) 00144 : bytesSent(other.bytesSent) 00145 , header(other.header) 00146 , message(std::move(other.message)) 00147 { 00148 } 00149 00150 MessageSocket::Outbound::Outbound(MessageId messageId, 00151 Core::Buffer message) 00152 : bytesSent(0) 00153 , header() 00154 , message(std::move(message)) 00155 { 00156 header.fixed = 0xdaf4; 00157 header.version = 1; 00158 header.payloadLength = uint32_t(this->message.getLength()); 00159 header.messageId = messageId; 00160 header.toBigEndian(); 00161 } 00162 00163 MessageSocket::Outbound& 00164 MessageSocket::Outbound::operator=(Outbound&& other) 00165 { 00166 bytesSent = other.bytesSent; 00167 header = other.header; 00168 message = std::move(other.message); 00169 return *this; 00170 } 00171 00172 ////////// MessageSocket ////////// 00173 00174 MessageSocket::MessageSocket(Handler& handler, 00175 Event::Loop& eventLoop, int fd, 00176 uint32_t maxMessageLength) 00177 : maxMessageLength(maxMessageLength) 00178 , handler(handler) 00179 , eventLoop(eventLoop) 00180 , inbound() 00181 , outboundQueueMutex() 00182 , outboundQueue() 00183 , receiveSocket(dupOrPanic(fd), *this) 00184 , sendSocket(fd, *this) 00185 , receiveSocketMonitor(eventLoop, receiveSocket, EPOLLIN) 00186 , sendSocketMonitor(eventLoop, sendSocket, 0) 00187 { 00188 } 00189 00190 MessageSocket::~MessageSocket() 00191 { 00192 } 00193 00194 void 00195 MessageSocket::close() 00196 { 00197 receiveSocketMonitor.disableForever(); 00198 sendSocketMonitor.disableForever(); 00199 00200 // Take an Event::Loop::Lock in case the handler assumes it's being 00201 // executed on the event loop thread. 00202 Event::Loop::Lock lock(eventLoop); 00203 handler.handleDisconnect(); 00204 } 00205 00206 void 00207 MessageSocket::sendMessage(MessageId messageId, Core::Buffer contents) 00208 { 00209 // Check the message length. 00210 if (contents.getLength() > maxMessageLength) { 00211 PANIC("Message of length %lu bytes is too long to send " 00212 "(limit is %u bytes)", 00213 contents.getLength(), maxMessageLength); 00214 } 00215 00216 bool kick; 00217 { // Place the message on the outbound queue. 00218 std::lock_guard<Core::Mutex> lock(outboundQueueMutex); 00219 kick = outboundQueue.empty(); 00220 outboundQueue.emplace_back(messageId, std::move(contents)); 00221 } 00222 // Make sure the SendSocket is set up to call writable(). 00223 if (kick) 00224 sendSocketMonitor.setEvents(EPOLLOUT|EPOLLONESHOT); 00225 } 00226 00227 void 00228 MessageSocket::disconnect() 00229 { 00230 receiveSocketMonitor.disableForever(); 00231 sendSocketMonitor.disableForever(); 00232 // TODO(ongaro): to make it safe for epoll_wait to return multiple events, 00233 // need to somehow queue the handleDisconnect for later. 00234 handler.handleDisconnect(); 00235 } 00236 00237 void 00238 MessageSocket::readable() 00239 { 00240 // Try to read data from the kernel until there is no more left. 00241 while (true) { 00242 if (inbound.bytesRead < sizeof(Header)) { 00243 // Receiving header 00244 ssize_t bytesRead = read( 00245 reinterpret_cast<char*>(&inbound.header) + inbound.bytesRead, 00246 sizeof(Header) - inbound.bytesRead); 00247 if (bytesRead == -1) { 00248 disconnect(); 00249 return; 00250 } 00251 inbound.bytesRead += size_t(bytesRead); 00252 if (inbound.bytesRead < sizeof(Header)) 00253 return; 00254 // Transition to receiving data 00255 inbound.header.fromBigEndian(); 00256 if (inbound.header.fixed != 0xdaf4) { 00257 WARNING("Disconnecting since message doesn't start with magic " 00258 "0xdaf4 (first two bytes are 0x%02x)", 00259 inbound.header.fixed); 00260 disconnect(); 00261 return; 00262 } 00263 if (inbound.header.version != 1) { 00264 WARNING("Disconnecting since message uses version %u, but " 00265 "this code only understands version 1", 00266 inbound.header.version); 00267 disconnect(); 00268 return; 00269 } 00270 if (inbound.header.payloadLength > maxMessageLength) { 00271 WARNING("Disconnecting since message is too long to receive " 00272 "(message is %u bytes, limit is %u bytes)", 00273 inbound.header.payloadLength, maxMessageLength); 00274 disconnect(); 00275 return; 00276 } 00277 inbound.message.setData(new char[inbound.header.payloadLength], 00278 inbound.header.payloadLength, 00279 Core::Buffer::deleteArrayFn<char>); 00280 } 00281 // Don't use 'else' here; we want to check this branch for two reasons: 00282 // First, if there is a header with a length of 0, the socket won't be 00283 // readable, but we still need to process the message. Second, most of 00284 // the time the header will arrive with at least some data. It makes 00285 // sense to go ahead and try a non-blocking read, rather than going 00286 // back to the event loop. 00287 if (inbound.bytesRead >= sizeof(Header)) { 00288 // Receiving data 00289 size_t payloadBytesRead = inbound.bytesRead - sizeof(Header); 00290 ssize_t bytesRead = read( 00291 (static_cast<char*>(inbound.message.getData()) + 00292 payloadBytesRead), 00293 inbound.header.payloadLength - payloadBytesRead); 00294 if (bytesRead == -1) { 00295 disconnect(); 00296 return; 00297 } 00298 inbound.bytesRead += size_t(bytesRead); 00299 if (inbound.bytesRead < (sizeof(Header) + 00300 inbound.header.payloadLength)) { 00301 return; 00302 } 00303 handler.handleReceivedMessage(inbound.header.messageId, 00304 std::move(inbound.message)); 00305 // Transition to receiving header 00306 inbound.bytesRead = 0; 00307 } 00308 } 00309 } 00310 00311 ssize_t 00312 MessageSocket::read(void* buf, size_t maxBytes) 00313 { 00314 ssize_t actual = recv(receiveSocket.fd, buf, maxBytes, MSG_DONTWAIT); 00315 if (actual > 0) 00316 return actual; 00317 if (actual == 0 || // peer performed orderly shutdown. 00318 errno == ECONNRESET || errno == ETIMEDOUT || errno == EHOSTUNREACH) { 00319 return -1; 00320 } 00321 if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) 00322 return 0; 00323 PANIC("Error while reading from socket: %s", strerror(errno)); 00324 } 00325 00326 void 00327 MessageSocket::writable() 00328 { 00329 // Each iteration of this loop tries to write one message 00330 // from outboundQueue. 00331 while (true) { 00332 00333 // Get the next outbound message. 00334 Outbound outbound; 00335 int flags = MSG_DONTWAIT | MSG_NOSIGNAL; 00336 { 00337 std::lock_guard<Core::Mutex> lock(outboundQueueMutex); 00338 if (outboundQueue.empty()) 00339 return; 00340 outbound = std::move(outboundQueue.front()); 00341 outboundQueue.pop_front(); 00342 if (!outboundQueue.empty()) 00343 flags |= MSG_MORE; 00344 } 00345 00346 // Use an iovec to send everything in one kernel call: one iov for the 00347 // header, another for the payload. 00348 enum { IOV_LEN = 2 }; 00349 struct iovec iov[IOV_LEN]; 00350 iov[0].iov_base = &outbound.header; 00351 iov[0].iov_len = sizeof(Header); 00352 iov[1].iov_base = outbound.message.getData(); 00353 iov[1].iov_len = outbound.message.getLength(); 00354 00355 { // Skip the parts of the iovec that have already been sent. 00356 size_t bytesSent = outbound.bytesSent; 00357 for (uint32_t i = 0; i < IOV_LEN; ++i) { 00358 iov[i].iov_base = (static_cast<char*>(iov[i].iov_base) + 00359 bytesSent); 00360 if (bytesSent < iov[i].iov_len) { 00361 iov[i].iov_len -= bytesSent; 00362 break; 00363 } else { 00364 bytesSent -= iov[i].iov_len; 00365 iov[i].iov_len = 0; 00366 } 00367 } 00368 } 00369 00370 struct msghdr msg; 00371 memset(&msg, 0, sizeof(msg)); 00372 msg.msg_iov = iov; 00373 msg.msg_iovlen = IOV_LEN; 00374 00375 // Do the actual send 00376 ssize_t bytesSent = sendmsg(sendSocket.fd, &msg, flags); 00377 if (bytesSent < 0) { 00378 if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { 00379 // Wasn't able to send, try again later. 00380 bytesSent = 0; 00381 } else if (errno == ECONNRESET || errno == EPIPE) { 00382 // Connection closed; disconnect this end. 00383 // This must be the last line to touch this object, in case 00384 // handleDisconnect() deletes this object. 00385 disconnect(); 00386 return; 00387 } else { 00388 // Unexpected error. 00389 PANIC("Error while writing to socket %d: %s", 00390 sendSocket.fd, strerror(errno)); 00391 } 00392 } 00393 00394 // Sent successfully. 00395 outbound.bytesSent += size_t(bytesSent); 00396 if (outbound.bytesSent != (sizeof(Header) + 00397 outbound.message.getLength())) { 00398 sendSocketMonitor.setEvents(EPOLLOUT|EPOLLONESHOT); 00399 std::lock_guard<Core::Mutex> lockGuard(outboundQueueMutex); 00400 outboundQueue.emplace_front(std::move(outbound)); 00401 return; 00402 } 00403 } 00404 } 00405 00406 } // namespace LogCabin::RPC 00407 } // namespace LogCabin