我有一个张量流图,我想转换为CoreML,但它使用了一些缺少的操作,我将不得不实现为自定义图层 .
我现在关注的两个操作是 Sin
和 FloorDiv
.
Sin
非常简单,我可以关注this tutorial,我有一个工作的Swift类和 Metal
内核来完成这项工作,我用玩具coreml文件测试过:
import Foundation
import CoreML
import Accelerate
@objc(Sin) class Sin: NSObject, MLCustomLayer {
let sinPipeline: MTLComputePipelineState
required init(parameters: [String : Any]) throws {
print(#function, parameters)
let sinFunction = GPUDispatch.sharedInstance.library.makeFunction(name: "sin")!
sinPipeline = try! GPUDispatch.sharedInstance.device.makeComputePipelineState(
function: sinFunction)
super.init()
}
func setWeightData(_ weights: [Data]) throws {
print(#function, weights)
}
func outputShapes(forInputShapes inputShapes: [[NSNumber]]) throws
-> [[NSNumber]] {
print(#function, inputShapes)
return inputShapes
}
func evaluate(inputs: [MLMultiArray], outputs: [MLMultiArray]) throws {
for i in 0..<inputs.count {
let input = inputs[i]
let output = outputs[i]
var count = Int32(input.count)
let iptr = UnsafeMutablePointer<Float>(OpaquePointer(input.dataPointer))
let optr = UnsafeMutablePointer<Float>(OpaquePointer(output.dataPointer))
vvsinf(optr, iptr, &count)
}
}
func encode(commandBuffer: MTLCommandBuffer,
inputs: [MTLTexture], outputs: [MTLTexture]) throws {
if let encoder = commandBuffer.makeComputeCommandEncoder() {
for i in 0..<inputs.count {
encoder.setTexture(inputs[i], index: 0)
encoder.setTexture(outputs[i], index: 1)
encoder.dispatch(pipeline: sinPipeline, texture: inputs[i])
encoder.endEncoding()
}
}
}
}
并在 Sin.metal
:
kernel void sin(
texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
ushort3 gid [[thread_position_in_grid]])
{
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height()) {
return;
}
const float4 x = float4(inTexture.read(gid.xy, gid.z));
const float4 y = sin(x);
outTexture.write(half4(y), gid.xy, gid.z);
}
我不明白的是,如果自定义图层有两个输入,这将是如何工作的,例如我需要 FloorDiv
,它返回 floor(x / y)
.
我如何调整我提供的 Sin
类来生成像 sin(x*y)
这样的东西,即使它只是在CPU上?这类东西还有其他好的教程吗?
1 回答
这种模式与我的预期不同,但现在很明显我已经使用了代码了 .
这是一个实现
FloorDiv
的类:这是金属内核: