diff --git a/socket.cxx b/socket.cxx index 1bd5090..7ddcf78 100644 --- a/socket.cxx +++ b/socket.cxx @@ -74,15 +74,34 @@ namespace xerxes MysqlData& data, int flags) { - return ::recv(socket.fd, data.first.get(), data.second, flags); + int ret = ::recv(socket.fd, data.first.get(), data.second, flags); + if(ret == 0) + { + throw ConResetErr(); + } + if(ret < 0) + { + throw ConDataErr(); + } + return ret; } int send (Socket& socket, MysqlData& data, + int len, int flags) { - return ::send(socket.fd, data.first.get(), data.second, flags); + int ret = ::send(socket.fd, data.first.get(), len, flags); + if(ret == 0) + { + throw ConResetErr(); + } + if(ret < 0) + { + throw ConDataErr(); + } + return ret; } int diff --git a/socket.hxx b/socket.hxx index b70d527..8f92ea9 100644 --- a/socket.hxx +++ b/socket.hxx @@ -62,6 +62,7 @@ namespace xerxes int send(Socket& socket, MysqlData& data, + int len, int flags); int @@ -95,7 +96,7 @@ namespace xerxes { events[source.fd] = event_t(new epoll_event); events[target.fd] = event_t(new epoll_event); - events[source.fd]->events = EPOLLIN; + events[source.fd]->events = EPOLLIN | EPOLLPRI | EPOLLERR | EPOLLHUP; events[target.fd]->events = events[source.fd]->events; events[source.fd]->data.fd = target.fd; @@ -126,6 +127,10 @@ namespace xerxes int fd; std::map events; }; + + class SocketErr{}; + class ConResetErr: public SocketErr{}; + class ConDataErr: public SocketErr{}; } #endif diff --git a/xerxes.cxx b/xerxes.cxx index 299bbac..d1a5011 100644 --- a/xerxes.cxx +++ b/xerxes.cxx @@ -41,6 +41,10 @@ main(int argc, char* argv[]) freeaddrinfo(res); int ret = listen(lstn, 3); + if( ret != 0) + { + exit(1); + } EPoll epoll; epoll.add(lstn); @@ -49,6 +53,8 @@ main(int argc, char* argv[]) boost::shared_array events(new epoll_event[max_events]); std::map > sockets; + MysqlData buffer = makeData(1024); + for(;;) { int num = epoll_wait(epoll.fd, events.get(), max_events, -1); @@ -70,11 +76,9 @@ main(int argc, char* argv[]) hints.ai_socktype = SOCK_STREAM; hints.ai_protocol = IPPROTO_TCP; hints.ai_flags = AI_PASSIVE | AI_NUMERICHOST; - getaddrinfo("127.0.0.1", "3306", &hints, &res); - + //getaddrinfo("127.0.0.1", "25", &hints, &res); connect(*target, res->ai_addr, res->ai_addrlen); - freeaddrinfo(res); sockets[target->fd] = target; @@ -85,9 +89,47 @@ main(int argc, char* argv[]) epoll.add(*source, *target); } - else{ cout << "hollo!" << events[i].data.fd << endl; } - } - } - + else + { + //lookup + boost::shared_ptr target(sockets[events[i].data.fd]); + boost::shared_ptr source(sockets[epoll.events[target->fd]->data.fd]); + if((events[i].events & EPOLLIN ) + || (events[i].events & EPOLLPRI)) + { + // read -> write + cout << "writer: "<< source->fd << endl; + cout << "reader: "<< target->fd << endl; + try + { + int len = recv(*source, buffer, 0); + send(*target, buffer, len, 0); + + } + catch (SocketErr e) + { + // hangup + cout << "Socket Error, closing " << target->fd << " and " << source->fd << endl; + epoll.del(target->fd); + epoll.del(source->fd); + sockets.erase(target->fd); + sockets.erase(source->fd); + continue; + } + } + if((events[i].events & EPOLLERR) + || (events[i].events & EPOLLHUP)) + { + // hangup + cout << target->fd << " closed by " << source->fd << endl; + epoll.del(target->fd); + epoll.del(source->fd); + sockets.erase(target->fd); + sockets.erase(source->fd); + continue; + } + } + } + } return 0; }