From 48ee7ecd7c6b6862a61820ef9c996dbeccc1ae02 Mon Sep 17 00:00:00 2001 From: jared Date: Fri, 23 Jan 2026 23:41:43 -0500 Subject: [PATCH] Add Metal-based inpainting and brush selection tools Implement GPU-accelerated inpainting using Metal compute shaders: - Add Shaders.metal with dilateMask, gaussianBlur, diffuseInpaint, edgeAwareBlend kernels - Add PatchMatchInpainter class for exemplar-based inpainting - Update InpaintEngine to use Metal with Accelerate fallback - Add BrushCanvasView for manual brush-based mask painting - Add LineBrushView for wire removal line drawing - Update CanvasView to integrate brush canvas overlay Co-Authored-By: Claude Opus 4.5 --- CheapRetouch.xcodeproj/project.pbxproj | 4 +- .../Features/Editor/BrushCanvasView.swift | 331 +++++++++++++ CheapRetouch/Features/Editor/CanvasView.swift | 46 ++ .../InpaintEngine/InpaintEngine.swift | 28 +- .../Services/InpaintEngine/PatchMatch.swift | 445 ++++++++++++++++++ .../Services/InpaintEngine/Shaders.metal | 267 +++++++++++ 6 files changed, 1114 insertions(+), 7 deletions(-) create mode 100644 CheapRetouch/Features/Editor/BrushCanvasView.swift create mode 100644 CheapRetouch/Services/InpaintEngine/PatchMatch.swift create mode 100644 CheapRetouch/Services/InpaintEngine/Shaders.metal diff --git a/CheapRetouch.xcodeproj/project.pbxproj b/CheapRetouch.xcodeproj/project.pbxproj index 874d2bf..d52fd3b 100644 --- a/CheapRetouch.xcodeproj/project.pbxproj +++ b/CheapRetouch.xcodeproj/project.pbxproj @@ -257,9 +257,9 @@ ENABLE_PREVIEWS = YES; GENERATE_INFOPLIST_FILE = YES; INFOPLIST_KEY_CFBundleDisplayName = CheapRetouch; + INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.utilities"; INFOPLIST_KEY_NSPhotoLibraryAddUsageDescription = "CheapRetouch needs permission to save edited photos to your library."; INFOPLIST_KEY_NSPhotoLibraryUsageDescription = "CheapRetouch needs access to your photos to edit and remove unwanted elements from them."; - INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.utilities"; INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; INFOPLIST_KEY_UILaunchScreen_Generation = YES; @@ -297,9 +297,9 @@ ENABLE_PREVIEWS = YES; GENERATE_INFOPLIST_FILE = YES; INFOPLIST_KEY_CFBundleDisplayName = CheapRetouch; + INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.utilities"; INFOPLIST_KEY_NSPhotoLibraryAddUsageDescription = "CheapRetouch needs permission to save edited photos to your library."; INFOPLIST_KEY_NSPhotoLibraryUsageDescription = "CheapRetouch needs access to your photos to edit and remove unwanted elements from them."; - INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.utilities"; INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; INFOPLIST_KEY_UILaunchScreen_Generation = YES; diff --git a/CheapRetouch/Features/Editor/BrushCanvasView.swift b/CheapRetouch/Features/Editor/BrushCanvasView.swift new file mode 100644 index 0000000..dba468d --- /dev/null +++ b/CheapRetouch/Features/Editor/BrushCanvasView.swift @@ -0,0 +1,331 @@ +// +// BrushCanvasView.swift +// CheapRetouch +// +// Canvas overlay for brush-based manual selection. +// + +import SwiftUI +import UIKit + +struct BrushCanvasView: View { + @Bindable var viewModel: EditorViewModel + let imageSize: CGSize + let displayedImageFrame: CGRect + + @State private var currentStroke: [CGPoint] = [] + @State private var allStrokes: [[CGPoint]] = [] + @State private var isErasing = false + + var body: some View { + Canvas { context, size in + // Draw all completed strokes + for stroke in allStrokes { + drawStroke(stroke, in: &context, color: isErasing ? .black : .white) + } + + // Draw current stroke + if !currentStroke.isEmpty { + drawStroke(currentStroke, in: &context, color: isErasing ? .black : .white) + } + } + .gesture(drawingGesture) + .overlay(alignment: .bottom) { + brushControls + } + } + + private func drawStroke(_ points: [CGPoint], in context: inout GraphicsContext, color: Color) { + guard points.count >= 2 else { return } + + var path = Path() + path.move(to: points[0]) + + if points.count == 2 { + path.addLine(to: points[1]) + } else { + for i in 1..= 2 else { continue } + + let path = UIBezierPath() + let firstPoint = CGPoint( + x: (stroke[0].x - displayedImageFrame.minX) * scaleX, + y: (stroke[0].y - displayedImageFrame.minY) * scaleY + ) + path.move(to: firstPoint) + + for i in 1..= 2 else { return } + + var path = Path() + path.move(to: linePoints[0]) + + for point in linePoints.dropFirst() { + path.addLine(to: point) + } + + context.stroke( + path, + with: .color(.white.opacity(0.7)), + style: StrokeStyle( + lineWidth: viewModel.wireWidth, + lineCap: .round, + lineJoin: .round + ) + ) + } + .gesture(lineDrawingGesture) + .overlay(alignment: .bottom) { + lineControls + } + } + + private var lineDrawingGesture: some Gesture { + DragGesture(minimumDistance: 0) + .onChanged { value in + let point = value.location + if displayedImageFrame.contains(point) { + // For line brush, we sample less frequently for smoother lines + if linePoints.isEmpty || distance(from: linePoints.last!, to: point) > 5 { + linePoints.append(point) + } + } + } + .onEnded { _ in + // Line complete, ready to apply + } + } + + private var lineControls: some View { + HStack(spacing: 16) { + // Clear + Button { + linePoints.removeAll() + } label: { + Image(systemName: "trash") + .font(.title2) + .frame(width: 44, height: 44) + } + .disabled(linePoints.isEmpty) + .accessibilityLabel("Clear line") + + Spacer() + + // Cancel + Button { + linePoints.removeAll() + viewModel.selectedTool = .wire + } label: { + Text("Cancel") + .font(.headline) + .padding(.horizontal, 16) + .padding(.vertical, 10) + } + .buttonStyle(.bordered) + + // Done button + Button { + Task { + await applyLineMask() + } + } label: { + Text("Remove Line") + .font(.headline) + .padding(.horizontal, 16) + .padding(.vertical, 10) + } + .buttonStyle(.borderedProminent) + .disabled(linePoints.count < 2) + } + .padding() + .background(.ultraThinMaterial) + } + + private func applyLineMask() async { + guard linePoints.count >= 2 else { return } + + let renderer = UIGraphicsImageRenderer(size: imageSize) + let maskImage = renderer.image { ctx in + UIColor.black.setFill() + ctx.fill(CGRect(origin: .zero, size: imageSize)) + + UIColor.white.setStroke() + + let scaleX = imageSize.width / displayedImageFrame.width + let scaleY = imageSize.height / displayedImageFrame.height + + let path = UIBezierPath() + let firstPoint = CGPoint( + x: (linePoints[0].x - displayedImageFrame.minX) * scaleX, + y: (linePoints[0].y - displayedImageFrame.minY) * scaleY + ) + path.move(to: firstPoint) + + for point in linePoints.dropFirst() { + let scaledPoint = CGPoint( + x: (point.x - displayedImageFrame.minX) * scaleX, + y: (point.y - displayedImageFrame.minY) * scaleY + ) + path.addLine(to: scaledPoint) + } + + path.lineWidth = viewModel.wireWidth * scaleX + path.lineCapStyle = .round + path.lineJoinStyle = .round + path.stroke() + } + + if let cgImage = maskImage.cgImage { + await viewModel.applyBrushMask(cgImage) + } + + linePoints.removeAll() + } + + private func distance(from p1: CGPoint, to p2: CGPoint) -> CGFloat { + sqrt(pow(p2.x - p1.x, 2) + pow(p2.y - p1.y, 2)) + } +} + +#Preview { + let viewModel = EditorViewModel() + return BrushCanvasView( + viewModel: viewModel, + imageSize: CGSize(width: 1000, height: 1000), + displayedImageFrame: CGRect(x: 0, y: 0, width: 300, height: 300) + ) +} diff --git a/CheapRetouch/Features/Editor/CanvasView.swift b/CheapRetouch/Features/Editor/CanvasView.swift index 388827f..0f00033 100644 --- a/CheapRetouch/Features/Editor/CanvasView.swift +++ b/CheapRetouch/Features/Editor/CanvasView.swift @@ -44,6 +44,15 @@ struct CanvasView: View { .blendMode(.multiply) .colorMultiply(.red.opacity(0.5)) } + + // Brush canvas overlay + if viewModel.selectedTool == .brush && !viewModel.showingMaskConfirmation { + BrushCanvasView( + viewModel: viewModel, + imageSize: viewModel.imageSize, + displayedImageFrame: displayedImageFrame(in: geometry.size) + ) + } } } .contentShape(Rectangle()) @@ -64,6 +73,43 @@ struct CanvasView: View { .clipped() } + // MARK: - Computed Properties + + private func displayedImageFrame(in viewSize: CGSize) -> CGRect { + let imageSize = viewModel.imageSize + guard imageSize.width > 0, imageSize.height > 0 else { + return .zero + } + + let imageAspect = imageSize.width / imageSize.height + let viewAspect = viewSize.width / viewSize.height + + let displayedSize: CGSize + if imageAspect > viewAspect { + displayedSize = CGSize( + width: viewSize.width, + height: viewSize.width / imageAspect + ) + } else { + displayedSize = CGSize( + width: viewSize.height * imageAspect, + height: viewSize.height + ) + } + + let scaledSize = CGSize( + width: displayedSize.width * scale, + height: displayedSize.height * scale + ) + + let origin = CGPoint( + x: (viewSize.width - scaledSize.width) / 2 + offset.width, + y: (viewSize.height - scaledSize.height) / 2 + offset.height + ) + + return CGRect(origin: origin, size: scaledSize) + } + // MARK: - Gestures private func tapGesture(in geometry: GeometryProxy) -> some Gesture { diff --git a/CheapRetouch/Services/InpaintEngine/InpaintEngine.swift b/CheapRetouch/Services/InpaintEngine/InpaintEngine.swift index 6ec1b58..b2e2b8b 100644 --- a/CheapRetouch/Services/InpaintEngine/InpaintEngine.swift +++ b/CheapRetouch/Services/InpaintEngine/InpaintEngine.swift @@ -45,6 +45,8 @@ actor InpaintEngine { private let patchRadius: Int private let maxPreviewSize: Int = 2048 private let maxMemoryBytes: Int = 1_500_000_000 // 1.5GB + private let previewDiffusionIterations: Int = 30 + private let fullDiffusionIterations: Int = 100 init(patchRadius: Int = 9) { self.patchRadius = patchRadius @@ -97,12 +99,30 @@ actor InpaintEngine { // MARK: - Metal Implementation private func inpaintWithMetal(image: CGImage, mask: CGImage, isPreview: Bool) async throws -> CGImage { - guard let device = device, let commandQueue = commandQueue else { + guard let device = device else { throw InpaintError.metalNotAvailable } - // For now, fall back to Accelerate while Metal shaders are being developed - return try await inpaintWithAccelerate(image: image, mask: mask) + // Create inpainter with appropriate iteration count for preview vs full + let iterations = isPreview ? previewDiffusionIterations : fullDiffusionIterations + let featherAmount: Float = isPreview ? 2.0 : 4.0 + + do { + // Metal operations need to run on main actor due to GPU resource management + let result = try await MainActor.run { + let inpainter = try PatchMatchInpainter( + device: device, + patchRadius: 4, + diffusionIterations: iterations + ) + return try inpainter.inpaint(image: image, mask: mask, featherAmount: featherAmount) + } + return result + } catch { + // If Metal inpainting fails, fall back to Accelerate + print("Metal inpainting failed: \(error.localizedDescription), falling back to Accelerate") + return try await inpaintWithAccelerate(image: image, mask: mask) + } } // MARK: - Accelerate Fallback @@ -126,8 +146,6 @@ actor InpaintEngine { } // Simple inpainting: for each masked pixel, average nearby unmasked pixels - let patchSize = patchRadius * 2 + 1 - for y in 0.. CGImage { + let width = image.width + let height = image.height + + // Create textures + let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor( + pixelFormat: .rgba8Unorm, + width: width, + height: height, + mipmapped: false + ) + textureDescriptor.usage = [.shaderRead, .shaderWrite] + + guard let sourceTexture = device.makeTexture(descriptor: textureDescriptor), + let resultTexture = device.makeTexture(descriptor: textureDescriptor), + let tempTexture = device.makeTexture(descriptor: textureDescriptor) else { + throw PatchMatchError.textureCreationFailed + } + + // Mask texture (single channel) + let maskDescriptor = MTLTextureDescriptor.texture2DDescriptor( + pixelFormat: .r8Unorm, + width: width, + height: height, + mipmapped: false + ) + maskDescriptor.usage = [.shaderRead, .shaderWrite] + + guard let maskTexture = device.makeTexture(descriptor: maskDescriptor), + let dilatedMaskTexture = device.makeTexture(descriptor: maskDescriptor), + let featheredMaskTexture = device.makeTexture(descriptor: maskDescriptor) else { + throw PatchMatchError.textureCreationFailed + } + + // Load image data into texture + try loadCGImage(image, into: sourceTexture) + try loadCGImageGrayscale(mask, into: maskTexture) + + // Copy source to result for initial state + try copyTexture(from: sourceTexture, to: resultTexture) + + guard let commandBuffer = commandQueue.makeCommandBuffer() else { + throw PatchMatchError.commandBufferCreationFailed + } + + // Step 1: Dilate mask + encodeDilateMask(commandBuffer: commandBuffer, input: maskTexture, output: dilatedMaskTexture, radius: patchRadius) + + // Step 2: Feather mask + encodeGaussianBlur(commandBuffer: commandBuffer, input: dilatedMaskTexture, output: featheredMaskTexture, radius: Int(featherAmount)) + + commandBuffer.commit() + commandBuffer.waitUntilCompleted() + + // Step 3: Diffusion-based inpainting (multiple iterations) + for _ in 0...size, index: 0) + + let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1) + let threadgroups = MTLSize( + width: (input.width + 15) / 16, + height: (input.height + 15) / 16, + depth: 1 + ) + + encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize) + encoder.endEncoding() + } + + private func encodeGaussianBlur(commandBuffer: MTLCommandBuffer, input: MTLTexture, output: MTLTexture, radius: Int) { + // Generate Gaussian weights + let sigma = Float(radius) / 3.0 + var weights: [Float] = [] + for i in 0...radius { + let weight = exp(-Float(i * i) / (2 * sigma * sigma)) + weights.append(weight) + } + + guard let weightsBuffer = device.makeBuffer(bytes: weights, length: weights.count * MemoryLayout.size, options: []) else { return } + + // Create temp texture for two-pass blur + let tempDescriptor = MTLTextureDescriptor.texture2DDescriptor( + pixelFormat: input.pixelFormat, + width: input.width, + height: input.height, + mipmapped: false + ) + tempDescriptor.usage = [.shaderRead, .shaderWrite] + + guard let tempTexture = device.makeTexture(descriptor: tempDescriptor) else { return } + + // Horizontal pass + if let encoder = commandBuffer.makeComputeCommandEncoder() { + encoder.setComputePipelineState(gaussianBlurPipeline) + encoder.setTexture(input, index: 0) + encoder.setTexture(tempTexture, index: 1) + encoder.setBuffer(weightsBuffer, offset: 0, index: 0) + var radiusValue = Int32(radius) + encoder.setBytes(&radiusValue, length: MemoryLayout.size, index: 1) + var horizontal: Bool = true + encoder.setBytes(&horizontal, length: MemoryLayout.size, index: 2) + + let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1) + let threadgroups = MTLSize( + width: (input.width + 15) / 16, + height: (input.height + 15) / 16, + depth: 1 + ) + + encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize) + encoder.endEncoding() + } + + // Vertical pass + if let encoder = commandBuffer.makeComputeCommandEncoder() { + encoder.setComputePipelineState(gaussianBlurPipeline) + encoder.setTexture(tempTexture, index: 0) + encoder.setTexture(output, index: 1) + encoder.setBuffer(weightsBuffer, offset: 0, index: 0) + var radiusValue = Int32(radius) + encoder.setBytes(&radiusValue, length: MemoryLayout.size, index: 1) + var horizontal: Bool = false + encoder.setBytes(&horizontal, length: MemoryLayout.size, index: 2) + + let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1) + let threadgroups = MTLSize( + width: (input.width + 15) / 16, + height: (input.height + 15) / 16, + depth: 1 + ) + + encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize) + encoder.endEncoding() + } + } + + private func encodeDiffuseInpaint(commandBuffer: MTLCommandBuffer, input: MTLTexture, mask: MTLTexture, output: MTLTexture) { + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { return } + + encoder.setComputePipelineState(diffuseInpaintPipeline) + encoder.setTexture(input, index: 0) + encoder.setTexture(mask, index: 1) + encoder.setTexture(output, index: 2) + + let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1) + let threadgroups = MTLSize( + width: (input.width + 15) / 16, + height: (input.height + 15) / 16, + depth: 1 + ) + + encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize) + encoder.endEncoding() + } + + private func encodeEdgeAwareBlend( + commandBuffer: MTLCommandBuffer, + inpainted: MTLTexture, + original: MTLTexture, + mask: MTLTexture, + output: MTLTexture, + blendRadius: Float + ) { + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { return } + + encoder.setComputePipelineState(edgeAwareBlendPipeline) + encoder.setTexture(inpainted, index: 0) + encoder.setTexture(original, index: 1) + encoder.setTexture(mask, index: 2) + encoder.setTexture(output, index: 3) + + var radius = blendRadius + encoder.setBytes(&radius, length: MemoryLayout.size, index: 0) + + let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1) + let threadgroups = MTLSize( + width: (inpainted.width + 15) / 16, + height: (inpainted.height + 15) / 16, + depth: 1 + ) + + encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize) + encoder.endEncoding() + } + + // MARK: - Texture Utilities + + private func loadCGImage(_ cgImage: CGImage, into texture: MTLTexture) throws { + let width = cgImage.width + let height = cgImage.height + + guard let context = CGContext( + data: nil, + width: width, + height: height, + bitsPerComponent: 8, + bytesPerRow: width * 4, + space: CGColorSpaceCreateDeviceRGB(), + bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue + ) else { + throw PatchMatchError.contextCreationFailed + } + + context.draw(cgImage, in: CGRect(x: 0, y: 0, width: width, height: height)) + + guard let data = context.data else { + throw PatchMatchError.dataExtractionFailed + } + + texture.replace( + region: MTLRegionMake2D(0, 0, width, height), + mipmapLevel: 0, + withBytes: data, + bytesPerRow: width * 4 + ) + } + + private func loadCGImageGrayscale(_ cgImage: CGImage, into texture: MTLTexture) throws { + let width = cgImage.width + let height = cgImage.height + + guard let context = CGContext( + data: nil, + width: width, + height: height, + bitsPerComponent: 8, + bytesPerRow: width, + space: CGColorSpaceCreateDeviceGray(), + bitmapInfo: CGImageAlphaInfo.none.rawValue + ) else { + throw PatchMatchError.contextCreationFailed + } + + context.draw(cgImage, in: CGRect(x: 0, y: 0, width: width, height: height)) + + guard let data = context.data else { + throw PatchMatchError.dataExtractionFailed + } + + texture.replace( + region: MTLRegionMake2D(0, 0, width, height), + mipmapLevel: 0, + withBytes: data, + bytesPerRow: width + ) + } + + private func copyTexture(from source: MTLTexture, to destination: MTLTexture) throws { + guard let commandBuffer = commandQueue.makeCommandBuffer(), + let blitEncoder = commandBuffer.makeBlitCommandEncoder() else { + throw PatchMatchError.commandBufferCreationFailed + } + + blitEncoder.copy( + from: source, + sourceSlice: 0, + sourceLevel: 0, + sourceOrigin: MTLOrigin(x: 0, y: 0, z: 0), + sourceSize: MTLSize(width: source.width, height: source.height, depth: 1), + to: destination, + destinationSlice: 0, + destinationLevel: 0, + destinationOrigin: MTLOrigin(x: 0, y: 0, z: 0) + ) + + blitEncoder.endEncoding() + commandBuffer.commit() + commandBuffer.waitUntilCompleted() + } + + private func extractCGImage(from texture: MTLTexture) throws -> CGImage { + let width = texture.width + let height = texture.height + let bytesPerRow = width * 4 + + var pixelData = [UInt8](repeating: 0, count: bytesPerRow * height) + + texture.getBytes( + &pixelData, + bytesPerRow: bytesPerRow, + from: MTLRegionMake2D(0, 0, width, height), + mipmapLevel: 0 + ) + + guard let context = CGContext( + data: &pixelData, + width: width, + height: height, + bitsPerComponent: 8, + bytesPerRow: bytesPerRow, + space: CGColorSpaceCreateDeviceRGB(), + bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue + ) else { + throw PatchMatchError.contextCreationFailed + } + + guard let cgImage = context.makeImage() else { + throw PatchMatchError.imageCreationFailed + } + + return cgImage + } +} + +// MARK: - Errors + +enum PatchMatchError: LocalizedError { + case commandQueueCreationFailed + case libraryNotFound + case functionNotFound(String) + case textureCreationFailed + case commandBufferCreationFailed + case contextCreationFailed + case dataExtractionFailed + case imageCreationFailed + + var errorDescription: String? { + switch self { + case .commandQueueCreationFailed: + return "Failed to create Metal command queue" + case .libraryNotFound: + return "Metal shader library not found" + case .functionNotFound(let name): + return "Metal function '\(name)' not found" + case .textureCreationFailed: + return "Failed to create Metal texture" + case .commandBufferCreationFailed: + return "Failed to create command buffer" + case .contextCreationFailed: + return "Failed to create graphics context" + case .dataExtractionFailed: + return "Failed to extract image data" + case .imageCreationFailed: + return "Failed to create output image" + } + } +} diff --git a/CheapRetouch/Services/InpaintEngine/Shaders.metal b/CheapRetouch/Services/InpaintEngine/Shaders.metal new file mode 100644 index 0000000..7d2fe1b --- /dev/null +++ b/CheapRetouch/Services/InpaintEngine/Shaders.metal @@ -0,0 +1,267 @@ +// +// Shaders.metal +// CheapRetouch +// +// Metal shaders for exemplar-based inpainting. +// + +#include +using namespace metal; + +// MARK: - Common Types + +struct VertexOut { + float4 position [[position]]; + float2 texCoord; +}; + +// MARK: - Mask Operations + +// Dilate mask by expanding white regions +kernel void dilateMask( + texture2d inMask [[texture(0)]], + texture2d outMask [[texture(1)]], + constant int &radius [[buffer(0)]], + uint2 gid [[thread_position_in_grid]] +) { + if (gid.x >= outMask.get_width() || gid.y >= outMask.get_height()) { + return; + } + + float maxValue = 0.0; + + for (int dy = -radius; dy <= radius; dy++) { + for (int dx = -radius; dx <= radius; dx++) { + int2 samplePos = int2(gid) + int2(dx, dy); + + if (samplePos.x >= 0 && samplePos.x < int(inMask.get_width()) && + samplePos.y >= 0 && samplePos.y < int(inMask.get_height())) { + float value = inMask.read(uint2(samplePos)).r; + maxValue = max(maxValue, value); + } + } + } + + outMask.write(float4(maxValue, maxValue, maxValue, 1.0), gid); +} + +// Gaussian blur for mask feathering +kernel void gaussianBlur( + texture2d inTexture [[texture(0)]], + texture2d outTexture [[texture(1)]], + constant float *weights [[buffer(0)]], + constant int &radius [[buffer(1)]], + constant bool &horizontal [[buffer(2)]], + uint2 gid [[thread_position_in_grid]] +) { + if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height()) { + return; + } + + float4 sum = float4(0.0); + float weightSum = 0.0; + + for (int i = -radius; i <= radius; i++) { + int2 offset = horizontal ? int2(i, 0) : int2(0, i); + int2 samplePos = int2(gid) + offset; + + if (samplePos.x >= 0 && samplePos.x < int(inTexture.get_width()) && + samplePos.y >= 0 && samplePos.y < int(inTexture.get_height())) { + float weight = weights[abs(i)]; + sum += inTexture.read(uint2(samplePos)) * weight; + weightSum += weight; + } + } + + outTexture.write(sum / weightSum, gid); +} + +// MARK: - Patch Matching + +// Compute sum of squared differences between two patches +kernel void computePatchSSD( + texture2d sourceTexture [[texture(0)]], + texture2d maskTexture [[texture(1)]], + texture2d ssdTexture [[texture(2)]], + constant int2 &targetPos [[buffer(0)]], + constant int &patchRadius [[buffer(1)]], + uint2 gid [[thread_position_in_grid]] +) { + int width = sourceTexture.get_width(); + int height = sourceTexture.get_height(); + + if (gid.x >= uint(width) || gid.y >= uint(height)) { + return; + } + + // Skip if this position is in the mask (unknown region) + float maskValue = maskTexture.read(gid).r; + if (maskValue > 0.5) { + ssdTexture.write(float4(1e10), gid); + return; + } + + float ssd = 0.0; + int validPixels = 0; + + for (int dy = -patchRadius; dy <= patchRadius; dy++) { + for (int dx = -patchRadius; dx <= patchRadius; dx++) { + int2 sourcePos = int2(gid) + int2(dx, dy); + int2 targetSamplePos = targetPos + int2(dx, dy); + + // Check bounds + if (sourcePos.x < 0 || sourcePos.x >= width || + sourcePos.y < 0 || sourcePos.y >= height || + targetSamplePos.x < 0 || targetSamplePos.x >= width || + targetSamplePos.y < 0 || targetSamplePos.y >= height) { + continue; + } + + // Only compare known pixels in source + float sourceMask = maskTexture.read(uint2(sourcePos)).r; + if (sourceMask > 0.5) { + continue; + } + + float4 sourceColor = sourceTexture.read(uint2(sourcePos)); + float4 targetColor = sourceTexture.read(uint2(targetSamplePos)); + + float3 diff = sourceColor.rgb - targetColor.rgb; + ssd += dot(diff, diff); + validPixels++; + } + } + + // Normalize by valid pixel count + if (validPixels > 0) { + ssd /= float(validPixels); + } else { + ssd = 1e10; + } + + ssdTexture.write(float4(ssd), gid); +} + +// MARK: - Inpainting + +// Simple diffusion-based inpainting step +kernel void diffuseInpaint( + texture2d inTexture [[texture(0)]], + texture2d maskTexture [[texture(1)]], + texture2d outTexture [[texture(2)]], + uint2 gid [[thread_position_in_grid]] +) { + int width = inTexture.get_width(); + int height = inTexture.get_height(); + + if (gid.x >= uint(width) || gid.y >= uint(height)) { + return; + } + + float maskValue = maskTexture.read(gid).r; + + // If not in mask, copy original + if (maskValue < 0.5) { + outTexture.write(inTexture.read(gid), gid); + return; + } + + // Diffuse from neighbors + float4 sum = float4(0.0); + float weightSum = 0.0; + + int offsets[4][2] = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}}; + + for (int i = 0; i < 4; i++) { + int2 neighborPos = int2(gid) + int2(offsets[i][0], offsets[i][1]); + + if (neighborPos.x >= 0 && neighborPos.x < width && + neighborPos.y >= 0 && neighborPos.y < height) { + float neighborMask = maskTexture.read(uint2(neighborPos)).r; + float weight = (neighborMask < 0.5) ? 2.0 : 1.0; + sum += inTexture.read(uint2(neighborPos)) * weight; + weightSum += weight; + } + } + + if (weightSum > 0.0) { + outTexture.write(sum / weightSum, gid); + } else { + outTexture.write(inTexture.read(gid), gid); + } +} + +// Copy patch from source to target +kernel void copyPatch( + texture2d sourceTexture [[texture(0)]], + texture2d maskTexture [[texture(1)]], + texture2d targetTexture [[texture(2)]], + constant int2 &sourcePos [[buffer(0)]], + constant int2 &targetPos [[buffer(1)]], + constant int &patchRadius [[buffer(2)]], + uint2 gid [[thread_position_in_grid]] +) { + int patchSize = patchRadius * 2 + 1; + + if (gid.x >= uint(patchSize) || gid.y >= uint(patchSize)) { + return; + } + + int2 offset = int2(gid) - int2(patchRadius); + int2 srcCoord = sourcePos + offset; + int2 dstCoord = targetPos + offset; + + int width = sourceTexture.get_width(); + int height = sourceTexture.get_height(); + + // Check bounds + if (srcCoord.x < 0 || srcCoord.x >= width || + srcCoord.y < 0 || srcCoord.y >= height || + dstCoord.x < 0 || dstCoord.x >= width || + dstCoord.y < 0 || dstCoord.y >= height) { + return; + } + + // Only write to masked (unknown) pixels + float maskValue = maskTexture.read(uint2(dstCoord)).r; + if (maskValue > 0.5) { + float4 sourceColor = sourceTexture.read(uint2(srcCoord)); + targetTexture.write(sourceColor, uint2(dstCoord)); + } +} + +// MARK: - Edge-aware blending + +kernel void edgeAwareBlend( + texture2d inTexture [[texture(0)]], + texture2d originalTexture [[texture(1)]], + texture2d maskTexture [[texture(2)]], + texture2d outTexture [[texture(3)]], + constant float &blendRadius [[buffer(0)]], + uint2 gid [[thread_position_in_grid]] +) { + int width = inTexture.get_width(); + int height = inTexture.get_height(); + + if (gid.x >= uint(width) || gid.y >= uint(height)) { + return; + } + + float maskValue = maskTexture.read(gid).r; + float4 inpaintedColor = inTexture.read(gid); + float4 originalColor = originalTexture.read(gid); + + // If far from mask boundary, use appropriate color directly + if (maskValue < 0.1) { + outTexture.write(originalColor, gid); + return; + } + if (maskValue > 0.9) { + outTexture.write(inpaintedColor, gid); + return; + } + + // Blend in the transition zone + float4 result = mix(originalColor, inpaintedColor, maskValue); + outTexture.write(result, gid); +}