MLIR  22.0.0git
SPIRVToLLVM.cpp
Go to the documentation of this file.
1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/PatternMatch.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/FormatVariadic.h"
25 
26 #define DEBUG_TYPE "spirv-to-llvm-pattern"
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Utility functions
32 //===----------------------------------------------------------------------===//
33 
34 /// Returns true if the given type is a signed integer or vector type.
35 static bool isSignedIntegerOrVector(Type type) {
36  if (type.isSignedInteger())
37  return true;
38  if (auto vecType = dyn_cast<VectorType>(type))
39  return vecType.getElementType().isSignedInteger();
40  return false;
41 }
42 
43 /// Returns true if the given type is an unsigned integer or vector type
44 static bool isUnsignedIntegerOrVector(Type type) {
45  if (type.isUnsignedInteger())
46  return true;
47  if (auto vecType = dyn_cast<VectorType>(type))
48  return vecType.getElementType().isUnsignedInteger();
49  return false;
50 }
51 
52 /// Returns the width of an integer or of the element type of an integer vector,
53 /// if applicable.
54 static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
55  if (auto intType = dyn_cast<IntegerType>(type))
56  return intType.getWidth();
57  if (auto vecType = dyn_cast<VectorType>(type))
58  if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
59  return intType.getWidth();
60  return std::nullopt;
61 }
62 
63 /// Returns the bit width of integer, float or vector of float or integer values
64 static unsigned getBitWidth(Type type) {
65  assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
66  "bitwidth is not supported for this type");
67  if (type.isIntOrFloat())
68  return type.getIntOrFloatBitWidth();
69  auto vecType = dyn_cast<VectorType>(type);
70  auto elementType = vecType.getElementType();
71  assert(elementType.isIntOrFloat() &&
72  "only integers and floats have a bitwidth");
73  return elementType.getIntOrFloatBitWidth();
74 }
75 
76 /// Returns the bit width of LLVMType integer or vector.
77 static unsigned getLLVMTypeBitWidth(Type type) {
78  if (auto vecTy = dyn_cast<VectorType>(type))
79  type = vecTy.getElementType();
80  return cast<IntegerType>(type).getWidth();
81 }
82 
83 /// Creates `IntegerAttribute` with all bits set for given type
84 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
85  if (auto vecType = dyn_cast<VectorType>(type)) {
86  auto integerType = cast<IntegerType>(vecType.getElementType());
87  return builder.getIntegerAttr(integerType, -1);
88  }
89  auto integerType = cast<IntegerType>(type);
90  return builder.getIntegerAttr(integerType, -1);
91 }
92 
93 /// Creates `llvm.mlir.constant` with all bits set for the given type.
94 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
95  PatternRewriter &rewriter) {
96  if (isa<VectorType>(srcType)) {
97  return rewriter.create<LLVM::ConstantOp>(
98  loc, dstType,
99  SplatElementsAttr::get(cast<ShapedType>(srcType),
100  minusOneIntegerAttribute(srcType, rewriter)));
101  }
102  return rewriter.create<LLVM::ConstantOp>(
103  loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
104 }
105 
106 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
107 static Value createFPConstant(Location loc, Type srcType, Type dstType,
108  PatternRewriter &rewriter, double value) {
109  if (auto vecType = dyn_cast<VectorType>(srcType)) {
110  auto floatType = cast<FloatType>(vecType.getElementType());
111  return rewriter.create<LLVM::ConstantOp>(
112  loc, dstType,
113  SplatElementsAttr::get(vecType,
114  rewriter.getFloatAttr(floatType, value)));
115  }
116  auto floatType = cast<FloatType>(srcType);
117  return rewriter.create<LLVM::ConstantOp>(
118  loc, dstType, rewriter.getFloatAttr(floatType, value));
119 }
120 
121 /// Utility function for bitfield ops:
122 /// - `BitFieldInsert`
123 /// - `BitFieldSExtract`
124 /// - `BitFieldUExtract`
125 /// Truncates or extends the value. If the bitwidth of the value is the same as
126 /// `llvmType` bitwidth, the value remains unchanged.
128  Type llvmType,
129  PatternRewriter &rewriter) {
130  auto srcType = value.getType();
131  unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
132  unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
133  ? getLLVMTypeBitWidth(srcType)
134  : getBitWidth(srcType);
135 
136  if (valueBitWidth < targetBitWidth)
137  return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
138  // If the bit widths of `Count` and `Offset` are greater than the bit width
139  // of the target type, they are truncated. Truncation is safe since `Count`
140  // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
141  // both values can be expressed in 8 bits.
142  if (valueBitWidth > targetBitWidth)
143  return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
144  return value;
145 }
146 
147 /// Broadcasts the value to vector with `numElements` number of elements.
148 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
149  const TypeConverter &typeConverter,
150  ConversionPatternRewriter &rewriter) {
151  auto vectorType = VectorType::get(numElements, toBroadcast.getType());
152  auto llvmVectorType = typeConverter.convertType(vectorType);
153  auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
154  Value broadcasted = rewriter.create<LLVM::PoisonOp>(loc, llvmVectorType);
155  for (unsigned i = 0; i < numElements; ++i) {
156  auto index = rewriter.create<LLVM::ConstantOp>(
157  loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
158  broadcasted = rewriter.create<LLVM::InsertElementOp>(
159  loc, llvmVectorType, broadcasted, toBroadcast, index);
160  }
161  return broadcasted;
162 }
163 
164 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
165 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
166  const TypeConverter &typeConverter,
167  ConversionPatternRewriter &rewriter) {
168  if (auto vectorType = dyn_cast<VectorType>(srcType)) {
169  unsigned numElements = vectorType.getNumElements();
170  return broadcast(loc, value, numElements, typeConverter, rewriter);
171  }
172  return value;
173 }
174 
175 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
176 /// `BitFieldUExtract`.
177 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
178 /// a vector type, construct a vector that has:
179 /// - same number of elements as `Base`
180 /// - each element has the type that is the same as the type of `Offset` or
181 /// `Count`
182 /// - each element has the same value as `Offset` or `Count`
183 /// Then cast `Offset` and `Count` if their bit width is different
184 /// from `Base` bit width.
185 static Value processCountOrOffset(Location loc, Value value, Type srcType,
186  Type dstType, const TypeConverter &converter,
187  ConversionPatternRewriter &rewriter) {
188  Value broadcasted =
189  optionallyBroadcast(loc, value, srcType, converter, rewriter);
190  return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
191 }
192 
193 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
194 /// offset to LLVM struct. Otherwise, the conversion is not supported.
196  const TypeConverter &converter) {
197  if (type != VulkanLayoutUtils::decorateType(type))
198  return nullptr;
199 
200  SmallVector<Type> elementsVector;
201  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
202  return nullptr;
203  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
204  /*isPacked=*/false);
205 }
206 
207 /// Converts SPIR-V struct with no offset to packed LLVM struct.
209  const TypeConverter &converter) {
210  SmallVector<Type> elementsVector;
211  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
212  return nullptr;
213  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
214  /*isPacked=*/true);
215 }
216 
217 /// Creates LLVM dialect constant with the given value.
219  unsigned value) {
220  return rewriter.create<LLVM::ConstantOp>(
221  loc, IntegerType::get(rewriter.getContext(), 32),
222  rewriter.getIntegerAttr(rewriter.getI32Type(), value));
223 }
224 
225 /// Utility for `spirv.Load` and `spirv.Store` conversion.
226 static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
227  ConversionPatternRewriter &rewriter,
228  const TypeConverter &typeConverter,
229  unsigned alignment, bool isVolatile,
230  bool isNonTemporal) {
231  if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
232  auto dstType = typeConverter.convertType(loadOp.getType());
233  if (!dstType)
234  return rewriter.notifyMatchFailure(op, "type conversion failed");
235  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
236  loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
237  isVolatile, isNonTemporal);
238  return success();
239  }
240  auto storeOp = cast<spirv::StoreOp>(op);
241  spirv::StoreOpAdaptor adaptor(operands);
242  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
243  adaptor.getPtr(), alignment,
244  isVolatile, isNonTemporal);
245  return success();
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // Type conversion
250 //===----------------------------------------------------------------------===//
251 
252 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
253 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
254 /// when converting ops that manipulate array types.
255 static std::optional<Type> convertArrayType(spirv::ArrayType type,
256  TypeConverter &converter) {
257  unsigned stride = type.getArrayStride();
258  Type elementType = type.getElementType();
259  auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
260  if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
261  return std::nullopt;
262 
263  auto llvmElementType = converter.convertType(elementType);
264  unsigned numElements = type.getNumElements();
265  return LLVM::LLVMArrayType::get(llvmElementType, numElements);
266 }
267 
268 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
269 /// modelled at the moment.
271  const TypeConverter &converter,
272  spirv::ClientAPI clientAPI) {
273  unsigned addressSpace =
274  storageClassToAddressSpace(clientAPI, type.getStorageClass());
275  return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
276 }
277 
278 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
279 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
280 /// no modelling of array stride at the moment.
281 static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
282  TypeConverter &converter) {
283  if (type.getArrayStride() != 0)
284  return std::nullopt;
285  auto elementType = converter.convertType(type.getElementType());
286  return LLVM::LLVMArrayType::get(elementType, 0);
287 }
288 
289 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
290 /// member decorations. Also, only natural offset is supported.
292  const TypeConverter &converter) {
294  type.getMemberDecorations(memberDecorations);
295  if (!memberDecorations.empty())
296  return nullptr;
297  if (type.hasOffset())
298  return convertStructTypeWithOffset(type, converter);
299  return convertStructTypePacked(type, converter);
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // Operation conversion
304 //===----------------------------------------------------------------------===//
305 
306 namespace {
307 
308 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
309 public:
311 
312  LogicalResult
313  matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
314  ConversionPatternRewriter &rewriter) const override {
315  auto dstType =
316  getTypeConverter()->convertType(op.getComponentPtr().getType());
317  if (!dstType)
318  return rewriter.notifyMatchFailure(op, "type conversion failed");
319  // To use GEP we need to add a first 0 index to go through the pointer.
320  auto indices = llvm::to_vector<4>(adaptor.getIndices());
321  Type indexType = op.getIndices().front().getType();
322  auto llvmIndexType = getTypeConverter()->convertType(indexType);
323  if (!llvmIndexType)
324  return rewriter.notifyMatchFailure(op, "type conversion failed");
325  Value zero = rewriter.create<LLVM::ConstantOp>(
326  op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
327  indices.insert(indices.begin(), zero);
328 
329  auto elementType = getTypeConverter()->convertType(
330  cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
331  if (!elementType)
332  return rewriter.notifyMatchFailure(op, "type conversion failed");
333  rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
334  adaptor.getBasePtr(), indices);
335  return success();
336  }
337 };
338 
339 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
340 public:
342 
343  LogicalResult
344  matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
345  ConversionPatternRewriter &rewriter) const override {
346  auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
347  if (!dstType)
348  return rewriter.notifyMatchFailure(op, "type conversion failed");
349  rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
350  op.getVariable());
351  return success();
352  }
353 };
354 
355 class BitFieldInsertPattern
356  : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
357 public:
359 
360  LogicalResult
361  matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
362  ConversionPatternRewriter &rewriter) const override {
363  auto srcType = op.getType();
364  auto dstType = getTypeConverter()->convertType(srcType);
365  if (!dstType)
366  return rewriter.notifyMatchFailure(op, "type conversion failed");
367  Location loc = op.getLoc();
368 
369  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
370  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
371  *getTypeConverter(), rewriter);
372  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
373  *getTypeConverter(), rewriter);
374 
375  // Create a mask with bits set outside [Offset, Offset + Count - 1].
376  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
377  Value maskShiftedByCount =
378  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
379  Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
380  maskShiftedByCount, minusOne);
381  Value maskShiftedByCountAndOffset =
382  rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
383  Value mask = rewriter.create<LLVM::XOrOp>(
384  loc, dstType, maskShiftedByCountAndOffset, minusOne);
385 
386  // Extract unchanged bits from the `Base` that are outside of
387  // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
388  Value baseAndMask =
389  rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
390  Value insertShiftedByOffset =
391  rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
392  rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
393  insertShiftedByOffset);
394  return success();
395  }
396 };
397 
398 /// Converts SPIR-V ConstantOp with scalar or vector type.
399 class ConstantScalarAndVectorPattern
400  : public SPIRVToLLVMConversion<spirv::ConstantOp> {
401 public:
403 
404  LogicalResult
405  matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
406  ConversionPatternRewriter &rewriter) const override {
407  auto srcType = constOp.getType();
408  if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
409  return failure();
410 
411  auto dstType = getTypeConverter()->convertType(srcType);
412  if (!dstType)
413  return rewriter.notifyMatchFailure(constOp, "type conversion failed");
414 
415  // SPIR-V constant can be a signed/unsigned integer, which has to be
416  // casted to signless integer when converting to LLVM dialect. Removing the
417  // sign bit may have unexpected behaviour. However, it is better to handle
418  // it case-by-case, given that the purpose of the conversion is not to
419  // cover all possible corner cases.
420  if (isSignedIntegerOrVector(srcType) ||
421  isUnsignedIntegerOrVector(srcType)) {
422  auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
423 
424  if (isa<VectorType>(srcType)) {
425  auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
426  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
427  constOp, dstType,
428  dstElementsAttr.mapValues(
429  signlessType, [&](const APInt &value) { return value; }));
430  return success();
431  }
432  auto srcAttr = cast<IntegerAttr>(constOp.getValue());
433  auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
434  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
435  return success();
436  }
437  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
438  constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
439  return success();
440  }
441 };
442 
443 class BitFieldSExtractPattern
444  : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
445 public:
447 
448  LogicalResult
449  matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
450  ConversionPatternRewriter &rewriter) const override {
451  auto srcType = op.getType();
452  auto dstType = getTypeConverter()->convertType(srcType);
453  if (!dstType)
454  return rewriter.notifyMatchFailure(op, "type conversion failed");
455  Location loc = op.getLoc();
456 
457  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
458  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
459  *getTypeConverter(), rewriter);
460  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
461  *getTypeConverter(), rewriter);
462 
463  // Create a constant that holds the size of the `Base`.
464  IntegerType integerType;
465  if (auto vecType = dyn_cast<VectorType>(srcType))
466  integerType = cast<IntegerType>(vecType.getElementType());
467  else
468  integerType = cast<IntegerType>(srcType);
469 
470  auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
471  Value size =
472  isa<VectorType>(srcType)
473  ? rewriter.create<LLVM::ConstantOp>(
474  loc, dstType,
475  SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
476  : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
477 
478  // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
479  // at Offset + Count - 1 is the most significant bit now.
480  Value countPlusOffset =
481  rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
482  Value amountToShiftLeft =
483  rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
484  Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
485  loc, dstType, op.getBase(), amountToShiftLeft);
486 
487  // Shift the result right, filling the bits with the sign bit.
488  Value amountToShiftRight =
489  rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
490  rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
491  amountToShiftRight);
492  return success();
493  }
494 };
495 
496 class BitFieldUExtractPattern
497  : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
498 public:
500 
501  LogicalResult
502  matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
503  ConversionPatternRewriter &rewriter) const override {
504  auto srcType = op.getType();
505  auto dstType = getTypeConverter()->convertType(srcType);
506  if (!dstType)
507  return rewriter.notifyMatchFailure(op, "type conversion failed");
508  Location loc = op.getLoc();
509 
510  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
511  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
512  *getTypeConverter(), rewriter);
513  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
514  *getTypeConverter(), rewriter);
515 
516  // Create a mask with bits set at [0, Count - 1].
517  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
518  Value maskShiftedByCount =
519  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
520  Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
521  minusOne);
522 
523  // Shift `Base` by `Offset` and apply the mask on it.
524  Value shiftedBase =
525  rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
526  rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
527  return success();
528  }
529 };
530 
531 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
532 public:
534 
535  LogicalResult
536  matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
537  ConversionPatternRewriter &rewriter) const override {
538  rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
539  branchOp.getTarget());
540  return success();
541  }
542 };
543 
544 class BranchConditionalConversionPattern
545  : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
546 public:
547  using SPIRVToLLVMConversion<
548  spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
549 
550  LogicalResult
551  matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
552  ConversionPatternRewriter &rewriter) const override {
553  // If branch weights exist, map them to 32-bit integer vector.
554  DenseI32ArrayAttr branchWeights = nullptr;
555  if (auto weights = op.getBranchWeights()) {
556  SmallVector<int32_t> weightValues;
557  for (auto weight : weights->getAsRange<IntegerAttr>())
558  weightValues.push_back(weight.getInt());
559  branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
560  }
561 
562  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
563  op, op.getCondition(), op.getTrueBlockArguments(),
564  op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
565  op.getFalseBlock());
566  return success();
567  }
568 };
569 
570 /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
571 /// type is an aggregate type (struct or array). Otherwise, converts to
572 /// `llvm.extractelement` that operates on vectors.
573 class CompositeExtractPattern
574  : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
575 public:
577 
578  LogicalResult
579  matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
580  ConversionPatternRewriter &rewriter) const override {
581  auto dstType = this->getTypeConverter()->convertType(op.getType());
582  if (!dstType)
583  return rewriter.notifyMatchFailure(op, "type conversion failed");
584 
585  Type containerType = op.getComposite().getType();
586  if (isa<VectorType>(containerType)) {
587  Location loc = op.getLoc();
588  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
589  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
590  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
591  op, dstType, adaptor.getComposite(), index);
592  return success();
593  }
594 
595  rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
596  op, adaptor.getComposite(),
597  LLVM::convertArrayToIndices(op.getIndices()));
598  return success();
599  }
600 };
601 
602 /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
603 /// type is an aggregate type (struct or array). Otherwise, converts to
604 /// `llvm.insertelement` that operates on vectors.
605 class CompositeInsertPattern
606  : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
607 public:
609 
610  LogicalResult
611  matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
612  ConversionPatternRewriter &rewriter) const override {
613  auto dstType = this->getTypeConverter()->convertType(op.getType());
614  if (!dstType)
615  return rewriter.notifyMatchFailure(op, "type conversion failed");
616 
617  Type containerType = op.getComposite().getType();
618  if (isa<VectorType>(containerType)) {
619  Location loc = op.getLoc();
620  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
621  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
622  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
623  op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
624  return success();
625  }
626 
627  rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
628  op, adaptor.getComposite(), adaptor.getObject(),
629  LLVM::convertArrayToIndices(op.getIndices()));
630  return success();
631  }
632 };
633 
634 /// Converts SPIR-V operations that have straightforward LLVM equivalent
635 /// into LLVM dialect operations.
636 template <typename SPIRVOp, typename LLVMOp>
637 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
638 public:
640 
641  LogicalResult
642  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
643  ConversionPatternRewriter &rewriter) const override {
644  auto dstType = this->getTypeConverter()->convertType(op.getType());
645  if (!dstType)
646  return rewriter.notifyMatchFailure(op, "type conversion failed");
647  rewriter.template replaceOpWithNewOp<LLVMOp>(
648  op, dstType, adaptor.getOperands(), op->getAttrs());
649  return success();
650  }
651 };
652 
653 /// Converts `spirv.ExecutionMode` into a global struct constant that holds
654 /// execution mode information.
655 class ExecutionModePattern
656  : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
657 public:
659 
660  LogicalResult
661  matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
662  ConversionPatternRewriter &rewriter) const override {
663  // First, create the global struct's name that would be associated with
664  // this entry point's execution mode. We set it to be:
665  // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
666  ModuleOp module = op->getParentOfType<ModuleOp>();
667  spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
668  std::string moduleName;
669  if (module.getName().has_value())
670  moduleName = "_" + module.getName()->str();
671  else
672  moduleName = "";
673  std::string executionModeInfoName = llvm::formatv(
674  "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
675  static_cast<uint32_t>(executionModeAttr.getValue()));
676 
677  MLIRContext *context = rewriter.getContext();
678  OpBuilder::InsertionGuard guard(rewriter);
679  rewriter.setInsertionPointToStart(module.getBody());
680 
681  // Create a struct type, corresponding to the C struct below.
682  // struct {
683  // int32_t executionMode;
684  // int32_t values[]; // optional values
685  // };
686  auto llvmI32Type = IntegerType::get(context, 32);
687  SmallVector<Type, 2> fields;
688  fields.push_back(llvmI32Type);
689  ArrayAttr values = op.getValues();
690  if (!values.empty()) {
691  auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
692  fields.push_back(arrayType);
693  }
694  auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
695 
696  // Create `llvm.mlir.global` with initializer region containing one block.
697  auto global = rewriter.create<LLVM::GlobalOp>(
698  UnknownLoc::get(context), structType, /*isConstant=*/true,
699  LLVM::Linkage::External, executionModeInfoName, Attribute(),
700  /*alignment=*/0);
701  Location loc = global.getLoc();
702  Region &region = global.getInitializerRegion();
703  Block *block = rewriter.createBlock(&region);
704 
705  // Initialize the struct and set the execution mode value.
706  rewriter.setInsertionPointToStart(block);
707  Value structValue = rewriter.create<LLVM::PoisonOp>(loc, structType);
708  Value executionMode = rewriter.create<LLVM::ConstantOp>(
709  loc, llvmI32Type,
710  rewriter.getI32IntegerAttr(
711  static_cast<uint32_t>(executionModeAttr.getValue())));
712  structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
713  executionMode, 0);
714 
715  // Insert extra operands if they exist into execution mode info struct.
716  for (unsigned i = 0, e = values.size(); i < e; ++i) {
717  auto attr = values.getValue()[i];
718  Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
719  structValue = rewriter.create<LLVM::InsertValueOp>(
720  loc, structValue, entry, ArrayRef<int64_t>({1, i}));
721  }
722  rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
723  rewriter.eraseOp(op);
724  return success();
725  }
726 };
727 
728 /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
729 /// global returns a pointer, whereas in LLVM dialect the global holds an actual
730 /// value. This difference is handled by `spirv.mlir.addressof` and
731 /// `llvm.mlir.addressof`ops that both return a pointer.
732 class GlobalVariablePattern
733  : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
734 public:
735  template <typename... Args>
736  GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
737  : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
738  std::forward<Args>(args)...),
739  clientAPI(clientAPI) {}
740 
741  LogicalResult
742  matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
743  ConversionPatternRewriter &rewriter) const override {
744  // Currently, there is no support of initialization with a constant value in
745  // SPIR-V dialect. Specialization constants are not considered as well.
746  if (op.getInitializer())
747  return failure();
748 
749  auto srcType = cast<spirv::PointerType>(op.getType());
750  auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
751  if (!dstType)
752  return rewriter.notifyMatchFailure(op, "type conversion failed");
753 
754  // Limit conversion to the current invocation only or `StorageBuffer`
755  // required by SPIR-V runner.
756  // This is okay because multiple invocations are not supported yet.
757  auto storageClass = srcType.getStorageClass();
758  switch (storageClass) {
759  case spirv::StorageClass::Input:
760  case spirv::StorageClass::Private:
761  case spirv::StorageClass::Output:
762  case spirv::StorageClass::StorageBuffer:
763  case spirv::StorageClass::UniformConstant:
764  break;
765  default:
766  return failure();
767  }
768 
769  // LLVM dialect spec: "If the global value is a constant, storing into it is
770  // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
771  // storage class that is read-only.
772  bool isConstant = (storageClass == spirv::StorageClass::Input) ||
773  (storageClass == spirv::StorageClass::UniformConstant);
774  // SPIR-V spec: "By default, functions and global variables are private to a
775  // module and cannot be accessed by other modules. However, a module may be
776  // written to export or import functions and global (module scope)
777  // variables.". Therefore, map 'Private' storage class to private linkage,
778  // 'Input' and 'Output' to external linkage.
779  auto linkage = storageClass == spirv::StorageClass::Private
780  ? LLVM::Linkage::Private
781  : LLVM::Linkage::External;
782  StringAttr locationAttrName = op.getLocationAttrName();
783  IntegerAttr locationAttr = op.getLocationAttr();
784  auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
785  op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
786  /*alignment=*/0, storageClassToAddressSpace(clientAPI, storageClass));
787 
788  // Attach location attribute if applicable
789  if (locationAttr)
790  newGlobalOp->setAttr(locationAttrName, locationAttr);
791 
792  return success();
793  }
794 
795 private:
796  spirv::ClientAPI clientAPI;
797 };
798 
799 /// Converts SPIR-V cast ops that do not have straightforward LLVM
800 /// equivalent in LLVM dialect.
801 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
802 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
803 public:
805 
806  LogicalResult
807  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
808  ConversionPatternRewriter &rewriter) const override {
809 
810  Type fromType = op.getOperand().getType();
811  Type toType = op.getType();
812 
813  auto dstType = this->getTypeConverter()->convertType(toType);
814  if (!dstType)
815  return rewriter.notifyMatchFailure(op, "type conversion failed");
816 
817  if (getBitWidth(fromType) < getBitWidth(toType)) {
818  rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
819  adaptor.getOperands());
820  return success();
821  }
822  if (getBitWidth(fromType) > getBitWidth(toType)) {
823  rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
824  adaptor.getOperands());
825  return success();
826  }
827  return failure();
828  }
829 };
830 
831 class FunctionCallPattern
832  : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
833 public:
835 
836  LogicalResult
837  matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
838  ConversionPatternRewriter &rewriter) const override {
839  if (callOp.getNumResults() == 0) {
840  auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
841  callOp, TypeRange(), adaptor.getOperands(), callOp->getAttrs());
842  newOp.getProperties().operandSegmentSizes = {
843  static_cast<int32_t>(adaptor.getOperands().size()), 0};
844  newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
845  return success();
846  }
847 
848  // Function returns a single result.
849  auto dstType = getTypeConverter()->convertType(callOp.getType(0));
850  if (!dstType)
851  return rewriter.notifyMatchFailure(callOp, "type conversion failed");
852  auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
853  callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
854  newOp.getProperties().operandSegmentSizes = {
855  static_cast<int32_t>(adaptor.getOperands().size()), 0};
856  newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
857  return success();
858  }
859 };
860 
861 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
862 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
863 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
864 public:
866 
867  LogicalResult
868  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
869  ConversionPatternRewriter &rewriter) const override {
870 
871  auto dstType = this->getTypeConverter()->convertType(op.getType());
872  if (!dstType)
873  return rewriter.notifyMatchFailure(op, "type conversion failed");
874 
875  rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
876  op, dstType, predicate, op.getOperand1(), op.getOperand2());
877  return success();
878  }
879 };
880 
881 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
882 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
883 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
884 public:
886 
887  LogicalResult
888  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
889  ConversionPatternRewriter &rewriter) const override {
890 
891  auto dstType = this->getTypeConverter()->convertType(op.getType());
892  if (!dstType)
893  return rewriter.notifyMatchFailure(op, "type conversion failed");
894 
895  rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
896  op, dstType, predicate, op.getOperand1(), op.getOperand2());
897  return success();
898  }
899 };
900 
901 class InverseSqrtPattern
902  : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
903 public:
905 
906  LogicalResult
907  matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
908  ConversionPatternRewriter &rewriter) const override {
909  auto srcType = op.getType();
910  auto dstType = getTypeConverter()->convertType(srcType);
911  if (!dstType)
912  return rewriter.notifyMatchFailure(op, "type conversion failed");
913 
914  Location loc = op.getLoc();
915  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
916  Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
917  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
918  return success();
919  }
920 };
921 
922 /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
923 template <typename SPIRVOp>
924 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
925 public:
927 
928  LogicalResult
929  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
930  ConversionPatternRewriter &rewriter) const override {
931  if (!op.getMemoryAccess()) {
932  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
933  *this->getTypeConverter(), /*alignment=*/0,
934  /*isVolatile=*/false,
935  /*isNonTemporal=*/false);
936  }
937  auto memoryAccess = *op.getMemoryAccess();
938  switch (memoryAccess) {
939  case spirv::MemoryAccess::Aligned:
941  case spirv::MemoryAccess::Nontemporal:
942  case spirv::MemoryAccess::Volatile: {
943  unsigned alignment =
944  memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
945  bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
946  bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
947  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
948  *this->getTypeConverter(), alignment,
949  isVolatile, isNonTemporal);
950  }
951  default:
952  // There is no support of other memory access attributes.
953  return failure();
954  }
955  }
956 };
957 
958 /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
959 template <typename SPIRVOp>
960 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
961 public:
963 
964  LogicalResult
965  matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
966  ConversionPatternRewriter &rewriter) const override {
967  auto srcType = notOp.getType();
968  auto dstType = this->getTypeConverter()->convertType(srcType);
969  if (!dstType)
970  return rewriter.notifyMatchFailure(notOp, "type conversion failed");
971 
972  Location loc = notOp.getLoc();
973  IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
974  auto mask =
975  isa<VectorType>(srcType)
976  ? rewriter.create<LLVM::ConstantOp>(
977  loc, dstType,
978  SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
979  : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
980  rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
981  notOp.getOperand(), mask);
982  return success();
983  }
984 };
985 
986 /// A template pattern that erases the given `SPIRVOp`.
987 template <typename SPIRVOp>
988 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
989 public:
991 
992  LogicalResult
993  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
994  ConversionPatternRewriter &rewriter) const override {
995  rewriter.eraseOp(op);
996  return success();
997  }
998 };
999 
1000 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
1001 public:
1003 
1004  LogicalResult
1005  matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1006  ConversionPatternRewriter &rewriter) const override {
1007  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
1008  ArrayRef<Value>());
1009  return success();
1010  }
1011 };
1012 
1013 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1014 public:
1016 
1017  LogicalResult
1018  matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1019  ConversionPatternRewriter &rewriter) const override {
1020  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
1021  adaptor.getOperands());
1022  return success();
1023  }
1024 };
1025 
1026 static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
1027  StringRef name,
1028  ArrayRef<Type> paramTypes,
1029  Type resultType,
1030  bool convergent = true) {
1031  auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1032  SymbolTable::lookupSymbolIn(symbolTable, name));
1033  if (func)
1034  return func;
1035 
1036  OpBuilder b(symbolTable->getRegion(0));
1037  func = b.create<LLVM::LLVMFuncOp>(
1038  symbolTable->getLoc(), name,
1039  LLVM::LLVMFunctionType::get(resultType, paramTypes));
1040  func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1041  func.setConvergent(convergent);
1042  func.setNoUnwind(true);
1043  func.setWillReturn(true);
1044  return func;
1045 }
1046 
1047 static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
1048  LLVM::LLVMFuncOp func,
1049  ValueRange args) {
1050  auto call = builder.create<LLVM::CallOp>(loc, func, args);
1051  call.setCConv(func.getCConv());
1052  call.setConvergentAttr(func.getConvergentAttr());
1053  call.setNoUnwindAttr(func.getNoUnwindAttr());
1054  call.setWillReturnAttr(func.getWillReturnAttr());
1055  return call;
1056 }
1057 
1058 template <typename BarrierOpTy>
1059 class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> {
1060 public:
1061  using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor;
1062 
1064 
1065  static constexpr StringRef getFuncName();
1066 
1067  LogicalResult
1068  matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1069  ConversionPatternRewriter &rewriter) const override {
1070  constexpr StringRef funcName = getFuncName();
1071  Operation *symbolTable =
1072  controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1073 
1074  Type i32 = rewriter.getI32Type();
1075 
1076  Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
1077  LLVM::LLVMFuncOp func =
1078  lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
1079 
1080  Location loc = controlBarrierOp->getLoc();
1081  Value execution = rewriter.create<LLVM::ConstantOp>(
1082  loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
1083  Value memory = rewriter.create<LLVM::ConstantOp>(
1084  loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
1085  Value semantics = rewriter.create<LLVM::ConstantOp>(
1086  loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
1087 
1088  auto call = createSPIRVBuiltinCall(loc, rewriter, func,
1089  {execution, memory, semantics});
1090 
1091  rewriter.replaceOp(controlBarrierOp, call);
1092  return success();
1093  }
1094 };
1095 
1096 namespace {
1097 
1098 StringRef getTypeMangling(Type type, bool isSigned) {
1100  .Case<Float16Type>([](auto) { return "Dh"; })
1101  .Case<Float32Type>([](auto) { return "f"; })
1102  .Case<Float64Type>([](auto) { return "d"; })
1103  .Case<IntegerType>([isSigned](IntegerType intTy) {
1104  switch (intTy.getWidth()) {
1105  case 1:
1106  return "b";
1107  case 8:
1108  return (isSigned) ? "a" : "c";
1109  case 16:
1110  return (isSigned) ? "s" : "t";
1111  case 32:
1112  return (isSigned) ? "i" : "j";
1113  case 64:
1114  return (isSigned) ? "l" : "m";
1115  default:
1116  llvm_unreachable("Unsupported integer width");
1117  }
1118  })
1119  .Default([](auto) {
1120  llvm_unreachable("No mangling defined");
1121  return "";
1122  });
1123 }
1124 
1125 template <typename ReduceOp>
1126 constexpr StringLiteral getGroupFuncName();
1127 
1128 template <>
1129 constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1130  return "_Z17__spirv_GroupIAddii";
1131 }
1132 template <>
1133 constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1134  return "_Z17__spirv_GroupFAddii";
1135 }
1136 template <>
1137 constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1138  return "_Z17__spirv_GroupSMinii";
1139 }
1140 template <>
1141 constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1142  return "_Z17__spirv_GroupUMinii";
1143 }
1144 template <>
1145 constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1146  return "_Z17__spirv_GroupFMinii";
1147 }
1148 template <>
1149 constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1150  return "_Z17__spirv_GroupSMaxii";
1151 }
1152 template <>
1153 constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1154  return "_Z17__spirv_GroupUMaxii";
1155 }
1156 template <>
1157 constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1158  return "_Z17__spirv_GroupFMaxii";
1159 }
1160 template <>
1161 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1162  return "_Z27__spirv_GroupNonUniformIAddii";
1163 }
1164 template <>
1165 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1166  return "_Z27__spirv_GroupNonUniformFAddii";
1167 }
1168 template <>
1169 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1170  return "_Z27__spirv_GroupNonUniformIMulii";
1171 }
1172 template <>
1173 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1174  return "_Z27__spirv_GroupNonUniformFMulii";
1175 }
1176 template <>
1177 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1178  return "_Z27__spirv_GroupNonUniformSMinii";
1179 }
1180 template <>
1181 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1182  return "_Z27__spirv_GroupNonUniformUMinii";
1183 }
1184 template <>
1185 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1186  return "_Z27__spirv_GroupNonUniformFMinii";
1187 }
1188 template <>
1189 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1190  return "_Z27__spirv_GroupNonUniformSMaxii";
1191 }
1192 template <>
1193 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1194  return "_Z27__spirv_GroupNonUniformUMaxii";
1195 }
1196 template <>
1197 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1198  return "_Z27__spirv_GroupNonUniformFMaxii";
1199 }
1200 template <>
1201 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1202  return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1203 }
1204 template <>
1205 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1206  return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1207 }
1208 template <>
1209 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1210  return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1211 }
1212 template <>
1213 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1214  return "_Z33__spirv_GroupNonUniformLogicalAndii";
1215 }
1216 template <>
1217 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1218  return "_Z32__spirv_GroupNonUniformLogicalOrii";
1219 }
1220 template <>
1221 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1222  return "_Z33__spirv_GroupNonUniformLogicalXorii";
1223 }
1224 } // namespace
1225 
1226 template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
1227 class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
1228 public:
1230 
1231  LogicalResult
1232  matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
1233  ConversionPatternRewriter &rewriter) const override {
1234 
1235  Type retTy = op.getResult().getType();
1236  if (!retTy.isIntOrFloat()) {
1237  return failure();
1238  }
1239  SmallString<36> funcName = getGroupFuncName<ReduceOp>();
1240  funcName += getTypeMangling(retTy, false);
1241 
1242  Type i32Ty = rewriter.getI32Type();
1243  SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
1244  if constexpr (NonUniform) {
1245  if (adaptor.getClusterSize()) {
1246  funcName += "j";
1247  paramTypes.push_back(i32Ty);
1248  }
1249  }
1250 
1251  Operation *symbolTable =
1252  op->template getParentWithTrait<OpTrait::SymbolTable>();
1253 
1254  LLVM::LLVMFuncOp func =
1255  lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy);
1256 
1257  Location loc = op.getLoc();
1258  Value scope = rewriter.create<LLVM::ConstantOp>(
1259  loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope()));
1260  Value groupOp = rewriter.create<LLVM::ConstantOp>(
1261  loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation()));
1262  SmallVector<Value> operands{scope, groupOp};
1263  operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1264 
1265  auto call = createSPIRVBuiltinCall(loc, rewriter, func, operands);
1266  rewriter.replaceOp(op, call);
1267  return success();
1268  }
1269 };
1270 
1271 template <>
1272 constexpr StringRef
1273 ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1274  return "_Z22__spirv_ControlBarrieriii";
1275 }
1276 
1277 template <>
1278 constexpr StringRef
1279 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1280  return "_Z33__spirv_ControlBarrierArriveINTELiii";
1281 }
1282 
1283 template <>
1284 constexpr StringRef
1285 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1286  return "_Z31__spirv_ControlBarrierWaitINTELiii";
1287 }
1288 
1289 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1290 /// should be reachable for conversion to succeed. The structure of the loop in
1291 /// LLVM dialect will be the following:
1292 ///
1293 /// +------------------------------------+
1294 /// | <code before spirv.mlir.loop> |
1295 /// | llvm.br ^header |
1296 /// +------------------------------------+
1297 /// |
1298 /// +----------------+ |
1299 /// | | |
1300 /// | V V
1301 /// | +------------------------------------+
1302 /// | | ^header: |
1303 /// | | <header code> |
1304 /// | | llvm.cond_br %cond, ^body, ^exit |
1305 /// | +------------------------------------+
1306 /// | |
1307 /// | |----------------------+
1308 /// | | |
1309 /// | V |
1310 /// | +------------------------------------+ |
1311 /// | | ^body: | |
1312 /// | | <body code> | |
1313 /// | | llvm.br ^continue | |
1314 /// | +------------------------------------+ |
1315 /// | | |
1316 /// | V |
1317 /// | +------------------------------------+ |
1318 /// | | ^continue: | |
1319 /// | | <continue code> | |
1320 /// | | llvm.br ^header | |
1321 /// | +------------------------------------+ |
1322 /// | | |
1323 /// +---------------+ +----------------------+
1324 /// |
1325 /// V
1326 /// +------------------------------------+
1327 /// | ^exit: |
1328 /// | llvm.br ^remaining |
1329 /// +------------------------------------+
1330 /// |
1331 /// V
1332 /// +------------------------------------+
1333 /// | ^remaining: |
1334 /// | <code after spirv.mlir.loop> |
1335 /// +------------------------------------+
1336 ///
1337 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1338 public:
1340 
1341  LogicalResult
1342  matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1343  ConversionPatternRewriter &rewriter) const override {
1344  // There is no support of loop control at the moment.
1345  if (loopOp.getLoopControl() != spirv::LoopControl::None)
1346  return failure();
1347 
1348  // `spirv.mlir.loop` with empty region is redundant and should be erased.
1349  if (loopOp.getBody().empty()) {
1350  rewriter.eraseOp(loopOp);
1351  return success();
1352  }
1353 
1354  Location loc = loopOp.getLoc();
1355 
1356  // Split the current block after `spirv.mlir.loop`. The remaining ops will
1357  // be used in `endBlock`.
1358  Block *currentBlock = rewriter.getBlock();
1359  auto position = Block::iterator(loopOp);
1360  Block *endBlock = rewriter.splitBlock(currentBlock, position);
1361 
1362  // Remove entry block and create a branch in the current block going to the
1363  // header block.
1364  Block *entryBlock = loopOp.getEntryBlock();
1365  assert(entryBlock->getOperations().size() == 1);
1366  auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1367  if (!brOp)
1368  return failure();
1369  Block *headerBlock = loopOp.getHeaderBlock();
1370  rewriter.setInsertionPointToEnd(currentBlock);
1371  rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1372  rewriter.eraseBlock(entryBlock);
1373 
1374  // Branch from merge block to end block.
1375  Block *mergeBlock = loopOp.getMergeBlock();
1376  Operation *terminator = mergeBlock->getTerminator();
1377  ValueRange terminatorOperands = terminator->getOperands();
1378  rewriter.setInsertionPointToEnd(mergeBlock);
1379  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1380 
1381  rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1382  rewriter.replaceOp(loopOp, endBlock->getArguments());
1383  return success();
1384  }
1385 };
1386 
1387 /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1388 /// block. All blocks within selection should be reachable for conversion to
1389 /// succeed.
1390 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1391 public:
1393 
1394  LogicalResult
1395  matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1396  ConversionPatternRewriter &rewriter) const override {
1397  // There is no support for `Flatten` or `DontFlatten` selection control at
1398  // the moment. This are just compiler hints and can be performed during the
1399  // optimization passes.
1400  if (op.getSelectionControl() != spirv::SelectionControl::None)
1401  return failure();
1402 
1403  // `spirv.mlir.selection` should have at least two blocks: one selection
1404  // header block and one merge block. If no blocks are present, or control
1405  // flow branches straight to merge block (two blocks are present), the op is
1406  // redundant and it is erased.
1407  if (op.getBody().getBlocks().size() <= 2) {
1408  rewriter.eraseOp(op);
1409  return success();
1410  }
1411 
1412  Location loc = op.getLoc();
1413 
1414  // Split the current block after `spirv.mlir.selection`. The remaining ops
1415  // will be used in `continueBlock`.
1416  auto *currentBlock = rewriter.getInsertionBlock();
1417  rewriter.setInsertionPointAfter(op);
1418  auto position = rewriter.getInsertionPoint();
1419  auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1420 
1421  // Extract conditional branch information from the header block. By SPIR-V
1422  // dialect spec, it should contain `spirv.BranchConditional` or
1423  // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1424  // moment in the SPIR-V dialect. Remove this block when finished.
1425  auto *headerBlock = op.getHeaderBlock();
1426  assert(headerBlock->getOperations().size() == 1);
1427  auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1428  headerBlock->getOperations().front());
1429  if (!condBrOp)
1430  return failure();
1431 
1432  // Branch from merge block to continue block.
1433  auto *mergeBlock = op.getMergeBlock();
1434  Operation *terminator = mergeBlock->getTerminator();
1435  ValueRange terminatorOperands = terminator->getOperands();
1436  rewriter.setInsertionPointToEnd(mergeBlock);
1437  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1438 
1439  // Link current block to `true` and `false` blocks within the selection.
1440  Block *trueBlock = condBrOp.getTrueBlock();
1441  Block *falseBlock = condBrOp.getFalseBlock();
1442  rewriter.setInsertionPointToEnd(currentBlock);
1443  rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1444  condBrOp.getTrueTargetOperands(),
1445  falseBlock,
1446  condBrOp.getFalseTargetOperands());
1447 
1448  rewriter.eraseBlock(headerBlock);
1449  rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1450  rewriter.replaceOp(op, continueBlock->getArguments());
1451  return success();
1452  }
1453 };
1454 
1455 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1456 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1457 /// `Shift` is zero or sign extended to match this specification. Cases when
1458 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1459 template <typename SPIRVOp, typename LLVMOp>
1460 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1461 public:
1463 
1464  LogicalResult
1465  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1466  ConversionPatternRewriter &rewriter) const override {
1467 
1468  auto dstType = this->getTypeConverter()->convertType(op.getType());
1469  if (!dstType)
1470  return rewriter.notifyMatchFailure(op, "type conversion failed");
1471 
1472  Type op1Type = op.getOperand1().getType();
1473  Type op2Type = op.getOperand2().getType();
1474 
1475  if (op1Type == op2Type) {
1476  rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1477  adaptor.getOperands());
1478  return success();
1479  }
1480 
1481  std::optional<uint64_t> dstTypeWidth =
1483  std::optional<uint64_t> op2TypeWidth =
1485 
1486  if (!dstTypeWidth || !op2TypeWidth)
1487  return failure();
1488 
1489  Location loc = op.getLoc();
1490  Value extended;
1491  if (op2TypeWidth < dstTypeWidth) {
1492  if (isUnsignedIntegerOrVector(op2Type)) {
1493  extended = rewriter.template create<LLVM::ZExtOp>(
1494  loc, dstType, adaptor.getOperand2());
1495  } else {
1496  extended = rewriter.template create<LLVM::SExtOp>(
1497  loc, dstType, adaptor.getOperand2());
1498  }
1499  } else if (op2TypeWidth == dstTypeWidth) {
1500  extended = adaptor.getOperand2();
1501  } else {
1502  return failure();
1503  }
1504 
1505  Value result = rewriter.template create<LLVMOp>(
1506  loc, dstType, adaptor.getOperand1(), extended);
1507  rewriter.replaceOp(op, result);
1508  return success();
1509  }
1510 };
1511 
1512 class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1513 public:
1515 
1516  LogicalResult
1517  matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1518  ConversionPatternRewriter &rewriter) const override {
1519  auto dstType = getTypeConverter()->convertType(tanOp.getType());
1520  if (!dstType)
1521  return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
1522 
1523  Location loc = tanOp.getLoc();
1524  Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1525  Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1526  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1527  return success();
1528  }
1529 };
1530 
1531 /// Convert `spirv.Tanh` to
1532 ///
1533 /// exp(2x) - 1
1534 /// -----------
1535 /// exp(2x) + 1
1536 ///
1537 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1538 public:
1540 
1541  LogicalResult
1542  matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1543  ConversionPatternRewriter &rewriter) const override {
1544  auto srcType = tanhOp.getType();
1545  auto dstType = getTypeConverter()->convertType(srcType);
1546  if (!dstType)
1547  return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
1548 
1549  Location loc = tanhOp.getLoc();
1550  Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1551  Value multiplied =
1552  rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1553  Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1554  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1555  Value numerator =
1556  rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1557  Value denominator =
1558  rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1559  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1560  denominator);
1561  return success();
1562  }
1563 };
1564 
1565 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1566 public:
1568 
1569  LogicalResult
1570  matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1571  ConversionPatternRewriter &rewriter) const override {
1572  auto srcType = varOp.getType();
1573  // Initialization is supported for scalars and vectors only.
1574  auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1575  auto init = varOp.getInitializer();
1576  if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1577  return failure();
1578 
1579  auto dstType = getTypeConverter()->convertType(srcType);
1580  if (!dstType)
1581  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1582 
1583  Location loc = varOp.getLoc();
1584  Value size = createI32ConstantOf(loc, rewriter, 1);
1585  if (!init) {
1586  auto elementType = getTypeConverter()->convertType(pointerTo);
1587  if (!elementType)
1588  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1589  rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
1590  size);
1591  return success();
1592  }
1593  auto elementType = getTypeConverter()->convertType(pointerTo);
1594  if (!elementType)
1595  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1596  Value allocated =
1597  rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
1598  rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1599  rewriter.replaceOp(varOp, allocated);
1600  return success();
1601  }
1602 };
1603 
1604 //===----------------------------------------------------------------------===//
1605 // BitcastOp conversion
1606 //===----------------------------------------------------------------------===//
1607 
1608 class BitcastConversionPattern
1609  : public SPIRVToLLVMConversion<spirv::BitcastOp> {
1610 public:
1612 
1613  LogicalResult
1614  matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1615  ConversionPatternRewriter &rewriter) const override {
1616  auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1617  if (!dstType)
1618  return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
1619 
1620  // LLVM's opaque pointers do not require bitcasts.
1621  if (isa<LLVM::LLVMPointerType>(dstType)) {
1622  rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1623  return success();
1624  }
1625 
1626  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1627  bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1628  return success();
1629  }
1630 };
1631 
1632 //===----------------------------------------------------------------------===//
1633 // FuncOp conversion
1634 //===----------------------------------------------------------------------===//
1635 
1636 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1637 public:
1639 
1640  LogicalResult
1641  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1642  ConversionPatternRewriter &rewriter) const override {
1643 
1644  // Convert function signature. At the moment LLVMType converter is enough
1645  // for currently supported types.
1646  auto funcType = funcOp.getFunctionType();
1647  TypeConverter::SignatureConversion signatureConverter(
1648  funcType.getNumInputs());
1649  auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter())
1650  ->convertFunctionSignature(
1651  funcType, /*isVariadic=*/false,
1652  /*useBarePtrCallConv=*/false, signatureConverter);
1653  if (!llvmType)
1654  return failure();
1655 
1656  // Create a new `LLVMFuncOp`
1657  Location loc = funcOp.getLoc();
1658  StringRef name = funcOp.getName();
1659  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1660 
1661  // Convert SPIR-V Function Control to equivalent LLVM function attribute
1662  MLIRContext *context = funcOp.getContext();
1663  switch (funcOp.getFunctionControl()) {
1664  case spirv::FunctionControl::Inline:
1665  newFuncOp.setAlwaysInline(true);
1666  break;
1667  case spirv::FunctionControl::DontInline:
1668  newFuncOp.setNoInline(true);
1669  break;
1670 
1671 #define DISPATCH(functionControl, llvmAttr) \
1672  case functionControl: \
1673  newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1674  break;
1675 
1676  DISPATCH(spirv::FunctionControl::Pure,
1677  StringAttr::get(context, "readonly"));
1678  DISPATCH(spirv::FunctionControl::Const,
1679  StringAttr::get(context, "readnone"));
1680 
1681 #undef DISPATCH
1682 
1683  // Default: if `spirv::FunctionControl::None`, then no attributes are
1684  // needed.
1685  default:
1686  break;
1687  }
1688 
1689  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1690  newFuncOp.end());
1691  if (failed(rewriter.convertRegionTypes(
1692  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1693  return failure();
1694  }
1695  rewriter.eraseOp(funcOp);
1696  return success();
1697  }
1698 };
1699 
1700 //===----------------------------------------------------------------------===//
1701 // ModuleOp conversion
1702 //===----------------------------------------------------------------------===//
1703 
1704 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1705 public:
1707 
1708  LogicalResult
1709  matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1710  ConversionPatternRewriter &rewriter) const override {
1711 
1712  auto newModuleOp =
1713  rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1714  rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1715 
1716  // Remove the terminator block that was automatically added by builder
1717  rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1718  rewriter.eraseOp(spvModuleOp);
1719  return success();
1720  }
1721 };
1722 
1723 //===----------------------------------------------------------------------===//
1724 // VectorShuffleOp conversion
1725 //===----------------------------------------------------------------------===//
1726 
1727 class VectorShufflePattern
1728  : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1729 public:
1731  LogicalResult
1732  matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1733  ConversionPatternRewriter &rewriter) const override {
1734  Location loc = op.getLoc();
1735  auto components = adaptor.getComponents();
1736  auto vector1 = adaptor.getVector1();
1737  auto vector2 = adaptor.getVector2();
1738  int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1739  int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1740  if (vector1Size == vector2Size) {
1741  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1742  op, vector1, vector2,
1743  LLVM::convertArrayToIndices<int32_t>(components));
1744  return success();
1745  }
1746 
1747  auto dstType = getTypeConverter()->convertType(op.getType());
1748  if (!dstType)
1749  return rewriter.notifyMatchFailure(op, "type conversion failed");
1750  auto scalarType = cast<VectorType>(dstType).getElementType();
1751  auto componentsArray = components.getValue();
1752  auto *context = rewriter.getContext();
1753  auto llvmI32Type = IntegerType::get(context, 32);
1754  Value targetOp = rewriter.create<LLVM::PoisonOp>(loc, dstType);
1755  for (unsigned i = 0; i < componentsArray.size(); i++) {
1756  if (!isa<IntegerAttr>(componentsArray[i]))
1757  return op.emitError("unable to support non-constant component");
1758 
1759  int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1760  if (indexVal == -1)
1761  continue;
1762 
1763  int offsetVal = 0;
1764  Value baseVector = vector1;
1765  if (indexVal >= vector1Size) {
1766  offsetVal = vector1Size;
1767  baseVector = vector2;
1768  }
1769 
1770  Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1771  loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1772  Value index = rewriter.create<LLVM::ConstantOp>(
1773  loc, llvmI32Type,
1774  rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1775 
1776  auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1777  loc, scalarType, baseVector, index);
1778  targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1779  extractOp, dstIndex);
1780  }
1781  rewriter.replaceOp(op, targetOp);
1782  return success();
1783  }
1784 };
1785 } // namespace
1786 
1787 //===----------------------------------------------------------------------===//
1788 // Pattern population
1789 //===----------------------------------------------------------------------===//
1790 
1792  spirv::ClientAPI clientAPI) {
1793  typeConverter.addConversion([&](spirv::ArrayType type) {
1794  return convertArrayType(type, typeConverter);
1795  });
1796  typeConverter.addConversion([&, clientAPI](spirv::PointerType type) {
1797  return convertPointerType(type, typeConverter, clientAPI);
1798  });
1799  typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1800  return convertRuntimeArrayType(type, typeConverter);
1801  });
1802  typeConverter.addConversion([&](spirv::StructType type) {
1803  return convertStructType(type, typeConverter);
1804  });
1805 }
1806 
1808  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
1809  spirv::ClientAPI clientAPI) {
1810  patterns.add<
1811  // Arithmetic ops
1812  DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1813  DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1814  DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1815  DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1816  DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1817  DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1818  DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1819  DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1820  DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1821  DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1822  DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1823  DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1824  DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1825 
1826  // Bitwise ops
1827  BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1828  DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1829  DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1830  DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1831  DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1832  DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1833  NotPattern<spirv::NotOp>,
1834 
1835  // Cast ops
1836  BitcastConversionPattern,
1837  DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1838  DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1839  DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1840  DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1841  IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1842  IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1843  IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1844 
1845  // Comparison ops
1846  IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1847  IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1848  FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1849  FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1850  FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1851  FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1852  FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1853  FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1854  FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1855  FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1856  FComparePattern<spirv::FUnordGreaterThanEqualOp,
1857  LLVM::FCmpPredicate::uge>,
1858  FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1859  FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1860  FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1861  IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1862  IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1863  IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1864  IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1865  IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1866  IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1867  IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1868  IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1869 
1870  // Constant op
1871  ConstantScalarAndVectorPattern,
1872 
1873  // Control Flow ops
1874  BranchConversionPattern, BranchConditionalConversionPattern,
1875  FunctionCallPattern, LoopPattern, SelectionPattern,
1876  ErasePattern<spirv::MergeOp>,
1877 
1878  // Entry points and execution mode are handled separately.
1879  ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1880 
1881  // GLSL extended instruction set ops
1882  DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1883  DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1884  DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1885  DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1886  DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1887  DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1888  DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1889  DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1890  DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1891  DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1892  DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1893  DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1894  InverseSqrtPattern, TanPattern, TanhPattern,
1895 
1896  // Logical ops
1897  DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1898  DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1899  IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1900  IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1901  NotPattern<spirv::LogicalNotOp>,
1902 
1903  // Memory ops
1904  AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1905  LoadStorePattern<spirv::StoreOp>, VariablePattern,
1906 
1907  // Miscellaneous ops
1908  CompositeExtractPattern, CompositeInsertPattern,
1909  DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1910  DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1911  VectorShufflePattern,
1912 
1913  // Shift ops
1914  ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1915  ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1916  ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1917 
1918  // Return ops
1919  ReturnPattern, ReturnValuePattern,
1920 
1921  // Barrier ops
1922  ControlBarrierPattern<spirv::ControlBarrierOp>,
1923  ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1924  ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1925 
1926  // Group reduction operations
1927  GroupReducePattern<spirv::GroupIAddOp>,
1928  GroupReducePattern<spirv::GroupFAddOp>,
1929  GroupReducePattern<spirv::GroupFMinOp>,
1930  GroupReducePattern<spirv::GroupUMinOp>,
1931  GroupReducePattern<spirv::GroupSMinOp, /*Signed=*/true>,
1932  GroupReducePattern<spirv::GroupFMaxOp>,
1933  GroupReducePattern<spirv::GroupUMaxOp>,
1934  GroupReducePattern<spirv::GroupSMaxOp, /*Signed=*/true>,
1935  GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed=*/false,
1936  /*NonUniform=*/true>,
1937  GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed=*/false,
1938  /*NonUniform=*/true>,
1939  GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed=*/false,
1940  /*NonUniform=*/true>,
1941  GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed=*/false,
1942  /*NonUniform=*/true>,
1943  GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed=*/true,
1944  /*NonUniform=*/true>,
1945  GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed=*/false,
1946  /*NonUniform=*/true>,
1947  GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed=*/false,
1948  /*NonUniform=*/true>,
1949  GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed=*/true,
1950  /*NonUniform=*/true>,
1951  GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed=*/false,
1952  /*NonUniform=*/true>,
1953  GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed=*/false,
1954  /*NonUniform=*/true>,
1955  GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed=*/false,
1956  /*NonUniform=*/true>,
1957  GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed=*/false,
1958  /*NonUniform=*/true>,
1959  GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed=*/false,
1960  /*NonUniform=*/true>,
1961  GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed=*/false,
1962  /*NonUniform=*/true>,
1963  GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed=*/false,
1964  /*NonUniform=*/true>,
1965  GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed=*/false,
1966  /*NonUniform=*/true>>(patterns.getContext(),
1967  typeConverter);
1968 
1969  patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
1970  typeConverter);
1971 }
1972 
1974  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1975  patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1976 }
1977 
1979  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1980  patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1981 }
1982 
1983 //===----------------------------------------------------------------------===//
1984 // Pre-conversion hooks
1985 //===----------------------------------------------------------------------===//
1986 
1987 /// Hook for descriptor set and binding number encoding.
1988 static constexpr StringRef kBinding = "binding";
1989 static constexpr StringRef kDescriptorSet = "descriptor_set";
1990 void mlir::encodeBindAttribute(ModuleOp module) {
1991  auto spvModules = module.getOps<spirv::ModuleOp>();
1992  for (auto spvModule : spvModules) {
1993  spvModule.walk([&](spirv::GlobalVariableOp op) {
1994  IntegerAttr descriptorSet =
1995  op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1996  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1997  // For every global variable in the module, get the ones with descriptor
1998  // set and binding numbers.
1999  if (descriptorSet && binding) {
2000  // Encode these numbers into the variable's symbolic name. If the
2001  // SPIR-V module has a name, add it at the beginning.
2002  auto moduleAndName =
2003  spvModule.getName().has_value()
2004  ? spvModule.getName()->str() + "_" + op.getSymName().str()
2005  : op.getSymName().str();
2006  std::string name =
2007  llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
2008  std::to_string(descriptorSet.getInt()),
2009  std::to_string(binding.getInt()));
2010  auto nameAttr = StringAttr::get(op->getContext(), name);
2011 
2012  // Replace all symbol uses and set the new symbol name. Finally, remove
2013  // descriptor set and binding attributes.
2014  if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
2015  op.emitError("unable to replace all symbol uses for ") << name;
2016  SymbolTable::setSymbolName(op, nameAttr);
2017  op->removeAttr(kDescriptorSet);
2018  op->removeAttr(kBinding);
2019  }
2020  });
2021  }
2022 }
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static MLIRContext * getContext(OpFoldResult val)
@ None
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
Definition: SPIRVToLLVM.cpp:64
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
Definition: SPIRVToLLVM.cpp:35
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
Definition: SPIRVToLLVM.cpp:44
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
Definition: SPIRVToLLVM.cpp:84
static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
Definition: SPIRVToLLVM.cpp:94
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
Definition: SPIRVToLLVM.cpp:77
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
Definition: SPIRVToLLVM.cpp:54
static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Type convertStructType(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:195
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:158
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:89
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:446
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:767
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:700
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:20
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
Definition: SPIRVTypes.cpp:66
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:64
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:455
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:516
SPIR-V struct type.
Definition: SPIRVTypes.h:295
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:795
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Definition: LLVMDialect.h:232
Include the generated interface declarations.
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
const FrozenRewritePatternSet & patterns
void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.