Add Metal-based inpainting and brush selection tools
Implement GPU-accelerated inpainting using Metal compute shaders: - Add Shaders.metal with dilateMask, gaussianBlur, diffuseInpaint, edgeAwareBlend kernels - Add PatchMatchInpainter class for exemplar-based inpainting - Update InpaintEngine to use Metal with Accelerate fallback - Add BrushCanvasView for manual brush-based mask painting - Add LineBrushView for wire removal line drawing - Update CanvasView to integrate brush canvas overlay Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -257,9 +257,9 @@
|
|||||||
ENABLE_PREVIEWS = YES;
|
ENABLE_PREVIEWS = YES;
|
||||||
GENERATE_INFOPLIST_FILE = YES;
|
GENERATE_INFOPLIST_FILE = YES;
|
||||||
INFOPLIST_KEY_CFBundleDisplayName = CheapRetouch;
|
INFOPLIST_KEY_CFBundleDisplayName = CheapRetouch;
|
||||||
|
INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.utilities";
|
||||||
INFOPLIST_KEY_NSPhotoLibraryAddUsageDescription = "CheapRetouch needs permission to save edited photos to your library.";
|
INFOPLIST_KEY_NSPhotoLibraryAddUsageDescription = "CheapRetouch needs permission to save edited photos to your library.";
|
||||||
INFOPLIST_KEY_NSPhotoLibraryUsageDescription = "CheapRetouch needs access to your photos to edit and remove unwanted elements from them.";
|
INFOPLIST_KEY_NSPhotoLibraryUsageDescription = "CheapRetouch needs access to your photos to edit and remove unwanted elements from them.";
|
||||||
INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.utilities";
|
|
||||||
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
|
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
|
||||||
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
|
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
|
||||||
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
|
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
|
||||||
@@ -297,9 +297,9 @@
|
|||||||
ENABLE_PREVIEWS = YES;
|
ENABLE_PREVIEWS = YES;
|
||||||
GENERATE_INFOPLIST_FILE = YES;
|
GENERATE_INFOPLIST_FILE = YES;
|
||||||
INFOPLIST_KEY_CFBundleDisplayName = CheapRetouch;
|
INFOPLIST_KEY_CFBundleDisplayName = CheapRetouch;
|
||||||
|
INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.utilities";
|
||||||
INFOPLIST_KEY_NSPhotoLibraryAddUsageDescription = "CheapRetouch needs permission to save edited photos to your library.";
|
INFOPLIST_KEY_NSPhotoLibraryAddUsageDescription = "CheapRetouch needs permission to save edited photos to your library.";
|
||||||
INFOPLIST_KEY_NSPhotoLibraryUsageDescription = "CheapRetouch needs access to your photos to edit and remove unwanted elements from them.";
|
INFOPLIST_KEY_NSPhotoLibraryUsageDescription = "CheapRetouch needs access to your photos to edit and remove unwanted elements from them.";
|
||||||
INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.utilities";
|
|
||||||
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
|
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
|
||||||
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
|
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
|
||||||
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
|
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
|
||||||
|
|||||||
331
CheapRetouch/Features/Editor/BrushCanvasView.swift
Normal file
331
CheapRetouch/Features/Editor/BrushCanvasView.swift
Normal file
@@ -0,0 +1,331 @@
|
|||||||
|
//
|
||||||
|
// BrushCanvasView.swift
|
||||||
|
// CheapRetouch
|
||||||
|
//
|
||||||
|
// Canvas overlay for brush-based manual selection.
|
||||||
|
//
|
||||||
|
|
||||||
|
import SwiftUI
|
||||||
|
import UIKit
|
||||||
|
|
||||||
|
struct BrushCanvasView: View {
|
||||||
|
@Bindable var viewModel: EditorViewModel
|
||||||
|
let imageSize: CGSize
|
||||||
|
let displayedImageFrame: CGRect
|
||||||
|
|
||||||
|
@State private var currentStroke: [CGPoint] = []
|
||||||
|
@State private var allStrokes: [[CGPoint]] = []
|
||||||
|
@State private var isErasing = false
|
||||||
|
|
||||||
|
var body: some View {
|
||||||
|
Canvas { context, size in
|
||||||
|
// Draw all completed strokes
|
||||||
|
for stroke in allStrokes {
|
||||||
|
drawStroke(stroke, in: &context, color: isErasing ? .black : .white)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Draw current stroke
|
||||||
|
if !currentStroke.isEmpty {
|
||||||
|
drawStroke(currentStroke, in: &context, color: isErasing ? .black : .white)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.gesture(drawingGesture)
|
||||||
|
.overlay(alignment: .bottom) {
|
||||||
|
brushControls
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func drawStroke(_ points: [CGPoint], in context: inout GraphicsContext, color: Color) {
|
||||||
|
guard points.count >= 2 else { return }
|
||||||
|
|
||||||
|
var path = Path()
|
||||||
|
path.move(to: points[0])
|
||||||
|
|
||||||
|
if points.count == 2 {
|
||||||
|
path.addLine(to: points[1])
|
||||||
|
} else {
|
||||||
|
for i in 1..<points.count {
|
||||||
|
let mid = CGPoint(
|
||||||
|
x: (points[i-1].x + points[i].x) / 2,
|
||||||
|
y: (points[i-1].y + points[i].y) / 2
|
||||||
|
)
|
||||||
|
path.addQuadCurve(to: mid, control: points[i-1])
|
||||||
|
}
|
||||||
|
path.addLine(to: points[points.count - 1])
|
||||||
|
}
|
||||||
|
|
||||||
|
context.stroke(
|
||||||
|
path,
|
||||||
|
with: .color(color.opacity(0.7)),
|
||||||
|
style: StrokeStyle(
|
||||||
|
lineWidth: viewModel.brushSize,
|
||||||
|
lineCap: .round,
|
||||||
|
lineJoin: .round
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private var drawingGesture: some Gesture {
|
||||||
|
DragGesture(minimumDistance: 0)
|
||||||
|
.onChanged { value in
|
||||||
|
let point = value.location
|
||||||
|
// Only add points within the image bounds
|
||||||
|
if displayedImageFrame.contains(point) {
|
||||||
|
currentStroke.append(point)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.onEnded { _ in
|
||||||
|
if !currentStroke.isEmpty {
|
||||||
|
allStrokes.append(currentStroke)
|
||||||
|
currentStroke = []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private var brushControls: some View {
|
||||||
|
HStack(spacing: 16) {
|
||||||
|
// Erase toggle
|
||||||
|
Button {
|
||||||
|
isErasing.toggle()
|
||||||
|
} label: {
|
||||||
|
Image(systemName: isErasing ? "eraser.fill" : "eraser")
|
||||||
|
.font(.title2)
|
||||||
|
.frame(width: 44, height: 44)
|
||||||
|
.background(isErasing ? Color.accentColor : Color.clear)
|
||||||
|
.clipShape(Circle())
|
||||||
|
}
|
||||||
|
.accessibilityLabel(isErasing ? "Eraser active" : "Switch to eraser")
|
||||||
|
|
||||||
|
// Clear all
|
||||||
|
Button {
|
||||||
|
allStrokes.removeAll()
|
||||||
|
currentStroke.removeAll()
|
||||||
|
} label: {
|
||||||
|
Image(systemName: "trash")
|
||||||
|
.font(.title2)
|
||||||
|
.frame(width: 44, height: 44)
|
||||||
|
}
|
||||||
|
.disabled(allStrokes.isEmpty)
|
||||||
|
.accessibilityLabel("Clear all strokes")
|
||||||
|
|
||||||
|
Spacer()
|
||||||
|
|
||||||
|
// Done button
|
||||||
|
Button {
|
||||||
|
Task {
|
||||||
|
await applyBrushMask()
|
||||||
|
}
|
||||||
|
} label: {
|
||||||
|
Text("Done")
|
||||||
|
.font(.headline)
|
||||||
|
.padding(.horizontal, 20)
|
||||||
|
.padding(.vertical, 10)
|
||||||
|
}
|
||||||
|
.buttonStyle(.borderedProminent)
|
||||||
|
.disabled(allStrokes.isEmpty)
|
||||||
|
}
|
||||||
|
.padding()
|
||||||
|
.background(.ultraThinMaterial)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func applyBrushMask() async {
|
||||||
|
guard !allStrokes.isEmpty else { return }
|
||||||
|
|
||||||
|
// Create mask image from strokes
|
||||||
|
let renderer = UIGraphicsImageRenderer(size: imageSize)
|
||||||
|
let maskImage = renderer.image { ctx in
|
||||||
|
// Fill with black (not masked)
|
||||||
|
UIColor.black.setFill()
|
||||||
|
ctx.fill(CGRect(origin: .zero, size: imageSize))
|
||||||
|
|
||||||
|
// Draw strokes in white (masked areas)
|
||||||
|
UIColor.white.setStroke()
|
||||||
|
|
||||||
|
let scaleX = imageSize.width / displayedImageFrame.width
|
||||||
|
let scaleY = imageSize.height / displayedImageFrame.height
|
||||||
|
|
||||||
|
for stroke in allStrokes {
|
||||||
|
guard stroke.count >= 2 else { continue }
|
||||||
|
|
||||||
|
let path = UIBezierPath()
|
||||||
|
let firstPoint = CGPoint(
|
||||||
|
x: (stroke[0].x - displayedImageFrame.minX) * scaleX,
|
||||||
|
y: (stroke[0].y - displayedImageFrame.minY) * scaleY
|
||||||
|
)
|
||||||
|
path.move(to: firstPoint)
|
||||||
|
|
||||||
|
for i in 1..<stroke.count {
|
||||||
|
let point = CGPoint(
|
||||||
|
x: (stroke[i].x - displayedImageFrame.minX) * scaleX,
|
||||||
|
y: (stroke[i].y - displayedImageFrame.minY) * scaleY
|
||||||
|
)
|
||||||
|
path.addLine(to: point)
|
||||||
|
}
|
||||||
|
|
||||||
|
path.lineWidth = viewModel.brushSize * scaleX
|
||||||
|
path.lineCapStyle = .round
|
||||||
|
path.lineJoinStyle = .round
|
||||||
|
path.stroke()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let cgImage = maskImage.cgImage {
|
||||||
|
await viewModel.applyBrushMask(cgImage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear strokes after applying
|
||||||
|
allStrokes.removeAll()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Line Brush View for Wire Tool
|
||||||
|
|
||||||
|
struct LineBrushView: View {
|
||||||
|
@Bindable var viewModel: EditorViewModel
|
||||||
|
let imageSize: CGSize
|
||||||
|
let displayedImageFrame: CGRect
|
||||||
|
|
||||||
|
@State private var linePoints: [CGPoint] = []
|
||||||
|
|
||||||
|
var body: some View {
|
||||||
|
Canvas { context, size in
|
||||||
|
guard linePoints.count >= 2 else { return }
|
||||||
|
|
||||||
|
var path = Path()
|
||||||
|
path.move(to: linePoints[0])
|
||||||
|
|
||||||
|
for point in linePoints.dropFirst() {
|
||||||
|
path.addLine(to: point)
|
||||||
|
}
|
||||||
|
|
||||||
|
context.stroke(
|
||||||
|
path,
|
||||||
|
with: .color(.white.opacity(0.7)),
|
||||||
|
style: StrokeStyle(
|
||||||
|
lineWidth: viewModel.wireWidth,
|
||||||
|
lineCap: .round,
|
||||||
|
lineJoin: .round
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
.gesture(lineDrawingGesture)
|
||||||
|
.overlay(alignment: .bottom) {
|
||||||
|
lineControls
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private var lineDrawingGesture: some Gesture {
|
||||||
|
DragGesture(minimumDistance: 0)
|
||||||
|
.onChanged { value in
|
||||||
|
let point = value.location
|
||||||
|
if displayedImageFrame.contains(point) {
|
||||||
|
// For line brush, we sample less frequently for smoother lines
|
||||||
|
if linePoints.isEmpty || distance(from: linePoints.last!, to: point) > 5 {
|
||||||
|
linePoints.append(point)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.onEnded { _ in
|
||||||
|
// Line complete, ready to apply
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private var lineControls: some View {
|
||||||
|
HStack(spacing: 16) {
|
||||||
|
// Clear
|
||||||
|
Button {
|
||||||
|
linePoints.removeAll()
|
||||||
|
} label: {
|
||||||
|
Image(systemName: "trash")
|
||||||
|
.font(.title2)
|
||||||
|
.frame(width: 44, height: 44)
|
||||||
|
}
|
||||||
|
.disabled(linePoints.isEmpty)
|
||||||
|
.accessibilityLabel("Clear line")
|
||||||
|
|
||||||
|
Spacer()
|
||||||
|
|
||||||
|
// Cancel
|
||||||
|
Button {
|
||||||
|
linePoints.removeAll()
|
||||||
|
viewModel.selectedTool = .wire
|
||||||
|
} label: {
|
||||||
|
Text("Cancel")
|
||||||
|
.font(.headline)
|
||||||
|
.padding(.horizontal, 16)
|
||||||
|
.padding(.vertical, 10)
|
||||||
|
}
|
||||||
|
.buttonStyle(.bordered)
|
||||||
|
|
||||||
|
// Done button
|
||||||
|
Button {
|
||||||
|
Task {
|
||||||
|
await applyLineMask()
|
||||||
|
}
|
||||||
|
} label: {
|
||||||
|
Text("Remove Line")
|
||||||
|
.font(.headline)
|
||||||
|
.padding(.horizontal, 16)
|
||||||
|
.padding(.vertical, 10)
|
||||||
|
}
|
||||||
|
.buttonStyle(.borderedProminent)
|
||||||
|
.disabled(linePoints.count < 2)
|
||||||
|
}
|
||||||
|
.padding()
|
||||||
|
.background(.ultraThinMaterial)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func applyLineMask() async {
|
||||||
|
guard linePoints.count >= 2 else { return }
|
||||||
|
|
||||||
|
let renderer = UIGraphicsImageRenderer(size: imageSize)
|
||||||
|
let maskImage = renderer.image { ctx in
|
||||||
|
UIColor.black.setFill()
|
||||||
|
ctx.fill(CGRect(origin: .zero, size: imageSize))
|
||||||
|
|
||||||
|
UIColor.white.setStroke()
|
||||||
|
|
||||||
|
let scaleX = imageSize.width / displayedImageFrame.width
|
||||||
|
let scaleY = imageSize.height / displayedImageFrame.height
|
||||||
|
|
||||||
|
let path = UIBezierPath()
|
||||||
|
let firstPoint = CGPoint(
|
||||||
|
x: (linePoints[0].x - displayedImageFrame.minX) * scaleX,
|
||||||
|
y: (linePoints[0].y - displayedImageFrame.minY) * scaleY
|
||||||
|
)
|
||||||
|
path.move(to: firstPoint)
|
||||||
|
|
||||||
|
for point in linePoints.dropFirst() {
|
||||||
|
let scaledPoint = CGPoint(
|
||||||
|
x: (point.x - displayedImageFrame.minX) * scaleX,
|
||||||
|
y: (point.y - displayedImageFrame.minY) * scaleY
|
||||||
|
)
|
||||||
|
path.addLine(to: scaledPoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
path.lineWidth = viewModel.wireWidth * scaleX
|
||||||
|
path.lineCapStyle = .round
|
||||||
|
path.lineJoinStyle = .round
|
||||||
|
path.stroke()
|
||||||
|
}
|
||||||
|
|
||||||
|
if let cgImage = maskImage.cgImage {
|
||||||
|
await viewModel.applyBrushMask(cgImage)
|
||||||
|
}
|
||||||
|
|
||||||
|
linePoints.removeAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
private func distance(from p1: CGPoint, to p2: CGPoint) -> CGFloat {
|
||||||
|
sqrt(pow(p2.x - p1.x, 2) + pow(p2.y - p1.y, 2))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#Preview {
|
||||||
|
let viewModel = EditorViewModel()
|
||||||
|
return BrushCanvasView(
|
||||||
|
viewModel: viewModel,
|
||||||
|
imageSize: CGSize(width: 1000, height: 1000),
|
||||||
|
displayedImageFrame: CGRect(x: 0, y: 0, width: 300, height: 300)
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -44,6 +44,15 @@ struct CanvasView: View {
|
|||||||
.blendMode(.multiply)
|
.blendMode(.multiply)
|
||||||
.colorMultiply(.red.opacity(0.5))
|
.colorMultiply(.red.opacity(0.5))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Brush canvas overlay
|
||||||
|
if viewModel.selectedTool == .brush && !viewModel.showingMaskConfirmation {
|
||||||
|
BrushCanvasView(
|
||||||
|
viewModel: viewModel,
|
||||||
|
imageSize: viewModel.imageSize,
|
||||||
|
displayedImageFrame: displayedImageFrame(in: geometry.size)
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
.contentShape(Rectangle())
|
.contentShape(Rectangle())
|
||||||
@@ -64,6 +73,43 @@ struct CanvasView: View {
|
|||||||
.clipped()
|
.clipped()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MARK: - Computed Properties
|
||||||
|
|
||||||
|
private func displayedImageFrame(in viewSize: CGSize) -> CGRect {
|
||||||
|
let imageSize = viewModel.imageSize
|
||||||
|
guard imageSize.width > 0, imageSize.height > 0 else {
|
||||||
|
return .zero
|
||||||
|
}
|
||||||
|
|
||||||
|
let imageAspect = imageSize.width / imageSize.height
|
||||||
|
let viewAspect = viewSize.width / viewSize.height
|
||||||
|
|
||||||
|
let displayedSize: CGSize
|
||||||
|
if imageAspect > viewAspect {
|
||||||
|
displayedSize = CGSize(
|
||||||
|
width: viewSize.width,
|
||||||
|
height: viewSize.width / imageAspect
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
displayedSize = CGSize(
|
||||||
|
width: viewSize.height * imageAspect,
|
||||||
|
height: viewSize.height
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
let scaledSize = CGSize(
|
||||||
|
width: displayedSize.width * scale,
|
||||||
|
height: displayedSize.height * scale
|
||||||
|
)
|
||||||
|
|
||||||
|
let origin = CGPoint(
|
||||||
|
x: (viewSize.width - scaledSize.width) / 2 + offset.width,
|
||||||
|
y: (viewSize.height - scaledSize.height) / 2 + offset.height
|
||||||
|
)
|
||||||
|
|
||||||
|
return CGRect(origin: origin, size: scaledSize)
|
||||||
|
}
|
||||||
|
|
||||||
// MARK: - Gestures
|
// MARK: - Gestures
|
||||||
|
|
||||||
private func tapGesture(in geometry: GeometryProxy) -> some Gesture {
|
private func tapGesture(in geometry: GeometryProxy) -> some Gesture {
|
||||||
|
|||||||
@@ -45,6 +45,8 @@ actor InpaintEngine {
|
|||||||
private let patchRadius: Int
|
private let patchRadius: Int
|
||||||
private let maxPreviewSize: Int = 2048
|
private let maxPreviewSize: Int = 2048
|
||||||
private let maxMemoryBytes: Int = 1_500_000_000 // 1.5GB
|
private let maxMemoryBytes: Int = 1_500_000_000 // 1.5GB
|
||||||
|
private let previewDiffusionIterations: Int = 30
|
||||||
|
private let fullDiffusionIterations: Int = 100
|
||||||
|
|
||||||
init(patchRadius: Int = 9) {
|
init(patchRadius: Int = 9) {
|
||||||
self.patchRadius = patchRadius
|
self.patchRadius = patchRadius
|
||||||
@@ -97,12 +99,30 @@ actor InpaintEngine {
|
|||||||
// MARK: - Metal Implementation
|
// MARK: - Metal Implementation
|
||||||
|
|
||||||
private func inpaintWithMetal(image: CGImage, mask: CGImage, isPreview: Bool) async throws -> CGImage {
|
private func inpaintWithMetal(image: CGImage, mask: CGImage, isPreview: Bool) async throws -> CGImage {
|
||||||
guard let device = device, let commandQueue = commandQueue else {
|
guard let device = device else {
|
||||||
throw InpaintError.metalNotAvailable
|
throw InpaintError.metalNotAvailable
|
||||||
}
|
}
|
||||||
|
|
||||||
// For now, fall back to Accelerate while Metal shaders are being developed
|
// Create inpainter with appropriate iteration count for preview vs full
|
||||||
return try await inpaintWithAccelerate(image: image, mask: mask)
|
let iterations = isPreview ? previewDiffusionIterations : fullDiffusionIterations
|
||||||
|
let featherAmount: Float = isPreview ? 2.0 : 4.0
|
||||||
|
|
||||||
|
do {
|
||||||
|
// Metal operations need to run on main actor due to GPU resource management
|
||||||
|
let result = try await MainActor.run {
|
||||||
|
let inpainter = try PatchMatchInpainter(
|
||||||
|
device: device,
|
||||||
|
patchRadius: 4,
|
||||||
|
diffusionIterations: iterations
|
||||||
|
)
|
||||||
|
return try inpainter.inpaint(image: image, mask: mask, featherAmount: featherAmount)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
} catch {
|
||||||
|
// If Metal inpainting fails, fall back to Accelerate
|
||||||
|
print("Metal inpainting failed: \(error.localizedDescription), falling back to Accelerate")
|
||||||
|
return try await inpaintWithAccelerate(image: image, mask: mask)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MARK: - Accelerate Fallback
|
// MARK: - Accelerate Fallback
|
||||||
@@ -126,8 +146,6 @@ actor InpaintEngine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Simple inpainting: for each masked pixel, average nearby unmasked pixels
|
// Simple inpainting: for each masked pixel, average nearby unmasked pixels
|
||||||
let patchSize = patchRadius * 2 + 1
|
|
||||||
|
|
||||||
for y in 0..<height {
|
for y in 0..<height {
|
||||||
for x in 0..<width {
|
for x in 0..<width {
|
||||||
let maskIndex = y * width + x
|
let maskIndex = y * width + x
|
||||||
|
|||||||
445
CheapRetouch/Services/InpaintEngine/PatchMatch.swift
Normal file
445
CheapRetouch/Services/InpaintEngine/PatchMatch.swift
Normal file
@@ -0,0 +1,445 @@
|
|||||||
|
//
|
||||||
|
// PatchMatch.swift
|
||||||
|
// CheapRetouch
|
||||||
|
//
|
||||||
|
// Exemplar-based inpainting algorithm implementation using Metal.
|
||||||
|
//
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import Metal
|
||||||
|
import MetalKit
|
||||||
|
import CoreGraphics
|
||||||
|
import simd
|
||||||
|
|
||||||
|
final class PatchMatchInpainter {
|
||||||
|
|
||||||
|
private let device: MTLDevice
|
||||||
|
private let commandQueue: MTLCommandQueue
|
||||||
|
private let library: MTLLibrary
|
||||||
|
|
||||||
|
// Compute pipelines
|
||||||
|
private let dilateMaskPipeline: MTLComputePipelineState
|
||||||
|
private let gaussianBlurPipeline: MTLComputePipelineState
|
||||||
|
private let diffuseInpaintPipeline: MTLComputePipelineState
|
||||||
|
private let edgeAwareBlendPipeline: MTLComputePipelineState
|
||||||
|
|
||||||
|
private let patchRadius: Int
|
||||||
|
private let diffusionIterations: Int
|
||||||
|
|
||||||
|
init(device: MTLDevice, patchRadius: Int = 4, diffusionIterations: Int = 100) throws {
|
||||||
|
self.device = device
|
||||||
|
self.patchRadius = patchRadius
|
||||||
|
self.diffusionIterations = diffusionIterations
|
||||||
|
|
||||||
|
guard let queue = device.makeCommandQueue() else {
|
||||||
|
throw PatchMatchError.commandQueueCreationFailed
|
||||||
|
}
|
||||||
|
self.commandQueue = queue
|
||||||
|
|
||||||
|
guard let library = device.makeDefaultLibrary() else {
|
||||||
|
throw PatchMatchError.libraryNotFound
|
||||||
|
}
|
||||||
|
self.library = library
|
||||||
|
|
||||||
|
// Create compute pipelines
|
||||||
|
guard let dilateFunc = library.makeFunction(name: "dilateMask") else {
|
||||||
|
throw PatchMatchError.functionNotFound("dilateMask")
|
||||||
|
}
|
||||||
|
self.dilateMaskPipeline = try device.makeComputePipelineState(function: dilateFunc)
|
||||||
|
|
||||||
|
guard let blurFunc = library.makeFunction(name: "gaussianBlur") else {
|
||||||
|
throw PatchMatchError.functionNotFound("gaussianBlur")
|
||||||
|
}
|
||||||
|
self.gaussianBlurPipeline = try device.makeComputePipelineState(function: blurFunc)
|
||||||
|
|
||||||
|
guard let diffuseFunc = library.makeFunction(name: "diffuseInpaint") else {
|
||||||
|
throw PatchMatchError.functionNotFound("diffuseInpaint")
|
||||||
|
}
|
||||||
|
self.diffuseInpaintPipeline = try device.makeComputePipelineState(function: diffuseFunc)
|
||||||
|
|
||||||
|
guard let blendFunc = library.makeFunction(name: "edgeAwareBlend") else {
|
||||||
|
throw PatchMatchError.functionNotFound("edgeAwareBlend")
|
||||||
|
}
|
||||||
|
self.edgeAwareBlendPipeline = try device.makeComputePipelineState(function: blendFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func inpaint(image: CGImage, mask: CGImage, featherAmount: Float = 4.0) throws -> CGImage {
|
||||||
|
let width = image.width
|
||||||
|
let height = image.height
|
||||||
|
|
||||||
|
// Create textures
|
||||||
|
let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor(
|
||||||
|
pixelFormat: .rgba8Unorm,
|
||||||
|
width: width,
|
||||||
|
height: height,
|
||||||
|
mipmapped: false
|
||||||
|
)
|
||||||
|
textureDescriptor.usage = [.shaderRead, .shaderWrite]
|
||||||
|
|
||||||
|
guard let sourceTexture = device.makeTexture(descriptor: textureDescriptor),
|
||||||
|
let resultTexture = device.makeTexture(descriptor: textureDescriptor),
|
||||||
|
let tempTexture = device.makeTexture(descriptor: textureDescriptor) else {
|
||||||
|
throw PatchMatchError.textureCreationFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mask texture (single channel)
|
||||||
|
let maskDescriptor = MTLTextureDescriptor.texture2DDescriptor(
|
||||||
|
pixelFormat: .r8Unorm,
|
||||||
|
width: width,
|
||||||
|
height: height,
|
||||||
|
mipmapped: false
|
||||||
|
)
|
||||||
|
maskDescriptor.usage = [.shaderRead, .shaderWrite]
|
||||||
|
|
||||||
|
guard let maskTexture = device.makeTexture(descriptor: maskDescriptor),
|
||||||
|
let dilatedMaskTexture = device.makeTexture(descriptor: maskDescriptor),
|
||||||
|
let featheredMaskTexture = device.makeTexture(descriptor: maskDescriptor) else {
|
||||||
|
throw PatchMatchError.textureCreationFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load image data into texture
|
||||||
|
try loadCGImage(image, into: sourceTexture)
|
||||||
|
try loadCGImageGrayscale(mask, into: maskTexture)
|
||||||
|
|
||||||
|
// Copy source to result for initial state
|
||||||
|
try copyTexture(from: sourceTexture, to: resultTexture)
|
||||||
|
|
||||||
|
guard let commandBuffer = commandQueue.makeCommandBuffer() else {
|
||||||
|
throw PatchMatchError.commandBufferCreationFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 1: Dilate mask
|
||||||
|
encodeDilateMask(commandBuffer: commandBuffer, input: maskTexture, output: dilatedMaskTexture, radius: patchRadius)
|
||||||
|
|
||||||
|
// Step 2: Feather mask
|
||||||
|
encodeGaussianBlur(commandBuffer: commandBuffer, input: dilatedMaskTexture, output: featheredMaskTexture, radius: Int(featherAmount))
|
||||||
|
|
||||||
|
commandBuffer.commit()
|
||||||
|
commandBuffer.waitUntilCompleted()
|
||||||
|
|
||||||
|
// Step 3: Diffusion-based inpainting (multiple iterations)
|
||||||
|
for _ in 0..<diffusionIterations {
|
||||||
|
guard let iterBuffer = commandQueue.makeCommandBuffer() else {
|
||||||
|
throw PatchMatchError.commandBufferCreationFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
encodeDiffuseInpaint(commandBuffer: iterBuffer, input: resultTexture, mask: dilatedMaskTexture, output: tempTexture)
|
||||||
|
|
||||||
|
iterBuffer.commit()
|
||||||
|
iterBuffer.waitUntilCompleted()
|
||||||
|
|
||||||
|
try copyTexture(from: tempTexture, to: resultTexture)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4: Edge-aware blending
|
||||||
|
guard let finalBuffer = commandQueue.makeCommandBuffer() else {
|
||||||
|
throw PatchMatchError.commandBufferCreationFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
encodeEdgeAwareBlend(
|
||||||
|
commandBuffer: finalBuffer,
|
||||||
|
inpainted: resultTexture,
|
||||||
|
original: sourceTexture,
|
||||||
|
mask: featheredMaskTexture,
|
||||||
|
output: tempTexture,
|
||||||
|
blendRadius: featherAmount
|
||||||
|
)
|
||||||
|
|
||||||
|
finalBuffer.commit()
|
||||||
|
finalBuffer.waitUntilCompleted()
|
||||||
|
|
||||||
|
// Convert result to CGImage
|
||||||
|
return try extractCGImage(from: tempTexture)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Shader Encoding
|
||||||
|
|
||||||
|
private func encodeDilateMask(commandBuffer: MTLCommandBuffer, input: MTLTexture, output: MTLTexture, radius: Int) {
|
||||||
|
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { return }
|
||||||
|
|
||||||
|
encoder.setComputePipelineState(dilateMaskPipeline)
|
||||||
|
encoder.setTexture(input, index: 0)
|
||||||
|
encoder.setTexture(output, index: 1)
|
||||||
|
|
||||||
|
var radiusValue = Int32(radius)
|
||||||
|
encoder.setBytes(&radiusValue, length: MemoryLayout<Int32>.size, index: 0)
|
||||||
|
|
||||||
|
let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1)
|
||||||
|
let threadgroups = MTLSize(
|
||||||
|
width: (input.width + 15) / 16,
|
||||||
|
height: (input.height + 15) / 16,
|
||||||
|
depth: 1
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
|
||||||
|
encoder.endEncoding()
|
||||||
|
}
|
||||||
|
|
||||||
|
private func encodeGaussianBlur(commandBuffer: MTLCommandBuffer, input: MTLTexture, output: MTLTexture, radius: Int) {
|
||||||
|
// Generate Gaussian weights
|
||||||
|
let sigma = Float(radius) / 3.0
|
||||||
|
var weights: [Float] = []
|
||||||
|
for i in 0...radius {
|
||||||
|
let weight = exp(-Float(i * i) / (2 * sigma * sigma))
|
||||||
|
weights.append(weight)
|
||||||
|
}
|
||||||
|
|
||||||
|
guard let weightsBuffer = device.makeBuffer(bytes: weights, length: weights.count * MemoryLayout<Float>.size, options: []) else { return }
|
||||||
|
|
||||||
|
// Create temp texture for two-pass blur
|
||||||
|
let tempDescriptor = MTLTextureDescriptor.texture2DDescriptor(
|
||||||
|
pixelFormat: input.pixelFormat,
|
||||||
|
width: input.width,
|
||||||
|
height: input.height,
|
||||||
|
mipmapped: false
|
||||||
|
)
|
||||||
|
tempDescriptor.usage = [.shaderRead, .shaderWrite]
|
||||||
|
|
||||||
|
guard let tempTexture = device.makeTexture(descriptor: tempDescriptor) else { return }
|
||||||
|
|
||||||
|
// Horizontal pass
|
||||||
|
if let encoder = commandBuffer.makeComputeCommandEncoder() {
|
||||||
|
encoder.setComputePipelineState(gaussianBlurPipeline)
|
||||||
|
encoder.setTexture(input, index: 0)
|
||||||
|
encoder.setTexture(tempTexture, index: 1)
|
||||||
|
encoder.setBuffer(weightsBuffer, offset: 0, index: 0)
|
||||||
|
var radiusValue = Int32(radius)
|
||||||
|
encoder.setBytes(&radiusValue, length: MemoryLayout<Int32>.size, index: 1)
|
||||||
|
var horizontal: Bool = true
|
||||||
|
encoder.setBytes(&horizontal, length: MemoryLayout<Bool>.size, index: 2)
|
||||||
|
|
||||||
|
let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1)
|
||||||
|
let threadgroups = MTLSize(
|
||||||
|
width: (input.width + 15) / 16,
|
||||||
|
height: (input.height + 15) / 16,
|
||||||
|
depth: 1
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
|
||||||
|
encoder.endEncoding()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vertical pass
|
||||||
|
if let encoder = commandBuffer.makeComputeCommandEncoder() {
|
||||||
|
encoder.setComputePipelineState(gaussianBlurPipeline)
|
||||||
|
encoder.setTexture(tempTexture, index: 0)
|
||||||
|
encoder.setTexture(output, index: 1)
|
||||||
|
encoder.setBuffer(weightsBuffer, offset: 0, index: 0)
|
||||||
|
var radiusValue = Int32(radius)
|
||||||
|
encoder.setBytes(&radiusValue, length: MemoryLayout<Int32>.size, index: 1)
|
||||||
|
var horizontal: Bool = false
|
||||||
|
encoder.setBytes(&horizontal, length: MemoryLayout<Bool>.size, index: 2)
|
||||||
|
|
||||||
|
let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1)
|
||||||
|
let threadgroups = MTLSize(
|
||||||
|
width: (input.width + 15) / 16,
|
||||||
|
height: (input.height + 15) / 16,
|
||||||
|
depth: 1
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
|
||||||
|
encoder.endEncoding()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func encodeDiffuseInpaint(commandBuffer: MTLCommandBuffer, input: MTLTexture, mask: MTLTexture, output: MTLTexture) {
|
||||||
|
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { return }
|
||||||
|
|
||||||
|
encoder.setComputePipelineState(diffuseInpaintPipeline)
|
||||||
|
encoder.setTexture(input, index: 0)
|
||||||
|
encoder.setTexture(mask, index: 1)
|
||||||
|
encoder.setTexture(output, index: 2)
|
||||||
|
|
||||||
|
let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1)
|
||||||
|
let threadgroups = MTLSize(
|
||||||
|
width: (input.width + 15) / 16,
|
||||||
|
height: (input.height + 15) / 16,
|
||||||
|
depth: 1
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
|
||||||
|
encoder.endEncoding()
|
||||||
|
}
|
||||||
|
|
||||||
|
private func encodeEdgeAwareBlend(
|
||||||
|
commandBuffer: MTLCommandBuffer,
|
||||||
|
inpainted: MTLTexture,
|
||||||
|
original: MTLTexture,
|
||||||
|
mask: MTLTexture,
|
||||||
|
output: MTLTexture,
|
||||||
|
blendRadius: Float
|
||||||
|
) {
|
||||||
|
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { return }
|
||||||
|
|
||||||
|
encoder.setComputePipelineState(edgeAwareBlendPipeline)
|
||||||
|
encoder.setTexture(inpainted, index: 0)
|
||||||
|
encoder.setTexture(original, index: 1)
|
||||||
|
encoder.setTexture(mask, index: 2)
|
||||||
|
encoder.setTexture(output, index: 3)
|
||||||
|
|
||||||
|
var radius = blendRadius
|
||||||
|
encoder.setBytes(&radius, length: MemoryLayout<Float>.size, index: 0)
|
||||||
|
|
||||||
|
let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1)
|
||||||
|
let threadgroups = MTLSize(
|
||||||
|
width: (inpainted.width + 15) / 16,
|
||||||
|
height: (inpainted.height + 15) / 16,
|
||||||
|
depth: 1
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
|
||||||
|
encoder.endEncoding()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Texture Utilities
|
||||||
|
|
||||||
|
private func loadCGImage(_ cgImage: CGImage, into texture: MTLTexture) throws {
|
||||||
|
let width = cgImage.width
|
||||||
|
let height = cgImage.height
|
||||||
|
|
||||||
|
guard let context = CGContext(
|
||||||
|
data: nil,
|
||||||
|
width: width,
|
||||||
|
height: height,
|
||||||
|
bitsPerComponent: 8,
|
||||||
|
bytesPerRow: width * 4,
|
||||||
|
space: CGColorSpaceCreateDeviceRGB(),
|
||||||
|
bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue
|
||||||
|
) else {
|
||||||
|
throw PatchMatchError.contextCreationFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
context.draw(cgImage, in: CGRect(x: 0, y: 0, width: width, height: height))
|
||||||
|
|
||||||
|
guard let data = context.data else {
|
||||||
|
throw PatchMatchError.dataExtractionFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
texture.replace(
|
||||||
|
region: MTLRegionMake2D(0, 0, width, height),
|
||||||
|
mipmapLevel: 0,
|
||||||
|
withBytes: data,
|
||||||
|
bytesPerRow: width * 4
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func loadCGImageGrayscale(_ cgImage: CGImage, into texture: MTLTexture) throws {
|
||||||
|
let width = cgImage.width
|
||||||
|
let height = cgImage.height
|
||||||
|
|
||||||
|
guard let context = CGContext(
|
||||||
|
data: nil,
|
||||||
|
width: width,
|
||||||
|
height: height,
|
||||||
|
bitsPerComponent: 8,
|
||||||
|
bytesPerRow: width,
|
||||||
|
space: CGColorSpaceCreateDeviceGray(),
|
||||||
|
bitmapInfo: CGImageAlphaInfo.none.rawValue
|
||||||
|
) else {
|
||||||
|
throw PatchMatchError.contextCreationFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
context.draw(cgImage, in: CGRect(x: 0, y: 0, width: width, height: height))
|
||||||
|
|
||||||
|
guard let data = context.data else {
|
||||||
|
throw PatchMatchError.dataExtractionFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
texture.replace(
|
||||||
|
region: MTLRegionMake2D(0, 0, width, height),
|
||||||
|
mipmapLevel: 0,
|
||||||
|
withBytes: data,
|
||||||
|
bytesPerRow: width
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func copyTexture(from source: MTLTexture, to destination: MTLTexture) throws {
|
||||||
|
guard let commandBuffer = commandQueue.makeCommandBuffer(),
|
||||||
|
let blitEncoder = commandBuffer.makeBlitCommandEncoder() else {
|
||||||
|
throw PatchMatchError.commandBufferCreationFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
blitEncoder.copy(
|
||||||
|
from: source,
|
||||||
|
sourceSlice: 0,
|
||||||
|
sourceLevel: 0,
|
||||||
|
sourceOrigin: MTLOrigin(x: 0, y: 0, z: 0),
|
||||||
|
sourceSize: MTLSize(width: source.width, height: source.height, depth: 1),
|
||||||
|
to: destination,
|
||||||
|
destinationSlice: 0,
|
||||||
|
destinationLevel: 0,
|
||||||
|
destinationOrigin: MTLOrigin(x: 0, y: 0, z: 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
blitEncoder.endEncoding()
|
||||||
|
commandBuffer.commit()
|
||||||
|
commandBuffer.waitUntilCompleted()
|
||||||
|
}
|
||||||
|
|
||||||
|
private func extractCGImage(from texture: MTLTexture) throws -> CGImage {
|
||||||
|
let width = texture.width
|
||||||
|
let height = texture.height
|
||||||
|
let bytesPerRow = width * 4
|
||||||
|
|
||||||
|
var pixelData = [UInt8](repeating: 0, count: bytesPerRow * height)
|
||||||
|
|
||||||
|
texture.getBytes(
|
||||||
|
&pixelData,
|
||||||
|
bytesPerRow: bytesPerRow,
|
||||||
|
from: MTLRegionMake2D(0, 0, width, height),
|
||||||
|
mipmapLevel: 0
|
||||||
|
)
|
||||||
|
|
||||||
|
guard let context = CGContext(
|
||||||
|
data: &pixelData,
|
||||||
|
width: width,
|
||||||
|
height: height,
|
||||||
|
bitsPerComponent: 8,
|
||||||
|
bytesPerRow: bytesPerRow,
|
||||||
|
space: CGColorSpaceCreateDeviceRGB(),
|
||||||
|
bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue
|
||||||
|
) else {
|
||||||
|
throw PatchMatchError.contextCreationFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
guard let cgImage = context.makeImage() else {
|
||||||
|
throw PatchMatchError.imageCreationFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
return cgImage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Errors
|
||||||
|
|
||||||
|
enum PatchMatchError: LocalizedError {
|
||||||
|
case commandQueueCreationFailed
|
||||||
|
case libraryNotFound
|
||||||
|
case functionNotFound(String)
|
||||||
|
case textureCreationFailed
|
||||||
|
case commandBufferCreationFailed
|
||||||
|
case contextCreationFailed
|
||||||
|
case dataExtractionFailed
|
||||||
|
case imageCreationFailed
|
||||||
|
|
||||||
|
var errorDescription: String? {
|
||||||
|
switch self {
|
||||||
|
case .commandQueueCreationFailed:
|
||||||
|
return "Failed to create Metal command queue"
|
||||||
|
case .libraryNotFound:
|
||||||
|
return "Metal shader library not found"
|
||||||
|
case .functionNotFound(let name):
|
||||||
|
return "Metal function '\(name)' not found"
|
||||||
|
case .textureCreationFailed:
|
||||||
|
return "Failed to create Metal texture"
|
||||||
|
case .commandBufferCreationFailed:
|
||||||
|
return "Failed to create command buffer"
|
||||||
|
case .contextCreationFailed:
|
||||||
|
return "Failed to create graphics context"
|
||||||
|
case .dataExtractionFailed:
|
||||||
|
return "Failed to extract image data"
|
||||||
|
case .imageCreationFailed:
|
||||||
|
return "Failed to create output image"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
267
CheapRetouch/Services/InpaintEngine/Shaders.metal
Normal file
267
CheapRetouch/Services/InpaintEngine/Shaders.metal
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
//
|
||||||
|
// Shaders.metal
|
||||||
|
// CheapRetouch
|
||||||
|
//
|
||||||
|
// Metal shaders for exemplar-based inpainting.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <metal_stdlib>
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
// MARK: - Common Types
|
||||||
|
|
||||||
|
struct VertexOut {
|
||||||
|
float4 position [[position]];
|
||||||
|
float2 texCoord;
|
||||||
|
};
|
||||||
|
|
||||||
|
// MARK: - Mask Operations
|
||||||
|
|
||||||
|
// Dilate mask by expanding white regions
|
||||||
|
kernel void dilateMask(
|
||||||
|
texture2d<float, access::read> inMask [[texture(0)]],
|
||||||
|
texture2d<float, access::write> outMask [[texture(1)]],
|
||||||
|
constant int &radius [[buffer(0)]],
|
||||||
|
uint2 gid [[thread_position_in_grid]]
|
||||||
|
) {
|
||||||
|
if (gid.x >= outMask.get_width() || gid.y >= outMask.get_height()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float maxValue = 0.0;
|
||||||
|
|
||||||
|
for (int dy = -radius; dy <= radius; dy++) {
|
||||||
|
for (int dx = -radius; dx <= radius; dx++) {
|
||||||
|
int2 samplePos = int2(gid) + int2(dx, dy);
|
||||||
|
|
||||||
|
if (samplePos.x >= 0 && samplePos.x < int(inMask.get_width()) &&
|
||||||
|
samplePos.y >= 0 && samplePos.y < int(inMask.get_height())) {
|
||||||
|
float value = inMask.read(uint2(samplePos)).r;
|
||||||
|
maxValue = max(maxValue, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
outMask.write(float4(maxValue, maxValue, maxValue, 1.0), gid);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gaussian blur for mask feathering
|
||||||
|
kernel void gaussianBlur(
|
||||||
|
texture2d<float, access::read> inTexture [[texture(0)]],
|
||||||
|
texture2d<float, access::write> outTexture [[texture(1)]],
|
||||||
|
constant float *weights [[buffer(0)]],
|
||||||
|
constant int &radius [[buffer(1)]],
|
||||||
|
constant bool &horizontal [[buffer(2)]],
|
||||||
|
uint2 gid [[thread_position_in_grid]]
|
||||||
|
) {
|
||||||
|
if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float4 sum = float4(0.0);
|
||||||
|
float weightSum = 0.0;
|
||||||
|
|
||||||
|
for (int i = -radius; i <= radius; i++) {
|
||||||
|
int2 offset = horizontal ? int2(i, 0) : int2(0, i);
|
||||||
|
int2 samplePos = int2(gid) + offset;
|
||||||
|
|
||||||
|
if (samplePos.x >= 0 && samplePos.x < int(inTexture.get_width()) &&
|
||||||
|
samplePos.y >= 0 && samplePos.y < int(inTexture.get_height())) {
|
||||||
|
float weight = weights[abs(i)];
|
||||||
|
sum += inTexture.read(uint2(samplePos)) * weight;
|
||||||
|
weightSum += weight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
outTexture.write(sum / weightSum, gid);
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Patch Matching
|
||||||
|
|
||||||
|
// Compute sum of squared differences between two patches
|
||||||
|
kernel void computePatchSSD(
|
||||||
|
texture2d<float, access::read> sourceTexture [[texture(0)]],
|
||||||
|
texture2d<float, access::read> maskTexture [[texture(1)]],
|
||||||
|
texture2d<float, access::write> ssdTexture [[texture(2)]],
|
||||||
|
constant int2 &targetPos [[buffer(0)]],
|
||||||
|
constant int &patchRadius [[buffer(1)]],
|
||||||
|
uint2 gid [[thread_position_in_grid]]
|
||||||
|
) {
|
||||||
|
int width = sourceTexture.get_width();
|
||||||
|
int height = sourceTexture.get_height();
|
||||||
|
|
||||||
|
if (gid.x >= uint(width) || gid.y >= uint(height)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if this position is in the mask (unknown region)
|
||||||
|
float maskValue = maskTexture.read(gid).r;
|
||||||
|
if (maskValue > 0.5) {
|
||||||
|
ssdTexture.write(float4(1e10), gid);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float ssd = 0.0;
|
||||||
|
int validPixels = 0;
|
||||||
|
|
||||||
|
for (int dy = -patchRadius; dy <= patchRadius; dy++) {
|
||||||
|
for (int dx = -patchRadius; dx <= patchRadius; dx++) {
|
||||||
|
int2 sourcePos = int2(gid) + int2(dx, dy);
|
||||||
|
int2 targetSamplePos = targetPos + int2(dx, dy);
|
||||||
|
|
||||||
|
// Check bounds
|
||||||
|
if (sourcePos.x < 0 || sourcePos.x >= width ||
|
||||||
|
sourcePos.y < 0 || sourcePos.y >= height ||
|
||||||
|
targetSamplePos.x < 0 || targetSamplePos.x >= width ||
|
||||||
|
targetSamplePos.y < 0 || targetSamplePos.y >= height) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only compare known pixels in source
|
||||||
|
float sourceMask = maskTexture.read(uint2(sourcePos)).r;
|
||||||
|
if (sourceMask > 0.5) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
float4 sourceColor = sourceTexture.read(uint2(sourcePos));
|
||||||
|
float4 targetColor = sourceTexture.read(uint2(targetSamplePos));
|
||||||
|
|
||||||
|
float3 diff = sourceColor.rgb - targetColor.rgb;
|
||||||
|
ssd += dot(diff, diff);
|
||||||
|
validPixels++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize by valid pixel count
|
||||||
|
if (validPixels > 0) {
|
||||||
|
ssd /= float(validPixels);
|
||||||
|
} else {
|
||||||
|
ssd = 1e10;
|
||||||
|
}
|
||||||
|
|
||||||
|
ssdTexture.write(float4(ssd), gid);
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Inpainting
|
||||||
|
|
||||||
|
// Simple diffusion-based inpainting step
|
||||||
|
kernel void diffuseInpaint(
|
||||||
|
texture2d<float, access::read> inTexture [[texture(0)]],
|
||||||
|
texture2d<float, access::read> maskTexture [[texture(1)]],
|
||||||
|
texture2d<float, access::write> outTexture [[texture(2)]],
|
||||||
|
uint2 gid [[thread_position_in_grid]]
|
||||||
|
) {
|
||||||
|
int width = inTexture.get_width();
|
||||||
|
int height = inTexture.get_height();
|
||||||
|
|
||||||
|
if (gid.x >= uint(width) || gid.y >= uint(height)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float maskValue = maskTexture.read(gid).r;
|
||||||
|
|
||||||
|
// If not in mask, copy original
|
||||||
|
if (maskValue < 0.5) {
|
||||||
|
outTexture.write(inTexture.read(gid), gid);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Diffuse from neighbors
|
||||||
|
float4 sum = float4(0.0);
|
||||||
|
float weightSum = 0.0;
|
||||||
|
|
||||||
|
int offsets[4][2] = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}};
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
int2 neighborPos = int2(gid) + int2(offsets[i][0], offsets[i][1]);
|
||||||
|
|
||||||
|
if (neighborPos.x >= 0 && neighborPos.x < width &&
|
||||||
|
neighborPos.y >= 0 && neighborPos.y < height) {
|
||||||
|
float neighborMask = maskTexture.read(uint2(neighborPos)).r;
|
||||||
|
float weight = (neighborMask < 0.5) ? 2.0 : 1.0;
|
||||||
|
sum += inTexture.read(uint2(neighborPos)) * weight;
|
||||||
|
weightSum += weight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (weightSum > 0.0) {
|
||||||
|
outTexture.write(sum / weightSum, gid);
|
||||||
|
} else {
|
||||||
|
outTexture.write(inTexture.read(gid), gid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy patch from source to target
|
||||||
|
kernel void copyPatch(
|
||||||
|
texture2d<float, access::read> sourceTexture [[texture(0)]],
|
||||||
|
texture2d<float, access::read> maskTexture [[texture(1)]],
|
||||||
|
texture2d<float, access::read_write> targetTexture [[texture(2)]],
|
||||||
|
constant int2 &sourcePos [[buffer(0)]],
|
||||||
|
constant int2 &targetPos [[buffer(1)]],
|
||||||
|
constant int &patchRadius [[buffer(2)]],
|
||||||
|
uint2 gid [[thread_position_in_grid]]
|
||||||
|
) {
|
||||||
|
int patchSize = patchRadius * 2 + 1;
|
||||||
|
|
||||||
|
if (gid.x >= uint(patchSize) || gid.y >= uint(patchSize)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int2 offset = int2(gid) - int2(patchRadius);
|
||||||
|
int2 srcCoord = sourcePos + offset;
|
||||||
|
int2 dstCoord = targetPos + offset;
|
||||||
|
|
||||||
|
int width = sourceTexture.get_width();
|
||||||
|
int height = sourceTexture.get_height();
|
||||||
|
|
||||||
|
// Check bounds
|
||||||
|
if (srcCoord.x < 0 || srcCoord.x >= width ||
|
||||||
|
srcCoord.y < 0 || srcCoord.y >= height ||
|
||||||
|
dstCoord.x < 0 || dstCoord.x >= width ||
|
||||||
|
dstCoord.y < 0 || dstCoord.y >= height) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only write to masked (unknown) pixels
|
||||||
|
float maskValue = maskTexture.read(uint2(dstCoord)).r;
|
||||||
|
if (maskValue > 0.5) {
|
||||||
|
float4 sourceColor = sourceTexture.read(uint2(srcCoord));
|
||||||
|
targetTexture.write(sourceColor, uint2(dstCoord));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Edge-aware blending
|
||||||
|
|
||||||
|
kernel void edgeAwareBlend(
|
||||||
|
texture2d<float, access::read> inTexture [[texture(0)]],
|
||||||
|
texture2d<float, access::read> originalTexture [[texture(1)]],
|
||||||
|
texture2d<float, access::read> maskTexture [[texture(2)]],
|
||||||
|
texture2d<float, access::write> outTexture [[texture(3)]],
|
||||||
|
constant float &blendRadius [[buffer(0)]],
|
||||||
|
uint2 gid [[thread_position_in_grid]]
|
||||||
|
) {
|
||||||
|
int width = inTexture.get_width();
|
||||||
|
int height = inTexture.get_height();
|
||||||
|
|
||||||
|
if (gid.x >= uint(width) || gid.y >= uint(height)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float maskValue = maskTexture.read(gid).r;
|
||||||
|
float4 inpaintedColor = inTexture.read(gid);
|
||||||
|
float4 originalColor = originalTexture.read(gid);
|
||||||
|
|
||||||
|
// If far from mask boundary, use appropriate color directly
|
||||||
|
if (maskValue < 0.1) {
|
||||||
|
outTexture.write(originalColor, gid);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (maskValue > 0.9) {
|
||||||
|
outTexture.write(inpaintedColor, gid);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Blend in the transition zone
|
||||||
|
float4 result = mix(originalColor, inpaintedColor, maskValue);
|
||||||
|
outTexture.write(result, gid);
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user