Add Metal-based inpainting and brush selection tools
Implement GPU-accelerated inpainting using Metal compute shaders: - Add Shaders.metal with dilateMask, gaussianBlur, diffuseInpaint, edgeAwareBlend kernels - Add PatchMatchInpainter class for exemplar-based inpainting - Update InpaintEngine to use Metal with Accelerate fallback - Add BrushCanvasView for manual brush-based mask painting - Add LineBrushView for wire removal line drawing - Update CanvasView to integrate brush canvas overlay Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -257,9 +257,9 @@
|
||||
ENABLE_PREVIEWS = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
INFOPLIST_KEY_CFBundleDisplayName = CheapRetouch;
|
||||
INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.utilities";
|
||||
INFOPLIST_KEY_NSPhotoLibraryAddUsageDescription = "CheapRetouch needs permission to save edited photos to your library.";
|
||||
INFOPLIST_KEY_NSPhotoLibraryUsageDescription = "CheapRetouch needs access to your photos to edit and remove unwanted elements from them.";
|
||||
INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.utilities";
|
||||
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
|
||||
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
|
||||
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
|
||||
@@ -297,9 +297,9 @@
|
||||
ENABLE_PREVIEWS = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
INFOPLIST_KEY_CFBundleDisplayName = CheapRetouch;
|
||||
INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.utilities";
|
||||
INFOPLIST_KEY_NSPhotoLibraryAddUsageDescription = "CheapRetouch needs permission to save edited photos to your library.";
|
||||
INFOPLIST_KEY_NSPhotoLibraryUsageDescription = "CheapRetouch needs access to your photos to edit and remove unwanted elements from them.";
|
||||
INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.utilities";
|
||||
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
|
||||
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
|
||||
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
|
||||
|
||||
331
CheapRetouch/Features/Editor/BrushCanvasView.swift
Normal file
331
CheapRetouch/Features/Editor/BrushCanvasView.swift
Normal file
@@ -0,0 +1,331 @@
|
||||
//
|
||||
// BrushCanvasView.swift
|
||||
// CheapRetouch
|
||||
//
|
||||
// Canvas overlay for brush-based manual selection.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
import UIKit
|
||||
|
||||
struct BrushCanvasView: View {
|
||||
@Bindable var viewModel: EditorViewModel
|
||||
let imageSize: CGSize
|
||||
let displayedImageFrame: CGRect
|
||||
|
||||
@State private var currentStroke: [CGPoint] = []
|
||||
@State private var allStrokes: [[CGPoint]] = []
|
||||
@State private var isErasing = false
|
||||
|
||||
var body: some View {
|
||||
Canvas { context, size in
|
||||
// Draw all completed strokes
|
||||
for stroke in allStrokes {
|
||||
drawStroke(stroke, in: &context, color: isErasing ? .black : .white)
|
||||
}
|
||||
|
||||
// Draw current stroke
|
||||
if !currentStroke.isEmpty {
|
||||
drawStroke(currentStroke, in: &context, color: isErasing ? .black : .white)
|
||||
}
|
||||
}
|
||||
.gesture(drawingGesture)
|
||||
.overlay(alignment: .bottom) {
|
||||
brushControls
|
||||
}
|
||||
}
|
||||
|
||||
private func drawStroke(_ points: [CGPoint], in context: inout GraphicsContext, color: Color) {
|
||||
guard points.count >= 2 else { return }
|
||||
|
||||
var path = Path()
|
||||
path.move(to: points[0])
|
||||
|
||||
if points.count == 2 {
|
||||
path.addLine(to: points[1])
|
||||
} else {
|
||||
for i in 1..<points.count {
|
||||
let mid = CGPoint(
|
||||
x: (points[i-1].x + points[i].x) / 2,
|
||||
y: (points[i-1].y + points[i].y) / 2
|
||||
)
|
||||
path.addQuadCurve(to: mid, control: points[i-1])
|
||||
}
|
||||
path.addLine(to: points[points.count - 1])
|
||||
}
|
||||
|
||||
context.stroke(
|
||||
path,
|
||||
with: .color(color.opacity(0.7)),
|
||||
style: StrokeStyle(
|
||||
lineWidth: viewModel.brushSize,
|
||||
lineCap: .round,
|
||||
lineJoin: .round
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
private var drawingGesture: some Gesture {
|
||||
DragGesture(minimumDistance: 0)
|
||||
.onChanged { value in
|
||||
let point = value.location
|
||||
// Only add points within the image bounds
|
||||
if displayedImageFrame.contains(point) {
|
||||
currentStroke.append(point)
|
||||
}
|
||||
}
|
||||
.onEnded { _ in
|
||||
if !currentStroke.isEmpty {
|
||||
allStrokes.append(currentStroke)
|
||||
currentStroke = []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var brushControls: some View {
|
||||
HStack(spacing: 16) {
|
||||
// Erase toggle
|
||||
Button {
|
||||
isErasing.toggle()
|
||||
} label: {
|
||||
Image(systemName: isErasing ? "eraser.fill" : "eraser")
|
||||
.font(.title2)
|
||||
.frame(width: 44, height: 44)
|
||||
.background(isErasing ? Color.accentColor : Color.clear)
|
||||
.clipShape(Circle())
|
||||
}
|
||||
.accessibilityLabel(isErasing ? "Eraser active" : "Switch to eraser")
|
||||
|
||||
// Clear all
|
||||
Button {
|
||||
allStrokes.removeAll()
|
||||
currentStroke.removeAll()
|
||||
} label: {
|
||||
Image(systemName: "trash")
|
||||
.font(.title2)
|
||||
.frame(width: 44, height: 44)
|
||||
}
|
||||
.disabled(allStrokes.isEmpty)
|
||||
.accessibilityLabel("Clear all strokes")
|
||||
|
||||
Spacer()
|
||||
|
||||
// Done button
|
||||
Button {
|
||||
Task {
|
||||
await applyBrushMask()
|
||||
}
|
||||
} label: {
|
||||
Text("Done")
|
||||
.font(.headline)
|
||||
.padding(.horizontal, 20)
|
||||
.padding(.vertical, 10)
|
||||
}
|
||||
.buttonStyle(.borderedProminent)
|
||||
.disabled(allStrokes.isEmpty)
|
||||
}
|
||||
.padding()
|
||||
.background(.ultraThinMaterial)
|
||||
}
|
||||
|
||||
private func applyBrushMask() async {
|
||||
guard !allStrokes.isEmpty else { return }
|
||||
|
||||
// Create mask image from strokes
|
||||
let renderer = UIGraphicsImageRenderer(size: imageSize)
|
||||
let maskImage = renderer.image { ctx in
|
||||
// Fill with black (not masked)
|
||||
UIColor.black.setFill()
|
||||
ctx.fill(CGRect(origin: .zero, size: imageSize))
|
||||
|
||||
// Draw strokes in white (masked areas)
|
||||
UIColor.white.setStroke()
|
||||
|
||||
let scaleX = imageSize.width / displayedImageFrame.width
|
||||
let scaleY = imageSize.height / displayedImageFrame.height
|
||||
|
||||
for stroke in allStrokes {
|
||||
guard stroke.count >= 2 else { continue }
|
||||
|
||||
let path = UIBezierPath()
|
||||
let firstPoint = CGPoint(
|
||||
x: (stroke[0].x - displayedImageFrame.minX) * scaleX,
|
||||
y: (stroke[0].y - displayedImageFrame.minY) * scaleY
|
||||
)
|
||||
path.move(to: firstPoint)
|
||||
|
||||
for i in 1..<stroke.count {
|
||||
let point = CGPoint(
|
||||
x: (stroke[i].x - displayedImageFrame.minX) * scaleX,
|
||||
y: (stroke[i].y - displayedImageFrame.minY) * scaleY
|
||||
)
|
||||
path.addLine(to: point)
|
||||
}
|
||||
|
||||
path.lineWidth = viewModel.brushSize * scaleX
|
||||
path.lineCapStyle = .round
|
||||
path.lineJoinStyle = .round
|
||||
path.stroke()
|
||||
}
|
||||
}
|
||||
|
||||
if let cgImage = maskImage.cgImage {
|
||||
await viewModel.applyBrushMask(cgImage)
|
||||
}
|
||||
|
||||
// Clear strokes after applying
|
||||
allStrokes.removeAll()
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Line Brush View for Wire Tool
|
||||
|
||||
struct LineBrushView: View {
|
||||
@Bindable var viewModel: EditorViewModel
|
||||
let imageSize: CGSize
|
||||
let displayedImageFrame: CGRect
|
||||
|
||||
@State private var linePoints: [CGPoint] = []
|
||||
|
||||
var body: some View {
|
||||
Canvas { context, size in
|
||||
guard linePoints.count >= 2 else { return }
|
||||
|
||||
var path = Path()
|
||||
path.move(to: linePoints[0])
|
||||
|
||||
for point in linePoints.dropFirst() {
|
||||
path.addLine(to: point)
|
||||
}
|
||||
|
||||
context.stroke(
|
||||
path,
|
||||
with: .color(.white.opacity(0.7)),
|
||||
style: StrokeStyle(
|
||||
lineWidth: viewModel.wireWidth,
|
||||
lineCap: .round,
|
||||
lineJoin: .round
|
||||
)
|
||||
)
|
||||
}
|
||||
.gesture(lineDrawingGesture)
|
||||
.overlay(alignment: .bottom) {
|
||||
lineControls
|
||||
}
|
||||
}
|
||||
|
||||
private var lineDrawingGesture: some Gesture {
|
||||
DragGesture(minimumDistance: 0)
|
||||
.onChanged { value in
|
||||
let point = value.location
|
||||
if displayedImageFrame.contains(point) {
|
||||
// For line brush, we sample less frequently for smoother lines
|
||||
if linePoints.isEmpty || distance(from: linePoints.last!, to: point) > 5 {
|
||||
linePoints.append(point)
|
||||
}
|
||||
}
|
||||
}
|
||||
.onEnded { _ in
|
||||
// Line complete, ready to apply
|
||||
}
|
||||
}
|
||||
|
||||
private var lineControls: some View {
|
||||
HStack(spacing: 16) {
|
||||
// Clear
|
||||
Button {
|
||||
linePoints.removeAll()
|
||||
} label: {
|
||||
Image(systemName: "trash")
|
||||
.font(.title2)
|
||||
.frame(width: 44, height: 44)
|
||||
}
|
||||
.disabled(linePoints.isEmpty)
|
||||
.accessibilityLabel("Clear line")
|
||||
|
||||
Spacer()
|
||||
|
||||
// Cancel
|
||||
Button {
|
||||
linePoints.removeAll()
|
||||
viewModel.selectedTool = .wire
|
||||
} label: {
|
||||
Text("Cancel")
|
||||
.font(.headline)
|
||||
.padding(.horizontal, 16)
|
||||
.padding(.vertical, 10)
|
||||
}
|
||||
.buttonStyle(.bordered)
|
||||
|
||||
// Done button
|
||||
Button {
|
||||
Task {
|
||||
await applyLineMask()
|
||||
}
|
||||
} label: {
|
||||
Text("Remove Line")
|
||||
.font(.headline)
|
||||
.padding(.horizontal, 16)
|
||||
.padding(.vertical, 10)
|
||||
}
|
||||
.buttonStyle(.borderedProminent)
|
||||
.disabled(linePoints.count < 2)
|
||||
}
|
||||
.padding()
|
||||
.background(.ultraThinMaterial)
|
||||
}
|
||||
|
||||
private func applyLineMask() async {
|
||||
guard linePoints.count >= 2 else { return }
|
||||
|
||||
let renderer = UIGraphicsImageRenderer(size: imageSize)
|
||||
let maskImage = renderer.image { ctx in
|
||||
UIColor.black.setFill()
|
||||
ctx.fill(CGRect(origin: .zero, size: imageSize))
|
||||
|
||||
UIColor.white.setStroke()
|
||||
|
||||
let scaleX = imageSize.width / displayedImageFrame.width
|
||||
let scaleY = imageSize.height / displayedImageFrame.height
|
||||
|
||||
let path = UIBezierPath()
|
||||
let firstPoint = CGPoint(
|
||||
x: (linePoints[0].x - displayedImageFrame.minX) * scaleX,
|
||||
y: (linePoints[0].y - displayedImageFrame.minY) * scaleY
|
||||
)
|
||||
path.move(to: firstPoint)
|
||||
|
||||
for point in linePoints.dropFirst() {
|
||||
let scaledPoint = CGPoint(
|
||||
x: (point.x - displayedImageFrame.minX) * scaleX,
|
||||
y: (point.y - displayedImageFrame.minY) * scaleY
|
||||
)
|
||||
path.addLine(to: scaledPoint)
|
||||
}
|
||||
|
||||
path.lineWidth = viewModel.wireWidth * scaleX
|
||||
path.lineCapStyle = .round
|
||||
path.lineJoinStyle = .round
|
||||
path.stroke()
|
||||
}
|
||||
|
||||
if let cgImage = maskImage.cgImage {
|
||||
await viewModel.applyBrushMask(cgImage)
|
||||
}
|
||||
|
||||
linePoints.removeAll()
|
||||
}
|
||||
|
||||
private func distance(from p1: CGPoint, to p2: CGPoint) -> CGFloat {
|
||||
sqrt(pow(p2.x - p1.x, 2) + pow(p2.y - p1.y, 2))
|
||||
}
|
||||
}
|
||||
|
||||
#Preview {
|
||||
let viewModel = EditorViewModel()
|
||||
return BrushCanvasView(
|
||||
viewModel: viewModel,
|
||||
imageSize: CGSize(width: 1000, height: 1000),
|
||||
displayedImageFrame: CGRect(x: 0, y: 0, width: 300, height: 300)
|
||||
)
|
||||
}
|
||||
@@ -44,6 +44,15 @@ struct CanvasView: View {
|
||||
.blendMode(.multiply)
|
||||
.colorMultiply(.red.opacity(0.5))
|
||||
}
|
||||
|
||||
// Brush canvas overlay
|
||||
if viewModel.selectedTool == .brush && !viewModel.showingMaskConfirmation {
|
||||
BrushCanvasView(
|
||||
viewModel: viewModel,
|
||||
imageSize: viewModel.imageSize,
|
||||
displayedImageFrame: displayedImageFrame(in: geometry.size)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
.contentShape(Rectangle())
|
||||
@@ -64,6 +73,43 @@ struct CanvasView: View {
|
||||
.clipped()
|
||||
}
|
||||
|
||||
// MARK: - Computed Properties
|
||||
|
||||
private func displayedImageFrame(in viewSize: CGSize) -> CGRect {
|
||||
let imageSize = viewModel.imageSize
|
||||
guard imageSize.width > 0, imageSize.height > 0 else {
|
||||
return .zero
|
||||
}
|
||||
|
||||
let imageAspect = imageSize.width / imageSize.height
|
||||
let viewAspect = viewSize.width / viewSize.height
|
||||
|
||||
let displayedSize: CGSize
|
||||
if imageAspect > viewAspect {
|
||||
displayedSize = CGSize(
|
||||
width: viewSize.width,
|
||||
height: viewSize.width / imageAspect
|
||||
)
|
||||
} else {
|
||||
displayedSize = CGSize(
|
||||
width: viewSize.height * imageAspect,
|
||||
height: viewSize.height
|
||||
)
|
||||
}
|
||||
|
||||
let scaledSize = CGSize(
|
||||
width: displayedSize.width * scale,
|
||||
height: displayedSize.height * scale
|
||||
)
|
||||
|
||||
let origin = CGPoint(
|
||||
x: (viewSize.width - scaledSize.width) / 2 + offset.width,
|
||||
y: (viewSize.height - scaledSize.height) / 2 + offset.height
|
||||
)
|
||||
|
||||
return CGRect(origin: origin, size: scaledSize)
|
||||
}
|
||||
|
||||
// MARK: - Gestures
|
||||
|
||||
private func tapGesture(in geometry: GeometryProxy) -> some Gesture {
|
||||
|
||||
@@ -45,6 +45,8 @@ actor InpaintEngine {
|
||||
private let patchRadius: Int
|
||||
private let maxPreviewSize: Int = 2048
|
||||
private let maxMemoryBytes: Int = 1_500_000_000 // 1.5GB
|
||||
private let previewDiffusionIterations: Int = 30
|
||||
private let fullDiffusionIterations: Int = 100
|
||||
|
||||
init(patchRadius: Int = 9) {
|
||||
self.patchRadius = patchRadius
|
||||
@@ -97,12 +99,30 @@ actor InpaintEngine {
|
||||
// MARK: - Metal Implementation
|
||||
|
||||
private func inpaintWithMetal(image: CGImage, mask: CGImage, isPreview: Bool) async throws -> CGImage {
|
||||
guard let device = device, let commandQueue = commandQueue else {
|
||||
guard let device = device else {
|
||||
throw InpaintError.metalNotAvailable
|
||||
}
|
||||
|
||||
// For now, fall back to Accelerate while Metal shaders are being developed
|
||||
return try await inpaintWithAccelerate(image: image, mask: mask)
|
||||
// Create inpainter with appropriate iteration count for preview vs full
|
||||
let iterations = isPreview ? previewDiffusionIterations : fullDiffusionIterations
|
||||
let featherAmount: Float = isPreview ? 2.0 : 4.0
|
||||
|
||||
do {
|
||||
// Metal operations need to run on main actor due to GPU resource management
|
||||
let result = try await MainActor.run {
|
||||
let inpainter = try PatchMatchInpainter(
|
||||
device: device,
|
||||
patchRadius: 4,
|
||||
diffusionIterations: iterations
|
||||
)
|
||||
return try inpainter.inpaint(image: image, mask: mask, featherAmount: featherAmount)
|
||||
}
|
||||
return result
|
||||
} catch {
|
||||
// If Metal inpainting fails, fall back to Accelerate
|
||||
print("Metal inpainting failed: \(error.localizedDescription), falling back to Accelerate")
|
||||
return try await inpaintWithAccelerate(image: image, mask: mask)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Accelerate Fallback
|
||||
@@ -126,8 +146,6 @@ actor InpaintEngine {
|
||||
}
|
||||
|
||||
// Simple inpainting: for each masked pixel, average nearby unmasked pixels
|
||||
let patchSize = patchRadius * 2 + 1
|
||||
|
||||
for y in 0..<height {
|
||||
for x in 0..<width {
|
||||
let maskIndex = y * width + x
|
||||
|
||||
445
CheapRetouch/Services/InpaintEngine/PatchMatch.swift
Normal file
445
CheapRetouch/Services/InpaintEngine/PatchMatch.swift
Normal file
@@ -0,0 +1,445 @@
|
||||
//
|
||||
// PatchMatch.swift
|
||||
// CheapRetouch
|
||||
//
|
||||
// Exemplar-based inpainting algorithm implementation using Metal.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import Metal
|
||||
import MetalKit
|
||||
import CoreGraphics
|
||||
import simd
|
||||
|
||||
final class PatchMatchInpainter {
|
||||
|
||||
private let device: MTLDevice
|
||||
private let commandQueue: MTLCommandQueue
|
||||
private let library: MTLLibrary
|
||||
|
||||
// Compute pipelines
|
||||
private let dilateMaskPipeline: MTLComputePipelineState
|
||||
private let gaussianBlurPipeline: MTLComputePipelineState
|
||||
private let diffuseInpaintPipeline: MTLComputePipelineState
|
||||
private let edgeAwareBlendPipeline: MTLComputePipelineState
|
||||
|
||||
private let patchRadius: Int
|
||||
private let diffusionIterations: Int
|
||||
|
||||
init(device: MTLDevice, patchRadius: Int = 4, diffusionIterations: Int = 100) throws {
|
||||
self.device = device
|
||||
self.patchRadius = patchRadius
|
||||
self.diffusionIterations = diffusionIterations
|
||||
|
||||
guard let queue = device.makeCommandQueue() else {
|
||||
throw PatchMatchError.commandQueueCreationFailed
|
||||
}
|
||||
self.commandQueue = queue
|
||||
|
||||
guard let library = device.makeDefaultLibrary() else {
|
||||
throw PatchMatchError.libraryNotFound
|
||||
}
|
||||
self.library = library
|
||||
|
||||
// Create compute pipelines
|
||||
guard let dilateFunc = library.makeFunction(name: "dilateMask") else {
|
||||
throw PatchMatchError.functionNotFound("dilateMask")
|
||||
}
|
||||
self.dilateMaskPipeline = try device.makeComputePipelineState(function: dilateFunc)
|
||||
|
||||
guard let blurFunc = library.makeFunction(name: "gaussianBlur") else {
|
||||
throw PatchMatchError.functionNotFound("gaussianBlur")
|
||||
}
|
||||
self.gaussianBlurPipeline = try device.makeComputePipelineState(function: blurFunc)
|
||||
|
||||
guard let diffuseFunc = library.makeFunction(name: "diffuseInpaint") else {
|
||||
throw PatchMatchError.functionNotFound("diffuseInpaint")
|
||||
}
|
||||
self.diffuseInpaintPipeline = try device.makeComputePipelineState(function: diffuseFunc)
|
||||
|
||||
guard let blendFunc = library.makeFunction(name: "edgeAwareBlend") else {
|
||||
throw PatchMatchError.functionNotFound("edgeAwareBlend")
|
||||
}
|
||||
self.edgeAwareBlendPipeline = try device.makeComputePipelineState(function: blendFunc)
|
||||
}
|
||||
|
||||
func inpaint(image: CGImage, mask: CGImage, featherAmount: Float = 4.0) throws -> CGImage {
|
||||
let width = image.width
|
||||
let height = image.height
|
||||
|
||||
// Create textures
|
||||
let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor(
|
||||
pixelFormat: .rgba8Unorm,
|
||||
width: width,
|
||||
height: height,
|
||||
mipmapped: false
|
||||
)
|
||||
textureDescriptor.usage = [.shaderRead, .shaderWrite]
|
||||
|
||||
guard let sourceTexture = device.makeTexture(descriptor: textureDescriptor),
|
||||
let resultTexture = device.makeTexture(descriptor: textureDescriptor),
|
||||
let tempTexture = device.makeTexture(descriptor: textureDescriptor) else {
|
||||
throw PatchMatchError.textureCreationFailed
|
||||
}
|
||||
|
||||
// Mask texture (single channel)
|
||||
let maskDescriptor = MTLTextureDescriptor.texture2DDescriptor(
|
||||
pixelFormat: .r8Unorm,
|
||||
width: width,
|
||||
height: height,
|
||||
mipmapped: false
|
||||
)
|
||||
maskDescriptor.usage = [.shaderRead, .shaderWrite]
|
||||
|
||||
guard let maskTexture = device.makeTexture(descriptor: maskDescriptor),
|
||||
let dilatedMaskTexture = device.makeTexture(descriptor: maskDescriptor),
|
||||
let featheredMaskTexture = device.makeTexture(descriptor: maskDescriptor) else {
|
||||
throw PatchMatchError.textureCreationFailed
|
||||
}
|
||||
|
||||
// Load image data into texture
|
||||
try loadCGImage(image, into: sourceTexture)
|
||||
try loadCGImageGrayscale(mask, into: maskTexture)
|
||||
|
||||
// Copy source to result for initial state
|
||||
try copyTexture(from: sourceTexture, to: resultTexture)
|
||||
|
||||
guard let commandBuffer = commandQueue.makeCommandBuffer() else {
|
||||
throw PatchMatchError.commandBufferCreationFailed
|
||||
}
|
||||
|
||||
// Step 1: Dilate mask
|
||||
encodeDilateMask(commandBuffer: commandBuffer, input: maskTexture, output: dilatedMaskTexture, radius: patchRadius)
|
||||
|
||||
// Step 2: Feather mask
|
||||
encodeGaussianBlur(commandBuffer: commandBuffer, input: dilatedMaskTexture, output: featheredMaskTexture, radius: Int(featherAmount))
|
||||
|
||||
commandBuffer.commit()
|
||||
commandBuffer.waitUntilCompleted()
|
||||
|
||||
// Step 3: Diffusion-based inpainting (multiple iterations)
|
||||
for _ in 0..<diffusionIterations {
|
||||
guard let iterBuffer = commandQueue.makeCommandBuffer() else {
|
||||
throw PatchMatchError.commandBufferCreationFailed
|
||||
}
|
||||
|
||||
encodeDiffuseInpaint(commandBuffer: iterBuffer, input: resultTexture, mask: dilatedMaskTexture, output: tempTexture)
|
||||
|
||||
iterBuffer.commit()
|
||||
iterBuffer.waitUntilCompleted()
|
||||
|
||||
try copyTexture(from: tempTexture, to: resultTexture)
|
||||
}
|
||||
|
||||
// Step 4: Edge-aware blending
|
||||
guard let finalBuffer = commandQueue.makeCommandBuffer() else {
|
||||
throw PatchMatchError.commandBufferCreationFailed
|
||||
}
|
||||
|
||||
encodeEdgeAwareBlend(
|
||||
commandBuffer: finalBuffer,
|
||||
inpainted: resultTexture,
|
||||
original: sourceTexture,
|
||||
mask: featheredMaskTexture,
|
||||
output: tempTexture,
|
||||
blendRadius: featherAmount
|
||||
)
|
||||
|
||||
finalBuffer.commit()
|
||||
finalBuffer.waitUntilCompleted()
|
||||
|
||||
// Convert result to CGImage
|
||||
return try extractCGImage(from: tempTexture)
|
||||
}
|
||||
|
||||
// MARK: - Shader Encoding
|
||||
|
||||
private func encodeDilateMask(commandBuffer: MTLCommandBuffer, input: MTLTexture, output: MTLTexture, radius: Int) {
|
||||
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { return }
|
||||
|
||||
encoder.setComputePipelineState(dilateMaskPipeline)
|
||||
encoder.setTexture(input, index: 0)
|
||||
encoder.setTexture(output, index: 1)
|
||||
|
||||
var radiusValue = Int32(radius)
|
||||
encoder.setBytes(&radiusValue, length: MemoryLayout<Int32>.size, index: 0)
|
||||
|
||||
let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1)
|
||||
let threadgroups = MTLSize(
|
||||
width: (input.width + 15) / 16,
|
||||
height: (input.height + 15) / 16,
|
||||
depth: 1
|
||||
)
|
||||
|
||||
encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
|
||||
encoder.endEncoding()
|
||||
}
|
||||
|
||||
private func encodeGaussianBlur(commandBuffer: MTLCommandBuffer, input: MTLTexture, output: MTLTexture, radius: Int) {
|
||||
// Generate Gaussian weights
|
||||
let sigma = Float(radius) / 3.0
|
||||
var weights: [Float] = []
|
||||
for i in 0...radius {
|
||||
let weight = exp(-Float(i * i) / (2 * sigma * sigma))
|
||||
weights.append(weight)
|
||||
}
|
||||
|
||||
guard let weightsBuffer = device.makeBuffer(bytes: weights, length: weights.count * MemoryLayout<Float>.size, options: []) else { return }
|
||||
|
||||
// Create temp texture for two-pass blur
|
||||
let tempDescriptor = MTLTextureDescriptor.texture2DDescriptor(
|
||||
pixelFormat: input.pixelFormat,
|
||||
width: input.width,
|
||||
height: input.height,
|
||||
mipmapped: false
|
||||
)
|
||||
tempDescriptor.usage = [.shaderRead, .shaderWrite]
|
||||
|
||||
guard let tempTexture = device.makeTexture(descriptor: tempDescriptor) else { return }
|
||||
|
||||
// Horizontal pass
|
||||
if let encoder = commandBuffer.makeComputeCommandEncoder() {
|
||||
encoder.setComputePipelineState(gaussianBlurPipeline)
|
||||
encoder.setTexture(input, index: 0)
|
||||
encoder.setTexture(tempTexture, index: 1)
|
||||
encoder.setBuffer(weightsBuffer, offset: 0, index: 0)
|
||||
var radiusValue = Int32(radius)
|
||||
encoder.setBytes(&radiusValue, length: MemoryLayout<Int32>.size, index: 1)
|
||||
var horizontal: Bool = true
|
||||
encoder.setBytes(&horizontal, length: MemoryLayout<Bool>.size, index: 2)
|
||||
|
||||
let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1)
|
||||
let threadgroups = MTLSize(
|
||||
width: (input.width + 15) / 16,
|
||||
height: (input.height + 15) / 16,
|
||||
depth: 1
|
||||
)
|
||||
|
||||
encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
|
||||
encoder.endEncoding()
|
||||
}
|
||||
|
||||
// Vertical pass
|
||||
if let encoder = commandBuffer.makeComputeCommandEncoder() {
|
||||
encoder.setComputePipelineState(gaussianBlurPipeline)
|
||||
encoder.setTexture(tempTexture, index: 0)
|
||||
encoder.setTexture(output, index: 1)
|
||||
encoder.setBuffer(weightsBuffer, offset: 0, index: 0)
|
||||
var radiusValue = Int32(radius)
|
||||
encoder.setBytes(&radiusValue, length: MemoryLayout<Int32>.size, index: 1)
|
||||
var horizontal: Bool = false
|
||||
encoder.setBytes(&horizontal, length: MemoryLayout<Bool>.size, index: 2)
|
||||
|
||||
let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1)
|
||||
let threadgroups = MTLSize(
|
||||
width: (input.width + 15) / 16,
|
||||
height: (input.height + 15) / 16,
|
||||
depth: 1
|
||||
)
|
||||
|
||||
encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
|
||||
encoder.endEncoding()
|
||||
}
|
||||
}
|
||||
|
||||
private func encodeDiffuseInpaint(commandBuffer: MTLCommandBuffer, input: MTLTexture, mask: MTLTexture, output: MTLTexture) {
|
||||
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { return }
|
||||
|
||||
encoder.setComputePipelineState(diffuseInpaintPipeline)
|
||||
encoder.setTexture(input, index: 0)
|
||||
encoder.setTexture(mask, index: 1)
|
||||
encoder.setTexture(output, index: 2)
|
||||
|
||||
let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1)
|
||||
let threadgroups = MTLSize(
|
||||
width: (input.width + 15) / 16,
|
||||
height: (input.height + 15) / 16,
|
||||
depth: 1
|
||||
)
|
||||
|
||||
encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
|
||||
encoder.endEncoding()
|
||||
}
|
||||
|
||||
private func encodeEdgeAwareBlend(
|
||||
commandBuffer: MTLCommandBuffer,
|
||||
inpainted: MTLTexture,
|
||||
original: MTLTexture,
|
||||
mask: MTLTexture,
|
||||
output: MTLTexture,
|
||||
blendRadius: Float
|
||||
) {
|
||||
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { return }
|
||||
|
||||
encoder.setComputePipelineState(edgeAwareBlendPipeline)
|
||||
encoder.setTexture(inpainted, index: 0)
|
||||
encoder.setTexture(original, index: 1)
|
||||
encoder.setTexture(mask, index: 2)
|
||||
encoder.setTexture(output, index: 3)
|
||||
|
||||
var radius = blendRadius
|
||||
encoder.setBytes(&radius, length: MemoryLayout<Float>.size, index: 0)
|
||||
|
||||
let threadgroupSize = MTLSize(width: 16, height: 16, depth: 1)
|
||||
let threadgroups = MTLSize(
|
||||
width: (inpainted.width + 15) / 16,
|
||||
height: (inpainted.height + 15) / 16,
|
||||
depth: 1
|
||||
)
|
||||
|
||||
encoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
|
||||
encoder.endEncoding()
|
||||
}
|
||||
|
||||
// MARK: - Texture Utilities
|
||||
|
||||
private func loadCGImage(_ cgImage: CGImage, into texture: MTLTexture) throws {
|
||||
let width = cgImage.width
|
||||
let height = cgImage.height
|
||||
|
||||
guard let context = CGContext(
|
||||
data: nil,
|
||||
width: width,
|
||||
height: height,
|
||||
bitsPerComponent: 8,
|
||||
bytesPerRow: width * 4,
|
||||
space: CGColorSpaceCreateDeviceRGB(),
|
||||
bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue
|
||||
) else {
|
||||
throw PatchMatchError.contextCreationFailed
|
||||
}
|
||||
|
||||
context.draw(cgImage, in: CGRect(x: 0, y: 0, width: width, height: height))
|
||||
|
||||
guard let data = context.data else {
|
||||
throw PatchMatchError.dataExtractionFailed
|
||||
}
|
||||
|
||||
texture.replace(
|
||||
region: MTLRegionMake2D(0, 0, width, height),
|
||||
mipmapLevel: 0,
|
||||
withBytes: data,
|
||||
bytesPerRow: width * 4
|
||||
)
|
||||
}
|
||||
|
||||
private func loadCGImageGrayscale(_ cgImage: CGImage, into texture: MTLTexture) throws {
|
||||
let width = cgImage.width
|
||||
let height = cgImage.height
|
||||
|
||||
guard let context = CGContext(
|
||||
data: nil,
|
||||
width: width,
|
||||
height: height,
|
||||
bitsPerComponent: 8,
|
||||
bytesPerRow: width,
|
||||
space: CGColorSpaceCreateDeviceGray(),
|
||||
bitmapInfo: CGImageAlphaInfo.none.rawValue
|
||||
) else {
|
||||
throw PatchMatchError.contextCreationFailed
|
||||
}
|
||||
|
||||
context.draw(cgImage, in: CGRect(x: 0, y: 0, width: width, height: height))
|
||||
|
||||
guard let data = context.data else {
|
||||
throw PatchMatchError.dataExtractionFailed
|
||||
}
|
||||
|
||||
texture.replace(
|
||||
region: MTLRegionMake2D(0, 0, width, height),
|
||||
mipmapLevel: 0,
|
||||
withBytes: data,
|
||||
bytesPerRow: width
|
||||
)
|
||||
}
|
||||
|
||||
private func copyTexture(from source: MTLTexture, to destination: MTLTexture) throws {
|
||||
guard let commandBuffer = commandQueue.makeCommandBuffer(),
|
||||
let blitEncoder = commandBuffer.makeBlitCommandEncoder() else {
|
||||
throw PatchMatchError.commandBufferCreationFailed
|
||||
}
|
||||
|
||||
blitEncoder.copy(
|
||||
from: source,
|
||||
sourceSlice: 0,
|
||||
sourceLevel: 0,
|
||||
sourceOrigin: MTLOrigin(x: 0, y: 0, z: 0),
|
||||
sourceSize: MTLSize(width: source.width, height: source.height, depth: 1),
|
||||
to: destination,
|
||||
destinationSlice: 0,
|
||||
destinationLevel: 0,
|
||||
destinationOrigin: MTLOrigin(x: 0, y: 0, z: 0)
|
||||
)
|
||||
|
||||
blitEncoder.endEncoding()
|
||||
commandBuffer.commit()
|
||||
commandBuffer.waitUntilCompleted()
|
||||
}
|
||||
|
||||
private func extractCGImage(from texture: MTLTexture) throws -> CGImage {
|
||||
let width = texture.width
|
||||
let height = texture.height
|
||||
let bytesPerRow = width * 4
|
||||
|
||||
var pixelData = [UInt8](repeating: 0, count: bytesPerRow * height)
|
||||
|
||||
texture.getBytes(
|
||||
&pixelData,
|
||||
bytesPerRow: bytesPerRow,
|
||||
from: MTLRegionMake2D(0, 0, width, height),
|
||||
mipmapLevel: 0
|
||||
)
|
||||
|
||||
guard let context = CGContext(
|
||||
data: &pixelData,
|
||||
width: width,
|
||||
height: height,
|
||||
bitsPerComponent: 8,
|
||||
bytesPerRow: bytesPerRow,
|
||||
space: CGColorSpaceCreateDeviceRGB(),
|
||||
bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue
|
||||
) else {
|
||||
throw PatchMatchError.contextCreationFailed
|
||||
}
|
||||
|
||||
guard let cgImage = context.makeImage() else {
|
||||
throw PatchMatchError.imageCreationFailed
|
||||
}
|
||||
|
||||
return cgImage
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Errors
|
||||
|
||||
enum PatchMatchError: LocalizedError {
|
||||
case commandQueueCreationFailed
|
||||
case libraryNotFound
|
||||
case functionNotFound(String)
|
||||
case textureCreationFailed
|
||||
case commandBufferCreationFailed
|
||||
case contextCreationFailed
|
||||
case dataExtractionFailed
|
||||
case imageCreationFailed
|
||||
|
||||
var errorDescription: String? {
|
||||
switch self {
|
||||
case .commandQueueCreationFailed:
|
||||
return "Failed to create Metal command queue"
|
||||
case .libraryNotFound:
|
||||
return "Metal shader library not found"
|
||||
case .functionNotFound(let name):
|
||||
return "Metal function '\(name)' not found"
|
||||
case .textureCreationFailed:
|
||||
return "Failed to create Metal texture"
|
||||
case .commandBufferCreationFailed:
|
||||
return "Failed to create command buffer"
|
||||
case .contextCreationFailed:
|
||||
return "Failed to create graphics context"
|
||||
case .dataExtractionFailed:
|
||||
return "Failed to extract image data"
|
||||
case .imageCreationFailed:
|
||||
return "Failed to create output image"
|
||||
}
|
||||
}
|
||||
}
|
||||
267
CheapRetouch/Services/InpaintEngine/Shaders.metal
Normal file
267
CheapRetouch/Services/InpaintEngine/Shaders.metal
Normal file
@@ -0,0 +1,267 @@
|
||||
//
|
||||
// Shaders.metal
|
||||
// CheapRetouch
|
||||
//
|
||||
// Metal shaders for exemplar-based inpainting.
|
||||
//
|
||||
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
// MARK: - Common Types
|
||||
|
||||
struct VertexOut {
|
||||
float4 position [[position]];
|
||||
float2 texCoord;
|
||||
};
|
||||
|
||||
// MARK: - Mask Operations
|
||||
|
||||
// Dilate mask by expanding white regions
|
||||
kernel void dilateMask(
|
||||
texture2d<float, access::read> inMask [[texture(0)]],
|
||||
texture2d<float, access::write> outMask [[texture(1)]],
|
||||
constant int &radius [[buffer(0)]],
|
||||
uint2 gid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (gid.x >= outMask.get_width() || gid.y >= outMask.get_height()) {
|
||||
return;
|
||||
}
|
||||
|
||||
float maxValue = 0.0;
|
||||
|
||||
for (int dy = -radius; dy <= radius; dy++) {
|
||||
for (int dx = -radius; dx <= radius; dx++) {
|
||||
int2 samplePos = int2(gid) + int2(dx, dy);
|
||||
|
||||
if (samplePos.x >= 0 && samplePos.x < int(inMask.get_width()) &&
|
||||
samplePos.y >= 0 && samplePos.y < int(inMask.get_height())) {
|
||||
float value = inMask.read(uint2(samplePos)).r;
|
||||
maxValue = max(maxValue, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
outMask.write(float4(maxValue, maxValue, maxValue, 1.0), gid);
|
||||
}
|
||||
|
||||
// Gaussian blur for mask feathering
|
||||
kernel void gaussianBlur(
|
||||
texture2d<float, access::read> inTexture [[texture(0)]],
|
||||
texture2d<float, access::write> outTexture [[texture(1)]],
|
||||
constant float *weights [[buffer(0)]],
|
||||
constant int &radius [[buffer(1)]],
|
||||
constant bool &horizontal [[buffer(2)]],
|
||||
uint2 gid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height()) {
|
||||
return;
|
||||
}
|
||||
|
||||
float4 sum = float4(0.0);
|
||||
float weightSum = 0.0;
|
||||
|
||||
for (int i = -radius; i <= radius; i++) {
|
||||
int2 offset = horizontal ? int2(i, 0) : int2(0, i);
|
||||
int2 samplePos = int2(gid) + offset;
|
||||
|
||||
if (samplePos.x >= 0 && samplePos.x < int(inTexture.get_width()) &&
|
||||
samplePos.y >= 0 && samplePos.y < int(inTexture.get_height())) {
|
||||
float weight = weights[abs(i)];
|
||||
sum += inTexture.read(uint2(samplePos)) * weight;
|
||||
weightSum += weight;
|
||||
}
|
||||
}
|
||||
|
||||
outTexture.write(sum / weightSum, gid);
|
||||
}
|
||||
|
||||
// MARK: - Patch Matching
|
||||
|
||||
// Compute sum of squared differences between two patches
|
||||
kernel void computePatchSSD(
|
||||
texture2d<float, access::read> sourceTexture [[texture(0)]],
|
||||
texture2d<float, access::read> maskTexture [[texture(1)]],
|
||||
texture2d<float, access::write> ssdTexture [[texture(2)]],
|
||||
constant int2 &targetPos [[buffer(0)]],
|
||||
constant int &patchRadius [[buffer(1)]],
|
||||
uint2 gid [[thread_position_in_grid]]
|
||||
) {
|
||||
int width = sourceTexture.get_width();
|
||||
int height = sourceTexture.get_height();
|
||||
|
||||
if (gid.x >= uint(width) || gid.y >= uint(height)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Skip if this position is in the mask (unknown region)
|
||||
float maskValue = maskTexture.read(gid).r;
|
||||
if (maskValue > 0.5) {
|
||||
ssdTexture.write(float4(1e10), gid);
|
||||
return;
|
||||
}
|
||||
|
||||
float ssd = 0.0;
|
||||
int validPixels = 0;
|
||||
|
||||
for (int dy = -patchRadius; dy <= patchRadius; dy++) {
|
||||
for (int dx = -patchRadius; dx <= patchRadius; dx++) {
|
||||
int2 sourcePos = int2(gid) + int2(dx, dy);
|
||||
int2 targetSamplePos = targetPos + int2(dx, dy);
|
||||
|
||||
// Check bounds
|
||||
if (sourcePos.x < 0 || sourcePos.x >= width ||
|
||||
sourcePos.y < 0 || sourcePos.y >= height ||
|
||||
targetSamplePos.x < 0 || targetSamplePos.x >= width ||
|
||||
targetSamplePos.y < 0 || targetSamplePos.y >= height) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Only compare known pixels in source
|
||||
float sourceMask = maskTexture.read(uint2(sourcePos)).r;
|
||||
if (sourceMask > 0.5) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float4 sourceColor = sourceTexture.read(uint2(sourcePos));
|
||||
float4 targetColor = sourceTexture.read(uint2(targetSamplePos));
|
||||
|
||||
float3 diff = sourceColor.rgb - targetColor.rgb;
|
||||
ssd += dot(diff, diff);
|
||||
validPixels++;
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize by valid pixel count
|
||||
if (validPixels > 0) {
|
||||
ssd /= float(validPixels);
|
||||
} else {
|
||||
ssd = 1e10;
|
||||
}
|
||||
|
||||
ssdTexture.write(float4(ssd), gid);
|
||||
}
|
||||
|
||||
// MARK: - Inpainting
|
||||
|
||||
// Simple diffusion-based inpainting step
|
||||
kernel void diffuseInpaint(
|
||||
texture2d<float, access::read> inTexture [[texture(0)]],
|
||||
texture2d<float, access::read> maskTexture [[texture(1)]],
|
||||
texture2d<float, access::write> outTexture [[texture(2)]],
|
||||
uint2 gid [[thread_position_in_grid]]
|
||||
) {
|
||||
int width = inTexture.get_width();
|
||||
int height = inTexture.get_height();
|
||||
|
||||
if (gid.x >= uint(width) || gid.y >= uint(height)) {
|
||||
return;
|
||||
}
|
||||
|
||||
float maskValue = maskTexture.read(gid).r;
|
||||
|
||||
// If not in mask, copy original
|
||||
if (maskValue < 0.5) {
|
||||
outTexture.write(inTexture.read(gid), gid);
|
||||
return;
|
||||
}
|
||||
|
||||
// Diffuse from neighbors
|
||||
float4 sum = float4(0.0);
|
||||
float weightSum = 0.0;
|
||||
|
||||
int offsets[4][2] = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}};
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int2 neighborPos = int2(gid) + int2(offsets[i][0], offsets[i][1]);
|
||||
|
||||
if (neighborPos.x >= 0 && neighborPos.x < width &&
|
||||
neighborPos.y >= 0 && neighborPos.y < height) {
|
||||
float neighborMask = maskTexture.read(uint2(neighborPos)).r;
|
||||
float weight = (neighborMask < 0.5) ? 2.0 : 1.0;
|
||||
sum += inTexture.read(uint2(neighborPos)) * weight;
|
||||
weightSum += weight;
|
||||
}
|
||||
}
|
||||
|
||||
if (weightSum > 0.0) {
|
||||
outTexture.write(sum / weightSum, gid);
|
||||
} else {
|
||||
outTexture.write(inTexture.read(gid), gid);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy patch from source to target
|
||||
kernel void copyPatch(
|
||||
texture2d<float, access::read> sourceTexture [[texture(0)]],
|
||||
texture2d<float, access::read> maskTexture [[texture(1)]],
|
||||
texture2d<float, access::read_write> targetTexture [[texture(2)]],
|
||||
constant int2 &sourcePos [[buffer(0)]],
|
||||
constant int2 &targetPos [[buffer(1)]],
|
||||
constant int &patchRadius [[buffer(2)]],
|
||||
uint2 gid [[thread_position_in_grid]]
|
||||
) {
|
||||
int patchSize = patchRadius * 2 + 1;
|
||||
|
||||
if (gid.x >= uint(patchSize) || gid.y >= uint(patchSize)) {
|
||||
return;
|
||||
}
|
||||
|
||||
int2 offset = int2(gid) - int2(patchRadius);
|
||||
int2 srcCoord = sourcePos + offset;
|
||||
int2 dstCoord = targetPos + offset;
|
||||
|
||||
int width = sourceTexture.get_width();
|
||||
int height = sourceTexture.get_height();
|
||||
|
||||
// Check bounds
|
||||
if (srcCoord.x < 0 || srcCoord.x >= width ||
|
||||
srcCoord.y < 0 || srcCoord.y >= height ||
|
||||
dstCoord.x < 0 || dstCoord.x >= width ||
|
||||
dstCoord.y < 0 || dstCoord.y >= height) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Only write to masked (unknown) pixels
|
||||
float maskValue = maskTexture.read(uint2(dstCoord)).r;
|
||||
if (maskValue > 0.5) {
|
||||
float4 sourceColor = sourceTexture.read(uint2(srcCoord));
|
||||
targetTexture.write(sourceColor, uint2(dstCoord));
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Edge-aware blending
|
||||
|
||||
kernel void edgeAwareBlend(
|
||||
texture2d<float, access::read> inTexture [[texture(0)]],
|
||||
texture2d<float, access::read> originalTexture [[texture(1)]],
|
||||
texture2d<float, access::read> maskTexture [[texture(2)]],
|
||||
texture2d<float, access::write> outTexture [[texture(3)]],
|
||||
constant float &blendRadius [[buffer(0)]],
|
||||
uint2 gid [[thread_position_in_grid]]
|
||||
) {
|
||||
int width = inTexture.get_width();
|
||||
int height = inTexture.get_height();
|
||||
|
||||
if (gid.x >= uint(width) || gid.y >= uint(height)) {
|
||||
return;
|
||||
}
|
||||
|
||||
float maskValue = maskTexture.read(gid).r;
|
||||
float4 inpaintedColor = inTexture.read(gid);
|
||||
float4 originalColor = originalTexture.read(gid);
|
||||
|
||||
// If far from mask boundary, use appropriate color directly
|
||||
if (maskValue < 0.1) {
|
||||
outTexture.write(originalColor, gid);
|
||||
return;
|
||||
}
|
||||
if (maskValue > 0.9) {
|
||||
outTexture.write(inpaintedColor, gid);
|
||||
return;
|
||||
}
|
||||
|
||||
// Blend in the transition zone
|
||||
float4 result = mix(originalColor, inpaintedColor, maskValue);
|
||||
outTexture.write(result, gid);
|
||||
}
|
||||
Reference in New Issue
Block a user