MLIR  20.0.0git
SPIRVAttachTarget.cpp
Go to the documentation of this file.
1 //===- SPIRVAttachTarget.cpp - Attach an SPIR-V target --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the `GPUSPIRVAttachTarget` pass, attaching
10 // `#spirv.target_env` attributes to GPU modules.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
20 #include "mlir/IR/Builders.h"
21 #include "mlir/Pass/Pass.h"
23 #include "llvm/Support/Regex.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_GPUSPIRVATTACHTARGET
27 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace mlir::spirv;
32 
33 namespace {
34 struct SPIRVAttachTarget
35  : public impl::GpuSPIRVAttachTargetBase<SPIRVAttachTarget> {
36  using Base::Base;
37 
38  void runOnOperation() override;
39 
40  void getDependentDialects(DialectRegistry &registry) const override {
41  registry.insert<spirv::SPIRVDialect>();
42  }
43 };
44 } // namespace
45 
46 void SPIRVAttachTarget::runOnOperation() {
47  OpBuilder builder(&getContext());
48  auto versionSymbol = symbolizeVersion(spirvVersion);
49  if (!versionSymbol)
50  return signalPassFailure();
51  auto apiSymbol = symbolizeClientAPI(clientApi);
52  if (!apiSymbol)
53  return signalPassFailure();
54  auto vendorSymbol = symbolizeVendor(deviceVendor);
55  if (!vendorSymbol)
56  return signalPassFailure();
57  auto deviceTypeSymbol = symbolizeDeviceType(deviceType);
58  if (!deviceTypeSymbol)
59  return signalPassFailure();
60  // Set the default device ID if none was given
61  if (!deviceId.hasValue())
63 
64  Version version = versionSymbol.value();
65  SmallVector<Capability, 4> capabilities;
66  SmallVector<Extension, 8> extensions;
67  for (const auto &cap : spirvCapabilities) {
68  auto capSymbol = symbolizeCapability(cap);
69  if (capSymbol)
70  capabilities.push_back(capSymbol.value());
71  }
72  ArrayRef<Capability> caps(capabilities);
73  for (const auto &ext : spirvExtensions) {
74  auto extSymbol = symbolizeExtension(ext);
75  if (extSymbol)
76  extensions.push_back(extSymbol.value());
77  }
78  ArrayRef<Extension> exts(extensions);
79  VerCapExtAttr vce = VerCapExtAttr::get(version, caps, exts, &getContext());
81  apiSymbol.value(), vendorSymbol.value(),
82  deviceTypeSymbol.value(), deviceId);
83  llvm::Regex matcher(moduleMatcher);
84  getOperation()->walk([&](gpu::GPUModuleOp gpuModule) {
85  // Check if the name of the module matches.
86  if (!moduleMatcher.empty() && !matcher.match(gpuModule.getName()))
87  return;
88  // Create the target array.
89  SmallVector<Attribute> targets;
90  if (std::optional<ArrayAttr> attrs = gpuModule.getTargets())
91  targets.append(attrs->getValue().begin(), attrs->getValue().end());
92  targets.push_back(target);
93  // Remove any duplicate targets.
94  targets.erase(llvm::unique(targets), targets.end());
95  // Update the target attribute array.
96  gpuModule.setTargetsAttr(builder.getArrayAttr(targets));
97  });
98 }
static MLIRContext * getContext(OpFoldResult val)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class helps build Operations.
Definition: Builders.h:211
static constexpr uint32_t kUnknownDeviceID
ID for unknown devices.
static TargetEnvAttr get(VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI=ClientAPI::Unknown, Vendor vendorID=Vendor::Unknown, DeviceType deviceType=DeviceType::Unknown, uint32_t deviceId=kUnknownDeviceID)
Gets a TargetEnvAttr instance.
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
ResourceLimitsAttr getDefaultResourceLimits(MLIRContext *context)
Returns a default resource limits attribute that uses numbers from "Table 46. Required Limits" of the...
Include the generated interface declarations.