Skip to content

Commit

Permalink
Corrected implementation of adaptive weighted median filter.
Browse files Browse the repository at this point in the history
  • Loading branch information
tbsfchnr committed Nov 17, 2020
1 parent bea3f01 commit b54883d
Showing 1 changed file with 23 additions and 18 deletions.
41 changes: 23 additions & 18 deletions filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,20 +349,20 @@ def computePixel(self, sub):
return statistics.median(trimmedSub)

class AdaptiveWeightedMedian(SpatialFilter):
def __init__(self, maskSize):
def __init__(self, maskSize, constant, centralWeight):

# Create kernel with weights representing distance from centre using equivalent of pythagoras
ax = np.linspace(-(maskSize - 1) / 2., (maskSize - 1) / 2., maskSize)
xx, yy = np.meshgrid(ax, ax)
kernel = np.sqrt(np.square(xx) + np.square(yy))

# set max weight, used for centre of kernel, and constant used in formula
self.constant = constant
self.centralWeight = centralWeight

super().__init__(maskSize, kernel, name='adaptive-weighted-median', linearity='non-linear')

def computePixel(self, sub):
# set max weight, used for centre of kernel, and constant
centralWeight = 100
k = 10 #TODO: make configurable

# calculate the standard deviation and mean of sub matrix
std = np.std(sub)
mean = np.mean(sub)
Expand All @@ -373,18 +373,24 @@ def computePixel(self, sub):
pass

# create matrix of weights based on sub matrix, using formula for adaptive weighted median filter
# truncate negative weights to zero to ensure low pass characteristics
weights = centralWeight - np.divide(k*std*self.kernel, mean)
weights = self.centralWeight - self.constant*std*np.divide(self.kernel, mean)

# Identify any negative weights in boolean array
mask = weights < 0
# Use as inverse mask truncate negative weights to zero to ensure low pass characteristics
weights = np.multiply(np.invert(mask), weights)

# use list comprehension to pair each element from sub matrix with respective weighting in tuple
# and sort based on sub matrix elements/ pixel intensities
b = sorted((elementSub, elementKernel) for elementSub, elementKernel in zip(sub.flatten(), weights.flatten()))
# and sort based on sub matrix values/ pixel intensities
pairings = sorted((pixelIntensity, weight) for pixelIntensity, weight in zip(sub.flatten(), weights.flatten()))

# multiply weights with sub matrix pixel intensity values
combinedElements = [elementSub*elementKernel for elementSub, elementKernel in b]
# calculate where median position will be
medIndex = ceil((np.sum(weights) + 1)/ 2)
cs = np.cumsum([pair[1] for pair in pairings])
medPairIndex = np.searchsorted(cs, medIndex)

# return median of list of weighted sub matrix values
return statistics.median(combinedElements)
return pairings[medPairIndex][0]

class Mean(SpatialFilter):
"""
Expand Down Expand Up @@ -440,8 +446,7 @@ def __init__(self, maskSize):
kernel = np.full((maskSize, maskSize), -1/(maskSize**2))
middle = int((maskSize-1)/2)
kernel[middle, middle] = 1 - 1/(maskSize**2)
#TODO: Check for high and low pass filter if they are non-linear or linear
super().__init__(maskSize, kernel, name='high-pass', linearity='non-linear')
super().__init__(maskSize, kernel, name='high-pass', linearity='linear')

def computePixel(self, sub):
try:
Expand All @@ -452,13 +457,13 @@ def computePixel(self, sub):
return (self.kernel * sub).sum()

class LowPass(SpatialFilter):
def __init__(self, maskSize):
def __init__(self, maskSize, middleWeight=1/2, otherWeights=1/8):

kernel = np.zeros((maskSize, maskSize))
middle = int((maskSize-1)/2)
kernel[middle, :] = 1/8
kernel[:, middle] = 1/8
kernel[middle, middle] = 1/2
kernel[middle, :] = otherWeights
kernel[:, middle] = otherWeights
kernel[middle, middle] = middleWeight

super().__init__(maskSize, kernel, name='low-pass', linearity='non-linear')

Expand Down

0 comments on commit b54883d

Please sign in to comment.