diff --git a/filters.py b/filters.py index ad246db..60a8145 100644 --- a/filters.py +++ b/filters.py @@ -349,20 +349,20 @@ def computePixel(self, sub): return statistics.median(trimmedSub) class AdaptiveWeightedMedian(SpatialFilter): - def __init__(self, maskSize): + def __init__(self, maskSize, constant, centralWeight): # Create kernel with weights representing distance from centre using equivalent of pythagoras ax = np.linspace(-(maskSize - 1) / 2., (maskSize - 1) / 2., maskSize) xx, yy = np.meshgrid(ax, ax) kernel = np.sqrt(np.square(xx) + np.square(yy)) + # set max weight, used for centre of kernel, and constant used in formula + self.constant = constant + self.centralWeight = centralWeight + super().__init__(maskSize, kernel, name='adaptive-weighted-median', linearity='non-linear') def computePixel(self, sub): - # set max weight, used for centre of kernel, and constant - centralWeight = 100 - k = 10 #TODO: make configurable - # calculate the standard deviation and mean of sub matrix std = np.std(sub) mean = np.mean(sub) @@ -373,18 +373,24 @@ def computePixel(self, sub): pass # create matrix of weights based on sub matrix, using formula for adaptive weighted median filter - # truncate negative weights to zero to ensure low pass characteristics - weights = centralWeight - np.divide(k*std*self.kernel, mean) + weights = self.centralWeight - self.constant*std*np.divide(self.kernel, mean) + + # Identify any negative weights in boolean array + mask = weights < 0 + # Use as inverse mask truncate negative weights to zero to ensure low pass characteristics + weights = np.multiply(np.invert(mask), weights) # use list comprehension to pair each element from sub matrix with respective weighting in tuple - # and sort based on sub matrix elements/ pixel intensities - b = sorted((elementSub, elementKernel) for elementSub, elementKernel in zip(sub.flatten(), weights.flatten())) + # and sort based on sub matrix values/ pixel intensities + pairings = sorted((pixelIntensity, weight) for pixelIntensity, weight in zip(sub.flatten(), weights.flatten())) - # multiply weights with sub matrix pixel intensity values - combinedElements = [elementSub*elementKernel for elementSub, elementKernel in b] + # calculate where median position will be + medIndex = ceil((np.sum(weights) + 1)/ 2) + cs = np.cumsum([pair[1] for pair in pairings]) + medPairIndex = np.searchsorted(cs, medIndex) # return median of list of weighted sub matrix values - return statistics.median(combinedElements) + return pairings[medPairIndex][0] class Mean(SpatialFilter): """ @@ -440,8 +446,7 @@ def __init__(self, maskSize): kernel = np.full((maskSize, maskSize), -1/(maskSize**2)) middle = int((maskSize-1)/2) kernel[middle, middle] = 1 - 1/(maskSize**2) - #TODO: Check for high and low pass filter if they are non-linear or linear - super().__init__(maskSize, kernel, name='high-pass', linearity='non-linear') + super().__init__(maskSize, kernel, name='high-pass', linearity='linear') def computePixel(self, sub): try: @@ -452,13 +457,13 @@ def computePixel(self, sub): return (self.kernel * sub).sum() class LowPass(SpatialFilter): - def __init__(self, maskSize): + def __init__(self, maskSize, middleWeight=1/2, otherWeights=1/8): kernel = np.zeros((maskSize, maskSize)) middle = int((maskSize-1)/2) - kernel[middle, :] = 1/8 - kernel[:, middle] = 1/8 - kernel[middle, middle] = 1/2 + kernel[middle, :] = otherWeights + kernel[:, middle] = otherWeights + kernel[middle, middle] = middleWeight super().__init__(maskSize, kernel, name='low-pass', linearity='non-linear')