Fix object selection, inpainting, and image display issues

- Fix object selection by correctly calculating bytesPerPixel in mask buffer
- Improve inpainting with 16-direction gradient fill and inverse-square blending
- Add 100 smoothing iterations to eliminate seam artifacts
- Increase mask dilation to 20px to exclude object edge remnants
- Fix brush tool gesture conflict - disable pan while brushing
- Fix image orientation by normalizing EXIF rotation on load

Modified files:
- MaskingService.swift: Fix pixel offset calculation for mask buffers
- Shaders.metal: Add gradient fill with 16 directions and multi-pixel sampling
- PatchMatch.swift: Use gradient fill + smoothing, add encodeGradientFill
- InpaintEngine.swift: Increase iterations to 500, patch radius to 8
- CanvasView.swift: Disable pan gesture when brush tool selected
- EditorViewModel.swift: Normalize image orientation on load
This commit is contained in:
2026-01-24 14:05:50 -05:00
parent 7d3794767f
commit 14e2502cf4
6 changed files with 328 additions and 18 deletions

View File

@@ -205,6 +205,9 @@ struct CanvasView: View {
private func dragGesture(in geometry: GeometryProxy) -> some Gesture { private func dragGesture(in geometry: GeometryProxy) -> some Gesture {
DragGesture() DragGesture()
.onChanged { value in .onChanged { value in
// Don't pan when brush tool is selected - let brush drawing take priority
guard viewModel.selectedTool != .brush else { return }
if scale > 1.0 { if scale > 1.0 {
offset = CGSize( offset = CGSize(
width: lastOffset.width + value.translation.width, width: lastOffset.width + value.translation.width,
@@ -213,6 +216,9 @@ struct CanvasView: View {
} }
} }
.onEnded { _ in .onEnded { _ in
// Don't update offset if brush tool is selected
guard viewModel.selectedTool != .brush else { return }
lastOffset = offset lastOffset = offset
withAnimation(.spring(duration: 0.3)) { withAnimation(.spring(duration: 0.3)) {
clampOffset(in: geometry.size) clampOffset(in: geometry.size)

View File

@@ -75,7 +75,11 @@ final class EditorViewModel {
func loadImage(_ uiImage: UIImage, localIdentifier: String? = nil) { func loadImage(_ uiImage: UIImage, localIdentifier: String? = nil) {
DebugLogger.action("loadImage called") DebugLogger.action("loadImage called")
guard let cgImage = uiImage.cgImage else {
// Normalize image orientation - fixes photos that appear rotated
let normalizedImage = normalizeImageOrientation(uiImage)
guard let cgImage = normalizedImage.cgImage else {
DebugLogger.error("Failed to get CGImage from UIImage") DebugLogger.error("Failed to get CGImage from UIImage")
return return
} }
@@ -102,13 +106,27 @@ final class EditorViewModel {
if let identifier = localIdentifier { if let identifier = localIdentifier {
imageSource = .photoLibrary(localIdentifier: identifier) imageSource = .photoLibrary(localIdentifier: identifier)
} else { } else {
let imageData = uiImage.jpegData(compressionQuality: 0.9) ?? Data() let imageData = normalizedImage.jpegData(compressionQuality: 0.9) ?? Data()
imageSource = .embedded(data: imageData) imageSource = .embedded(data: imageData)
} }
project = Project(imageSource: imageSource) project = Project(imageSource: imageSource)
announceForVoiceOver("Photo loaded") announceForVoiceOver("Photo loaded")
} }
/// Normalizes image orientation by redrawing with correct transform
private func normalizeImageOrientation(_ image: UIImage) -> UIImage {
// If orientation is already up, no need to redraw
guard image.imageOrientation != .up else { return image }
// Redraw image with correct orientation applied
UIGraphicsBeginImageContextWithOptions(image.size, false, image.scale)
image.draw(in: CGRect(origin: .zero, size: image.size))
let normalizedImage = UIGraphicsGetImageFromCurrentImageContext()
UIGraphicsEndImageContext()
return normalizedImage ?? image
}
// MARK: - Tap Handling // MARK: - Tap Handling

View File

@@ -46,7 +46,7 @@ actor InpaintEngine {
private let maxPreviewSize: Int = 2048 private let maxPreviewSize: Int = 2048
private let maxMemoryBytes: Int = 1_500_000_000 // 1.5GB private let maxMemoryBytes: Int = 1_500_000_000 // 1.5GB
private let previewDiffusionIterations: Int = 30 private let previewDiffusionIterations: Int = 30
private let fullDiffusionIterations: Int = 100 private let fullDiffusionIterations: Int = 500
init(patchRadius: Int = 9) { init(patchRadius: Int = 9) {
self.patchRadius = patchRadius self.patchRadius = patchRadius
@@ -120,7 +120,7 @@ actor InpaintEngine {
let result = try await MainActor.run { let result = try await MainActor.run {
let inpainter = try PatchMatchInpainter( let inpainter = try PatchMatchInpainter(
device: device, device: device,
patchRadius: 4, patchRadius: 8,
diffusionIterations: iterations diffusionIterations: iterations
) )
return try inpainter.inpaint(image: image, mask: mask, featherAmount: featherAmount) return try inpainter.inpaint(image: image, mask: mask, featherAmount: featherAmount)

View File

@@ -22,6 +22,7 @@ final class PatchMatchInpainter {
private let gaussianBlurPipeline: MTLComputePipelineState private let gaussianBlurPipeline: MTLComputePipelineState
private let diffuseInpaintPipeline: MTLComputePipelineState private let diffuseInpaintPipeline: MTLComputePipelineState
private let edgeAwareBlendPipeline: MTLComputePipelineState private let edgeAwareBlendPipeline: MTLComputePipelineState
private let gradientFillPipeline: MTLComputePipelineState
private let patchRadius: Int private let patchRadius: Int
private let diffusionIterations: Int private let diffusionIterations: Int
@@ -61,6 +62,11 @@ final class PatchMatchInpainter {
throw PatchMatchError.functionNotFound("edgeAwareBlend") throw PatchMatchError.functionNotFound("edgeAwareBlend")
} }
self.edgeAwareBlendPipeline = try device.makeComputePipelineState(function: blendFunc) self.edgeAwareBlendPipeline = try device.makeComputePipelineState(function: blendFunc)
guard let gradientFunc = library.makeFunction(name: "gradientFill") else {
throw PatchMatchError.functionNotFound("gradientFill")
}
self.gradientFillPipeline = try device.makeComputePipelineState(function: gradientFunc)
} }
func inpaint(image: CGImage, mask: CGImage, featherAmount: Float = 4.0) throws -> CGImage { func inpaint(image: CGImage, mask: CGImage, featherAmount: Float = 4.0) throws -> CGImage {
@@ -115,8 +121,11 @@ final class PatchMatchInpainter {
throw PatchMatchError.commandBufferCreationFailed throw PatchMatchError.commandBufferCreationFailed
} }
// Step 1: Dilate mask // Step 1: Dilate mask significantly to exclude object edge pixels
encodeDilateMask(commandBuffer: commandBuffer, input: maskTexture, output: dilatedMaskTexture, radius: patchRadius) // Use 20px dilation to ensure we sample from clean background, not object edges
let dilationRadius = max(patchRadius, 20)
DebugLogger.log("Dilating mask by \(dilationRadius) pixels...")
encodeDilateMask(commandBuffer: commandBuffer, input: maskTexture, output: dilatedMaskTexture, radius: dilationRadius)
// Step 2: Feather mask // Step 2: Feather mask
encodeGaussianBlur(commandBuffer: commandBuffer, input: dilatedMaskTexture, output: featheredMaskTexture, radius: Int(featherAmount)) encodeGaussianBlur(commandBuffer: commandBuffer, input: dilatedMaskTexture, output: featheredMaskTexture, radius: Int(featherAmount))
@@ -124,8 +133,29 @@ final class PatchMatchInpainter {
commandBuffer.commit() commandBuffer.commit()
commandBuffer.waitUntilCompleted() commandBuffer.waitUntilCompleted()
// Step 3: Diffusion-based inpainting (multiple iterations) // Step 3: Gradient-based fill (much faster than diffusion for uniform backgrounds)
for _ in 0..<diffusionIterations { // This samples colors from 8 directions and blends based on distance
DebugLogger.log("Starting gradient fill...")
guard let gradientBuffer = commandQueue.makeCommandBuffer() else {
throw PatchMatchError.commandBufferCreationFailed
}
encodeGradientFill(
commandBuffer: gradientBuffer,
source: sourceTexture,
mask: dilatedMaskTexture,
output: resultTexture,
maxDistance: Float(max(width, height))
)
gradientBuffer.commit()
gradientBuffer.waitUntilCompleted()
DebugLogger.log("Gradient fill complete")
// Step 4: Apply diffusion iterations to smooth the result and eliminate seams
let smoothingIterations = min(diffusionIterations, 100)
DebugLogger.log("Applying \(smoothingIterations) smoothing iterations...")
for _ in 0..<smoothingIterations {
guard let iterBuffer = commandQueue.makeCommandBuffer() else { guard let iterBuffer = commandQueue.makeCommandBuffer() else {
throw PatchMatchError.commandBufferCreationFailed throw PatchMatchError.commandBufferCreationFailed
} }
@@ -137,8 +167,9 @@ final class PatchMatchInpainter {
try copyTexture(from: tempTexture, to: resultTexture) try copyTexture(from: tempTexture, to: resultTexture)
} }
DebugLogger.log("Smoothing complete")
// Step 4: Edge-aware blending // Step 5: Edge-aware blending
guard let finalBuffer = commandQueue.makeCommandBuffer() else { guard let finalBuffer = commandQueue.makeCommandBuffer() else {
throw PatchMatchError.commandBufferCreationFailed throw PatchMatchError.commandBufferCreationFailed
} }
@@ -297,6 +328,36 @@ final class PatchMatchInpainter {
encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize) encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
encoder.endEncoding() encoder.endEncoding()
} }
private func encodeGradientFill(
commandBuffer: MTLCommandBuffer,
source: MTLTexture,
mask: MTLTexture,
output: MTLTexture,
maxDistance: Float
) {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { return }
encoder.setComputePipelineState(gradientFillPipeline)
encoder.setTexture(source, index: 0)
encoder.setTexture(mask, index: 1)
// We don't have a distance texture, so we'll pass mask twice (distanceTexture is unused in our simplified version)
encoder.setTexture(mask, index: 2)
encoder.setTexture(output, index: 3)
var maxDist = maxDistance
encoder.setBytes(&maxDist, length: MemoryLayout<Float>.size, index: 0)
let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1)
let threadgroups = MTLSize(
width: (source.width + 15) / 16,
height: (source.height + 15) / 16,
depth: 1
)
encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
encoder.endEncoding()
}
// MARK: - Texture Utilities // MARK: - Texture Utilities
@@ -357,6 +418,19 @@ final class PatchMatchInpainter {
guard let data = context.data else { guard let data = context.data else {
throw PatchMatchError.dataExtractionFailed throw PatchMatchError.dataExtractionFailed
} }
// Debug: count non-zero pixels to verify mask has content
let buffer = data.bindMemory(to: UInt8.self, capacity: width * height)
var whitePixelCount = 0
var totalValue: Int = 0
for i in 0..<(width * height) {
let value = buffer[i]
totalValue += Int(value)
if value > 127 {
whitePixelCount += 1
}
}
DebugLogger.log("Mask loaded: \(width)x\(height), whitePixels=\(whitePixelCount) (\(String(format: "%.2f", Double(whitePixelCount) / Double(width * height) * 100))%), avgValue=\(totalValue / (width * height))")
texture.replace( texture.replace(
region: MTLRegionMake2D(0, 0, width, height), region: MTLRegionMake2D(0, 0, width, height),

View File

@@ -265,3 +265,189 @@ kernel void edgeAwareBlend(
float4 result = mix(originalColor, inpaintedColor, maskValue); float4 result = mix(originalColor, inpaintedColor, maskValue);
outTexture.write(result, gid); outTexture.write(result, gid);
} }
// MARK: - Uniform Background Fill
// Fill masked region with uniform background color using gradient blending from edges
kernel void uniformFill(
texture2d<float, access::read> sourceTexture [[texture(0)]],
texture2d<float, access::read> maskTexture [[texture(1)]],
texture2d<float, access::write> outTexture [[texture(2)]],
constant float4 &avgColor [[buffer(0)]],
constant float &maxDistance [[buffer(1)]],
uint2 gid [[thread_position_in_grid]]
) {
int width = sourceTexture.get_width();
int height = sourceTexture.get_height();
if (gid.x >= uint(width) || gid.y >= uint(height)) {
return;
}
float maskValue = maskTexture.read(gid).r;
// If not in mask, copy original
if (maskValue < 0.5) {
outTexture.write(sourceTexture.read(gid), gid);
return;
}
// For masked pixels, fill with average color
outTexture.write(avgColor, gid);
}
// Sample border pixels and compute average color and variance
kernel void sampleBorder(
texture2d<float, access::read> sourceTexture [[texture(0)]],
texture2d<float, access::read> maskTexture [[texture(1)]],
device atomic_uint *sumR [[buffer(0)]],
device atomic_uint *sumG [[buffer(1)]],
device atomic_uint *sumB [[buffer(2)]],
device atomic_uint *count [[buffer(3)]],
device atomic_uint *varSum [[buffer(4)]],
constant int &borderWidth [[buffer(5)]],
uint2 gid [[thread_position_in_grid]]
) {
int width = sourceTexture.get_width();
int height = sourceTexture.get_height();
if (gid.x >= uint(width) || gid.y >= uint(height)) {
return;
}
float maskValue = maskTexture.read(gid).r;
// Only sample pixels just outside the mask (border pixels)
if (maskValue >= 0.5) {
return;
}
// Check if this pixel is adjacent to a masked pixel
bool isBorder = false;
for (int dy = -borderWidth; dy <= borderWidth && !isBorder; dy++) {
for (int dx = -borderWidth; dx <= borderWidth && !isBorder; dx++) {
int2 neighborPos = int2(gid) + int2(dx, dy);
if (neighborPos.x >= 0 && neighborPos.x < width &&
neighborPos.y >= 0 && neighborPos.y < height) {
if (maskTexture.read(uint2(neighborPos)).r >= 0.5) {
isBorder = true;
}
}
}
}
if (!isBorder) {
return;
}
float4 color = sourceTexture.read(gid);
// Accumulate color values (scaled to avoid precision issues)
uint r = uint(color.r * 1000.0);
uint g = uint(color.g * 1000.0);
uint b = uint(color.b * 1000.0);
atomic_fetch_add_explicit(sumR, r, memory_order_relaxed);
atomic_fetch_add_explicit(sumG, g, memory_order_relaxed);
atomic_fetch_add_explicit(sumB, b, memory_order_relaxed);
atomic_fetch_add_explicit(count, 1, memory_order_relaxed);
}
// Gradient fill from edges - samples from multiple directions with smooth blending
kernel void gradientFill(
texture2d<float, access::read> sourceTexture [[texture(0)]],
texture2d<float, access::read> maskTexture [[texture(1)]],
texture2d<float, access::read> distanceTexture [[texture(2)]],
texture2d<float, access::write> outTexture [[texture(3)]],
constant float &maxDist [[buffer(0)]],
uint2 gid [[thread_position_in_grid]]
) {
int width = sourceTexture.get_width();
int height = sourceTexture.get_height();
if (gid.x >= uint(width) || gid.y >= uint(height)) {
return;
}
float maskValue = maskTexture.read(gid).r;
// If not in mask, copy original
if (maskValue < 0.5) {
outTexture.write(sourceTexture.read(gid), gid);
return;
}
// Sample colors from 16 directions for smoother blending
float4 colorSum = float4(0.0);
float weightSum = 0.0;
// 16 directions: cardinal, diagonal, and intermediate angles
float2 directions[16] = {
float2(-1.0, 0.0), // left
float2(1.0, 0.0), // right
float2(0.0, -1.0), // up
float2(0.0, 1.0), // down
float2(-1.0, -1.0), // top-left
float2(1.0, -1.0), // top-right
float2(-1.0, 1.0), // bottom-left
float2(1.0, 1.0), // bottom-right
float2(-2.0, -1.0), // intermediate angles
float2(-1.0, -2.0),
float2(1.0, -2.0),
float2(2.0, -1.0),
float2(2.0, 1.0),
float2(1.0, 2.0),
float2(-1.0, 2.0),
float2(-2.0, 1.0)
};
for (int d = 0; d < 16; d++) {
float2 dir = normalize(directions[d]);
float2 pos = float2(gid);
float distance = 0.0;
// Walk in this direction until we find a non-masked pixel
for (int i = 0; i < int(maxDist); i++) {
pos += dir;
distance += 1.0;
int2 ipos = int2(pos);
if (ipos.x < 0 || ipos.x >= width || ipos.y < 0 || ipos.y >= height) {
break;
}
float m = maskTexture.read(uint2(ipos)).r;
if (m < 0.5) {
// Found edge pixel - sample it with inverse-square distance weight
// This creates smoother blending than linear
float4 color = sourceTexture.read(uint2(ipos));
float weight = 1.0 / ((distance * distance) + 1.0);
colorSum += color * weight;
weightSum += weight;
// Also sample a few more pixels in this direction for averaging
for (int j = 1; j <= 3; j++) {
int2 extraPos = ipos + int2(dir * float(j));
if (extraPos.x >= 0 && extraPos.x < width &&
extraPos.y >= 0 && extraPos.y < height) {
float em = maskTexture.read(uint2(extraPos)).r;
if (em < 0.5) {
float4 extraColor = sourceTexture.read(uint2(extraPos));
float extraWeight = 0.5 / ((distance * distance) + float(j * j) + 1.0);
colorSum += extraColor * extraWeight;
weightSum += extraWeight;
}
}
}
break;
}
}
}
if (weightSum > 0.0) {
outTexture.write(colorSum / weightSum, gid);
} else {
// Fallback - shouldn't happen
outTexture.write(float4(0.5, 0.5, 0.5, 1.0), gid);
}
}

View File

@@ -246,24 +246,50 @@ actor MaskingService {
CVPixelBufferLockBaseAddress(pixelBuffer, .readOnly) CVPixelBufferLockBaseAddress(pixelBuffer, .readOnly)
defer { CVPixelBufferUnlockBaseAddress(pixelBuffer, .readOnly) } defer { CVPixelBufferUnlockBaseAddress(pixelBuffer, .readOnly) }
let width = CVPixelBufferGetWidth(pixelBuffer) let maskWidth = CVPixelBufferGetWidth(pixelBuffer)
let height = CVPixelBufferGetHeight(pixelBuffer) let maskHeight = CVPixelBufferGetHeight(pixelBuffer)
let bytesPerRow = CVPixelBufferGetBytesPerRow(pixelBuffer)
let bytesPerPixel = bytesPerRow / maskWidth
let x = Int(point.x * CGFloat(width)) DebugLogger.log("isPoint check: visionPoint=\(point), maskSize=\(maskWidth)x\(maskHeight), bytesPerPixel=\(bytesPerPixel)")
let y = Int((1.0 - point.y) * CGFloat(height))
guard x >= 0, x < width, y >= 0, y < height else { // point is in Vision coordinates (0-1 normalized, origin bottom-left)
// Mask pixel buffer has origin at top-left
// Scale point to mask dimensions and flip Y from Vision coords to image/mask coords
let x = Int(point.x * CGFloat(maskWidth))
let y = Int((1.0 - point.y) * CGFloat(maskHeight))
DebugLogger.log("Pixel coords: x=\(x), y=\(y) (from visionY=\(point.y) -> imageY=\(1.0 - point.y))")
guard x >= 0, x < maskWidth, y >= 0, y < maskHeight else {
DebugLogger.log("Pixel coords out of bounds")
return false return false
} }
guard let baseAddress = CVPixelBufferGetBaseAddress(pixelBuffer) else { guard let baseAddress = CVPixelBufferGetBaseAddress(pixelBuffer) else {
DebugLogger.log("Failed to get pixel buffer base address")
return false return false
} }
let bytesPerRow = CVPixelBufferGetBytesPerRow(pixelBuffer) // Calculate offset accounting for bytes per pixel
let pixelOffset = y * bytesPerRow + x let pixelOffset = y * bytesPerRow + x * bytesPerPixel
let pixelValue = baseAddress.load(fromByteOffset: pixelOffset, as: UInt8.self) let pixelValue = baseAddress.load(fromByteOffset: pixelOffset, as: UInt8.self)
// Also sample nearby pixels to see if we're close to the mask
var nearbyValues: [UInt8] = []
for dy in -2...2 {
for dx in -2...2 {
let nx = x + dx
let ny = y + dy
if nx >= 0, nx < maskWidth, ny >= 0, ny < maskHeight {
let offset = ny * bytesPerRow + nx * bytesPerPixel
nearbyValues.append(baseAddress.load(fromByteOffset: offset, as: UInt8.self))
}
}
}
let maxNearby = nearbyValues.max() ?? 0
DebugLogger.log("Pixel value at (\(x), \(y)): \(pixelValue), maxNearby=\(maxNearby) (threshold 127)")
return pixelValue > 127 return pixelValue > 127
} }