Skip to content

Commit

Permalink
augmentation constructor takes str or np.random.Generator
Browse files Browse the repository at this point in the history
  • Loading branch information
MasonDill committed Nov 12, 2024
1 parent 367e6fb commit af5193a
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions ctc_otfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ctc_utils
import os
import json
from typing import Union
from abc import ABC, abstractmethod
import ctc_otfa
import tensorflow as tf
Expand Down Expand Up @@ -103,18 +104,31 @@ def apply_augmentations(image, augmentations):


class augmentation(ABC):
def __init__(self, type, distribution, variance):
def __init__(self, type:str, distribution: Union[str, np.random.Generator], variance:float):
self.type = str(type)

if isinstance(distribution, np.random.Generator):
self.distribution = distribution
else:
raise Exception('Invalid distribution type \"' + str(distribution) + '\"')
self.variance = float(variance)

if isinstance(distribution, str):
self.distribution = self.str_to_distribution(distribution=distribution) #convert to np.random.Generator
else:
self.distribution = distribution

def __str__(self) -> str:
return 'AUGMENTATION: ' +self.type + ', ' + str(self.variance) + ', ' + str(self.distribution)

def str_to_distribution(self, distribution: str) -> np.random.Generator:
distribution.lower()
if distribution == "gaussian" or distribution == "normal":
return np.random.normal
elif distribution == "uniform":
return np.random.uniform
elif distribution == "random":
return np.random.random
elif distribution == "binomial":
return np.random.binomial
else:
raise Exception('Invalid distribution type \"' + str(distribution) + '\"')

# This function should take a tensor and return a tensor
@abstractmethod
def augment(self):
Expand Down

0 comments on commit af5193a

Please sign in to comment.