Skip to content

Commit

Permalink
Refactored run.py and display method
Browse files Browse the repository at this point in the history
  • Loading branch information
Tanin69 committed Jun 23, 2024
1 parent 1d13754 commit b31ad0f
Show file tree
Hide file tree
Showing 11 changed files with 233 additions and 288 deletions.
83 changes: 79 additions & 4 deletions AgriFieldGenerator/processing/data_processor_base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
import os

import matplotlib.figure
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import pickle
from PIL import Image
from shapely.geometry import MultiPolygon, Polygon

# Increase the maximum image pixels limit beacause Enfusion image files are very large
Image.MAX_IMAGE_PIXELS = 400000000
Expand All @@ -27,7 +29,7 @@ class DataProcessorBaseClass:
Base class for data processing tasks. Provides methods for loading and saving data,
and a method for processing data that should be implemented by subclasses.
"""
def __init__(self, source_path, save_path, save_data_path):
def __init__(self, source_path, save_path, save_data_path, svg_path, svg_height, svg_width):
"""
Initializes a new instance of the class. Sets the source, save, and save data directories.
Expand All @@ -38,6 +40,12 @@ def __init__(self, source_path, save_path, save_data_path):
self.source_directory = source_path
self.save_directory = save_path
self.save_data_directory = save_data_path
self.svg_path = svg_path
self.svg_height = svg_height
self.svg_width = svg_width
self.points = None
self.polygon = None
self.polygons = None

def process(self):
"""
Expand Down Expand Up @@ -88,8 +96,75 @@ def save(self, result, filename, data_file=False, dpi=100):
else:
raise TypeError(f"Unable to save object of type {type(result)}")

def display(self):
def display(self, file_to_display):
"""
This method should be implemented by subclasses. It should contain the logic for processing the data.
"""
raise NotImplementedError("Subclasses should implement this!")
try:
self.polygon = self.load('polygon.pkl', data_file=True)
except FileNotFoundError:
raise FileNotFoundError("Polygon data are missing. Please run the relevant generator(s) first.")

if file_to_display == 'main_polygon':
self._plot()
return

if file_to_display == 'seed_points':
try:
self.points = self.load('points.pkl', data_file=True)
except FileNotFoundError:
raise FileNotFoundError("Main polygon or seed points data are missing. Please run the relevant generator(s) first.")
self._plot(points=True, bounding_box=True)
return

if file_to_display == 'voronoi':
try:
self.polygons = self.load('voronoi.pkl', data_file=True)
except FileNotFoundError:
raise FileNotFoundError("Voronoi diagram data are missing. Please run the relevant generator(s) first.")
self._plot(polygons=True)
return

def _plot(self, points=False, bounding_box=False, polygons=False):

# Create a new figure and axes
fig, ax = plt.subplots()

# Set the limits of the axes to the SVG dimensions
ax.set_xlim(0, self.svg_width)
ax.set_ylim(0, self.svg_height)

# Display polygon
if isinstance(self.polygon, Polygon):
x, y = self.polygon.exterior.xy
ax.fill(x, y, alpha=0.5, fc='r', ec='none')
elif isinstance(self.polygon, MultiPolygon):
for poly in self.polygon.geoms:
x, y = poly.exterior.xy
ax.plot(x, y, color='r')

if bounding_box:
# Display bounding box
minx, miny, maxx, maxy = self.polygon.bounds
rect = patches.Rectangle((minx, miny), maxx-minx, maxy-miny, linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rect)

if points:
# Display points
for point in self.points:
ax.plot(*point, 'ko', markersize=1)

if polygons:
for polygon in self.polygons:
if isinstance(polygon, Polygon):
x, y = polygon.exterior.xy
ax.plot(x, y, color='b')
elif isinstance(polygon, MultiPolygon):
for poly in polygon.geoms:
x, y = poly.exterior.xy
ax.plot(x, y, color='b')

plt.show()


10 changes: 4 additions & 6 deletions AgriFieldGenerator/processing/mask_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import cv2
import numpy as np
from PIL import Image
from skimage import measure
from tqdm import tqdm

from .data_processor_base_class import DataProcessorBaseClass
Expand Down Expand Up @@ -59,6 +58,7 @@ def __init__(self,
source_path,
save_path,
save_data_path,
svg_path,
svg_height,
svg_width,
palette,
Expand Down Expand Up @@ -89,12 +89,10 @@ def __init__(self,
max_border_width : float
The maximum border width.
"""
super().__init__(source_path=source_path, save_path=save_path, save_data_path=save_data_path)
super().__init__(source_path=source_path, save_path=save_path, save_data_path=save_data_path, svg_path=svg_path, svg_height=svg_height, svg_width=svg_width)
self.source_path = source_path
self.save_path = save_path
self.save_data_path = save_data_path
self.svg_height = svg_height
self.svg_width = svg_width
self.palette = palette
self.enfusion_texture_masks = enfusion_texture_masks
self.min_border_width = min_border_width
Expand All @@ -117,7 +115,7 @@ def process(self):
# For each color in the palette
description = "Generating masks"
description += " " * (26 - len(description))
for i, color in tqdm(enumerate(palette), desc=description, total=len(palette)):
for i, color in tqdm(enumerate(palette), desc=description, total=len(palette), unit=" step"):
# Remove the '#' from the color string
color = color.lstrip('#')

Expand Down Expand Up @@ -153,7 +151,7 @@ def merge_masks(self, reset=True):
# For each color in the palette
description = "Merging masks"
description += " " * (26 - len(description))
for i, color in tqdm(enumerate(self.palette), desc=description, total=len(self.palette)):
for i, color in tqdm(enumerate(self.palette), desc=description, total=len(self.palette), unit=" step"):
# Remove the '#' from the color string
color = color.lstrip('#')

Expand Down
55 changes: 8 additions & 47 deletions AgriFieldGenerator/processing/points_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@

import os

import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from shapely.geometry import Point, Polygon, MultiPolygon
from shapely.geometry import Point
from tqdm import tqdm

from .data_processor_base_class import DataProcessorBaseClass
Expand Down Expand Up @@ -86,6 +84,7 @@ def __init__(self,
source_path,
save_path,
save_data_path,
svg_path,
svg_height,
svg_width,
num_points,
Expand All @@ -102,13 +101,11 @@ def __init__(self,
min_height,
max_height):

super().__init__(source_path=source_path, save_path=save_path, save_data_path=save_data_path)
super().__init__(source_path=source_path, save_path=save_path, save_data_path=save_data_path, svg_path=svg_path, svg_height=svg_height, svg_width=svg_width)

self.source_path = source_path
self.save_path = save_path
self.save_data_path = save_data_path
self.svg_height = svg_height
self.svg_width = svg_width
self.num_points = num_points
self.nx = nx
self.ny = ny
Expand Down Expand Up @@ -145,9 +142,9 @@ def random_generator(self):

minx, miny, maxx, maxy = self.polygon.bounds
# Create a tqdm object with a total
description = "Generating points"
description = "Generating seed points"
description += " " * (26 - len(description))
pbar = tqdm(total=self.num_points, desc=description, unit="points")
pbar = tqdm(total=self.num_points, desc=description, unit=" seed point(s)")
minx, miny, maxx, maxy = self.polygon.bounds
self.points = []
while len(self.points) < self.num_points:
Expand Down Expand Up @@ -176,7 +173,7 @@ def grid_generator(self):
# Create a tqdm object with a total
description = "Generating points"
description += " " * (26 - len(description))
pbar = tqdm(total=5, desc=description, unit="step")
pbar = tqdm(total=5, desc=description, unit=" seed point(s)")
minx, miny, maxx, maxy = self.polygon.bounds

# Define the number of points in the x and y directions
Expand Down Expand Up @@ -229,7 +226,7 @@ def rectangle_generator(self):
description = "Generating points"
description += " " * (26 - len(description))

for _ in tqdm(range(self.num_rectangles), desc=description, unit="rectangle"):
for _ in tqdm(range(self.num_rectangles), desc=description, unit=" seed point(s)"):
# Choose a random location for the bottom left corner of the rectangle
x0 = np.random.uniform(minx, maxx)
y0 = np.random.uniform(miny, maxy)
Expand Down Expand Up @@ -261,7 +258,7 @@ def rectangle_tiling_generator(self):
x0 = minx
y0 = miny

for _ in tqdm(range(self.num_rectangles), desc=description, unit="rectangle"):
for _ in tqdm(range(self.num_rectangles), desc=description, unit=" seed point(s)"):
# Choose a random width and height for the rectangle
width = np.random.uniform(self.min_width, self.max_width)
height = np.random.uniform(self.min_height, self.max_height)
Expand All @@ -285,42 +282,6 @@ def rectangle_tiling_generator(self):
self.save(self.points, 'points.pkl', data_file=True)
return self.points

def display(self):
"""
Displays the polygon and the generated points. If the points have not been generated yet, an error message is printed.
"""

# Create a new figure and axes
fig, ax = plt.subplots()

# Set the limits of the axes to the SVG dimensions
ax.set_xlim(0, self.svg_width)
ax.set_ylim(0, self.svg_height)

# Display polygon
if isinstance(self.polygon, Polygon):
x, y = self.polygon.exterior.xy
ax.fill(x, y, alpha=0.5, fc='r', ec='none')
elif isinstance(self.polygon, MultiPolygon):
for poly in self.polygon.geoms:
x, y = poly.exterior.xy
ax.plot(x, y, color='r')

# Display bounding box
minx, miny, maxx, maxy = self.polygon.bounds
rect = patches.Rectangle((minx, miny), maxx-minx, maxy-miny, linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rect)

# Display points
if self.points is None:
print("Error: self.points is None. You need to generate points first by calling one of the *_generator() method.")
return

for point in self.points:
ax.plot(*point, 'ko', markersize=1)

plt.show()

def _clean_data(self):
"""
Deletes 'voronoi.pkl' and 'colored.pkl' files if they exist. This method is called before generating a new set of points.
Expand Down
2 changes: 1 addition & 1 deletion AgriFieldGenerator/processing/polyline_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def generate_polylines(self):
with open(self.save_path + 'polylines_colored.layer', 'w') as f:
f.write('\n'.join(polylines))

print(f"Generated {len(polylines)} Enfusion polylines for colored polygons (see 'polylines_colored.layer' in the saves directory).")
print(f"Generated {len(polylines)} Enfusion polylines for colored polygons.\n(see {self.save_path} polylines_colored.layer file).")

return polylines

Expand Down
16 changes: 7 additions & 9 deletions AgriFieldGenerator/processing/spline_to_svg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,11 @@
from .data_processor_base_class import DataProcessorBaseClass

class SplineToSVG(DataProcessorBaseClass):
def __init__(self, svg_file_name, source_path, save_path, save_data_path, svg_height, svg_width, surface_map_resolution, enfusion_spline_layer_file):
super().__init__(source_path=source_path, save_path=save_path, save_data_path=save_data_path)
self.source_path = source_path
self.svg_file_name = source_path + svg_file_name
self.source_layer = self.source_path + enfusion_spline_layer_file
def __init__(self, source_path, save_path, save_data_path, svg_path, svg_height, svg_width, surface_map_resolution, enfusion_spline_layer_file):
super().__init__(source_path=source_path, save_path=save_path, save_data_path=save_data_path, svg_path=svg_path, svg_height=svg_height, svg_width=svg_width)
self.svg_path = svg_path
self.source_layer = source_path + enfusion_spline_layer_file
self.save_path = save_path
self.svg_height = svg_height
self.svg_width = svg_width
self.surface_map_resolution = surface_map_resolution

def parse_spline_file(self):
Expand Down Expand Up @@ -74,7 +71,7 @@ def parse_spline_file(self):

def hermite_to_bezier(self, splines):

dwg = svgwrite.Drawing(self.svg_file_name , profile='tiny', size=(self.svg_height, self.svg_width))
dwg = svgwrite.Drawing(self.svg_path , profile='tiny', size=(self.svg_height, self.svg_width))
for spline in splines:
points = spline['points'][1:] # Exclude only the first point
# Because we use this svg to generate surface masks, we have to divide the spline coordinates by the surface map resolution
Expand All @@ -97,6 +94,7 @@ def hermite_to_bezier(self, splines):
dwg.add(path)

dwg.save()
return self.svg_file_name
print(f"SVG file saved to {self.svg_path}")
return self.svg_path


Loading

0 comments on commit b31ad0f

Please sign in to comment.