diff --git a/android/src/main/java/com/backgroundremover/BackgroundRemoverModule.kt b/android/src/main/java/com/backgroundremover/BackgroundRemoverModule.kt index fb03050..eaf237f 100644 --- a/android/src/main/java/com/backgroundremover/BackgroundRemoverModule.kt +++ b/android/src/main/java/com/backgroundremover/BackgroundRemoverModule.kt @@ -31,7 +31,7 @@ class BackgroundRemoverModule internal constructor(context: ReactApplicationCont } @ReactMethod - override fun removeBackground(imageURI: String, promise: Promise) { + override fun removeBackground(imageURI: String, redValue: Int, greenValue: Int, blueValue: Int, promise: Promise) { val segmenter = this.segmenter ?: createSegmenter() val image = getImageBitmap(imageURI) @@ -45,14 +45,19 @@ class BackgroundRemoverModule internal constructor(context: ReactApplicationCont for (y in 0 until result.height) { for (x in 0 until result.width) { - val alpha = maskBuffer.getFloat().pow(4) - mask.setPixel(x, y, Color.argb((alpha * 255).toInt(), 0, 0, 0)) + val alpha = maskBuffer.getFloat() + if (alpha < 0.5f) { + image.setPixel(x, y, Color.rgb(redValue, greenValue, blueValue)) + } + else { + mask.setPixel(x, y, Color.argb((alpha * 255).toInt(), 255, 255, 255)) + } } } val paint = Paint(Paint.ANTI_ALIAS_FLAG) paint.setXfermode(PorterDuffXfermode(PorterDuff.Mode.DST_IN)) - val canvas = Canvas(image) + val canvas = Canvas() canvas.drawBitmap(mask, 0f, 0f, paint) val fileName = URI(imageURI).path.split("/").last() diff --git a/android/src/oldarch/BackgroundRemoverSpec.kt b/android/src/oldarch/BackgroundRemoverSpec.kt index b22d057..c50f011 100644 --- a/android/src/oldarch/BackgroundRemoverSpec.kt +++ b/android/src/oldarch/BackgroundRemoverSpec.kt @@ -7,5 +7,5 @@ import com.facebook.react.bridge.Promise abstract class BackgroundRemoverSpec internal constructor(context: ReactApplicationContext) : ReactContextBaseJavaModule(context) { - abstract fun removeBackground(imageURI: String, promise: Promise) + abstract fun removeBackground(imageURI: String, redValue: Int, greenValue: Int, blueValue: Int, promise: Promise) } diff --git a/ios/ReactNativeBackgroundRemover.mm b/ios/ReactNativeBackgroundRemover.mm index 1717f93..2f15685 100644 --- a/ios/ReactNativeBackgroundRemover.mm +++ b/ios/ReactNativeBackgroundRemover.mm @@ -1,5 +1,5 @@ #import "ReactNativeBackgroundRemover.h" -#import "ReactNativeBackgroundRemover-Swift.h" +#import "ReactNativeBackgroundRemover/ReactNativeBackgroundRemover-Swift.h" @implementation BackgroundRemover { BackgroundRemoverSwift *backgroundRemover; @@ -18,10 +18,18 @@ - (id)init { } RCT_EXPORT_METHOD(removeBackground:(NSString *)imageURI - resolve:(RCTPromiseResolveBlock)resolve + redValue:(NSInteger)redValue + greenValue:(NSInteger)greenValue + blueValue:(NSInteger)blueValue + resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject) { - [backgroundRemover removeBackground:imageURI resolve:resolve reject:reject]; + [backgroundRemover removeBackground:imageURI + redValue:redValue + greenValue:greenValue + blueValue:blueValue + resolve:resolve + reject:reject]; } // Don't compile this code when we build for the old architecture. diff --git a/ios/ReactNativeBackgroundRemover.swift b/ios/ReactNativeBackgroundRemover.swift index 8a88f7d..b8f78c3 100644 --- a/ios/ReactNativeBackgroundRemover.swift +++ b/ios/ReactNativeBackgroundRemover.swift @@ -4,7 +4,7 @@ import CoreImage public class BackgroundRemoverSwift: NSObject { @objc - public func removeBackground(_ imageURI: String, resolve: @escaping RCTPromiseResolveBlock, reject: @escaping RCTPromiseRejectBlock)->Void { + public func removeBackground(_ imageURI: String, redValue: Int, greenValue: Int, blueValue: Int, resolve: @escaping RCTPromiseResolveBlock, reject: @escaping RCTPromiseRejectBlock)->Void { #if targetEnvironment(simulator) reject("BackgroundRemover", "SimulatorError", NSError(domain: "BackgroundRemover", code: 2)) return @@ -31,11 +31,18 @@ public class BackgroundRemoverSwift: NSObject { let scaleY = originalImage.extent.height / maskImage.extent.height maskImage = maskImage.transformed(by: .init(scaleX: scaleX, y: scaleY)) - let maskedImage = originalImage.applyingFilter("CIBlendWithMask", parameters: [kCIInputMaskImageKey: maskImage]) + // Create a solid color background image + let backgroundColor = CIColor(red: CGFloat(redValue) / 255.0, green: CGFloat(greenValue) / 255.0, blue: CGFloat(blueValue) / 255.0) + let backgroundImage = CIImage(color: backgroundColor).cropped(to: originalImage.extent) + + let maskedImage = originalImage.applyingFilter("CIBlendWithMask", parameters: [kCIInputImageKey: originalImage,kCIInputMaskImageKey: maskImage]) + + // Combine the masked image with the background + let finalImage = maskedImage.composited(over: backgroundImage) // Save the masked image to a temporary file let tempURL = URL(fileURLWithPath: NSTemporaryDirectory()).appendingPathComponent(url.lastPathComponent) - let uiImage = UIImage(ciImage: maskedImage) + let uiImage = UIImage(ciImage: finalImage) if let data = uiImage.pngData() { try data.write(to: tempURL) resolve(tempURL.absoluteString) diff --git a/src/NativeBackgroundRemover.ts b/src/NativeBackgroundRemover.ts index 1865738..885eabc 100644 --- a/src/NativeBackgroundRemover.ts +++ b/src/NativeBackgroundRemover.ts @@ -1,8 +1,9 @@ import type { TurboModule } from 'react-native'; import { TurboModuleRegistry } from 'react-native'; +import { Int32 } from 'react-native/Libraries/Types/CodegenTypes'; export interface Spec extends TurboModule { - removeBackground(imageURI: string): Promise; + removeBackground(imageURI: string, redValue: Int32, greenValue: Int32, blueValue: Int32): Promise; } export default TurboModuleRegistry.getEnforcing('BackgroundRemover'); diff --git a/src/index.tsx b/src/index.tsx index 42b251f..4327053 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -1,4 +1,5 @@ import { NativeModules, Platform } from 'react-native'; +import { Int32 } from 'react-native/Libraries/Types/CodegenTypes'; const LINKING_ERROR = `The package 'react-native-background-remover' doesn't seem to be linked. Make sure: \n\n` + @@ -32,9 +33,9 @@ const BackgroundRemover = BackgroundRemoverModule * @throws Error if the iOS device is not at least on iOS 15.0. * @throws Error if the image could not be processed for an unknown reason. */ -export async function removeBackground(imageURI: string): Promise { +export async function removeBackground(imageURI: string, redValue: Int32, greenValue: Int32, blueValue: Int32): Promise { try { - const result: string = await BackgroundRemover.removeBackground(imageURI); + const result: string = await BackgroundRemover.removeBackground(imageURI, redValue, greenValue, blueValue); return result; } catch (error) { if (error instanceof Error && error.message === 'SimulatorError') {