Add LaMa Core ML model for AI-powered inpainting

- Add LaMaFP16_512.mlpackage (~90MB) for high-quality object removal
- Add LaMaInpainter.swift wrapper with image preprocessing and merging
- Modify InpaintEngine to use LaMa first, gradient fill as fallback
- Fix brush mask size (use scale 1.0 instead of screen scale)
- Fix LaMa output size (use scale 1.0 in merge function)
- Add model loading wait with 5 second timeout

The LaMa model provides significantly better inpainting quality compared
to the gradient fill method, especially for complex backgrounds.
This commit is contained in:
2026-01-24 14:31:54 -05:00
parent 14e2502cf4
commit eb047e27b8
7 changed files with 132598 additions and 3 deletions

1
.gitignore vendored
View File

@@ -54,3 +54,4 @@ logs/
# Temporary files
*.tmp
*.temp
inpaint-ios-reference/

View File

@@ -205,8 +205,10 @@ struct BrushCanvasView: View {
}
}
// Create mask image from strokes
let renderer = UIGraphicsImageRenderer(size: imageSize)
// Create mask image from strokes (use scale 1.0 to match actual image pixels)
let format = UIGraphicsImageRendererFormat()
format.scale = 1.0 // Don't use screen scale, use actual pixel size
let renderer = UIGraphicsImageRenderer(size: imageSize, format: format)
let maskImage = renderer.image { ctx in
// Fill with black (not masked)
UIColor.black.setFill()

View File

@@ -74,8 +74,19 @@ actor InpaintEngine {
throw InpaintError.memoryPressure
}
// Try LaMa AI inpainting first for best quality
do {
DebugLogger.log("Attempting LaMa AI inpainting...")
let lamaResult = try await lamaInpainter.inpaint(image: image, mask: mask)
DebugLogger.imageInfo("Inpaint result (LaMa)", image: lamaResult)
return lamaResult
} catch {
DebugLogger.log("LaMa failed: \(error.localizedDescription), falling back to gradient fill")
}
// Fallback to gradient-based Metal inpainting
if isMetalAvailable {
DebugLogger.log("Using Metal for inpainting")
DebugLogger.log("Using Metal gradient fill for inpainting")
return try await inpaintWithMetal(image: image, mask: mask, isPreview: false)
} else {
DebugLogger.log("Using Accelerate fallback for inpainting")
@@ -83,6 +94,9 @@ actor InpaintEngine {
}
}
/// LaMa inpainter for AI-powered inpainting (initialized lazily on first use)
private nonisolated(unsafe) lazy var lamaInpainter: LaMaInpainter = LaMaInpainter()
func inpaintPreview(image: CGImage, mask: CGImage) async throws -> CGImage {
// Scale down for preview if needed
let scaledImage: CGImage

View File

@@ -0,0 +1,18 @@
{
"fileFormatVersion": "1.0.0",
"itemInfoEntries": {
"834F1D4D-C413-4927-9314-FF1187E2F6B4": {
"author": "com.apple.CoreML",
"description": "CoreML Model Weights",
"name": "weights",
"path": "com.apple.CoreML/weights"
},
"DFC457AE-DC53-4BC5-B66A-1A6B1CB59064": {
"author": "com.apple.CoreML",
"description": "CoreML Model Specification",
"name": "model.mlmodel",
"path": "com.apple.CoreML/model.mlmodel"
}
},
"rootModelIdentifier": "DFC457AE-DC53-4BC5-B66A-1A6B1CB59064"
}

View File

@@ -0,0 +1,346 @@
//
// LaMaInpainter.swift
// CheapRetouch
//
// LaMa (Large Mask Inpainting) Core ML wrapper for AI-powered object removal.
// Based on https://github.com/wudijimao/Inpaint-iOS
//
import UIKit
import CoreML
/// LaMa-based inpainting using Core ML
/// Provides high-quality AI-powered inpainting for object removal
final class LaMaInpainter {
/// Fixed input size for the LaMa model
private let modelSize: Int = 512
/// The Core ML model instance
private var model: LaMaFP16_512?
/// Configuration for the model
private let config: MLModelConfiguration
/// Work queue for inference
private let workQueue = DispatchQueue(label: "com.cheapretouch.lama", qos: .userInitiated)
/// CI context for image operations
private let ciContext = CIContext()
init() {
config = MLModelConfiguration()
config.computeUnits = .cpuAndGPU
preloadModel()
}
/// Preload the model in the background
private func preloadModel() {
workQueue.async { [weak self] in
guard let self = self else { return }
do {
self.model = try LaMaFP16_512(configuration: self.config)
DebugLogger.log("LaMa model loaded successfully")
} catch {
DebugLogger.error("Failed to load LaMa model: \(error)")
}
}
}
/// Check if the model is ready
var isModelReady: Bool {
return model != nil
}
/// Inpaint the masked region of an image using LaMa
/// - Parameters:
/// - image: The source CGImage
/// - mask: The mask CGImage (white = areas to inpaint)
/// - Returns: The inpainted CGImage
func inpaint(image: CGImage, mask: CGImage) async throws -> CGImage {
// Wait for model to be ready (with timeout)
let maxWaitTime: TimeInterval = 5.0
let startTime = Date()
while model == nil {
if -startTime.timeIntervalSinceNow > maxWaitTime {
throw LaMaError.modelNotLoaded
}
try await Task.sleep(nanoseconds: 100_000_000) // 100ms
}
guard let model = model else {
throw LaMaError.modelNotLoaded
}
let originalSize = CGSize(width: image.width, height: image.height)
return try await withCheckedThrowingContinuation { continuation in
workQueue.async {
do {
let result = try self.performInpainting(
model: model,
image: image,
mask: mask,
originalSize: originalSize
)
continuation.resume(returning: result)
} catch {
continuation.resume(throwing: error)
}
}
}
}
/// Perform the actual inpainting operation
private func performInpainting(
model: LaMaFP16_512,
image: CGImage,
mask: CGImage,
originalSize: CGSize
) throws -> CGImage {
DebugLogger.processing("LaMa inpainting started")
let startTime = Date()
// Find the bounding box of the mask to crop efficiently
let maskBounds = findMaskBounds(mask: mask)
DebugLogger.log("Mask bounds: \(maskBounds)")
// Calculate the crop region (expand to square and add padding)
let cropRegion = calculateCropRegion(
maskBounds: maskBounds,
imageSize: originalSize,
targetSize: modelSize
)
DebugLogger.log("Crop region: \(cropRegion)")
// Crop the image and mask to the region
guard let croppedImage = cropImage(image, to: cropRegion),
let croppedMask = cropImage(mask, to: cropRegion) else {
throw LaMaError.imageProcessingFailed
}
// Resize to model input size
guard let resizedImage = resizeImage(croppedImage, to: CGSize(width: modelSize, height: modelSize)),
let resizedMask = resizeImage(croppedMask, to: CGSize(width: modelSize, height: modelSize)) else {
throw LaMaError.imageProcessingFailed
}
// Convert to pixel buffers
guard let imageBuffer = createPixelBuffer(from: resizedImage, format: kCVPixelFormatType_32ARGB),
let maskBuffer = createGrayscalePixelBuffer(from: resizedMask) else {
throw LaMaError.bufferCreationFailed
}
// Run inference
DebugLogger.log("Running LaMa inference...")
let output = try model.prediction(image: imageBuffer, mask: maskBuffer)
// Convert output to CGImage
guard let outputImage = cgImageFromPixelBuffer(output.output) else {
throw LaMaError.outputConversionFailed
}
// Resize output back to crop region size
guard let resizedOutput = resizeImage(outputImage, to: cropRegion.size) else {
throw LaMaError.imageProcessingFailed
}
// Merge back into original image
let finalImage = mergeIntoOriginal(
original: image,
inpainted: resizedOutput,
at: cropRegion.origin
)
let elapsed = -startTime.timeIntervalSinceNow
DebugLogger.processing("LaMa inpainting completed in \(String(format: "%.2f", elapsed))s")
return finalImage
}
// MARK: - Helper Methods
/// Find the bounding box of white pixels in the mask
private func findMaskBounds(mask: CGImage) -> CGRect {
let width = mask.width
let height = mask.height
guard let data = mask.dataProvider?.data,
let bytes = CFDataGetBytePtr(data) else {
return CGRect(x: 0, y: 0, width: CGFloat(width), height: CGFloat(height))
}
var minX = width, minY = height, maxX = 0, maxY = 0
let bytesPerPixel = mask.bitsPerPixel / 8
let bytesPerRow = mask.bytesPerRow
for y in 0..<height {
for x in 0..<width {
let offset = y * bytesPerRow + x * bytesPerPixel
let value = bytes[offset]
if value > 127 {
minX = min(minX, x)
minY = min(minY, y)
maxX = max(maxX, x)
maxY = max(maxY, y)
}
}
}
if minX > maxX || minY > maxY {
// No white pixels found, return full image
return CGRect(x: 0, y: 0, width: CGFloat(width), height: CGFloat(height))
}
return CGRect(
x: CGFloat(minX),
y: CGFloat(minY),
width: CGFloat(maxX - minX + 1),
height: CGFloat(maxY - minY + 1)
)
}
/// Calculate the crop region for the mask area
private func calculateCropRegion(maskBounds: CGRect, imageSize: CGSize, targetSize: Int) -> CGRect {
// Add 20% padding around the mask
let padding = max(maskBounds.width, maskBounds.height) * 0.2
var region = maskBounds.insetBy(dx: -padding, dy: -padding)
// Make it square (use the larger dimension)
let maxSide = max(region.width, region.height)
let centerX = region.midX
let centerY = region.midY
region = CGRect(
x: centerX - maxSide / 2,
y: centerY - maxSide / 2,
width: maxSide,
height: maxSide
)
// Ensure minimum size matches model size
if region.width < CGFloat(targetSize) {
let diff = CGFloat(targetSize) - region.width
region = region.insetBy(dx: -diff / 2, dy: -diff / 2)
}
// Clamp to image bounds
region.origin.x = max(0, min(region.origin.x, imageSize.width - region.width))
region.origin.y = max(0, min(region.origin.y, imageSize.height - region.height))
// Ensure we don't exceed image bounds
if region.maxX > imageSize.width {
region.origin.x = imageSize.width - region.width
}
if region.maxY > imageSize.height {
region.origin.y = imageSize.height - region.height
}
// Final clamp if region is larger than image
region.origin.x = max(0, region.origin.x)
region.origin.y = max(0, region.origin.y)
region.size.width = min(region.width, imageSize.width)
region.size.height = min(region.height, imageSize.height)
return CGRect(
x: floor(region.origin.x),
y: floor(region.origin.y),
width: ceil(region.width),
height: ceil(region.height)
)
}
/// Crop an image to the specified region
private func cropImage(_ image: CGImage, to rect: CGRect) -> CGImage? {
return image.cropping(to: rect)
}
/// Resize an image to the specified size
private func resizeImage(_ image: CGImage, to size: CGSize) -> CGImage? {
let ciImage = CIImage(cgImage: image)
let scaleX = size.width / CGFloat(image.width)
let scaleY = size.height / CGFloat(image.height)
let scaled = ciImage.transformed(by: CGAffineTransform(scaleX: scaleX, y: scaleY))
return ciContext.createCGImage(scaled, from: scaled.extent)
}
/// Create an ARGB pixel buffer from a CGImage
private func createPixelBuffer(from image: CGImage, format: OSType) -> CVPixelBuffer? {
do {
let feature = try MLFeatureValue(
cgImage: image,
pixelsWide: modelSize,
pixelsHigh: modelSize,
pixelFormatType: format
)
return feature.imageBufferValue
} catch {
DebugLogger.error("Failed to create pixel buffer: \(error)")
return nil
}
}
/// Create a grayscale pixel buffer from a CGImage
private func createGrayscalePixelBuffer(from image: CGImage) -> CVPixelBuffer? {
do {
let feature = try MLFeatureValue(
cgImage: image,
pixelsWide: modelSize,
pixelsHigh: modelSize,
pixelFormatType: kCVPixelFormatType_OneComponent8
)
return feature.imageBufferValue
} catch {
DebugLogger.error("Failed to create grayscale buffer: \(error)")
return nil
}
}
/// Convert a CVPixelBuffer to CGImage
private func cgImageFromPixelBuffer(_ buffer: CVPixelBuffer) -> CGImage? {
let ciImage = CIImage(cvPixelBuffer: buffer)
return ciContext.createCGImage(ciImage, from: ciImage.extent)
}
/// Merge the inpainted region back into the original image
private func mergeIntoOriginal(original: CGImage, inpainted: CGImage, at position: CGPoint) -> CGImage {
let size = CGSize(width: original.width, height: original.height)
let inpaintedSize = CGSize(width: inpainted.width, height: inpainted.height)
// Use scale 1.0 to match actual pixel size (not screen scale)
let format = UIGraphicsImageRendererFormat()
format.scale = 1.0
let renderer = UIGraphicsImageRenderer(size: size, format: format)
let resultImage = renderer.image { context in
// Draw original
UIImage(cgImage: original).draw(at: .zero)
// Draw inpainted region on top
UIImage(cgImage: inpainted).draw(in: CGRect(origin: position, size: inpaintedSize))
}
return resultImage.cgImage!
}
}
// MARK: - Errors
enum LaMaError: Error, LocalizedError {
case modelNotLoaded
case imageProcessingFailed
case bufferCreationFailed
case outputConversionFailed
var errorDescription: String? {
switch self {
case .modelNotLoaded:
return "LaMa model is not loaded"
case .imageProcessingFailed:
return "Failed to process image for inpainting"
case .bufferCreationFailed:
return "Failed to create pixel buffer"
case .outputConversionFailed:
return "Failed to convert model output to image"
}
}
}