diff --git a/src/PlainSession.cpp b/src/PlainSession.cpp index 79994d4c56f2673eb53fc6a67fc9e78b5e4c2939..b2062a533b8bdd382877973a93657464c81f2af7 100644 --- a/src/PlainSession.cpp +++ b/src/PlainSession.cpp @@ -164,6 +164,61 @@ void PlainSession::startWriteResponse() void PlainSession::startWriteData() { DEBUG_STREAM << "PlainSession::startWriteData()" << endl; + + try + { + if(!m_inputStream.bad()) + { + if(m_inputStream.tellg()<m_inputStreamSize) + { + int leftToRead = m_inputStreamSize - m_inputStream.tellg(); + + DEBUG_STREAM << "PlainSession::startWriteData() left to read " << leftToRead << endl; + + int bufferSize = 0; + + if(leftToRead < BUFFER_SIZE) + bufferSize = leftToRead; + else + bufferSize = BUFFER_SIZE; + + DEBUG_STREAM << "PlainSession::startWriteData() buffer size " << bufferSize << endl; + + std::vector<char> writeBuff; + writeBuff.resize(bufferSize); + + m_inputStream.read(&writeBuff[0], bufferSize); + + boost::asio::async_write(m_plainSocket, boost::asio::buffer(writeBuff), + m_strand.wrap(boost::bind(&PlainSession::handleWriteData, + shared_from_this(), boost::asio::placeholders::error))); + } + else + { + INFO_STREAM << "SSLSession::startWriteData() " + << " transfer completed " << endl; + + m_inputStream.close(); + + startReadRequestHeader(); + } + } + else + { + ERROR_STREAM << "SSLSession::startWriteData() error on file I/O " + << "from " << m_remoteEndpoint << endl; + } + } + catch(std::exception& ec) + { + ERROR_STREAM << "PlainSession::startWriteData() " + << ec.what() << " from " << m_remoteEndpoint << endl; + } + catch(...) + { + ERROR_STREAM << "PlainSession::startWriteData() unknown error from " + << m_remoteEndpoint << endl; + } } } //namespace \ No newline at end of file diff --git a/src/ProtocolManager.cpp b/src/ProtocolManager.cpp index 07a4c679d8fc17a1b61537849411dfafcfefddf4..41925eacaa34cc25f466e1605bf21f2ef89214f1 100644 --- a/src/ProtocolManager.cpp +++ b/src/ProtocolManager.cpp @@ -17,6 +17,8 @@ ProtocolManager::ProtocolManager(Tango::DeviceImpl* deviceImpl_p, DEBUG_STREAM << "ProtocolManager::ProtocolManager()" << endl; m_isAuthorised = false; + m_isValidated = false; + m_isTransferRequest = false; } //============================================================================== @@ -94,6 +96,36 @@ ResponseSP ProtocolManager::prepareResponse(RequestSP request_sp) return response_sp; } +//============================================================================== +// ProtocolManager::isTransferRequest() +//============================================================================== +bool ProtocolManager::isTransferRequest() +{ + DEBUG_STREAM << "ProtocolManager::isTransferRequest()" << endl; + + return m_isTransferRequest; +} + +//============================================================================== +// ProtocolManager::getFilePath() +//============================================================================== +std::string ProtocolManager::getFilePath() +{ + DEBUG_STREAM << "ProtocolManager::getFilePath()" << endl; + + return m_filePath; +} + +//============================================================================== +// ProtocolManager::getFileSize() +//============================================================================== +int ProtocolManager::getFileSize() +{ + DEBUG_STREAM << "ProtocolManager::getFileSize()" << endl; + + return m_fileSize; +} + //============================================================================== // ProtocolManager::prepareAuthroisation() //============================================================================== @@ -233,6 +265,8 @@ ResponseSP ProtocolManager::prepareTransfer(RequestSP request_sp) try { + m_isTransferRequest = false; + DBManager::FileTuple fileTuple = m_dBManager_sp->retrieveFileInfo(m_validatedSchema, m_validatedTable, fileVersion, fileName); @@ -276,6 +310,10 @@ ResponseSP ProtocolManager::prepareTransfer(RequestSP request_sp) transferRes->set_state(Response::Transfer::ACCEPTED); transferRes->set_status("File found"); + + m_isTransferRequest = true; + m_filePath = absPath.string(); + m_fileSize = boost::filesystem::file_size(absPath); } catch(std::exception& ex) { diff --git a/src/ProtocolManager.h b/src/ProtocolManager.h index 815c662dfb29c61adb51474753bf734219addac9..9b38336edfc9b63b081ed92bae2ee885a21365a5 100644 --- a/src/ProtocolManager.h +++ b/src/ProtocolManager.h @@ -59,6 +59,15 @@ public: virtual ResponseSP prepareResponse(RequestSP) throw(std::runtime_error); +//------------------------------------------------------------------------------ +// [Public] File transfer methods +//------------------------------------------------------------------------------ + virtual bool isTransferRequest(); + + virtual std::string getFilePath(); + + virtual int getFileSize(); + protected: //------------------------------------------------------------------------------ // [Protected] Request specific methods @@ -94,6 +103,12 @@ protected: //Address and port of remote endpoint std::string m_remoteEndpoint; + + bool m_isTransferRequest; + + std::string m_filePath; + + int m_fileSize; }; } //End of namespace diff --git a/src/SSLSession.cpp b/src/SSLSession.cpp index 8f2a896884d5b0eea0e04686064e7427c21db3fb..9ebcd9498c5d4eca020d74cc6bfc3a4c10f5ddce 100644 --- a/src/SSLSession.cpp +++ b/src/SSLSession.cpp @@ -199,6 +199,61 @@ void SSLSession::startWriteResponse() void SSLSession::startWriteData() { DEBUG_STREAM << "SSLSession::startWriteData()" << endl; + + try + { + if(!m_inputStream.bad()) + { + if(m_inputStream.tellg()<m_inputStreamSize) + { + int leftToRead = m_inputStreamSize - m_inputStream.tellg(); + + DEBUG_STREAM << "SSLSession::startWriteData() left to read " << leftToRead << endl; + + int bufferSize = 0; + + if(leftToRead < BUFFER_SIZE) + bufferSize = leftToRead; + else + bufferSize = BUFFER_SIZE; + + DEBUG_STREAM << "SSLSession::startWriteData() buffer size " << bufferSize << endl; + + std::vector<char> writeBuff; + writeBuff.resize(bufferSize); + + m_inputStream.read(&writeBuff[0], bufferSize); + + boost::asio::async_write(m_sslSocket, boost::asio::buffer(writeBuff), + m_strand.wrap(boost::bind(&SSLSession::handleWriteData, + shared_from_this(), boost::asio::placeholders::error))); + } + else + { + INFO_STREAM << "SSLSession::startWriteData() " + << " transfer completed " << endl; + + m_inputStream.close(); + + startReadRequestHeader(); + } + } + else + { + ERROR_STREAM << "SSLSession::startWriteData() error on file I/O " + << "from " << m_remoteEndpoint << endl; + } + } + catch(std::exception& ec) + { + ERROR_STREAM << "SSLSession::startWriteData() " + << ec.what() << " from " << m_remoteEndpoint << endl; + } + catch(...) + { + ERROR_STREAM << "SSLSession::startWriteData() unknown error from " + << m_remoteEndpoint << endl; + } } } //namespace diff --git a/src/Session.cpp b/src/Session.cpp index 481de1e90f0662d105a8ea28e2138a045e18e128..15366eef77863bfd8c01d9476594a491ccbde05b 100644 --- a/src/Session.cpp +++ b/src/Session.cpp @@ -82,7 +82,36 @@ void Session::handleWriteResponse(const boost::system::error_code& errorCode) if(!errorCode) { - startReadRequestHeader(); + if(m_protocolManager_sp->isTransferRequest()) + { + std::string filePath = m_protocolManager_sp->getFilePath(); + int fileSize = m_protocolManager_sp->getFileSize(); + + INFO_STREAM << "Session::handleWriteResponse() transfer request " + << filePath << " size " << fileSize << " from " + << m_remoteEndpoint << endl; + + m_inputStreamSize = fileSize; + + if(m_inputStream.is_open()) + m_inputStream.close(); + + m_inputStream.open(filePath.c_str(), std::ios::binary); + + if(m_inputStream) + { + startWriteData(); + } + else + { + ERROR_STREAM << "Session::handleWriteResponse() Cannot open " + << filePath << endl; + } + } + else + { + startReadRequestHeader(); + } } else if(errorCode == boost::asio::error::eof) { @@ -99,9 +128,24 @@ void Session::handleWriteResponse(const boost::system::error_code& errorCode) //============================================================================== // Session::handleWriteData() //============================================================================== -void Session::handleWriteData(const boost::system::error_code&) +void Session::handleWriteData(const boost::system::error_code& errorCode) { DEBUG_STREAM << "Session::handleWriteData()" << endl; + + if(!errorCode) + { + startWriteData(); + } + else if(errorCode == boost::asio::error::eof) + { + DEBUG_STREAM << "Session::handleWriteResponse() end of file from " + << m_remoteEndpoint << endl; + } + else + { + ERROR_STREAM << "Session::handleWriteResponse() " + << errorCode.message() << " from " << m_remoteEndpoint << endl; + } } //============================================================================== diff --git a/src/Session.h b/src/Session.h index ed022a636c690db94c96256001268e116f4beff6..5f7b2cc7b9868a204165199087e7b7e2cf957a2a 100644 --- a/src/Session.h +++ b/src/Session.h @@ -97,6 +97,12 @@ protected: //Address and port of remote endpoint std::string m_remoteEndpoint; + + const int BUFFER_SIZE = 1024; + + std::ifstream m_inputStream; + + int m_inputStreamSize; }; } //End of namespace