23 #include "llvm/Support/Regex.h"
26 #define GEN_PASS_DEF_GPUSPIRVATTACHTARGET
27 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
34 struct SPIRVAttachTarget
35 :
public impl::GpuSPIRVAttachTargetBase<SPIRVAttachTarget> {
38 void runOnOperation()
override;
41 registry.
insert<spirv::SPIRVDialect>();
46 void SPIRVAttachTarget::runOnOperation() {
48 auto versionSymbol = symbolizeVersion(spirvVersion);
50 return signalPassFailure();
51 auto apiSymbol = symbolizeClientAPI(clientApi);
53 return signalPassFailure();
54 auto vendorSymbol = symbolizeVendor(deviceVendor);
56 return signalPassFailure();
57 auto deviceTypeSymbol = symbolizeDeviceType(deviceType);
58 if (!deviceTypeSymbol)
59 return signalPassFailure();
61 if (!deviceId.hasValue())
64 Version version = versionSymbol.value();
67 for (
const auto &cap : spirvCapabilities) {
68 auto capSymbol = symbolizeCapability(cap);
70 capabilities.push_back(capSymbol.value());
73 for (
const auto &ext : spirvExtensions) {
74 auto extSymbol = symbolizeExtension(ext);
76 extensions.push_back(extSymbol.value());
81 apiSymbol.value(), vendorSymbol.value(),
82 deviceTypeSymbol.value(), deviceId);
83 llvm::Regex matcher(moduleMatcher);
84 getOperation()->walk([&](gpu::GPUModuleOp gpuModule) {
86 if (!moduleMatcher.empty() && !matcher.match(gpuModule.getName()))
90 if (std::optional<ArrayAttr> attrs = gpuModule.getTargets())
91 targets.append(attrs->getValue().begin(), attrs->getValue().end());
92 targets.push_back(target);
94 targets.erase(llvm::unique(targets), targets.end());
96 gpuModule.setTargetsAttr(builder.getArrayAttr(targets));
static MLIRContext * getContext(OpFoldResult val)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class helps build Operations.
static constexpr uint32_t kUnknownDeviceID
ID for unknown devices.
static TargetEnvAttr get(VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI=ClientAPI::Unknown, Vendor vendorID=Vendor::Unknown, DeviceType deviceType=DeviceType::Unknown, uint32_t deviceId=kUnknownDeviceID)
Gets a TargetEnvAttr instance.
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
ResourceLimitsAttr getDefaultResourceLimits(MLIRContext *context)
Returns a default resource limits attribute that uses numbers from "Table 46. Required Limits" of the...
Include the generated interface declarations.