MLIR  19.0.0git
GPUToLLVMSPV.cpp
Go to the documentation of this file.
1 //===- GPUToLLVMSPV.cpp - Convert GPU operations to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/SymbolTable.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Support/LLVM.h"
28 
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/FormatVariadic.h"
31 
32 using namespace mlir;
33 
34 namespace mlir {
35 #define GEN_PASS_DEF_CONVERTGPUOPSTOLLVMSPVOPS
36 #include "mlir/Conversion/Passes.h.inc"
37 } // namespace mlir
38 
39 //===----------------------------------------------------------------------===//
40 // Helper Functions
41 //===----------------------------------------------------------------------===//
42 
43 static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
44  StringRef name,
45  ArrayRef<Type> paramTypes,
46  Type resultType,
47  bool isConvergent = false) {
48  auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
49  SymbolTable::lookupSymbolIn(symbolTable, name));
50  if (!func) {
51  OpBuilder b(symbolTable->getRegion(0));
52  func = b.create<LLVM::LLVMFuncOp>(
53  symbolTable->getLoc(), name,
54  LLVM::LLVMFunctionType::get(resultType, paramTypes));
55  func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
56  func.setConvergent(isConvergent);
57  }
58  return func;
59 }
60 
61 static LLVM::CallOp createSPIRVBuiltinCall(Location loc,
62  ConversionPatternRewriter &rewriter,
63  LLVM::LLVMFuncOp func,
64  ValueRange args) {
65  auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
66  call.setCConv(func.getCConv());
67  return call;
68 }
69 
70 namespace {
71 //===----------------------------------------------------------------------===//
72 // Barriers
73 //===----------------------------------------------------------------------===//
74 
75 /// Replace `gpu.barrier` with an `llvm.call` to `barrier` with
76 /// `CLK_LOCAL_MEM_FENCE` argument, indicating work-group memory scope:
77 /// ```
78 /// // gpu.barrier
79 /// %c1 = llvm.mlir.constant(1: i32) : i32
80 /// llvm.call spir_funccc @_Z7barrierj(%c1) : (i32) -> ()
81 /// ```
82 struct GPUBarrierConversion final : ConvertOpToLLVMPattern<gpu::BarrierOp> {
84 
85  LogicalResult
86  matchAndRewrite(gpu::BarrierOp op, OpAdaptor adaptor,
87  ConversionPatternRewriter &rewriter) const final {
88  constexpr StringLiteral funcName = "_Z7barrierj";
89 
91  assert(moduleOp && "Expecting module");
92  Type flagTy = rewriter.getI32Type();
93  Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
94  LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
95  moduleOp, funcName, flagTy, voidTy, /*isConvergent=*/true);
96 
97  // Value used by SPIR-V backend to represent `CLK_LOCAL_MEM_FENCE`.
98  // See `llvm/lib/Target/SPIRV/SPIRVBuiltins.td`.
99  constexpr int64_t localMemFenceFlag = 1;
100  Location loc = op->getLoc();
101  Value flag =
102  rewriter.create<LLVM::ConstantOp>(loc, flagTy, localMemFenceFlag);
103  rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, flag));
104  return success();
105  }
106 };
107 
108 //===----------------------------------------------------------------------===//
109 // SPIR-V Builtins
110 //===----------------------------------------------------------------------===//
111 
112 /// Replace `gpu.*` with an `llvm.call` to the corresponding SPIR-V builtin with
113 /// a constant argument for the `dimension` attribute. Return type will depend
114 /// on index width option:
115 /// ```
116 /// // %thread_id_y = gpu.thread_id y
117 /// %c1 = llvm.mlir.constant(1: i32) : i32
118 /// %0 = llvm.call spir_funccc @_Z12get_local_idj(%c1) : (i32) -> i64
119 /// ```
120 struct LaunchConfigConversion : ConvertToLLVMPattern {
121  LaunchConfigConversion(StringRef funcName, StringRef rootOpName,
122  MLIRContext *context,
123  const LLVMTypeConverter &typeConverter,
124  PatternBenefit benefit)
125  : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit),
126  funcName(funcName) {}
127 
128  virtual gpu::Dimension getDimension(Operation *op) const = 0;
129 
130  LogicalResult
131  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
132  ConversionPatternRewriter &rewriter) const final {
134  assert(moduleOp && "Expecting module");
135  Type dimTy = rewriter.getI32Type();
136  Type indexTy = getTypeConverter()->getIndexType();
137  LLVM::LLVMFuncOp func =
138  lookupOrCreateSPIRVFn(moduleOp, funcName, dimTy, indexTy);
139 
140  Location loc = op->getLoc();
141  gpu::Dimension dim = getDimension(op);
142  Value dimVal = rewriter.create<LLVM::ConstantOp>(loc, dimTy,
143  static_cast<int64_t>(dim));
144  rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, dimVal));
145  return success();
146  }
147 
148  StringRef funcName;
149 };
150 
151 template <typename SourceOp>
152 struct LaunchConfigOpConversion final : LaunchConfigConversion {
153  static StringRef getFuncName();
154 
155  explicit LaunchConfigOpConversion(const LLVMTypeConverter &typeConverter,
156  PatternBenefit benefit = 1)
157  : LaunchConfigConversion(getFuncName(), SourceOp::getOperationName(),
158  &typeConverter.getContext(), typeConverter,
159  benefit) {}
160 
161  gpu::Dimension getDimension(Operation *op) const final {
162  return cast<SourceOp>(op).getDimension();
163  }
164 };
165 
166 template <>
167 StringRef LaunchConfigOpConversion<gpu::BlockIdOp>::getFuncName() {
168  return "_Z12get_group_idj";
169 }
170 
171 template <>
172 StringRef LaunchConfigOpConversion<gpu::GridDimOp>::getFuncName() {
173  return "_Z14get_num_groupsj";
174 }
175 
176 template <>
177 StringRef LaunchConfigOpConversion<gpu::BlockDimOp>::getFuncName() {
178  return "_Z14get_local_sizej";
179 }
180 
181 template <>
182 StringRef LaunchConfigOpConversion<gpu::ThreadIdOp>::getFuncName() {
183  return "_Z12get_local_idj";
184 }
185 
186 template <>
187 StringRef LaunchConfigOpConversion<gpu::GlobalIdOp>::getFuncName() {
188  return "_Z13get_global_idj";
189 }
190 
191 //===----------------------------------------------------------------------===//
192 // Shuffles
193 //===----------------------------------------------------------------------===//
194 
195 /// Replace `gpu.shuffle` with an `llvm.call` to the corresponding SPIR-V
196 /// builtin for `shuffleResult`, keeping `value` and `offset` arguments, and a
197 /// `true` constant for the `valid` result type. Conversion will only take place
198 /// if `width` is constant and equal to the `subgroup` pass option:
199 /// ```
200 /// // %0 = gpu.shuffle idx %value, %offset, %width : f64
201 /// %0 = llvm.call spir_funccc @_Z17sub_group_shuffledj(%value, %offset)
202 /// : (f64, i32) -> f64
203 /// ```
204 struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
206 
207  static StringRef getBaseName(gpu::ShuffleMode mode) {
208  switch (mode) {
209  case gpu::ShuffleMode::IDX:
210  return "sub_group_shuffle";
211  case gpu::ShuffleMode::XOR:
212  return "sub_group_shuffle_xor";
213  case gpu::ShuffleMode::UP:
214  return "sub_group_shuffle_up";
215  case gpu::ShuffleMode::DOWN:
216  return "sub_group_shuffle_down";
217  }
218  llvm_unreachable("Unhandled shuffle mode");
219  }
220 
221  static StringRef getTypeMangling(Type type) {
222  return TypeSwitch<Type, StringRef>(type)
223  .Case<Float32Type>([](auto) { return "fj"; })
224  .Case<Float64Type>([](auto) { return "dj"; })
225  .Case<IntegerType>([](auto intTy) {
226  switch (intTy.getWidth()) {
227  case 32:
228  return "ij";
229  case 64:
230  return "lj";
231  }
232  llvm_unreachable("Invalid integer width");
233  });
234  }
235 
236  static std::string getFuncName(gpu::ShuffleOp op) {
237  StringRef baseName = getBaseName(op.getMode());
238  StringRef typeMangling = getTypeMangling(op.getType(0));
239  return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName,
240  typeMangling);
241  }
242 
243  /// Get the subgroup size from the target or return a default.
244  static int getSubgroupSize(Operation *op) {
247  .getSubgroupSize();
248  }
249 
250  static bool hasValidWidth(gpu::ShuffleOp op) {
251  llvm::APInt val;
252  Value width = op.getWidth();
253  return matchPattern(width, m_ConstantInt(&val)) &&
254  val == getSubgroupSize(op);
255  }
256 
257  LogicalResult
258  matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
259  ConversionPatternRewriter &rewriter) const final {
260  if (!hasValidWidth(op))
261  return rewriter.notifyMatchFailure(
262  op, "shuffle width and subgroup size mismatch");
263 
264  std::string funcName = getFuncName(op);
265 
267  assert(moduleOp && "Expecting module");
268  Type valueType = adaptor.getValue().getType();
269  Type offsetType = adaptor.getOffset().getType();
270  Type resultType = valueType;
271  LLVM::LLVMFuncOp func =
272  lookupOrCreateSPIRVFn(moduleOp, funcName, {valueType, offsetType},
273  resultType, /*isConvergent=*/true);
274 
275  Location loc = op->getLoc();
276  std::array<Value, 2> args{adaptor.getValue(), adaptor.getOffset()};
277  Value result =
278  createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
279  Value trueVal =
280  rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
281  rewriter.replaceOp(op, {result, trueVal});
282  return success();
283  }
284 };
285 
286 //===----------------------------------------------------------------------===//
287 // GPU To LLVM-SPV Pass.
288 //===----------------------------------------------------------------------===//
289 
290 struct GPUToLLVMSPVConversionPass final
291  : impl::ConvertGpuOpsToLLVMSPVOpsBase<GPUToLLVMSPVConversionPass> {
292  using Base::Base;
293 
294  void runOnOperation() final {
295  MLIRContext *context = &getContext();
296  RewritePatternSet patterns(context);
297 
298  LowerToLLVMOptions options(context);
299  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
300  options.overrideIndexBitwidth(indexBitwidth);
301 
302  LLVMTypeConverter converter(context, options);
303  LLVMConversionTarget target(*context);
304 
305  target.addIllegalOp<gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp,
306  gpu::GlobalIdOp, gpu::GridDimOp, gpu::ShuffleOp,
307  gpu::ThreadIdOp>();
308 
309  populateGpuToLLVMSPVConversionPatterns(converter, patterns);
310 
311  if (failed(applyPartialConversion(getOperation(), target,
312  std::move(patterns))))
313  signalPassFailure();
314  }
315 };
316 } // namespace
317 
318 //===----------------------------------------------------------------------===//
319 // GPU To LLVM-SPV Patterns.
320 //===----------------------------------------------------------------------===//
321 
322 namespace mlir {
324  RewritePatternSet &patterns) {
325  patterns.add<GPUBarrierConversion, GPUShuffleConversion,
326  LaunchConfigOpConversion<gpu::BlockIdOp>,
327  LaunchConfigOpConversion<gpu::GridDimOp>,
328  LaunchConfigOpConversion<gpu::BlockDimOp>,
329  LaunchConfigOpConversion<gpu::ThreadIdOp>,
330  LaunchConfigOpConversion<gpu::GlobalIdOp>>(typeConverter);
331 }
332 } // namespace mlir
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isConvergent=false)
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class implements a pattern rewriter for use with ConversionPatterns.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:147
Base class for operation conversions targeting the LLVM IR dialect.
Definition: Pattern.h:41
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateGpuToLLVMSPVConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.