MLIR  21.0.0git
SPIRVToLLVM.cpp
Go to the documentation of this file.
1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/PatternMatch.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/FormatVariadic.h"
25 
26 #define DEBUG_TYPE "spirv-to-llvm-pattern"
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Utility functions
32 //===----------------------------------------------------------------------===//
33 
34 /// Returns true if the given type is a signed integer or vector type.
35 static bool isSignedIntegerOrVector(Type type) {
36  if (type.isSignedInteger())
37  return true;
38  if (auto vecType = dyn_cast<VectorType>(type))
39  return vecType.getElementType().isSignedInteger();
40  return false;
41 }
42 
43 /// Returns true if the given type is an unsigned integer or vector type
44 static bool isUnsignedIntegerOrVector(Type type) {
45  if (type.isUnsignedInteger())
46  return true;
47  if (auto vecType = dyn_cast<VectorType>(type))
48  return vecType.getElementType().isUnsignedInteger();
49  return false;
50 }
51 
52 /// Returns the width of an integer or of the element type of an integer vector,
53 /// if applicable.
54 static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
55  if (auto intType = dyn_cast<IntegerType>(type))
56  return intType.getWidth();
57  if (auto vecType = dyn_cast<VectorType>(type))
58  if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
59  return intType.getWidth();
60  return std::nullopt;
61 }
62 
63 /// Returns the bit width of integer, float or vector of float or integer values
64 static unsigned getBitWidth(Type type) {
65  assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
66  "bitwidth is not supported for this type");
67  if (type.isIntOrFloat())
68  return type.getIntOrFloatBitWidth();
69  auto vecType = dyn_cast<VectorType>(type);
70  auto elementType = vecType.getElementType();
71  assert(elementType.isIntOrFloat() &&
72  "only integers and floats have a bitwidth");
73  return elementType.getIntOrFloatBitWidth();
74 }
75 
76 /// Returns the bit width of LLVMType integer or vector.
77 static unsigned getLLVMTypeBitWidth(Type type) {
78  if (auto vecTy = dyn_cast<VectorType>(type))
79  type = vecTy.getElementType();
80  return cast<IntegerType>(type).getWidth();
81 }
82 
83 /// Creates `IntegerAttribute` with all bits set for given type
84 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
85  if (auto vecType = dyn_cast<VectorType>(type)) {
86  auto integerType = cast<IntegerType>(vecType.getElementType());
87  return builder.getIntegerAttr(integerType, -1);
88  }
89  auto integerType = cast<IntegerType>(type);
90  return builder.getIntegerAttr(integerType, -1);
91 }
92 
93 /// Creates `llvm.mlir.constant` with all bits set for the given type.
94 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
95  PatternRewriter &rewriter) {
96  if (isa<VectorType>(srcType)) {
97  return rewriter.create<LLVM::ConstantOp>(
98  loc, dstType,
99  SplatElementsAttr::get(cast<ShapedType>(srcType),
100  minusOneIntegerAttribute(srcType, rewriter)));
101  }
102  return rewriter.create<LLVM::ConstantOp>(
103  loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
104 }
105 
106 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
107 static Value createFPConstant(Location loc, Type srcType, Type dstType,
108  PatternRewriter &rewriter, double value) {
109  if (auto vecType = dyn_cast<VectorType>(srcType)) {
110  auto floatType = cast<FloatType>(vecType.getElementType());
111  return rewriter.create<LLVM::ConstantOp>(
112  loc, dstType,
113  SplatElementsAttr::get(vecType,
114  rewriter.getFloatAttr(floatType, value)));
115  }
116  auto floatType = cast<FloatType>(srcType);
117  return rewriter.create<LLVM::ConstantOp>(
118  loc, dstType, rewriter.getFloatAttr(floatType, value));
119 }
120 
121 /// Utility function for bitfield ops:
122 /// - `BitFieldInsert`
123 /// - `BitFieldSExtract`
124 /// - `BitFieldUExtract`
125 /// Truncates or extends the value. If the bitwidth of the value is the same as
126 /// `llvmType` bitwidth, the value remains unchanged.
128  Type llvmType,
129  PatternRewriter &rewriter) {
130  auto srcType = value.getType();
131  unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
132  unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
133  ? getLLVMTypeBitWidth(srcType)
134  : getBitWidth(srcType);
135 
136  if (valueBitWidth < targetBitWidth)
137  return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
138  // If the bit widths of `Count` and `Offset` are greater than the bit width
139  // of the target type, they are truncated. Truncation is safe since `Count`
140  // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
141  // both values can be expressed in 8 bits.
142  if (valueBitWidth > targetBitWidth)
143  return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
144  return value;
145 }
146 
147 /// Broadcasts the value to vector with `numElements` number of elements.
148 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
149  const TypeConverter &typeConverter,
150  ConversionPatternRewriter &rewriter) {
151  auto vectorType = VectorType::get(numElements, toBroadcast.getType());
152  auto llvmVectorType = typeConverter.convertType(vectorType);
153  auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
154  Value broadcasted = rewriter.create<LLVM::PoisonOp>(loc, llvmVectorType);
155  for (unsigned i = 0; i < numElements; ++i) {
156  auto index = rewriter.create<LLVM::ConstantOp>(
157  loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
158  broadcasted = rewriter.create<LLVM::InsertElementOp>(
159  loc, llvmVectorType, broadcasted, toBroadcast, index);
160  }
161  return broadcasted;
162 }
163 
164 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
165 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
166  const TypeConverter &typeConverter,
167  ConversionPatternRewriter &rewriter) {
168  if (auto vectorType = dyn_cast<VectorType>(srcType)) {
169  unsigned numElements = vectorType.getNumElements();
170  return broadcast(loc, value, numElements, typeConverter, rewriter);
171  }
172  return value;
173 }
174 
175 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
176 /// `BitFieldUExtract`.
177 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
178 /// a vector type, construct a vector that has:
179 /// - same number of elements as `Base`
180 /// - each element has the type that is the same as the type of `Offset` or
181 /// `Count`
182 /// - each element has the same value as `Offset` or `Count`
183 /// Then cast `Offset` and `Count` if their bit width is different
184 /// from `Base` bit width.
185 static Value processCountOrOffset(Location loc, Value value, Type srcType,
186  Type dstType, const TypeConverter &converter,
187  ConversionPatternRewriter &rewriter) {
188  Value broadcasted =
189  optionallyBroadcast(loc, value, srcType, converter, rewriter);
190  return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
191 }
192 
193 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
194 /// offset to LLVM struct. Otherwise, the conversion is not supported.
196  const TypeConverter &converter) {
197  if (type != VulkanLayoutUtils::decorateType(type))
198  return nullptr;
199 
200  SmallVector<Type> elementsVector;
201  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
202  return nullptr;
203  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
204  /*isPacked=*/false);
205 }
206 
207 /// Converts SPIR-V struct with no offset to packed LLVM struct.
209  const TypeConverter &converter) {
210  SmallVector<Type> elementsVector;
211  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
212  return nullptr;
213  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
214  /*isPacked=*/true);
215 }
216 
217 /// Creates LLVM dialect constant with the given value.
219  unsigned value) {
220  return rewriter.create<LLVM::ConstantOp>(
221  loc, IntegerType::get(rewriter.getContext(), 32),
222  rewriter.getIntegerAttr(rewriter.getI32Type(), value));
223 }
224 
225 /// Utility for `spirv.Load` and `spirv.Store` conversion.
226 static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
227  ConversionPatternRewriter &rewriter,
228  const TypeConverter &typeConverter,
229  unsigned alignment, bool isVolatile,
230  bool isNonTemporal) {
231  if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
232  auto dstType = typeConverter.convertType(loadOp.getType());
233  if (!dstType)
234  return rewriter.notifyMatchFailure(op, "type conversion failed");
235  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
236  loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
237  isVolatile, isNonTemporal);
238  return success();
239  }
240  auto storeOp = cast<spirv::StoreOp>(op);
241  spirv::StoreOpAdaptor adaptor(operands);
242  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
243  adaptor.getPtr(), alignment,
244  isVolatile, isNonTemporal);
245  return success();
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // Type conversion
250 //===----------------------------------------------------------------------===//
251 
252 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
253 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
254 /// when converting ops that manipulate array types.
255 static std::optional<Type> convertArrayType(spirv::ArrayType type,
256  TypeConverter &converter) {
257  unsigned stride = type.getArrayStride();
258  Type elementType = type.getElementType();
259  auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
260  if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
261  return std::nullopt;
262 
263  auto llvmElementType = converter.convertType(elementType);
264  unsigned numElements = type.getNumElements();
265  return LLVM::LLVMArrayType::get(llvmElementType, numElements);
266 }
267 
268 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
269 /// modelled at the moment.
271  const TypeConverter &converter,
272  spirv::ClientAPI clientAPI) {
273  unsigned addressSpace =
274  storageClassToAddressSpace(clientAPI, type.getStorageClass());
275  return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
276 }
277 
278 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
279 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
280 /// no modelling of array stride at the moment.
281 static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
282  TypeConverter &converter) {
283  if (type.getArrayStride() != 0)
284  return std::nullopt;
285  auto elementType = converter.convertType(type.getElementType());
286  return LLVM::LLVMArrayType::get(elementType, 0);
287 }
288 
289 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
290 /// member decorations. Also, only natural offset is supported.
292  const TypeConverter &converter) {
294  type.getMemberDecorations(memberDecorations);
295  if (!memberDecorations.empty())
296  return nullptr;
297  if (type.hasOffset())
298  return convertStructTypeWithOffset(type, converter);
299  return convertStructTypePacked(type, converter);
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // Operation conversion
304 //===----------------------------------------------------------------------===//
305 
306 namespace {
307 
308 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
309 public:
311 
312  LogicalResult
313  matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
314  ConversionPatternRewriter &rewriter) const override {
315  auto dstType =
316  getTypeConverter()->convertType(op.getComponentPtr().getType());
317  if (!dstType)
318  return rewriter.notifyMatchFailure(op, "type conversion failed");
319  // To use GEP we need to add a first 0 index to go through the pointer.
320  auto indices = llvm::to_vector<4>(adaptor.getIndices());
321  Type indexType = op.getIndices().front().getType();
322  auto llvmIndexType = getTypeConverter()->convertType(indexType);
323  if (!llvmIndexType)
324  return rewriter.notifyMatchFailure(op, "type conversion failed");
325  Value zero = rewriter.create<LLVM::ConstantOp>(
326  op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
327  indices.insert(indices.begin(), zero);
328 
329  auto elementType = getTypeConverter()->convertType(
330  cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
331  if (!elementType)
332  return rewriter.notifyMatchFailure(op, "type conversion failed");
333  rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
334  adaptor.getBasePtr(), indices);
335  return success();
336  }
337 };
338 
339 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
340 public:
342 
343  LogicalResult
344  matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
345  ConversionPatternRewriter &rewriter) const override {
346  auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
347  if (!dstType)
348  return rewriter.notifyMatchFailure(op, "type conversion failed");
349  rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
350  op.getVariable());
351  return success();
352  }
353 };
354 
355 class BitFieldInsertPattern
356  : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
357 public:
359 
360  LogicalResult
361  matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
362  ConversionPatternRewriter &rewriter) const override {
363  auto srcType = op.getType();
364  auto dstType = getTypeConverter()->convertType(srcType);
365  if (!dstType)
366  return rewriter.notifyMatchFailure(op, "type conversion failed");
367  Location loc = op.getLoc();
368 
369  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
370  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
371  *getTypeConverter(), rewriter);
372  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
373  *getTypeConverter(), rewriter);
374 
375  // Create a mask with bits set outside [Offset, Offset + Count - 1].
376  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
377  Value maskShiftedByCount =
378  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
379  Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
380  maskShiftedByCount, minusOne);
381  Value maskShiftedByCountAndOffset =
382  rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
383  Value mask = rewriter.create<LLVM::XOrOp>(
384  loc, dstType, maskShiftedByCountAndOffset, minusOne);
385 
386  // Extract unchanged bits from the `Base` that are outside of
387  // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
388  Value baseAndMask =
389  rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
390  Value insertShiftedByOffset =
391  rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
392  rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
393  insertShiftedByOffset);
394  return success();
395  }
396 };
397 
398 /// Converts SPIR-V ConstantOp with scalar or vector type.
399 class ConstantScalarAndVectorPattern
400  : public SPIRVToLLVMConversion<spirv::ConstantOp> {
401 public:
403 
404  LogicalResult
405  matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
406  ConversionPatternRewriter &rewriter) const override {
407  auto srcType = constOp.getType();
408  if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
409  return failure();
410 
411  auto dstType = getTypeConverter()->convertType(srcType);
412  if (!dstType)
413  return rewriter.notifyMatchFailure(constOp, "type conversion failed");
414 
415  // SPIR-V constant can be a signed/unsigned integer, which has to be
416  // casted to signless integer when converting to LLVM dialect. Removing the
417  // sign bit may have unexpected behaviour. However, it is better to handle
418  // it case-by-case, given that the purpose of the conversion is not to
419  // cover all possible corner cases.
420  if (isSignedIntegerOrVector(srcType) ||
421  isUnsignedIntegerOrVector(srcType)) {
422  auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
423 
424  if (isa<VectorType>(srcType)) {
425  auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
426  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
427  constOp, dstType,
428  dstElementsAttr.mapValues(
429  signlessType, [&](const APInt &value) { return value; }));
430  return success();
431  }
432  auto srcAttr = cast<IntegerAttr>(constOp.getValue());
433  auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
434  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
435  return success();
436  }
437  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
438  constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
439  return success();
440  }
441 };
442 
443 class BitFieldSExtractPattern
444  : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
445 public:
447 
448  LogicalResult
449  matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
450  ConversionPatternRewriter &rewriter) const override {
451  auto srcType = op.getType();
452  auto dstType = getTypeConverter()->convertType(srcType);
453  if (!dstType)
454  return rewriter.notifyMatchFailure(op, "type conversion failed");
455  Location loc = op.getLoc();
456 
457  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
458  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
459  *getTypeConverter(), rewriter);
460  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
461  *getTypeConverter(), rewriter);
462 
463  // Create a constant that holds the size of the `Base`.
464  IntegerType integerType;
465  if (auto vecType = dyn_cast<VectorType>(srcType))
466  integerType = cast<IntegerType>(vecType.getElementType());
467  else
468  integerType = cast<IntegerType>(srcType);
469 
470  auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
471  Value size =
472  isa<VectorType>(srcType)
473  ? rewriter.create<LLVM::ConstantOp>(
474  loc, dstType,
475  SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
476  : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
477 
478  // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
479  // at Offset + Count - 1 is the most significant bit now.
480  Value countPlusOffset =
481  rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
482  Value amountToShiftLeft =
483  rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
484  Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
485  loc, dstType, op.getBase(), amountToShiftLeft);
486 
487  // Shift the result right, filling the bits with the sign bit.
488  Value amountToShiftRight =
489  rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
490  rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
491  amountToShiftRight);
492  return success();
493  }
494 };
495 
496 class BitFieldUExtractPattern
497  : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
498 public:
500 
501  LogicalResult
502  matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
503  ConversionPatternRewriter &rewriter) const override {
504  auto srcType = op.getType();
505  auto dstType = getTypeConverter()->convertType(srcType);
506  if (!dstType)
507  return rewriter.notifyMatchFailure(op, "type conversion failed");
508  Location loc = op.getLoc();
509 
510  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
511  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
512  *getTypeConverter(), rewriter);
513  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
514  *getTypeConverter(), rewriter);
515 
516  // Create a mask with bits set at [0, Count - 1].
517  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
518  Value maskShiftedByCount =
519  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
520  Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
521  minusOne);
522 
523  // Shift `Base` by `Offset` and apply the mask on it.
524  Value shiftedBase =
525  rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
526  rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
527  return success();
528  }
529 };
530 
531 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
532 public:
534 
535  LogicalResult
536  matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
537  ConversionPatternRewriter &rewriter) const override {
538  rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
539  branchOp.getTarget());
540  return success();
541  }
542 };
543 
544 class BranchConditionalConversionPattern
545  : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
546 public:
547  using SPIRVToLLVMConversion<
548  spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
549 
550  LogicalResult
551  matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
552  ConversionPatternRewriter &rewriter) const override {
553  // If branch weights exist, map them to 32-bit integer vector.
554  DenseI32ArrayAttr branchWeights = nullptr;
555  if (auto weights = op.getBranchWeights()) {
556  SmallVector<int32_t> weightValues;
557  for (auto weight : weights->getAsRange<IntegerAttr>())
558  weightValues.push_back(weight.getInt());
559  branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
560  }
561 
562  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
563  op, op.getCondition(), op.getTrueBlockArguments(),
564  op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
565  op.getFalseBlock());
566  return success();
567  }
568 };
569 
570 /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
571 /// type is an aggregate type (struct or array). Otherwise, converts to
572 /// `llvm.extractelement` that operates on vectors.
573 class CompositeExtractPattern
574  : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
575 public:
577 
578  LogicalResult
579  matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
580  ConversionPatternRewriter &rewriter) const override {
581  auto dstType = this->getTypeConverter()->convertType(op.getType());
582  if (!dstType)
583  return rewriter.notifyMatchFailure(op, "type conversion failed");
584 
585  Type containerType = op.getComposite().getType();
586  if (isa<VectorType>(containerType)) {
587  Location loc = op.getLoc();
588  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
589  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
590  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
591  op, dstType, adaptor.getComposite(), index);
592  return success();
593  }
594 
595  rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
596  op, adaptor.getComposite(),
597  LLVM::convertArrayToIndices(op.getIndices()));
598  return success();
599  }
600 };
601 
602 /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
603 /// type is an aggregate type (struct or array). Otherwise, converts to
604 /// `llvm.insertelement` that operates on vectors.
605 class CompositeInsertPattern
606  : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
607 public:
609 
610  LogicalResult
611  matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
612  ConversionPatternRewriter &rewriter) const override {
613  auto dstType = this->getTypeConverter()->convertType(op.getType());
614  if (!dstType)
615  return rewriter.notifyMatchFailure(op, "type conversion failed");
616 
617  Type containerType = op.getComposite().getType();
618  if (isa<VectorType>(containerType)) {
619  Location loc = op.getLoc();
620  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
621  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
622  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
623  op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
624  return success();
625  }
626 
627  rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
628  op, adaptor.getComposite(), adaptor.getObject(),
629  LLVM::convertArrayToIndices(op.getIndices()));
630  return success();
631  }
632 };
633 
634 /// Converts SPIR-V operations that have straightforward LLVM equivalent
635 /// into LLVM dialect operations.
636 template <typename SPIRVOp, typename LLVMOp>
637 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
638 public:
640 
641  LogicalResult
642  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
643  ConversionPatternRewriter &rewriter) const override {
644  auto dstType = this->getTypeConverter()->convertType(op.getType());
645  if (!dstType)
646  return rewriter.notifyMatchFailure(op, "type conversion failed");
647  rewriter.template replaceOpWithNewOp<LLVMOp>(
648  op, dstType, adaptor.getOperands(), op->getAttrs());
649  return success();
650  }
651 };
652 
653 /// Converts `spirv.ExecutionMode` into a global struct constant that holds
654 /// execution mode information.
655 class ExecutionModePattern
656  : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
657 public:
659 
660  LogicalResult
661  matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
662  ConversionPatternRewriter &rewriter) const override {
663  // First, create the global struct's name that would be associated with
664  // this entry point's execution mode. We set it to be:
665  // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
666  ModuleOp module = op->getParentOfType<ModuleOp>();
667  spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
668  std::string moduleName;
669  if (module.getName().has_value())
670  moduleName = "_" + module.getName()->str();
671  else
672  moduleName = "";
673  std::string executionModeInfoName = llvm::formatv(
674  "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
675  static_cast<uint32_t>(executionModeAttr.getValue()));
676 
677  MLIRContext *context = rewriter.getContext();
678  OpBuilder::InsertionGuard guard(rewriter);
679  rewriter.setInsertionPointToStart(module.getBody());
680 
681  // Create a struct type, corresponding to the C struct below.
682  // struct {
683  // int32_t executionMode;
684  // int32_t values[]; // optional values
685  // };
686  auto llvmI32Type = IntegerType::get(context, 32);
687  SmallVector<Type, 2> fields;
688  fields.push_back(llvmI32Type);
689  ArrayAttr values = op.getValues();
690  if (!values.empty()) {
691  auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
692  fields.push_back(arrayType);
693  }
694  auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
695 
696  // Create `llvm.mlir.global` with initializer region containing one block.
697  auto global = rewriter.create<LLVM::GlobalOp>(
698  UnknownLoc::get(context), structType, /*isConstant=*/true,
699  LLVM::Linkage::External, executionModeInfoName, Attribute(),
700  /*alignment=*/0);
701  Location loc = global.getLoc();
702  Region &region = global.getInitializerRegion();
703  Block *block = rewriter.createBlock(&region);
704 
705  // Initialize the struct and set the execution mode value.
706  rewriter.setInsertionPointToStart(block);
707  Value structValue = rewriter.create<LLVM::PoisonOp>(loc, structType);
708  Value executionMode = rewriter.create<LLVM::ConstantOp>(
709  loc, llvmI32Type,
710  rewriter.getI32IntegerAttr(
711  static_cast<uint32_t>(executionModeAttr.getValue())));
712  structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
713  executionMode, 0);
714 
715  // Insert extra operands if they exist into execution mode info struct.
716  for (unsigned i = 0, e = values.size(); i < e; ++i) {
717  auto attr = values.getValue()[i];
718  Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
719  structValue = rewriter.create<LLVM::InsertValueOp>(
720  loc, structValue, entry, ArrayRef<int64_t>({1, i}));
721  }
722  rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
723  rewriter.eraseOp(op);
724  return success();
725  }
726 };
727 
728 /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
729 /// global returns a pointer, whereas in LLVM dialect the global holds an actual
730 /// value. This difference is handled by `spirv.mlir.addressof` and
731 /// `llvm.mlir.addressof`ops that both return a pointer.
732 class GlobalVariablePattern
733  : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
734 public:
735  template <typename... Args>
736  GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
737  : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
738  std::forward<Args>(args)...),
739  clientAPI(clientAPI) {}
740 
741  LogicalResult
742  matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
743  ConversionPatternRewriter &rewriter) const override {
744  // Currently, there is no support of initialization with a constant value in
745  // SPIR-V dialect. Specialization constants are not considered as well.
746  if (op.getInitializer())
747  return failure();
748 
749  auto srcType = cast<spirv::PointerType>(op.getType());
750  auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
751  if (!dstType)
752  return rewriter.notifyMatchFailure(op, "type conversion failed");
753 
754  // Limit conversion to the current invocation only or `StorageBuffer`
755  // required by SPIR-V runner.
756  // This is okay because multiple invocations are not supported yet.
757  auto storageClass = srcType.getStorageClass();
758  switch (storageClass) {
759  case spirv::StorageClass::Input:
760  case spirv::StorageClass::Private:
761  case spirv::StorageClass::Output:
762  case spirv::StorageClass::StorageBuffer:
763  case spirv::StorageClass::UniformConstant:
764  break;
765  default:
766  return failure();
767  }
768 
769  // LLVM dialect spec: "If the global value is a constant, storing into it is
770  // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
771  // storage class that is read-only.
772  bool isConstant = (storageClass == spirv::StorageClass::Input) ||
773  (storageClass == spirv::StorageClass::UniformConstant);
774  // SPIR-V spec: "By default, functions and global variables are private to a
775  // module and cannot be accessed by other modules. However, a module may be
776  // written to export or import functions and global (module scope)
777  // variables.". Therefore, map 'Private' storage class to private linkage,
778  // 'Input' and 'Output' to external linkage.
779  auto linkage = storageClass == spirv::StorageClass::Private
780  ? LLVM::Linkage::Private
781  : LLVM::Linkage::External;
782  auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
783  op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
784  /*alignment=*/0, storageClassToAddressSpace(clientAPI, storageClass));
785 
786  // Attach location attribute if applicable
787  if (op.getLocationAttr())
788  newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
789 
790  return success();
791  }
792 
793 private:
794  spirv::ClientAPI clientAPI;
795 };
796 
797 /// Converts SPIR-V cast ops that do not have straightforward LLVM
798 /// equivalent in LLVM dialect.
799 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
800 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
801 public:
803 
804  LogicalResult
805  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
806  ConversionPatternRewriter &rewriter) const override {
807 
808  Type fromType = op.getOperand().getType();
809  Type toType = op.getType();
810 
811  auto dstType = this->getTypeConverter()->convertType(toType);
812  if (!dstType)
813  return rewriter.notifyMatchFailure(op, "type conversion failed");
814 
815  if (getBitWidth(fromType) < getBitWidth(toType)) {
816  rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
817  adaptor.getOperands());
818  return success();
819  }
820  if (getBitWidth(fromType) > getBitWidth(toType)) {
821  rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
822  adaptor.getOperands());
823  return success();
824  }
825  return failure();
826  }
827 };
828 
829 class FunctionCallPattern
830  : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
831 public:
833 
834  LogicalResult
835  matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
836  ConversionPatternRewriter &rewriter) const override {
837  if (callOp.getNumResults() == 0) {
838  auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
839  callOp, TypeRange(), adaptor.getOperands(), callOp->getAttrs());
840  newOp.getProperties().operandSegmentSizes = {
841  static_cast<int32_t>(adaptor.getOperands().size()), 0};
842  newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
843  return success();
844  }
845 
846  // Function returns a single result.
847  auto dstType = getTypeConverter()->convertType(callOp.getType(0));
848  if (!dstType)
849  return rewriter.notifyMatchFailure(callOp, "type conversion failed");
850  auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
851  callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
852  newOp.getProperties().operandSegmentSizes = {
853  static_cast<int32_t>(adaptor.getOperands().size()), 0};
854  newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
855  return success();
856  }
857 };
858 
859 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
860 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
861 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
862 public:
864 
865  LogicalResult
866  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
867  ConversionPatternRewriter &rewriter) const override {
868 
869  auto dstType = this->getTypeConverter()->convertType(op.getType());
870  if (!dstType)
871  return rewriter.notifyMatchFailure(op, "type conversion failed");
872 
873  rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
874  op, dstType, predicate, op.getOperand1(), op.getOperand2());
875  return success();
876  }
877 };
878 
879 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
880 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
881 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
882 public:
884 
885  LogicalResult
886  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
887  ConversionPatternRewriter &rewriter) const override {
888 
889  auto dstType = this->getTypeConverter()->convertType(op.getType());
890  if (!dstType)
891  return rewriter.notifyMatchFailure(op, "type conversion failed");
892 
893  rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
894  op, dstType, predicate, op.getOperand1(), op.getOperand2());
895  return success();
896  }
897 };
898 
899 class InverseSqrtPattern
900  : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
901 public:
903 
904  LogicalResult
905  matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
906  ConversionPatternRewriter &rewriter) const override {
907  auto srcType = op.getType();
908  auto dstType = getTypeConverter()->convertType(srcType);
909  if (!dstType)
910  return rewriter.notifyMatchFailure(op, "type conversion failed");
911 
912  Location loc = op.getLoc();
913  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
914  Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
915  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
916  return success();
917  }
918 };
919 
920 /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
921 template <typename SPIRVOp>
922 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
923 public:
925 
926  LogicalResult
927  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
928  ConversionPatternRewriter &rewriter) const override {
929  if (!op.getMemoryAccess()) {
930  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
931  *this->getTypeConverter(), /*alignment=*/0,
932  /*isVolatile=*/false,
933  /*isNonTemporal=*/false);
934  }
935  auto memoryAccess = *op.getMemoryAccess();
936  switch (memoryAccess) {
937  case spirv::MemoryAccess::Aligned:
939  case spirv::MemoryAccess::Nontemporal:
940  case spirv::MemoryAccess::Volatile: {
941  unsigned alignment =
942  memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
943  bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
944  bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
945  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
946  *this->getTypeConverter(), alignment,
947  isVolatile, isNonTemporal);
948  }
949  default:
950  // There is no support of other memory access attributes.
951  return failure();
952  }
953  }
954 };
955 
956 /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
957 template <typename SPIRVOp>
958 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
959 public:
961 
962  LogicalResult
963  matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
964  ConversionPatternRewriter &rewriter) const override {
965  auto srcType = notOp.getType();
966  auto dstType = this->getTypeConverter()->convertType(srcType);
967  if (!dstType)
968  return rewriter.notifyMatchFailure(notOp, "type conversion failed");
969 
970  Location loc = notOp.getLoc();
971  IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
972  auto mask =
973  isa<VectorType>(srcType)
974  ? rewriter.create<LLVM::ConstantOp>(
975  loc, dstType,
976  SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
977  : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
978  rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
979  notOp.getOperand(), mask);
980  return success();
981  }
982 };
983 
984 /// A template pattern that erases the given `SPIRVOp`.
985 template <typename SPIRVOp>
986 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
987 public:
989 
990  LogicalResult
991  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
992  ConversionPatternRewriter &rewriter) const override {
993  rewriter.eraseOp(op);
994  return success();
995  }
996 };
997 
998 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
999 public:
1001 
1002  LogicalResult
1003  matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1004  ConversionPatternRewriter &rewriter) const override {
1005  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
1006  ArrayRef<Value>());
1007  return success();
1008  }
1009 };
1010 
1011 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1012 public:
1014 
1015  LogicalResult
1016  matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1017  ConversionPatternRewriter &rewriter) const override {
1018  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
1019  adaptor.getOperands());
1020  return success();
1021  }
1022 };
1023 
1024 static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
1025  StringRef name,
1026  ArrayRef<Type> paramTypes,
1027  Type resultType,
1028  bool convergent = true) {
1029  auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1030  SymbolTable::lookupSymbolIn(symbolTable, name));
1031  if (func)
1032  return func;
1033 
1034  OpBuilder b(symbolTable->getRegion(0));
1035  func = b.create<LLVM::LLVMFuncOp>(
1036  symbolTable->getLoc(), name,
1037  LLVM::LLVMFunctionType::get(resultType, paramTypes));
1038  func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1039  func.setConvergent(convergent);
1040  func.setNoUnwind(true);
1041  func.setWillReturn(true);
1042  return func;
1043 }
1044 
1045 static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
1046  LLVM::LLVMFuncOp func,
1047  ValueRange args) {
1048  auto call = builder.create<LLVM::CallOp>(loc, func, args);
1049  call.setCConv(func.getCConv());
1050  call.setConvergentAttr(func.getConvergentAttr());
1051  call.setNoUnwindAttr(func.getNoUnwindAttr());
1052  call.setWillReturnAttr(func.getWillReturnAttr());
1053  return call;
1054 }
1055 
1056 template <typename BarrierOpTy>
1057 class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> {
1058 public:
1059  using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor;
1060 
1062 
1063  static constexpr StringRef getFuncName();
1064 
1065  LogicalResult
1066  matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1067  ConversionPatternRewriter &rewriter) const override {
1068  constexpr StringRef funcName = getFuncName();
1069  Operation *symbolTable =
1070  controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1071 
1072  Type i32 = rewriter.getI32Type();
1073 
1074  Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
1075  LLVM::LLVMFuncOp func =
1076  lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
1077 
1078  Location loc = controlBarrierOp->getLoc();
1079  Value execution = rewriter.create<LLVM::ConstantOp>(
1080  loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
1081  Value memory = rewriter.create<LLVM::ConstantOp>(
1082  loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
1083  Value semantics = rewriter.create<LLVM::ConstantOp>(
1084  loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
1085 
1086  auto call = createSPIRVBuiltinCall(loc, rewriter, func,
1087  {execution, memory, semantics});
1088 
1089  rewriter.replaceOp(controlBarrierOp, call);
1090  return success();
1091  }
1092 };
1093 
1094 namespace {
1095 
1096 StringRef getTypeMangling(Type type, bool isSigned) {
1098  .Case<Float16Type>([](auto) { return "Dh"; })
1099  .Case<Float32Type>([](auto) { return "f"; })
1100  .Case<Float64Type>([](auto) { return "d"; })
1101  .Case<IntegerType>([isSigned](IntegerType intTy) {
1102  switch (intTy.getWidth()) {
1103  case 1:
1104  return "b";
1105  case 8:
1106  return (isSigned) ? "a" : "c";
1107  case 16:
1108  return (isSigned) ? "s" : "t";
1109  case 32:
1110  return (isSigned) ? "i" : "j";
1111  case 64:
1112  return (isSigned) ? "l" : "m";
1113  default:
1114  llvm_unreachable("Unsupported integer width");
1115  }
1116  })
1117  .Default([](auto) {
1118  llvm_unreachable("No mangling defined");
1119  return "";
1120  });
1121 }
1122 
1123 template <typename ReduceOp>
1124 constexpr StringLiteral getGroupFuncName();
1125 
1126 template <>
1127 constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1128  return "_Z17__spirv_GroupIAddii";
1129 }
1130 template <>
1131 constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1132  return "_Z17__spirv_GroupFAddii";
1133 }
1134 template <>
1135 constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1136  return "_Z17__spirv_GroupSMinii";
1137 }
1138 template <>
1139 constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1140  return "_Z17__spirv_GroupUMinii";
1141 }
1142 template <>
1143 constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1144  return "_Z17__spirv_GroupFMinii";
1145 }
1146 template <>
1147 constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1148  return "_Z17__spirv_GroupSMaxii";
1149 }
1150 template <>
1151 constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1152  return "_Z17__spirv_GroupUMaxii";
1153 }
1154 template <>
1155 constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1156  return "_Z17__spirv_GroupFMaxii";
1157 }
1158 template <>
1159 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1160  return "_Z27__spirv_GroupNonUniformIAddii";
1161 }
1162 template <>
1163 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1164  return "_Z27__spirv_GroupNonUniformFAddii";
1165 }
1166 template <>
1167 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1168  return "_Z27__spirv_GroupNonUniformIMulii";
1169 }
1170 template <>
1171 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1172  return "_Z27__spirv_GroupNonUniformFMulii";
1173 }
1174 template <>
1175 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1176  return "_Z27__spirv_GroupNonUniformSMinii";
1177 }
1178 template <>
1179 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1180  return "_Z27__spirv_GroupNonUniformUMinii";
1181 }
1182 template <>
1183 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1184  return "_Z27__spirv_GroupNonUniformFMinii";
1185 }
1186 template <>
1187 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1188  return "_Z27__spirv_GroupNonUniformSMaxii";
1189 }
1190 template <>
1191 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1192  return "_Z27__spirv_GroupNonUniformUMaxii";
1193 }
1194 template <>
1195 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1196  return "_Z27__spirv_GroupNonUniformFMaxii";
1197 }
1198 template <>
1199 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1200  return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1201 }
1202 template <>
1203 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1204  return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1205 }
1206 template <>
1207 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1208  return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1209 }
1210 template <>
1211 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1212  return "_Z33__spirv_GroupNonUniformLogicalAndii";
1213 }
1214 template <>
1215 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1216  return "_Z32__spirv_GroupNonUniformLogicalOrii";
1217 }
1218 template <>
1219 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1220  return "_Z33__spirv_GroupNonUniformLogicalXorii";
1221 }
1222 } // namespace
1223 
1224 template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
1225 class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
1226 public:
1228 
1229  LogicalResult
1230  matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
1231  ConversionPatternRewriter &rewriter) const override {
1232 
1233  Type retTy = op.getResult().getType();
1234  if (!retTy.isIntOrFloat()) {
1235  return failure();
1236  }
1237  SmallString<36> funcName = getGroupFuncName<ReduceOp>();
1238  funcName += getTypeMangling(retTy, false);
1239 
1240  Type i32Ty = rewriter.getI32Type();
1241  SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
1242  if constexpr (NonUniform) {
1243  if (adaptor.getClusterSize()) {
1244  funcName += "j";
1245  paramTypes.push_back(i32Ty);
1246  }
1247  }
1248 
1249  Operation *symbolTable =
1250  op->template getParentWithTrait<OpTrait::SymbolTable>();
1251 
1252  LLVM::LLVMFuncOp func =
1253  lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy);
1254 
1255  Location loc = op.getLoc();
1256  Value scope = rewriter.create<LLVM::ConstantOp>(
1257  loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope()));
1258  Value groupOp = rewriter.create<LLVM::ConstantOp>(
1259  loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation()));
1260  SmallVector<Value> operands{scope, groupOp};
1261  operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1262 
1263  auto call = createSPIRVBuiltinCall(loc, rewriter, func, operands);
1264  rewriter.replaceOp(op, call);
1265  return success();
1266  }
1267 };
1268 
1269 template <>
1270 constexpr StringRef
1271 ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1272  return "_Z22__spirv_ControlBarrieriii";
1273 }
1274 
1275 template <>
1276 constexpr StringRef
1277 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1278  return "_Z33__spirv_ControlBarrierArriveINTELiii";
1279 }
1280 
1281 template <>
1282 constexpr StringRef
1283 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1284  return "_Z31__spirv_ControlBarrierWaitINTELiii";
1285 }
1286 
1287 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1288 /// should be reachable for conversion to succeed. The structure of the loop in
1289 /// LLVM dialect will be the following:
1290 ///
1291 /// +------------------------------------+
1292 /// | <code before spirv.mlir.loop> |
1293 /// | llvm.br ^header |
1294 /// +------------------------------------+
1295 /// |
1296 /// +----------------+ |
1297 /// | | |
1298 /// | V V
1299 /// | +------------------------------------+
1300 /// | | ^header: |
1301 /// | | <header code> |
1302 /// | | llvm.cond_br %cond, ^body, ^exit |
1303 /// | +------------------------------------+
1304 /// | |
1305 /// | |----------------------+
1306 /// | | |
1307 /// | V |
1308 /// | +------------------------------------+ |
1309 /// | | ^body: | |
1310 /// | | <body code> | |
1311 /// | | llvm.br ^continue | |
1312 /// | +------------------------------------+ |
1313 /// | | |
1314 /// | V |
1315 /// | +------------------------------------+ |
1316 /// | | ^continue: | |
1317 /// | | <continue code> | |
1318 /// | | llvm.br ^header | |
1319 /// | +------------------------------------+ |
1320 /// | | |
1321 /// +---------------+ +----------------------+
1322 /// |
1323 /// V
1324 /// +------------------------------------+
1325 /// | ^exit: |
1326 /// | llvm.br ^remaining |
1327 /// +------------------------------------+
1328 /// |
1329 /// V
1330 /// +------------------------------------+
1331 /// | ^remaining: |
1332 /// | <code after spirv.mlir.loop> |
1333 /// +------------------------------------+
1334 ///
1335 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1336 public:
1338 
1339  LogicalResult
1340  matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1341  ConversionPatternRewriter &rewriter) const override {
1342  // There is no support of loop control at the moment.
1343  if (loopOp.getLoopControl() != spirv::LoopControl::None)
1344  return failure();
1345 
1346  // `spirv.mlir.loop` with empty region is redundant and should be erased.
1347  if (loopOp.getBody().empty()) {
1348  rewriter.eraseOp(loopOp);
1349  return success();
1350  }
1351 
1352  Location loc = loopOp.getLoc();
1353 
1354  // Split the current block after `spirv.mlir.loop`. The remaining ops will
1355  // be used in `endBlock`.
1356  Block *currentBlock = rewriter.getBlock();
1357  auto position = Block::iterator(loopOp);
1358  Block *endBlock = rewriter.splitBlock(currentBlock, position);
1359 
1360  // Remove entry block and create a branch in the current block going to the
1361  // header block.
1362  Block *entryBlock = loopOp.getEntryBlock();
1363  assert(entryBlock->getOperations().size() == 1);
1364  auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1365  if (!brOp)
1366  return failure();
1367  Block *headerBlock = loopOp.getHeaderBlock();
1368  rewriter.setInsertionPointToEnd(currentBlock);
1369  rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1370  rewriter.eraseBlock(entryBlock);
1371 
1372  // Branch from merge block to end block.
1373  Block *mergeBlock = loopOp.getMergeBlock();
1374  Operation *terminator = mergeBlock->getTerminator();
1375  ValueRange terminatorOperands = terminator->getOperands();
1376  rewriter.setInsertionPointToEnd(mergeBlock);
1377  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1378 
1379  rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1380  rewriter.replaceOp(loopOp, endBlock->getArguments());
1381  return success();
1382  }
1383 };
1384 
1385 /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1386 /// block. All blocks within selection should be reachable for conversion to
1387 /// succeed.
1388 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1389 public:
1391 
1392  LogicalResult
1393  matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1394  ConversionPatternRewriter &rewriter) const override {
1395  // There is no support for `Flatten` or `DontFlatten` selection control at
1396  // the moment. This are just compiler hints and can be performed during the
1397  // optimization passes.
1398  if (op.getSelectionControl() != spirv::SelectionControl::None)
1399  return failure();
1400 
1401  // `spirv.mlir.selection` should have at least two blocks: one selection
1402  // header block and one merge block. If no blocks are present, or control
1403  // flow branches straight to merge block (two blocks are present), the op is
1404  // redundant and it is erased.
1405  if (op.getBody().getBlocks().size() <= 2) {
1406  rewriter.eraseOp(op);
1407  return success();
1408  }
1409 
1410  Location loc = op.getLoc();
1411 
1412  // Split the current block after `spirv.mlir.selection`. The remaining ops
1413  // will be used in `continueBlock`.
1414  auto *currentBlock = rewriter.getInsertionBlock();
1415  rewriter.setInsertionPointAfter(op);
1416  auto position = rewriter.getInsertionPoint();
1417  auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1418 
1419  // Extract conditional branch information from the header block. By SPIR-V
1420  // dialect spec, it should contain `spirv.BranchConditional` or
1421  // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1422  // moment in the SPIR-V dialect. Remove this block when finished.
1423  auto *headerBlock = op.getHeaderBlock();
1424  assert(headerBlock->getOperations().size() == 1);
1425  auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1426  headerBlock->getOperations().front());
1427  if (!condBrOp)
1428  return failure();
1429  rewriter.eraseBlock(headerBlock);
1430 
1431  // Branch from merge block to continue block.
1432  auto *mergeBlock = op.getMergeBlock();
1433  Operation *terminator = mergeBlock->getTerminator();
1434  ValueRange terminatorOperands = terminator->getOperands();
1435  rewriter.setInsertionPointToEnd(mergeBlock);
1436  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1437 
1438  // Link current block to `true` and `false` blocks within the selection.
1439  Block *trueBlock = condBrOp.getTrueBlock();
1440  Block *falseBlock = condBrOp.getFalseBlock();
1441  rewriter.setInsertionPointToEnd(currentBlock);
1442  rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1443  condBrOp.getTrueTargetOperands(),
1444  falseBlock,
1445  condBrOp.getFalseTargetOperands());
1446 
1447  rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1448  rewriter.replaceOp(op, continueBlock->getArguments());
1449  return success();
1450  }
1451 };
1452 
1453 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1454 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1455 /// `Shift` is zero or sign extended to match this specification. Cases when
1456 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1457 template <typename SPIRVOp, typename LLVMOp>
1458 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1459 public:
1461 
1462  LogicalResult
1463  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1464  ConversionPatternRewriter &rewriter) const override {
1465 
1466  auto dstType = this->getTypeConverter()->convertType(op.getType());
1467  if (!dstType)
1468  return rewriter.notifyMatchFailure(op, "type conversion failed");
1469 
1470  Type op1Type = op.getOperand1().getType();
1471  Type op2Type = op.getOperand2().getType();
1472 
1473  if (op1Type == op2Type) {
1474  rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1475  adaptor.getOperands());
1476  return success();
1477  }
1478 
1479  std::optional<uint64_t> dstTypeWidth =
1481  std::optional<uint64_t> op2TypeWidth =
1483 
1484  if (!dstTypeWidth || !op2TypeWidth)
1485  return failure();
1486 
1487  Location loc = op.getLoc();
1488  Value extended;
1489  if (op2TypeWidth < dstTypeWidth) {
1490  if (isUnsignedIntegerOrVector(op2Type)) {
1491  extended = rewriter.template create<LLVM::ZExtOp>(
1492  loc, dstType, adaptor.getOperand2());
1493  } else {
1494  extended = rewriter.template create<LLVM::SExtOp>(
1495  loc, dstType, adaptor.getOperand2());
1496  }
1497  } else if (op2TypeWidth == dstTypeWidth) {
1498  extended = adaptor.getOperand2();
1499  } else {
1500  return failure();
1501  }
1502 
1503  Value result = rewriter.template create<LLVMOp>(
1504  loc, dstType, adaptor.getOperand1(), extended);
1505  rewriter.replaceOp(op, result);
1506  return success();
1507  }
1508 };
1509 
1510 class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1511 public:
1513 
1514  LogicalResult
1515  matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1516  ConversionPatternRewriter &rewriter) const override {
1517  auto dstType = getTypeConverter()->convertType(tanOp.getType());
1518  if (!dstType)
1519  return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
1520 
1521  Location loc = tanOp.getLoc();
1522  Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1523  Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1524  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1525  return success();
1526  }
1527 };
1528 
1529 /// Convert `spirv.Tanh` to
1530 ///
1531 /// exp(2x) - 1
1532 /// -----------
1533 /// exp(2x) + 1
1534 ///
1535 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1536 public:
1538 
1539  LogicalResult
1540  matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1541  ConversionPatternRewriter &rewriter) const override {
1542  auto srcType = tanhOp.getType();
1543  auto dstType = getTypeConverter()->convertType(srcType);
1544  if (!dstType)
1545  return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
1546 
1547  Location loc = tanhOp.getLoc();
1548  Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1549  Value multiplied =
1550  rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1551  Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1552  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1553  Value numerator =
1554  rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1555  Value denominator =
1556  rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1557  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1558  denominator);
1559  return success();
1560  }
1561 };
1562 
1563 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1564 public:
1566 
1567  LogicalResult
1568  matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1569  ConversionPatternRewriter &rewriter) const override {
1570  auto srcType = varOp.getType();
1571  // Initialization is supported for scalars and vectors only.
1572  auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1573  auto init = varOp.getInitializer();
1574  if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1575  return failure();
1576 
1577  auto dstType = getTypeConverter()->convertType(srcType);
1578  if (!dstType)
1579  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1580 
1581  Location loc = varOp.getLoc();
1582  Value size = createI32ConstantOf(loc, rewriter, 1);
1583  if (!init) {
1584  auto elementType = getTypeConverter()->convertType(pointerTo);
1585  if (!elementType)
1586  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1587  rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
1588  size);
1589  return success();
1590  }
1591  auto elementType = getTypeConverter()->convertType(pointerTo);
1592  if (!elementType)
1593  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1594  Value allocated =
1595  rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
1596  rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1597  rewriter.replaceOp(varOp, allocated);
1598  return success();
1599  }
1600 };
1601 
1602 //===----------------------------------------------------------------------===//
1603 // BitcastOp conversion
1604 //===----------------------------------------------------------------------===//
1605 
1606 class BitcastConversionPattern
1607  : public SPIRVToLLVMConversion<spirv::BitcastOp> {
1608 public:
1610 
1611  LogicalResult
1612  matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1613  ConversionPatternRewriter &rewriter) const override {
1614  auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1615  if (!dstType)
1616  return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
1617 
1618  // LLVM's opaque pointers do not require bitcasts.
1619  if (isa<LLVM::LLVMPointerType>(dstType)) {
1620  rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1621  return success();
1622  }
1623 
1624  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1625  bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1626  return success();
1627  }
1628 };
1629 
1630 //===----------------------------------------------------------------------===//
1631 // FuncOp conversion
1632 //===----------------------------------------------------------------------===//
1633 
1634 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1635 public:
1637 
1638  LogicalResult
1639  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1640  ConversionPatternRewriter &rewriter) const override {
1641 
1642  // Convert function signature. At the moment LLVMType converter is enough
1643  // for currently supported types.
1644  auto funcType = funcOp.getFunctionType();
1645  TypeConverter::SignatureConversion signatureConverter(
1646  funcType.getNumInputs());
1647  auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter())
1648  ->convertFunctionSignature(
1649  funcType, /*isVariadic=*/false,
1650  /*useBarePtrCallConv=*/false, signatureConverter);
1651  if (!llvmType)
1652  return failure();
1653 
1654  // Create a new `LLVMFuncOp`
1655  Location loc = funcOp.getLoc();
1656  StringRef name = funcOp.getName();
1657  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1658 
1659  // Convert SPIR-V Function Control to equivalent LLVM function attribute
1660  MLIRContext *context = funcOp.getContext();
1661  switch (funcOp.getFunctionControl()) {
1662  case spirv::FunctionControl::Inline:
1663  newFuncOp.setAlwaysInline(true);
1664  break;
1665  case spirv::FunctionControl::DontInline:
1666  newFuncOp.setNoInline(true);
1667  break;
1668 
1669 #define DISPATCH(functionControl, llvmAttr) \
1670  case functionControl: \
1671  newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1672  break;
1673 
1674  DISPATCH(spirv::FunctionControl::Pure,
1675  StringAttr::get(context, "readonly"));
1676  DISPATCH(spirv::FunctionControl::Const,
1677  StringAttr::get(context, "readnone"));
1678 
1679 #undef DISPATCH
1680 
1681  // Default: if `spirv::FunctionControl::None`, then no attributes are
1682  // needed.
1683  default:
1684  break;
1685  }
1686 
1687  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1688  newFuncOp.end());
1689  if (failed(rewriter.convertRegionTypes(
1690  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1691  return failure();
1692  }
1693  rewriter.eraseOp(funcOp);
1694  return success();
1695  }
1696 };
1697 
1698 //===----------------------------------------------------------------------===//
1699 // ModuleOp conversion
1700 //===----------------------------------------------------------------------===//
1701 
1702 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1703 public:
1705 
1706  LogicalResult
1707  matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1708  ConversionPatternRewriter &rewriter) const override {
1709 
1710  auto newModuleOp =
1711  rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1712  rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1713 
1714  // Remove the terminator block that was automatically added by builder
1715  rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1716  rewriter.eraseOp(spvModuleOp);
1717  return success();
1718  }
1719 };
1720 
1721 //===----------------------------------------------------------------------===//
1722 // VectorShuffleOp conversion
1723 //===----------------------------------------------------------------------===//
1724 
1725 class VectorShufflePattern
1726  : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1727 public:
1729  LogicalResult
1730  matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1731  ConversionPatternRewriter &rewriter) const override {
1732  Location loc = op.getLoc();
1733  auto components = adaptor.getComponents();
1734  auto vector1 = adaptor.getVector1();
1735  auto vector2 = adaptor.getVector2();
1736  int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1737  int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1738  if (vector1Size == vector2Size) {
1739  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1740  op, vector1, vector2,
1741  LLVM::convertArrayToIndices<int32_t>(components));
1742  return success();
1743  }
1744 
1745  auto dstType = getTypeConverter()->convertType(op.getType());
1746  if (!dstType)
1747  return rewriter.notifyMatchFailure(op, "type conversion failed");
1748  auto scalarType = cast<VectorType>(dstType).getElementType();
1749  auto componentsArray = components.getValue();
1750  auto *context = rewriter.getContext();
1751  auto llvmI32Type = IntegerType::get(context, 32);
1752  Value targetOp = rewriter.create<LLVM::PoisonOp>(loc, dstType);
1753  for (unsigned i = 0; i < componentsArray.size(); i++) {
1754  if (!isa<IntegerAttr>(componentsArray[i]))
1755  return op.emitError("unable to support non-constant component");
1756 
1757  int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1758  if (indexVal == -1)
1759  continue;
1760 
1761  int offsetVal = 0;
1762  Value baseVector = vector1;
1763  if (indexVal >= vector1Size) {
1764  offsetVal = vector1Size;
1765  baseVector = vector2;
1766  }
1767 
1768  Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1769  loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1770  Value index = rewriter.create<LLVM::ConstantOp>(
1771  loc, llvmI32Type,
1772  rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1773 
1774  auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1775  loc, scalarType, baseVector, index);
1776  targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1777  extractOp, dstIndex);
1778  }
1779  rewriter.replaceOp(op, targetOp);
1780  return success();
1781  }
1782 };
1783 } // namespace
1784 
1785 //===----------------------------------------------------------------------===//
1786 // Pattern population
1787 //===----------------------------------------------------------------------===//
1788 
1790  spirv::ClientAPI clientAPI) {
1791  typeConverter.addConversion([&](spirv::ArrayType type) {
1792  return convertArrayType(type, typeConverter);
1793  });
1794  typeConverter.addConversion([&, clientAPI](spirv::PointerType type) {
1795  return convertPointerType(type, typeConverter, clientAPI);
1796  });
1797  typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1798  return convertRuntimeArrayType(type, typeConverter);
1799  });
1800  typeConverter.addConversion([&](spirv::StructType type) {
1801  return convertStructType(type, typeConverter);
1802  });
1803 }
1804 
1806  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
1807  spirv::ClientAPI clientAPI) {
1808  patterns.add<
1809  // Arithmetic ops
1810  DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1811  DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1812  DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1813  DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1814  DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1815  DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1816  DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1817  DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1818  DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1819  DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1820  DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1821  DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1822  DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1823 
1824  // Bitwise ops
1825  BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1826  DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1827  DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1828  DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1829  DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1830  DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1831  NotPattern<spirv::NotOp>,
1832 
1833  // Cast ops
1834  BitcastConversionPattern,
1835  DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1836  DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1837  DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1838  DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1839  IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1840  IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1841  IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1842 
1843  // Comparison ops
1844  IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1845  IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1846  FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1847  FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1848  FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1849  FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1850  FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1851  FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1852  FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1853  FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1854  FComparePattern<spirv::FUnordGreaterThanEqualOp,
1855  LLVM::FCmpPredicate::uge>,
1856  FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1857  FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1858  FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1859  IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1860  IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1861  IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1862  IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1863  IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1864  IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1865  IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1866  IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1867 
1868  // Constant op
1869  ConstantScalarAndVectorPattern,
1870 
1871  // Control Flow ops
1872  BranchConversionPattern, BranchConditionalConversionPattern,
1873  FunctionCallPattern, LoopPattern, SelectionPattern,
1874  ErasePattern<spirv::MergeOp>,
1875 
1876  // Entry points and execution mode are handled separately.
1877  ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1878 
1879  // GLSL extended instruction set ops
1880  DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1881  DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1882  DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1883  DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1884  DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1885  DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1886  DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1887  DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1888  DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1889  DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1890  DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1891  DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1892  InverseSqrtPattern, TanPattern, TanhPattern,
1893 
1894  // Logical ops
1895  DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1896  DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1897  IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1898  IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1899  NotPattern<spirv::LogicalNotOp>,
1900 
1901  // Memory ops
1902  AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1903  LoadStorePattern<spirv::StoreOp>, VariablePattern,
1904 
1905  // Miscellaneous ops
1906  CompositeExtractPattern, CompositeInsertPattern,
1907  DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1908  DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1909  VectorShufflePattern,
1910 
1911  // Shift ops
1912  ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1913  ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1914  ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1915 
1916  // Return ops
1917  ReturnPattern, ReturnValuePattern,
1918 
1919  // Barrier ops
1920  ControlBarrierPattern<spirv::ControlBarrierOp>,
1921  ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1922  ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1923 
1924  // Group reduction operations
1925  GroupReducePattern<spirv::GroupIAddOp>,
1926  GroupReducePattern<spirv::GroupFAddOp>,
1927  GroupReducePattern<spirv::GroupFMinOp>,
1928  GroupReducePattern<spirv::GroupUMinOp>,
1929  GroupReducePattern<spirv::GroupSMinOp, /*Signed=*/true>,
1930  GroupReducePattern<spirv::GroupFMaxOp>,
1931  GroupReducePattern<spirv::GroupUMaxOp>,
1932  GroupReducePattern<spirv::GroupSMaxOp, /*Signed=*/true>,
1933  GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed=*/false,
1934  /*NonUniform=*/true>,
1935  GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed=*/false,
1936  /*NonUniform=*/true>,
1937  GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed=*/false,
1938  /*NonUniform=*/true>,
1939  GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed=*/false,
1940  /*NonUniform=*/true>,
1941  GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed=*/true,
1942  /*NonUniform=*/true>,
1943  GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed=*/false,
1944  /*NonUniform=*/true>,
1945  GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed=*/false,
1946  /*NonUniform=*/true>,
1947  GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed=*/true,
1948  /*NonUniform=*/true>,
1949  GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed=*/false,
1950  /*NonUniform=*/true>,
1951  GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed=*/false,
1952  /*NonUniform=*/true>,
1953  GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed=*/false,
1954  /*NonUniform=*/true>,
1955  GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed=*/false,
1956  /*NonUniform=*/true>,
1957  GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed=*/false,
1958  /*NonUniform=*/true>,
1959  GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed=*/false,
1960  /*NonUniform=*/true>,
1961  GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed=*/false,
1962  /*NonUniform=*/true>,
1963  GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed=*/false,
1964  /*NonUniform=*/true>>(patterns.getContext(),
1965  typeConverter);
1966 
1967  patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
1968  typeConverter);
1969 }
1970 
1972  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1973  patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1974 }
1975 
1977  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1978  patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1979 }
1980 
1981 //===----------------------------------------------------------------------===//
1982 // Pre-conversion hooks
1983 //===----------------------------------------------------------------------===//
1984 
1985 /// Hook for descriptor set and binding number encoding.
1986 static constexpr StringRef kBinding = "binding";
1987 static constexpr StringRef kDescriptorSet = "descriptor_set";
1988 void mlir::encodeBindAttribute(ModuleOp module) {
1989  auto spvModules = module.getOps<spirv::ModuleOp>();
1990  for (auto spvModule : spvModules) {
1991  spvModule.walk([&](spirv::GlobalVariableOp op) {
1992  IntegerAttr descriptorSet =
1993  op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1994  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1995  // For every global variable in the module, get the ones with descriptor
1996  // set and binding numbers.
1997  if (descriptorSet && binding) {
1998  // Encode these numbers into the variable's symbolic name. If the
1999  // SPIR-V module has a name, add it at the beginning.
2000  auto moduleAndName =
2001  spvModule.getName().has_value()
2002  ? spvModule.getName()->str() + "_" + op.getSymName().str()
2003  : op.getSymName().str();
2004  std::string name =
2005  llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
2006  std::to_string(descriptorSet.getInt()),
2007  std::to_string(binding.getInt()));
2008  auto nameAttr = StringAttr::get(op->getContext(), name);
2009 
2010  // Replace all symbol uses and set the new symbol name. Finally, remove
2011  // descriptor set and binding attributes.
2012  if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
2013  op.emitError("unable to replace all symbol uses for ") << name;
2014  SymbolTable::setSymbolName(op, nameAttr);
2015  op->removeAttr(kDescriptorSet);
2016  op->removeAttr(kBinding);
2017  }
2018  });
2019  }
2020 }
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static MLIRContext * getContext(OpFoldResult val)
@ None
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
Definition: SPIRVToLLVM.cpp:64
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
Definition: SPIRVToLLVM.cpp:35
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
Definition: SPIRVToLLVM.cpp:44
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
Definition: SPIRVToLLVM.cpp:84
static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
Definition: SPIRVToLLVM.cpp:94
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
Definition: SPIRVToLLVM.cpp:77
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
Definition: SPIRVToLLVM.cpp:54
static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Type convertStructType(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:195
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:158
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:89
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:446
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:21
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
Definition: SPIRVTypes.cpp:67
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:69
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:65
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:456
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:517
SPIR-V struct type.
Definition: SPIRVTypes.h:295
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:796
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Definition: LLVMDialect.h:229
Include the generated interface declarations.
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
const FrozenRewritePatternSet & patterns
void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.