23#include "llvm/Support/Regex.h"
26#define GEN_PASS_DEF_GPUSPIRVATTACHTARGET
27#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
34struct SPIRVAttachTarget
38 void runOnOperation()
override;
40 void getDependentDialects(DialectRegistry ®istry)
const override {
41 registry.
insert<spirv::SPIRVDialect>();
46void SPIRVAttachTarget::runOnOperation() {
48 auto versionSymbol = symbolizeVersion(spirvVersion);
50 return signalPassFailure();
51 auto apiSymbol = symbolizeClientAPI(clientApi);
53 return signalPassFailure();
54 auto vendorSymbol = symbolizeVendor(deviceVendor);
56 return signalPassFailure();
57 auto deviceTypeSymbol = symbolizeDeviceType(deviceType);
58 if (!deviceTypeSymbol)
59 return signalPassFailure();
61 if (!deviceId.hasValue())
64 Version version = versionSymbol.value();
65 SmallVector<Capability, 4> capabilities;
66 SmallVector<Extension, 8> extensions;
67 for (
const auto &cap : spirvCapabilities) {
68 auto capSymbol = symbolizeCapability(cap);
70 capabilities.push_back(capSymbol.value());
72 ArrayRef<Capability> caps(capabilities);
73 for (
const auto &ext : spirvExtensions) {
74 auto extSymbol = symbolizeExtension(ext);
76 extensions.push_back(extSymbol.value());
78 ArrayRef<Extension> exts(extensions);
81 apiSymbol.value(), vendorSymbol.value(),
82 deviceTypeSymbol.value(), deviceId);
83 llvm::Regex matcher(moduleMatcher);
84 getOperation()->walk([&](gpu::GPUModuleOp gpuModule) {
86 if (!moduleMatcher.empty() && !matcher.match(gpuModule.getName()))
89 SmallVector<Attribute> targets;
90 if (std::optional<ArrayAttr> attrs = gpuModule.getTargets())
91 targets.append(attrs->getValue().begin(), attrs->getValue().end());
94 targets.erase(llvm::unique(targets), targets.end());
96 gpuModule.setTargetsAttr(builder.getArrayAttr(targets));
static constexpr uint32_t kUnknownDeviceID
ID for unknown devices.
static TargetEnvAttr get(VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI=ClientAPI::Unknown, Vendor vendorID=Vendor::Unknown, DeviceType deviceType=DeviceType::Unknown, uint32_t deviceId=kUnknownDeviceID)
Gets a TargetEnvAttr instance.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
ResourceLimitsAttr getDefaultResourceLimits(MLIRContext *context)
Returns a default resource limits attribute that uses numbers from "Table 46. Required Limits" of the...
Include the generated interface declarations.