From 03cc2e3fa73872ed003104b131e73e557bcc0037 Mon Sep 17 00:00:00 2001 From: arvidn Date: Sun, 9 Jul 2023 12:39:49 -0700 Subject: [PATCH] use reference counted buffer in MemoryPlot, to allow it to be copied (which it is) without a double-free --- src/tools/PlotReader.cpp | 73 +++++++++++----------------------------- src/tools/PlotReader.h | 12 ++++--- 2 files changed, 28 insertions(+), 57 deletions(-) diff --git a/src/tools/PlotReader.cpp b/src/tools/PlotReader.cpp index a55be7e9..744100ac 100644 --- a/src/tools/PlotReader.cpp +++ b/src/tools/PlotReader.cpp @@ -975,46 +975,15 @@ GreenReaperContext* PlotReader::GetGRContext() return _grContext; } +struct VirtualFreeDeleter { + void operator()(byte* ptr) const { + SysHost::VirtualFree( ptr ); + } +}; + /// /// Memory Plot /// -//----------------------------------------------------------- -MemoryPlot::MemoryPlot() - : _bytes( nullptr, 0 ) -{} - -//----------------------------------------------------------- -MemoryPlot::MemoryPlot( const MemoryPlot& plotFile ) -{ - _bytes = plotFile._bytes; - _err = 0; - _position = 0; - - int headerError = 0; - if( !ReadHeader( headerError ) ) - { - if( headerError ) - _err = headerError; - - if( _err == 0 ) - _err = -1; // #TODO: Set generic plot header read error - - _bytes.values = nullptr; - return; - } - - _plotPath = plotFile._plotPath; -} - -//----------------------------------------------------------- -MemoryPlot::~MemoryPlot() -{ - // #TODO: Don't destroy bytes unless we own them. Use a shared ptr here. - if( _bytes.values ) - SysHost::VirtualFree( _bytes.values ); - - _bytes = Span( nullptr, 0 ); -} //----------------------------------------------------------- bool MemoryPlot::Open( const char* path ) @@ -1047,7 +1016,8 @@ bool MemoryPlot::Open( const char* path ) // we have any remainder that does not align to a block const size_t allocSize = RoundUpToNextBoundary( (size_t)plotSize, (int)file.BlockSize() ) + file.BlockSize(); - byte* bytes = (byte*)SysHost::VirtualAlloc( allocSize ); + auto bytes = std::shared_ptr( + (byte*)SysHost::VirtualAlloc( allocSize ), VirtualFreeDeleter()); if( !bytes ) { _err = -1; // #TODO: Assign an actual user error. @@ -1058,7 +1028,7 @@ bool MemoryPlot::Open( const char* path ) size_t readSize = RoundUpToNextBoundary( plotSize, (int)file.BlockSize() );/// file.BlockSize() * file.BlockSize(); // size_t readRemainder = plotSize - readSize; const size_t readEnd = readSize - plotSize; - byte* reader = bytes; + byte* reader = bytes.get(); // Read blocks while( readSize > readEnd ) @@ -1069,8 +1039,6 @@ bool MemoryPlot::Open( const char* path ) if( read < 0 ) { _err = file.GetError(); - SysHost::VirtualFree( bytes ); - return false; } @@ -1089,8 +1057,6 @@ bool MemoryPlot::Open( const char* path ) // if( read < 0 ) // { // _err = file.GetError(); - // SysHost::VirtualFree( bytes ); - // return false; // } @@ -1098,7 +1064,8 @@ bool MemoryPlot::Open( const char* path ) // memmove( reader, block, readRemainder ); // } - _bytes = Span( bytes, (size_t)plotSize ); + _buffer = std::move(bytes); + _size = plotSize; // Read the header int headerError = 0; @@ -1110,13 +1077,13 @@ bool MemoryPlot::Open( const char* path ) if( _err == 0 ) _err = -1; // #TODO: Set generic plot header read error - _bytes.values = nullptr; - SysHost::VirtualFree( bytes ); + _buffer.reset(); + _size = 0; return false; } // Lock the plot memory into read-only mode - SysHost::VirtualProtect( bytes, allocSize, VProtect::Read ); + SysHost::VirtualProtect( _buffer.get(), allocSize, VProtect::Read ); // Save data, good to go _plotPath = path; @@ -1127,13 +1094,13 @@ bool MemoryPlot::Open( const char* path ) //----------------------------------------------------------- bool MemoryPlot::IsOpen() const { - return _bytes.values != nullptr; + return _buffer.get(); } //----------------------------------------------------------- size_t MemoryPlot::PlotSize() const { - return _bytes.length; + return _size; } //----------------------------------------------------------- @@ -1152,7 +1119,7 @@ bool MemoryPlot::Seek( SeekOrigin origin, int64 offset ) break; case SeekOrigin::End: - absPosition = (ssize_t)_bytes.length + offset; + absPosition = (ssize_t)_size + offset; break; default: @@ -1160,7 +1127,7 @@ bool MemoryPlot::Seek( SeekOrigin origin, int64 offset ) return false; } - if( absPosition < 0 || absPosition > (ssize_t)_bytes.length ) + if( absPosition < 0 || absPosition > (ssize_t)_size ) { _err = -1; // #TODO: Set proper user error. return false; @@ -1180,13 +1147,13 @@ ssize_t MemoryPlot::Read( size_t size, void* buffer ) const size_t endPos = (size_t)_position + size; - if( endPos > _bytes.length ) + if( endPos > _size ) { _err = -1; // #TODO: Set proper user error return false; } - memcpy( buffer, _bytes.values + _position, size ); + memcpy( buffer, _buffer.get() + _position, size ); _position = (ssize_t)endPos; return (ssize_t)size; diff --git a/src/tools/PlotReader.h b/src/tools/PlotReader.h index 9517c78a..d5c27db5 100644 --- a/src/tools/PlotReader.h +++ b/src/tools/PlotReader.h @@ -6,6 +6,7 @@ #include "io/FileStream.h" #include "util/Util.h" #include +#include class CPBitReader; @@ -119,9 +120,10 @@ class IPlotFile class MemoryPlot : public IPlotFile { public: - MemoryPlot(); - MemoryPlot( const MemoryPlot& plotFile ); - ~MemoryPlot(); + MemoryPlot() = default; + MemoryPlot( const MemoryPlot& plotFile ) = default; + MemoryPlot( MemoryPlot&& plotFile ) = default; + ~MemoryPlot() = default; bool Open( const char* path ) override; bool IsOpen() const override; @@ -135,7 +137,9 @@ class MemoryPlot : public IPlotFile int GetError() override; private: - Span _bytes; // Plot bytes + + std::shared_ptr _buffer; + size_t _size = 0; int _err = 0; ssize_t _position = 0; std::string _plotPath = "";