diff --git a/CheapRetouch/Features/Editor/CanvasView.swift b/CheapRetouch/Features/Editor/CanvasView.swift index 89c3f9d..cfe972e 100644 --- a/CheapRetouch/Features/Editor/CanvasView.swift +++ b/CheapRetouch/Features/Editor/CanvasView.swift @@ -205,6 +205,9 @@ struct CanvasView: View { private func dragGesture(in geometry: GeometryProxy) -> some Gesture { DragGesture() .onChanged { value in + // Don't pan when brush tool is selected - let brush drawing take priority + guard viewModel.selectedTool != .brush else { return } + if scale > 1.0 { offset = CGSize( width: lastOffset.width + value.translation.width, @@ -213,6 +216,9 @@ struct CanvasView: View { } } .onEnded { _ in + // Don't update offset if brush tool is selected + guard viewModel.selectedTool != .brush else { return } + lastOffset = offset withAnimation(.spring(duration: 0.3)) { clampOffset(in: geometry.size) diff --git a/CheapRetouch/Features/Editor/EditorViewModel.swift b/CheapRetouch/Features/Editor/EditorViewModel.swift index f7bf3a2..4346ad8 100644 --- a/CheapRetouch/Features/Editor/EditorViewModel.swift +++ b/CheapRetouch/Features/Editor/EditorViewModel.swift @@ -75,7 +75,11 @@ final class EditorViewModel { func loadImage(_ uiImage: UIImage, localIdentifier: String? = nil) { DebugLogger.action("loadImage called") - guard let cgImage = uiImage.cgImage else { + + // Normalize image orientation - fixes photos that appear rotated + let normalizedImage = normalizeImageOrientation(uiImage) + + guard let cgImage = normalizedImage.cgImage else { DebugLogger.error("Failed to get CGImage from UIImage") return } @@ -102,13 +106,27 @@ final class EditorViewModel { if let identifier = localIdentifier { imageSource = .photoLibrary(localIdentifier: identifier) } else { - let imageData = uiImage.jpegData(compressionQuality: 0.9) ?? Data() + let imageData = normalizedImage.jpegData(compressionQuality: 0.9) ?? Data() imageSource = .embedded(data: imageData) } project = Project(imageSource: imageSource) announceForVoiceOver("Photo loaded") } + + /// Normalizes image orientation by redrawing with correct transform + private func normalizeImageOrientation(_ image: UIImage) -> UIImage { + // If orientation is already up, no need to redraw + guard image.imageOrientation != .up else { return image } + + // Redraw image with correct orientation applied + UIGraphicsBeginImageContextWithOptions(image.size, false, image.scale) + image.draw(in: CGRect(origin: .zero, size: image.size)) + let normalizedImage = UIGraphicsGetImageFromCurrentImageContext() + UIGraphicsEndImageContext() + + return normalizedImage ?? image + } // MARK: - Tap Handling diff --git a/CheapRetouch/Services/InpaintEngine/InpaintEngine.swift b/CheapRetouch/Services/InpaintEngine/InpaintEngine.swift index 5b32a8c..d9277e5 100644 --- a/CheapRetouch/Services/InpaintEngine/InpaintEngine.swift +++ b/CheapRetouch/Services/InpaintEngine/InpaintEngine.swift @@ -46,7 +46,7 @@ actor InpaintEngine { private let maxPreviewSize: Int = 2048 private let maxMemoryBytes: Int = 1_500_000_000 // 1.5GB private let previewDiffusionIterations: Int = 30 - private let fullDiffusionIterations: Int = 100 + private let fullDiffusionIterations: Int = 500 init(patchRadius: Int = 9) { self.patchRadius = patchRadius @@ -120,7 +120,7 @@ actor InpaintEngine { let result = try await MainActor.run { let inpainter = try PatchMatchInpainter( device: device, - patchRadius: 4, + patchRadius: 8, diffusionIterations: iterations ) return try inpainter.inpaint(image: image, mask: mask, featherAmount: featherAmount) diff --git a/CheapRetouch/Services/InpaintEngine/PatchMatch.swift b/CheapRetouch/Services/InpaintEngine/PatchMatch.swift index 46574db..bd548f8 100644 --- a/CheapRetouch/Services/InpaintEngine/PatchMatch.swift +++ b/CheapRetouch/Services/InpaintEngine/PatchMatch.swift @@ -22,6 +22,7 @@ final class PatchMatchInpainter { private let gaussianBlurPipeline: MTLComputePipelineState private let diffuseInpaintPipeline: MTLComputePipelineState private let edgeAwareBlendPipeline: MTLComputePipelineState + private let gradientFillPipeline: MTLComputePipelineState private let patchRadius: Int private let diffusionIterations: Int @@ -61,6 +62,11 @@ final class PatchMatchInpainter { throw PatchMatchError.functionNotFound("edgeAwareBlend") } self.edgeAwareBlendPipeline = try device.makeComputePipelineState(function: blendFunc) + + guard let gradientFunc = library.makeFunction(name: "gradientFill") else { + throw PatchMatchError.functionNotFound("gradientFill") + } + self.gradientFillPipeline = try device.makeComputePipelineState(function: gradientFunc) } func inpaint(image: CGImage, mask: CGImage, featherAmount: Float = 4.0) throws -> CGImage { @@ -115,8 +121,11 @@ final class PatchMatchInpainter { throw PatchMatchError.commandBufferCreationFailed } - // Step 1: Dilate mask - encodeDilateMask(commandBuffer: commandBuffer, input: maskTexture, output: dilatedMaskTexture, radius: patchRadius) + // Step 1: Dilate mask significantly to exclude object edge pixels + // Use 20px dilation to ensure we sample from clean background, not object edges + let dilationRadius = max(patchRadius, 20) + DebugLogger.log("Dilating mask by \(dilationRadius) pixels...") + encodeDilateMask(commandBuffer: commandBuffer, input: maskTexture, output: dilatedMaskTexture, radius: dilationRadius) // Step 2: Feather mask encodeGaussianBlur(commandBuffer: commandBuffer, input: dilatedMaskTexture, output: featheredMaskTexture, radius: Int(featherAmount)) @@ -124,8 +133,29 @@ final class PatchMatchInpainter { commandBuffer.commit() commandBuffer.waitUntilCompleted() - // Step 3: Diffusion-based inpainting (multiple iterations) - for _ in 0...size, index: 0) + + let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1) + let threadgroups = MTLSize( + width: (source.width + 15) / 16, + height: (source.height + 15) / 16, + depth: 1 + ) + + encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize) + encoder.endEncoding() + } // MARK: - Texture Utilities @@ -357,6 +418,19 @@ final class PatchMatchInpainter { guard let data = context.data else { throw PatchMatchError.dataExtractionFailed } + + // Debug: count non-zero pixels to verify mask has content + let buffer = data.bindMemory(to: UInt8.self, capacity: width * height) + var whitePixelCount = 0 + var totalValue: Int = 0 + for i in 0..<(width * height) { + let value = buffer[i] + totalValue += Int(value) + if value > 127 { + whitePixelCount += 1 + } + } + DebugLogger.log("Mask loaded: \(width)x\(height), whitePixels=\(whitePixelCount) (\(String(format: "%.2f", Double(whitePixelCount) / Double(width * height) * 100))%), avgValue=\(totalValue / (width * height))") texture.replace( region: MTLRegionMake2D(0, 0, width, height), diff --git a/CheapRetouch/Services/InpaintEngine/Shaders.metal b/CheapRetouch/Services/InpaintEngine/Shaders.metal index 7d2fe1b..0471622 100644 --- a/CheapRetouch/Services/InpaintEngine/Shaders.metal +++ b/CheapRetouch/Services/InpaintEngine/Shaders.metal @@ -265,3 +265,189 @@ kernel void edgeAwareBlend( float4 result = mix(originalColor, inpaintedColor, maskValue); outTexture.write(result, gid); } + +// MARK: - Uniform Background Fill + +// Fill masked region with uniform background color using gradient blending from edges +kernel void uniformFill( + texture2d sourceTexture [[texture(0)]], + texture2d maskTexture [[texture(1)]], + texture2d outTexture [[texture(2)]], + constant float4 &avgColor [[buffer(0)]], + constant float &maxDistance [[buffer(1)]], + uint2 gid [[thread_position_in_grid]] +) { + int width = sourceTexture.get_width(); + int height = sourceTexture.get_height(); + + if (gid.x >= uint(width) || gid.y >= uint(height)) { + return; + } + + float maskValue = maskTexture.read(gid).r; + + // If not in mask, copy original + if (maskValue < 0.5) { + outTexture.write(sourceTexture.read(gid), gid); + return; + } + + // For masked pixels, fill with average color + outTexture.write(avgColor, gid); +} + +// Sample border pixels and compute average color and variance +kernel void sampleBorder( + texture2d sourceTexture [[texture(0)]], + texture2d maskTexture [[texture(1)]], + device atomic_uint *sumR [[buffer(0)]], + device atomic_uint *sumG [[buffer(1)]], + device atomic_uint *sumB [[buffer(2)]], + device atomic_uint *count [[buffer(3)]], + device atomic_uint *varSum [[buffer(4)]], + constant int &borderWidth [[buffer(5)]], + uint2 gid [[thread_position_in_grid]] +) { + int width = sourceTexture.get_width(); + int height = sourceTexture.get_height(); + + if (gid.x >= uint(width) || gid.y >= uint(height)) { + return; + } + + float maskValue = maskTexture.read(gid).r; + + // Only sample pixels just outside the mask (border pixels) + if (maskValue >= 0.5) { + return; + } + + // Check if this pixel is adjacent to a masked pixel + bool isBorder = false; + for (int dy = -borderWidth; dy <= borderWidth && !isBorder; dy++) { + for (int dx = -borderWidth; dx <= borderWidth && !isBorder; dx++) { + int2 neighborPos = int2(gid) + int2(dx, dy); + if (neighborPos.x >= 0 && neighborPos.x < width && + neighborPos.y >= 0 && neighborPos.y < height) { + if (maskTexture.read(uint2(neighborPos)).r >= 0.5) { + isBorder = true; + } + } + } + } + + if (!isBorder) { + return; + } + + float4 color = sourceTexture.read(gid); + + // Accumulate color values (scaled to avoid precision issues) + uint r = uint(color.r * 1000.0); + uint g = uint(color.g * 1000.0); + uint b = uint(color.b * 1000.0); + + atomic_fetch_add_explicit(sumR, r, memory_order_relaxed); + atomic_fetch_add_explicit(sumG, g, memory_order_relaxed); + atomic_fetch_add_explicit(sumB, b, memory_order_relaxed); + atomic_fetch_add_explicit(count, 1, memory_order_relaxed); +} + +// Gradient fill from edges - samples from multiple directions with smooth blending +kernel void gradientFill( + texture2d sourceTexture [[texture(0)]], + texture2d maskTexture [[texture(1)]], + texture2d distanceTexture [[texture(2)]], + texture2d outTexture [[texture(3)]], + constant float &maxDist [[buffer(0)]], + uint2 gid [[thread_position_in_grid]] +) { + int width = sourceTexture.get_width(); + int height = sourceTexture.get_height(); + + if (gid.x >= uint(width) || gid.y >= uint(height)) { + return; + } + + float maskValue = maskTexture.read(gid).r; + + // If not in mask, copy original + if (maskValue < 0.5) { + outTexture.write(sourceTexture.read(gid), gid); + return; + } + + // Sample colors from 16 directions for smoother blending + float4 colorSum = float4(0.0); + float weightSum = 0.0; + + // 16 directions: cardinal, diagonal, and intermediate angles + float2 directions[16] = { + float2(-1.0, 0.0), // left + float2(1.0, 0.0), // right + float2(0.0, -1.0), // up + float2(0.0, 1.0), // down + float2(-1.0, -1.0), // top-left + float2(1.0, -1.0), // top-right + float2(-1.0, 1.0), // bottom-left + float2(1.0, 1.0), // bottom-right + float2(-2.0, -1.0), // intermediate angles + float2(-1.0, -2.0), + float2(1.0, -2.0), + float2(2.0, -1.0), + float2(2.0, 1.0), + float2(1.0, 2.0), + float2(-1.0, 2.0), + float2(-2.0, 1.0) + }; + + for (int d = 0; d < 16; d++) { + float2 dir = normalize(directions[d]); + float2 pos = float2(gid); + float distance = 0.0; + + // Walk in this direction until we find a non-masked pixel + for (int i = 0; i < int(maxDist); i++) { + pos += dir; + distance += 1.0; + + int2 ipos = int2(pos); + if (ipos.x < 0 || ipos.x >= width || ipos.y < 0 || ipos.y >= height) { + break; + } + + float m = maskTexture.read(uint2(ipos)).r; + if (m < 0.5) { + // Found edge pixel - sample it with inverse-square distance weight + // This creates smoother blending than linear + float4 color = sourceTexture.read(uint2(ipos)); + float weight = 1.0 / ((distance * distance) + 1.0); + colorSum += color * weight; + weightSum += weight; + + // Also sample a few more pixels in this direction for averaging + for (int j = 1; j <= 3; j++) { + int2 extraPos = ipos + int2(dir * float(j)); + if (extraPos.x >= 0 && extraPos.x < width && + extraPos.y >= 0 && extraPos.y < height) { + float em = maskTexture.read(uint2(extraPos)).r; + if (em < 0.5) { + float4 extraColor = sourceTexture.read(uint2(extraPos)); + float extraWeight = 0.5 / ((distance * distance) + float(j * j) + 1.0); + colorSum += extraColor * extraWeight; + weightSum += extraWeight; + } + } + } + break; + } + } + } + + if (weightSum > 0.0) { + outTexture.write(colorSum / weightSum, gid); + } else { + // Fallback - shouldn't happen + outTexture.write(float4(0.5, 0.5, 0.5, 1.0), gid); + } +} diff --git a/CheapRetouch/Services/MaskingService.swift b/CheapRetouch/Services/MaskingService.swift index 3cffc44..172438d 100644 --- a/CheapRetouch/Services/MaskingService.swift +++ b/CheapRetouch/Services/MaskingService.swift @@ -246,24 +246,50 @@ actor MaskingService { CVPixelBufferLockBaseAddress(pixelBuffer, .readOnly) defer { CVPixelBufferUnlockBaseAddress(pixelBuffer, .readOnly) } - let width = CVPixelBufferGetWidth(pixelBuffer) - let height = CVPixelBufferGetHeight(pixelBuffer) + let maskWidth = CVPixelBufferGetWidth(pixelBuffer) + let maskHeight = CVPixelBufferGetHeight(pixelBuffer) + let bytesPerRow = CVPixelBufferGetBytesPerRow(pixelBuffer) + let bytesPerPixel = bytesPerRow / maskWidth - let x = Int(point.x * CGFloat(width)) - let y = Int((1.0 - point.y) * CGFloat(height)) + DebugLogger.log("isPoint check: visionPoint=\(point), maskSize=\(maskWidth)x\(maskHeight), bytesPerPixel=\(bytesPerPixel)") - guard x >= 0, x < width, y >= 0, y < height else { + // point is in Vision coordinates (0-1 normalized, origin bottom-left) + // Mask pixel buffer has origin at top-left + // Scale point to mask dimensions and flip Y from Vision coords to image/mask coords + let x = Int(point.x * CGFloat(maskWidth)) + let y = Int((1.0 - point.y) * CGFloat(maskHeight)) + + DebugLogger.log("Pixel coords: x=\(x), y=\(y) (from visionY=\(point.y) -> imageY=\(1.0 - point.y))") + + guard x >= 0, x < maskWidth, y >= 0, y < maskHeight else { + DebugLogger.log("Pixel coords out of bounds") return false } guard let baseAddress = CVPixelBufferGetBaseAddress(pixelBuffer) else { + DebugLogger.log("Failed to get pixel buffer base address") return false } - - let bytesPerRow = CVPixelBufferGetBytesPerRow(pixelBuffer) - let pixelOffset = y * bytesPerRow + x - + + // Calculate offset accounting for bytes per pixel + let pixelOffset = y * bytesPerRow + x * bytesPerPixel let pixelValue = baseAddress.load(fromByteOffset: pixelOffset, as: UInt8.self) + + // Also sample nearby pixels to see if we're close to the mask + var nearbyValues: [UInt8] = [] + for dy in -2...2 { + for dx in -2...2 { + let nx = x + dx + let ny = y + dy + if nx >= 0, nx < maskWidth, ny >= 0, ny < maskHeight { + let offset = ny * bytesPerRow + nx * bytesPerPixel + nearbyValues.append(baseAddress.load(fromByteOffset: offset, as: UInt8.self)) + } + } + } + let maxNearby = nearbyValues.max() ?? 0 + + DebugLogger.log("Pixel value at (\(x), \(y)): \(pixelValue), maxNearby=\(maxNearby) (threshold 127)") return pixelValue > 127 }