Skip to content

Commit

Permalink
use reference counted buffer in MemoryPlot, to allow it to be copied …
Browse files Browse the repository at this point in the history
…(which it is) without a double-free
  • Loading branch information
arvidn committed Jul 9, 2023
1 parent c0216c9 commit 03cc2e3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 57 deletions.
73 changes: 20 additions & 53 deletions src/tools/PlotReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -975,46 +975,15 @@ GreenReaperContext* PlotReader::GetGRContext()
return _grContext;
}

struct VirtualFreeDeleter {
void operator()(byte* ptr) const {
SysHost::VirtualFree( ptr );
}
};

///
/// Memory Plot
///
//-----------------------------------------------------------
MemoryPlot::MemoryPlot()
: _bytes( nullptr, 0 )
{}

//-----------------------------------------------------------
MemoryPlot::MemoryPlot( const MemoryPlot& plotFile )
{
_bytes = plotFile._bytes;
_err = 0;
_position = 0;

int headerError = 0;
if( !ReadHeader( headerError ) )
{
if( headerError )
_err = headerError;

if( _err == 0 )
_err = -1; // #TODO: Set generic plot header read error

_bytes.values = nullptr;
return;
}

_plotPath = plotFile._plotPath;
}

//-----------------------------------------------------------
MemoryPlot::~MemoryPlot()
{
// #TODO: Don't destroy bytes unless we own them. Use a shared ptr here.
if( _bytes.values )
SysHost::VirtualFree( _bytes.values );

_bytes = Span<byte>( nullptr, 0 );
}

//-----------------------------------------------------------
bool MemoryPlot::Open( const char* path )
Expand Down Expand Up @@ -1047,7 +1016,8 @@ bool MemoryPlot::Open( const char* path )
// we have any remainder that does not align to a block
const size_t allocSize = RoundUpToNextBoundary( (size_t)plotSize, (int)file.BlockSize() ) + file.BlockSize();

byte* bytes = (byte*)SysHost::VirtualAlloc( allocSize );
auto bytes = std::shared_ptr<byte[]>(
(byte*)SysHost::VirtualAlloc( allocSize ), VirtualFreeDeleter());
if( !bytes )
{
_err = -1; // #TODO: Assign an actual user error.
Expand All @@ -1058,7 +1028,7 @@ bool MemoryPlot::Open( const char* path )
size_t readSize = RoundUpToNextBoundary( plotSize, (int)file.BlockSize() );/// file.BlockSize() * file.BlockSize();
// size_t readRemainder = plotSize - readSize;
const size_t readEnd = readSize - plotSize;
byte* reader = bytes;
byte* reader = bytes.get();

// Read blocks
while( readSize > readEnd )
Expand All @@ -1069,8 +1039,6 @@ bool MemoryPlot::Open( const char* path )
if( read < 0 )
{
_err = file.GetError();
SysHost::VirtualFree( bytes );

return false;
}

Expand All @@ -1089,16 +1057,15 @@ bool MemoryPlot::Open( const char* path )
// if( read < 0 )
// {
// _err = file.GetError();
// SysHost::VirtualFree( bytes );

// return false;
// }

// if( reader != block )
// memmove( reader, block, readRemainder );
// }

_bytes = Span<byte>( bytes, (size_t)plotSize );
_buffer = std::move(bytes);
_size = plotSize;

// Read the header
int headerError = 0;
Expand All @@ -1110,13 +1077,13 @@ bool MemoryPlot::Open( const char* path )
if( _err == 0 )
_err = -1; // #TODO: Set generic plot header read error

_bytes.values = nullptr;
SysHost::VirtualFree( bytes );
_buffer.reset();
_size = 0;
return false;
}

// Lock the plot memory into read-only mode
SysHost::VirtualProtect( bytes, allocSize, VProtect::Read );
SysHost::VirtualProtect( _buffer.get(), allocSize, VProtect::Read );

// Save data, good to go
_plotPath = path;
Expand All @@ -1127,13 +1094,13 @@ bool MemoryPlot::Open( const char* path )
//-----------------------------------------------------------
bool MemoryPlot::IsOpen() const
{
return _bytes.values != nullptr;
return _buffer.get();
}

//-----------------------------------------------------------
size_t MemoryPlot::PlotSize() const
{
return _bytes.length;
return _size;
}

//-----------------------------------------------------------
Expand All @@ -1152,15 +1119,15 @@ bool MemoryPlot::Seek( SeekOrigin origin, int64 offset )
break;

case SeekOrigin::End:
absPosition = (ssize_t)_bytes.length + offset;
absPosition = (ssize_t)_size + offset;
break;

default:
_err = -1; // #TODO: Set proper user error.
return false;
}

if( absPosition < 0 || absPosition > (ssize_t)_bytes.length )
if( absPosition < 0 || absPosition > (ssize_t)_size )
{
_err = -1; // #TODO: Set proper user error.
return false;
Expand All @@ -1180,13 +1147,13 @@ ssize_t MemoryPlot::Read( size_t size, void* buffer )

const size_t endPos = (size_t)_position + size;

if( endPos > _bytes.length )
if( endPos > _size )
{
_err = -1; // #TODO: Set proper user error
return false;
}

memcpy( buffer, _bytes.values + _position, size );
memcpy( buffer, _buffer.get() + _position, size );
_position = (ssize_t)endPos;

return (ssize_t)size;
Expand Down
12 changes: 8 additions & 4 deletions src/tools/PlotReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "io/FileStream.h"
#include "util/Util.h"
#include <vector>
#include <memory>

class CPBitReader;

Expand Down Expand Up @@ -119,9 +120,10 @@ class IPlotFile
class MemoryPlot : public IPlotFile
{
public:
MemoryPlot();
MemoryPlot( const MemoryPlot& plotFile );
~MemoryPlot();
MemoryPlot() = default;
MemoryPlot( const MemoryPlot& plotFile ) = default;
MemoryPlot( MemoryPlot&& plotFile ) = default;
~MemoryPlot() = default;

bool Open( const char* path ) override;
bool IsOpen() const override;
Expand All @@ -135,7 +137,9 @@ class MemoryPlot : public IPlotFile
int GetError() override;

private:
Span<byte> _bytes; // Plot bytes

std::shared_ptr<byte[]> _buffer;
size_t _size = 0;
int _err = 0;
ssize_t _position = 0;
std::string _plotPath = "";
Expand Down

0 comments on commit 03cc2e3

Please sign in to comment.