Add Metal-based inpainting and brush selection tools

Implement GPU-accelerated inpainting using Metal compute shaders:
- Add Shaders.metal with dilateMask, gaussianBlur, diffuseInpaint, edgeAwareBlend kernels
- Add PatchMatchInpainter class for exemplar-based inpainting
- Update InpaintEngine to use Metal with Accelerate fallback
- Add BrushCanvasView for manual brush-based mask painting
- Add LineBrushView for wire removal line drawing
- Update CanvasView to integrate brush canvas overlay

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-23 23:41:43 -05:00
parent eec086e727
commit 48ee7ecd7c
6 changed files with 1114 additions and 7 deletions

View File

@@ -257,9 +257,9 @@
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_KEY_CFBundleDisplayName = CheapRetouch;
INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.utilities";
INFOPLIST_KEY_NSPhotoLibraryAddUsageDescription = "CheapRetouch needs permission to save edited photos to your library.";
INFOPLIST_KEY_NSPhotoLibraryUsageDescription = "CheapRetouch needs access to your photos to edit and remove unwanted elements from them.";
INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.utilities";
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
@@ -297,9 +297,9 @@
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_KEY_CFBundleDisplayName = CheapRetouch;
INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.utilities";
INFOPLIST_KEY_NSPhotoLibraryAddUsageDescription = "CheapRetouch needs permission to save edited photos to your library.";
INFOPLIST_KEY_NSPhotoLibraryUsageDescription = "CheapRetouch needs access to your photos to edit and remove unwanted elements from them.";
INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.utilities";
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
INFOPLIST_KEY_UILaunchScreen_Generation = YES;

View File

@@ -0,0 +1,331 @@
//
// BrushCanvasView.swift
// CheapRetouch
//
// Canvas overlay for brush-based manual selection.
//
import SwiftUI
import UIKit
struct BrushCanvasView: View {
@Bindable var viewModel: EditorViewModel
let imageSize: CGSize
let displayedImageFrame: CGRect
@State private var currentStroke: [CGPoint] = []
@State private var allStrokes: [[CGPoint]] = []
@State private var isErasing = false
var body: some View {
Canvas { context, size in
// Draw all completed strokes
for stroke in allStrokes {
drawStroke(stroke, in: &context, color: isErasing ? .black : .white)
}
// Draw current stroke
if !currentStroke.isEmpty {
drawStroke(currentStroke, in: &context, color: isErasing ? .black : .white)
}
}
.gesture(drawingGesture)
.overlay(alignment: .bottom) {
brushControls
}
}
private func drawStroke(_ points: [CGPoint], in context: inout GraphicsContext, color: Color) {
guard points.count >= 2 else { return }
var path = Path()
path.move(to: points[0])
if points.count == 2 {
path.addLine(to: points[1])
} else {
for i in 1..<points.count {
let mid = CGPoint(
x: (points[i-1].x + points[i].x) / 2,
y: (points[i-1].y + points[i].y) / 2
)
path.addQuadCurve(to: mid, control: points[i-1])
}
path.addLine(to: points[points.count - 1])
}
context.stroke(
path,
with: .color(color.opacity(0.7)),
style: StrokeStyle(
lineWidth: viewModel.brushSize,
lineCap: .round,
lineJoin: .round
)
)
}
private var drawingGesture: some Gesture {
DragGesture(minimumDistance: 0)
.onChanged { value in
let point = value.location
// Only add points within the image bounds
if displayedImageFrame.contains(point) {
currentStroke.append(point)
}
}
.onEnded { _ in
if !currentStroke.isEmpty {
allStrokes.append(currentStroke)
currentStroke = []
}
}
}
private var brushControls: some View {
HStack(spacing: 16) {
// Erase toggle
Button {
isErasing.toggle()
} label: {
Image(systemName: isErasing ? "eraser.fill" : "eraser")
.font(.title2)
.frame(width: 44, height: 44)
.background(isErasing ? Color.accentColor : Color.clear)
.clipShape(Circle())
}
.accessibilityLabel(isErasing ? "Eraser active" : "Switch to eraser")
// Clear all
Button {
allStrokes.removeAll()
currentStroke.removeAll()
} label: {
Image(systemName: "trash")
.font(.title2)
.frame(width: 44, height: 44)
}
.disabled(allStrokes.isEmpty)
.accessibilityLabel("Clear all strokes")
Spacer()
// Done button
Button {
Task {
await applyBrushMask()
}
} label: {
Text("Done")
.font(.headline)
.padding(.horizontal, 20)
.padding(.vertical, 10)
}
.buttonStyle(.borderedProminent)
.disabled(allStrokes.isEmpty)
}
.padding()
.background(.ultraThinMaterial)
}
private func applyBrushMask() async {
guard !allStrokes.isEmpty else { return }
// Create mask image from strokes
let renderer = UIGraphicsImageRenderer(size: imageSize)
let maskImage = renderer.image { ctx in
// Fill with black (not masked)
UIColor.black.setFill()
ctx.fill(CGRect(origin: .zero, size: imageSize))
// Draw strokes in white (masked areas)
UIColor.white.setStroke()
let scaleX = imageSize.width / displayedImageFrame.width
let scaleY = imageSize.height / displayedImageFrame.height
for stroke in allStrokes {
guard stroke.count >= 2 else { continue }
let path = UIBezierPath()
let firstPoint = CGPoint(
x: (stroke[0].x - displayedImageFrame.minX) * scaleX,
y: (stroke[0].y - displayedImageFrame.minY) * scaleY
)
path.move(to: firstPoint)
for i in 1..<stroke.count {
let point = CGPoint(
x: (stroke[i].x - displayedImageFrame.minX) * scaleX,
y: (stroke[i].y - displayedImageFrame.minY) * scaleY
)
path.addLine(to: point)
}
path.lineWidth = viewModel.brushSize * scaleX
path.lineCapStyle = .round
path.lineJoinStyle = .round
path.stroke()
}
}
if let cgImage = maskImage.cgImage {
await viewModel.applyBrushMask(cgImage)
}
// Clear strokes after applying
allStrokes.removeAll()
}
}
// MARK: - Line Brush View for Wire Tool
struct LineBrushView: View {
@Bindable var viewModel: EditorViewModel
let imageSize: CGSize
let displayedImageFrame: CGRect
@State private var linePoints: [CGPoint] = []
var body: some View {
Canvas { context, size in
guard linePoints.count >= 2 else { return }
var path = Path()
path.move(to: linePoints[0])
for point in linePoints.dropFirst() {
path.addLine(to: point)
}
context.stroke(
path,
with: .color(.white.opacity(0.7)),
style: StrokeStyle(
lineWidth: viewModel.wireWidth,
lineCap: .round,
lineJoin: .round
)
)
}
.gesture(lineDrawingGesture)
.overlay(alignment: .bottom) {
lineControls
}
}
private var lineDrawingGesture: some Gesture {
DragGesture(minimumDistance: 0)
.onChanged { value in
let point = value.location
if displayedImageFrame.contains(point) {
// For line brush, we sample less frequently for smoother lines
if linePoints.isEmpty || distance(from: linePoints.last!, to: point) > 5 {
linePoints.append(point)
}
}
}
.onEnded { _ in
// Line complete, ready to apply
}
}
private var lineControls: some View {
HStack(spacing: 16) {
// Clear
Button {
linePoints.removeAll()
} label: {
Image(systemName: "trash")
.font(.title2)
.frame(width: 44, height: 44)
}
.disabled(linePoints.isEmpty)
.accessibilityLabel("Clear line")
Spacer()
// Cancel
Button {
linePoints.removeAll()
viewModel.selectedTool = .wire
} label: {
Text("Cancel")
.font(.headline)
.padding(.horizontal, 16)
.padding(.vertical, 10)
}
.buttonStyle(.bordered)
// Done button
Button {
Task {
await applyLineMask()
}
} label: {
Text("Remove Line")
.font(.headline)
.padding(.horizontal, 16)
.padding(.vertical, 10)
}
.buttonStyle(.borderedProminent)
.disabled(linePoints.count < 2)
}
.padding()
.background(.ultraThinMaterial)
}
private func applyLineMask() async {
guard linePoints.count >= 2 else { return }
let renderer = UIGraphicsImageRenderer(size: imageSize)
let maskImage = renderer.image { ctx in
UIColor.black.setFill()
ctx.fill(CGRect(origin: .zero, size: imageSize))
UIColor.white.setStroke()
let scaleX = imageSize.width / displayedImageFrame.width
let scaleY = imageSize.height / displayedImageFrame.height
let path = UIBezierPath()
let firstPoint = CGPoint(
x: (linePoints[0].x - displayedImageFrame.minX) * scaleX,
y: (linePoints[0].y - displayedImageFrame.minY) * scaleY
)
path.move(to: firstPoint)
for point in linePoints.dropFirst() {
let scaledPoint = CGPoint(
x: (point.x - displayedImageFrame.minX) * scaleX,
y: (point.y - displayedImageFrame.minY) * scaleY
)
path.addLine(to: scaledPoint)
}
path.lineWidth = viewModel.wireWidth * scaleX
path.lineCapStyle = .round
path.lineJoinStyle = .round
path.stroke()
}
if let cgImage = maskImage.cgImage {
await viewModel.applyBrushMask(cgImage)
}
linePoints.removeAll()
}
private func distance(from p1: CGPoint, to p2: CGPoint) -> CGFloat {
sqrt(pow(p2.x - p1.x, 2) + pow(p2.y - p1.y, 2))
}
}
#Preview {
let viewModel = EditorViewModel()
return BrushCanvasView(
viewModel: viewModel,
imageSize: CGSize(width: 1000, height: 1000),
displayedImageFrame: CGRect(x: 0, y: 0, width: 300, height: 300)
)
}

View File

@@ -44,6 +44,15 @@ struct CanvasView: View {
.blendMode(.multiply)
.colorMultiply(.red.opacity(0.5))
}
// Brush canvas overlay
if viewModel.selectedTool == .brush && !viewModel.showingMaskConfirmation {
BrushCanvasView(
viewModel: viewModel,
imageSize: viewModel.imageSize,
displayedImageFrame: displayedImageFrame(in: geometry.size)
)
}
}
}
.contentShape(Rectangle())
@@ -64,6 +73,43 @@ struct CanvasView: View {
.clipped()
}
// MARK: - Computed Properties
private func displayedImageFrame(in viewSize: CGSize) -> CGRect {
let imageSize = viewModel.imageSize
guard imageSize.width > 0, imageSize.height > 0 else {
return .zero
}
let imageAspect = imageSize.width / imageSize.height
let viewAspect = viewSize.width / viewSize.height
let displayedSize: CGSize
if imageAspect > viewAspect {
displayedSize = CGSize(
width: viewSize.width,
height: viewSize.width / imageAspect
)
} else {
displayedSize = CGSize(
width: viewSize.height * imageAspect,
height: viewSize.height
)
}
let scaledSize = CGSize(
width: displayedSize.width * scale,
height: displayedSize.height * scale
)
let origin = CGPoint(
x: (viewSize.width - scaledSize.width) / 2 + offset.width,
y: (viewSize.height - scaledSize.height) / 2 + offset.height
)
return CGRect(origin: origin, size: scaledSize)
}
// MARK: - Gestures
private func tapGesture(in geometry: GeometryProxy) -> some Gesture {

View File

@@ -45,6 +45,8 @@ actor InpaintEngine {
private let patchRadius: Int
private let maxPreviewSize: Int = 2048
private let maxMemoryBytes: Int = 1_500_000_000 // 1.5GB
private let previewDiffusionIterations: Int = 30
private let fullDiffusionIterations: Int = 100
init(patchRadius: Int = 9) {
self.patchRadius = patchRadius
@@ -97,12 +99,30 @@ actor InpaintEngine {
// MARK: - Metal Implementation
private func inpaintWithMetal(image: CGImage, mask: CGImage, isPreview: Bool) async throws -> CGImage {
guard let device = device, let commandQueue = commandQueue else {
guard let device = device else {
throw InpaintError.metalNotAvailable
}
// For now, fall back to Accelerate while Metal shaders are being developed
return try await inpaintWithAccelerate(image: image, mask: mask)
// Create inpainter with appropriate iteration count for preview vs full
let iterations = isPreview ? previewDiffusionIterations : fullDiffusionIterations
let featherAmount: Float = isPreview ? 2.0 : 4.0
do {
// Metal operations need to run on main actor due to GPU resource management
let result = try await MainActor.run {
let inpainter = try PatchMatchInpainter(
device: device,
patchRadius: 4,
diffusionIterations: iterations
)
return try inpainter.inpaint(image: image, mask: mask, featherAmount: featherAmount)
}
return result
} catch {
// If Metal inpainting fails, fall back to Accelerate
print("Metal inpainting failed: \(error.localizedDescription), falling back to Accelerate")
return try await inpaintWithAccelerate(image: image, mask: mask)
}
}
// MARK: - Accelerate Fallback
@@ -126,8 +146,6 @@ actor InpaintEngine {
}
// Simple inpainting: for each masked pixel, average nearby unmasked pixels
let patchSize = patchRadius * 2 + 1
for y in 0..<height {
for x in 0..<width {
let maskIndex = y * width + x

View File

@@ -0,0 +1,445 @@
//
// PatchMatch.swift
// CheapRetouch
//
// Exemplar-based inpainting algorithm implementation using Metal.
//
import Foundation
import Metal
import MetalKit
import CoreGraphics
import simd
final class PatchMatchInpainter {
private let device: MTLDevice
private let commandQueue: MTLCommandQueue
private let library: MTLLibrary
// Compute pipelines
private let dilateMaskPipeline: MTLComputePipelineState
private let gaussianBlurPipeline: MTLComputePipelineState
private let diffuseInpaintPipeline: MTLComputePipelineState
private let edgeAwareBlendPipeline: MTLComputePipelineState
private let patchRadius: Int
private let diffusionIterations: Int
init(device: MTLDevice, patchRadius: Int = 4, diffusionIterations: Int = 100) throws {
self.device = device
self.patchRadius = patchRadius
self.diffusionIterations = diffusionIterations
guard let queue = device.makeCommandQueue() else {
throw PatchMatchError.commandQueueCreationFailed
}
self.commandQueue = queue
guard let library = device.makeDefaultLibrary() else {
throw PatchMatchError.libraryNotFound
}
self.library = library
// Create compute pipelines
guard let dilateFunc = library.makeFunction(name: "dilateMask") else {
throw PatchMatchError.functionNotFound("dilateMask")
}
self.dilateMaskPipeline = try device.makeComputePipelineState(function: dilateFunc)
guard let blurFunc = library.makeFunction(name: "gaussianBlur") else {
throw PatchMatchError.functionNotFound("gaussianBlur")
}
self.gaussianBlurPipeline = try device.makeComputePipelineState(function: blurFunc)
guard let diffuseFunc = library.makeFunction(name: "diffuseInpaint") else {
throw PatchMatchError.functionNotFound("diffuseInpaint")
}
self.diffuseInpaintPipeline = try device.makeComputePipelineState(function: diffuseFunc)
guard let blendFunc = library.makeFunction(name: "edgeAwareBlend") else {
throw PatchMatchError.functionNotFound("edgeAwareBlend")
}
self.edgeAwareBlendPipeline = try device.makeComputePipelineState(function: blendFunc)
}
func inpaint(image: CGImage, mask: CGImage, featherAmount: Float = 4.0) throws -> CGImage {
let width = image.width
let height = image.height
// Create textures
let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor(
pixelFormat: .rgba8Unorm,
width: width,
height: height,
mipmapped: false
)
textureDescriptor.usage = [.shaderRead, .shaderWrite]
guard let sourceTexture = device.makeTexture(descriptor: textureDescriptor),
let resultTexture = device.makeTexture(descriptor: textureDescriptor),
let tempTexture = device.makeTexture(descriptor: textureDescriptor) else {
throw PatchMatchError.textureCreationFailed
}
// Mask texture (single channel)
let maskDescriptor = MTLTextureDescriptor.texture2DDescriptor(
pixelFormat: .r8Unorm,
width: width,
height: height,
mipmapped: false
)
maskDescriptor.usage = [.shaderRead, .shaderWrite]
guard let maskTexture = device.makeTexture(descriptor: maskDescriptor),
let dilatedMaskTexture = device.makeTexture(descriptor: maskDescriptor),
let featheredMaskTexture = device.makeTexture(descriptor: maskDescriptor) else {
throw PatchMatchError.textureCreationFailed
}
// Load image data into texture
try loadCGImage(image, into: sourceTexture)
try loadCGImageGrayscale(mask, into: maskTexture)
// Copy source to result for initial state
try copyTexture(from: sourceTexture, to: resultTexture)
guard let commandBuffer = commandQueue.makeCommandBuffer() else {
throw PatchMatchError.commandBufferCreationFailed
}
// Step 1: Dilate mask
encodeDilateMask(commandBuffer: commandBuffer, input: maskTexture, output: dilatedMaskTexture, radius: patchRadius)
// Step 2: Feather mask
encodeGaussianBlur(commandBuffer: commandBuffer, input: dilatedMaskTexture, output: featheredMaskTexture, radius: Int(featherAmount))
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
// Step 3: Diffusion-based inpainting (multiple iterations)
for _ in 0..<diffusionIterations {
guard let iterBuffer = commandQueue.makeCommandBuffer() else {
throw PatchMatchError.commandBufferCreationFailed
}
encodeDiffuseInpaint(commandBuffer: iterBuffer, input: resultTexture, mask: dilatedMaskTexture, output: tempTexture)
iterBuffer.commit()
iterBuffer.waitUntilCompleted()
try copyTexture(from: tempTexture, to: resultTexture)
}
// Step 4: Edge-aware blending
guard let finalBuffer = commandQueue.makeCommandBuffer() else {
throw PatchMatchError.commandBufferCreationFailed
}
encodeEdgeAwareBlend(
commandBuffer: finalBuffer,
inpainted: resultTexture,
original: sourceTexture,
mask: featheredMaskTexture,
output: tempTexture,
blendRadius: featherAmount
)
finalBuffer.commit()
finalBuffer.waitUntilCompleted()
// Convert result to CGImage
return try extractCGImage(from: tempTexture)
}
// MARK: - Shader Encoding
private func encodeDilateMask(commandBuffer: MTLCommandBuffer, input: MTLTexture, output: MTLTexture, radius: Int) {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { return }
encoder.setComputePipelineState(dilateMaskPipeline)
encoder.setTexture(input, index: 0)
encoder.setTexture(output, index: 1)
var radiusValue = Int32(radius)
encoder.setBytes(&radiusValue, length: MemoryLayout<Int32>.size, index: 0)
let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1)
let threadgroups = MTLSize(
width: (input.width + 15) / 16,
height: (input.height + 15) / 16,
depth: 1
)
encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
encoder.endEncoding()
}
private func encodeGaussianBlur(commandBuffer: MTLCommandBuffer, input: MTLTexture, output: MTLTexture, radius: Int) {
// Generate Gaussian weights
let sigma = Float(radius) / 3.0
var weights: [Float] = []
for i in 0...radius {
let weight = exp(-Float(i * i) / (2 * sigma * sigma))
weights.append(weight)
}
guard let weightsBuffer = device.makeBuffer(bytes: weights, length: weights.count * MemoryLayout<Float>.size, options: []) else { return }
// Create temp texture for two-pass blur
let tempDescriptor = MTLTextureDescriptor.texture2DDescriptor(
pixelFormat: input.pixelFormat,
width: input.width,
height: input.height,
mipmapped: false
)
tempDescriptor.usage = [.shaderRead, .shaderWrite]
guard let tempTexture = device.makeTexture(descriptor: tempDescriptor) else { return }
// Horizontal pass
if let encoder = commandBuffer.makeComputeCommandEncoder() {
encoder.setComputePipelineState(gaussianBlurPipeline)
encoder.setTexture(input, index: 0)
encoder.setTexture(tempTexture, index: 1)
encoder.setBuffer(weightsBuffer, offset: 0, index: 0)
var radiusValue = Int32(radius)
encoder.setBytes(&radiusValue, length: MemoryLayout<Int32>.size, index: 1)
var horizontal: Bool = true
encoder.setBytes(&horizontal, length: MemoryLayout<Bool>.size, index: 2)
let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1)
let threadgroups = MTLSize(
width: (input.width + 15) / 16,
height: (input.height + 15) / 16,
depth: 1
)
encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
encoder.endEncoding()
}
// Vertical pass
if let encoder = commandBuffer.makeComputeCommandEncoder() {
encoder.setComputePipelineState(gaussianBlurPipeline)
encoder.setTexture(tempTexture, index: 0)
encoder.setTexture(output, index: 1)
encoder.setBuffer(weightsBuffer, offset: 0, index: 0)
var radiusValue = Int32(radius)
encoder.setBytes(&radiusValue, length: MemoryLayout<Int32>.size, index: 1)
var horizontal: Bool = false
encoder.setBytes(&horizontal, length: MemoryLayout<Bool>.size, index: 2)
let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1)
let threadgroups = MTLSize(
width: (input.width + 15) / 16,
height: (input.height + 15) / 16,
depth: 1
)
encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
encoder.endEncoding()
}
}
private func encodeDiffuseInpaint(commandBuffer: MTLCommandBuffer, input: MTLTexture, mask: MTLTexture, output: MTLTexture) {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { return }
encoder.setComputePipelineState(diffuseInpaintPipeline)
encoder.setTexture(input, index: 0)
encoder.setTexture(mask, index: 1)
encoder.setTexture(output, index: 2)
let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1)
let threadgroups = MTLSize(
width: (input.width + 15) / 16,
height: (input.height + 15) / 16,
depth: 1
)
encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
encoder.endEncoding()
}
private func encodeEdgeAwareBlend(
commandBuffer: MTLCommandBuffer,
inpainted: MTLTexture,
original: MTLTexture,
mask: MTLTexture,
output: MTLTexture,
blendRadius: Float
) {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { return }
encoder.setComputePipelineState(edgeAwareBlendPipeline)
encoder.setTexture(inpainted, index: 0)
encoder.setTexture(original, index: 1)
encoder.setTexture(mask, index: 2)
encoder.setTexture(output, index: 3)
var radius = blendRadius
encoder.setBytes(&radius, length: MemoryLayout<Float>.size, index: 0)
let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1)
let threadgroups = MTLSize(
width: (inpainted.width + 15) / 16,
height: (inpainted.height + 15) / 16,
depth: 1
)
encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
encoder.endEncoding()
}
// MARK: - Texture Utilities
private func loadCGImage(_ cgImage: CGImage, into texture: MTLTexture) throws {
let width = cgImage.width
let height = cgImage.height
guard let context = CGContext(
data: nil,
width: width,
height: height,
bitsPerComponent: 8,
bytesPerRow: width * 4,
space: CGColorSpaceCreateDeviceRGB(),
bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue
) else {
throw PatchMatchError.contextCreationFailed
}
context.draw(cgImage, in: CGRect(x: 0, y: 0, width: width, height: height))
guard let data = context.data else {
throw PatchMatchError.dataExtractionFailed
}
texture.replace(
region: MTLRegionMake2D(0, 0, width, height),
mipmapLevel: 0,
withBytes: data,
bytesPerRow: width * 4
)
}
private func loadCGImageGrayscale(_ cgImage: CGImage, into texture: MTLTexture) throws {
let width = cgImage.width
let height = cgImage.height
guard let context = CGContext(
data: nil,
width: width,
height: height,
bitsPerComponent: 8,
bytesPerRow: width,
space: CGColorSpaceCreateDeviceGray(),
bitmapInfo: CGImageAlphaInfo.none.rawValue
) else {
throw PatchMatchError.contextCreationFailed
}
context.draw(cgImage, in: CGRect(x: 0, y: 0, width: width, height: height))
guard let data = context.data else {
throw PatchMatchError.dataExtractionFailed
}
texture.replace(
region: MTLRegionMake2D(0, 0, width, height),
mipmapLevel: 0,
withBytes: data,
bytesPerRow: width
)
}
private func copyTexture(from source: MTLTexture, to destination: MTLTexture) throws {
guard let commandBuffer = commandQueue.makeCommandBuffer(),
let blitEncoder = commandBuffer.makeBlitCommandEncoder() else {
throw PatchMatchError.commandBufferCreationFailed
}
blitEncoder.copy(
from: source,
sourceSlice: 0,
sourceLevel: 0,
sourceOrigin: MTLOrigin(x: 0, y: 0, z: 0),
sourceSize: MTLSize(width: source.width, height: source.height, depth: 1),
to: destination,
destinationSlice: 0,
destinationLevel: 0,
destinationOrigin: MTLOrigin(x: 0, y: 0, z: 0)
)
blitEncoder.endEncoding()
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
}
private func extractCGImage(from texture: MTLTexture) throws -> CGImage {
let width = texture.width
let height = texture.height
let bytesPerRow = width * 4
var pixelData = [UInt8](repeating: 0, count: bytesPerRow * height)
texture.getBytes(
&pixelData,
bytesPerRow: bytesPerRow,
from: MTLRegionMake2D(0, 0, width, height),
mipmapLevel: 0
)
guard let context = CGContext(
data: &pixelData,
width: width,
height: height,
bitsPerComponent: 8,
bytesPerRow: bytesPerRow,
space: CGColorSpaceCreateDeviceRGB(),
bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue
) else {
throw PatchMatchError.contextCreationFailed
}
guard let cgImage = context.makeImage() else {
throw PatchMatchError.imageCreationFailed
}
return cgImage
}
}
// MARK: - Errors
enum PatchMatchError: LocalizedError {
case commandQueueCreationFailed
case libraryNotFound
case functionNotFound(String)
case textureCreationFailed
case commandBufferCreationFailed
case contextCreationFailed
case dataExtractionFailed
case imageCreationFailed
var errorDescription: String? {
switch self {
case .commandQueueCreationFailed:
return "Failed to create Metal command queue"
case .libraryNotFound:
return "Metal shader library not found"
case .functionNotFound(let name):
return "Metal function '\(name)' not found"
case .textureCreationFailed:
return "Failed to create Metal texture"
case .commandBufferCreationFailed:
return "Failed to create command buffer"
case .contextCreationFailed:
return "Failed to create graphics context"
case .dataExtractionFailed:
return "Failed to extract image data"
case .imageCreationFailed:
return "Failed to create output image"
}
}
}

View File

@@ -0,0 +1,267 @@
//
// Shaders.metal
// CheapRetouch
//
// Metal shaders for exemplar-based inpainting.
//
#include <metal_stdlib>
using namespace metal;
// MARK: - Common Types
struct VertexOut {
float4 position [[position]];
float2 texCoord;
};
// MARK: - Mask Operations
// Dilate mask by expanding white regions
kernel void dilateMask(
texture2d<float, access::read> inMask [[texture(0)]],
texture2d<float, access::write> outMask [[texture(1)]],
constant int &radius [[buffer(0)]],
uint2 gid [[thread_position_in_grid]]
) {
if (gid.x >= outMask.get_width() || gid.y >= outMask.get_height()) {
return;
}
float maxValue = 0.0;
for (int dy = -radius; dy <= radius; dy++) {
for (int dx = -radius; dx <= radius; dx++) {
int2 samplePos = int2(gid) + int2(dx, dy);
if (samplePos.x >= 0 && samplePos.x < int(inMask.get_width()) &&
samplePos.y >= 0 && samplePos.y < int(inMask.get_height())) {
float value = inMask.read(uint2(samplePos)).r;
maxValue = max(maxValue, value);
}
}
}
outMask.write(float4(maxValue, maxValue, maxValue, 1.0), gid);
}
// Gaussian blur for mask feathering
kernel void gaussianBlur(
texture2d<float, access::read> inTexture [[texture(0)]],
texture2d<float, access::write> outTexture [[texture(1)]],
constant float *weights [[buffer(0)]],
constant int &radius [[buffer(1)]],
constant bool &horizontal [[buffer(2)]],
uint2 gid [[thread_position_in_grid]]
) {
if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height()) {
return;
}
float4 sum = float4(0.0);
float weightSum = 0.0;
for (int i = -radius; i <= radius; i++) {
int2 offset = horizontal ? int2(i, 0) : int2(0, i);
int2 samplePos = int2(gid) + offset;
if (samplePos.x >= 0 && samplePos.x < int(inTexture.get_width()) &&
samplePos.y >= 0 && samplePos.y < int(inTexture.get_height())) {
float weight = weights[abs(i)];
sum += inTexture.read(uint2(samplePos)) * weight;
weightSum += weight;
}
}
outTexture.write(sum / weightSum, gid);
}
// MARK: - Patch Matching
// Compute sum of squared differences between two patches
kernel void computePatchSSD(
texture2d<float, access::read> sourceTexture [[texture(0)]],
texture2d<float, access::read> maskTexture [[texture(1)]],
texture2d<float, access::write> ssdTexture [[texture(2)]],
constant int2 &targetPos [[buffer(0)]],
constant int &patchRadius [[buffer(1)]],
uint2 gid [[thread_position_in_grid]]
) {
int width = sourceTexture.get_width();
int height = sourceTexture.get_height();
if (gid.x >= uint(width) || gid.y >= uint(height)) {
return;
}
// Skip if this position is in the mask (unknown region)
float maskValue = maskTexture.read(gid).r;
if (maskValue > 0.5) {
ssdTexture.write(float4(1e10), gid);
return;
}
float ssd = 0.0;
int validPixels = 0;
for (int dy = -patchRadius; dy <= patchRadius; dy++) {
for (int dx = -patchRadius; dx <= patchRadius; dx++) {
int2 sourcePos = int2(gid) + int2(dx, dy);
int2 targetSamplePos = targetPos + int2(dx, dy);
// Check bounds
if (sourcePos.x < 0 || sourcePos.x >= width ||
sourcePos.y < 0 || sourcePos.y >= height ||
targetSamplePos.x < 0 || targetSamplePos.x >= width ||
targetSamplePos.y < 0 || targetSamplePos.y >= height) {
continue;
}
// Only compare known pixels in source
float sourceMask = maskTexture.read(uint2(sourcePos)).r;
if (sourceMask > 0.5) {
continue;
}
float4 sourceColor = sourceTexture.read(uint2(sourcePos));
float4 targetColor = sourceTexture.read(uint2(targetSamplePos));
float3 diff = sourceColor.rgb - targetColor.rgb;
ssd += dot(diff, diff);
validPixels++;
}
}
// Normalize by valid pixel count
if (validPixels > 0) {
ssd /= float(validPixels);
} else {
ssd = 1e10;
}
ssdTexture.write(float4(ssd), gid);
}
// MARK: - Inpainting
// Simple diffusion-based inpainting step
kernel void diffuseInpaint(
texture2d<float, access::read> inTexture [[texture(0)]],
texture2d<float, access::read> maskTexture [[texture(1)]],
texture2d<float, access::write> outTexture [[texture(2)]],
uint2 gid [[thread_position_in_grid]]
) {
int width = inTexture.get_width();
int height = inTexture.get_height();
if (gid.x >= uint(width) || gid.y >= uint(height)) {
return;
}
float maskValue = maskTexture.read(gid).r;
// If not in mask, copy original
if (maskValue < 0.5) {
outTexture.write(inTexture.read(gid), gid);
return;
}
// Diffuse from neighbors
float4 sum = float4(0.0);
float weightSum = 0.0;
int offsets[4][2] = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}};
for (int i = 0; i < 4; i++) {
int2 neighborPos = int2(gid) + int2(offsets[i][0], offsets[i][1]);
if (neighborPos.x >= 0 && neighborPos.x < width &&
neighborPos.y >= 0 && neighborPos.y < height) {
float neighborMask = maskTexture.read(uint2(neighborPos)).r;
float weight = (neighborMask < 0.5) ? 2.0 : 1.0;
sum += inTexture.read(uint2(neighborPos)) * weight;
weightSum += weight;
}
}
if (weightSum > 0.0) {
outTexture.write(sum / weightSum, gid);
} else {
outTexture.write(inTexture.read(gid), gid);
}
}
// Copy patch from source to target
kernel void copyPatch(
texture2d<float, access::read> sourceTexture [[texture(0)]],
texture2d<float, access::read> maskTexture [[texture(1)]],
texture2d<float, access::read_write> targetTexture [[texture(2)]],
constant int2 &sourcePos [[buffer(0)]],
constant int2 &targetPos [[buffer(1)]],
constant int &patchRadius [[buffer(2)]],
uint2 gid [[thread_position_in_grid]]
) {
int patchSize = patchRadius * 2 + 1;
if (gid.x >= uint(patchSize) || gid.y >= uint(patchSize)) {
return;
}
int2 offset = int2(gid) - int2(patchRadius);
int2 srcCoord = sourcePos + offset;
int2 dstCoord = targetPos + offset;
int width = sourceTexture.get_width();
int height = sourceTexture.get_height();
// Check bounds
if (srcCoord.x < 0 || srcCoord.x >= width ||
srcCoord.y < 0 || srcCoord.y >= height ||
dstCoord.x < 0 || dstCoord.x >= width ||
dstCoord.y < 0 || dstCoord.y >= height) {
return;
}
// Only write to masked (unknown) pixels
float maskValue = maskTexture.read(uint2(dstCoord)).r;
if (maskValue > 0.5) {
float4 sourceColor = sourceTexture.read(uint2(srcCoord));
targetTexture.write(sourceColor, uint2(dstCoord));
}
}
// MARK: - Edge-aware blending
kernel void edgeAwareBlend(
texture2d<float, access::read> inTexture [[texture(0)]],
texture2d<float, access::read> originalTexture [[texture(1)]],
texture2d<float, access::read> maskTexture [[texture(2)]],
texture2d<float, access::write> outTexture [[texture(3)]],
constant float &blendRadius [[buffer(0)]],
uint2 gid [[thread_position_in_grid]]
) {
int width = inTexture.get_width();
int height = inTexture.get_height();
if (gid.x >= uint(width) || gid.y >= uint(height)) {
return;
}
float maskValue = maskTexture.read(gid).r;
float4 inpaintedColor = inTexture.read(gid);
float4 originalColor = originalTexture.read(gid);
// If far from mask boundary, use appropriate color directly
if (maskValue < 0.1) {
outTexture.write(originalColor, gid);
return;
}
if (maskValue > 0.9) {
outTexture.write(inpaintedColor, gid);
return;
}
// Blend in the transition zone
float4 result = mix(originalColor, inpaintedColor, maskValue);
outTexture.write(result, gid);
}