MLIR  22.0.0git
SPIRVToLLVM.cpp
Go to the documentation of this file.
1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/PatternMatch.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/FormatVariadic.h"
25 
26 #define DEBUG_TYPE "spirv-to-llvm-pattern"
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Utility functions
32 //===----------------------------------------------------------------------===//
33 
34 /// Returns true if the given type is a signed integer or vector type.
35 static bool isSignedIntegerOrVector(Type type) {
36  if (type.isSignedInteger())
37  return true;
38  if (auto vecType = dyn_cast<VectorType>(type))
39  return vecType.getElementType().isSignedInteger();
40  return false;
41 }
42 
43 /// Returns true if the given type is an unsigned integer or vector type
44 static bool isUnsignedIntegerOrVector(Type type) {
45  if (type.isUnsignedInteger())
46  return true;
47  if (auto vecType = dyn_cast<VectorType>(type))
48  return vecType.getElementType().isUnsignedInteger();
49  return false;
50 }
51 
52 /// Returns the width of an integer or of the element type of an integer vector,
53 /// if applicable.
54 static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
55  if (auto intType = dyn_cast<IntegerType>(type))
56  return intType.getWidth();
57  if (auto vecType = dyn_cast<VectorType>(type))
58  if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
59  return intType.getWidth();
60  return std::nullopt;
61 }
62 
63 /// Returns the bit width of integer, float or vector of float or integer values
64 static unsigned getBitWidth(Type type) {
65  assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
66  "bitwidth is not supported for this type");
67  if (type.isIntOrFloat())
68  return type.getIntOrFloatBitWidth();
69  auto vecType = dyn_cast<VectorType>(type);
70  auto elementType = vecType.getElementType();
71  assert(elementType.isIntOrFloat() &&
72  "only integers and floats have a bitwidth");
73  return elementType.getIntOrFloatBitWidth();
74 }
75 
76 /// Returns the bit width of LLVMType integer or vector.
77 static unsigned getLLVMTypeBitWidth(Type type) {
78  if (auto vecTy = dyn_cast<VectorType>(type))
79  type = vecTy.getElementType();
80  return cast<IntegerType>(type).getWidth();
81 }
82 
83 /// Creates `IntegerAttribute` with all bits set for given type
84 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
85  if (auto vecType = dyn_cast<VectorType>(type)) {
86  auto integerType = cast<IntegerType>(vecType.getElementType());
87  return builder.getIntegerAttr(integerType, -1);
88  }
89  auto integerType = cast<IntegerType>(type);
90  return builder.getIntegerAttr(integerType, -1);
91 }
92 
93 /// Creates `llvm.mlir.constant` with all bits set for the given type.
94 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
95  PatternRewriter &rewriter) {
96  if (isa<VectorType>(srcType)) {
97  return LLVM::ConstantOp::create(
98  rewriter, loc, dstType,
99  SplatElementsAttr::get(cast<ShapedType>(srcType),
100  minusOneIntegerAttribute(srcType, rewriter)));
101  }
102  return LLVM::ConstantOp::create(rewriter, loc, dstType,
103  minusOneIntegerAttribute(srcType, rewriter));
104 }
105 
106 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
107 static Value createFPConstant(Location loc, Type srcType, Type dstType,
108  PatternRewriter &rewriter, double value) {
109  if (auto vecType = dyn_cast<VectorType>(srcType)) {
110  auto floatType = cast<FloatType>(vecType.getElementType());
111  return LLVM::ConstantOp::create(
112  rewriter, loc, dstType,
113  SplatElementsAttr::get(vecType,
114  rewriter.getFloatAttr(floatType, value)));
115  }
116  auto floatType = cast<FloatType>(srcType);
117  return LLVM::ConstantOp::create(rewriter, loc, dstType,
118  rewriter.getFloatAttr(floatType, value));
119 }
120 
121 /// Utility function for bitfield ops:
122 /// - `BitFieldInsert`
123 /// - `BitFieldSExtract`
124 /// - `BitFieldUExtract`
125 /// Truncates or extends the value. If the bitwidth of the value is the same as
126 /// `llvmType` bitwidth, the value remains unchanged.
128  Type llvmType,
129  PatternRewriter &rewriter) {
130  auto srcType = value.getType();
131  unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
132  unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
133  ? getLLVMTypeBitWidth(srcType)
134  : getBitWidth(srcType);
135 
136  if (valueBitWidth < targetBitWidth)
137  return LLVM::ZExtOp::create(rewriter, loc, llvmType, value);
138  // If the bit widths of `Count` and `Offset` are greater than the bit width
139  // of the target type, they are truncated. Truncation is safe since `Count`
140  // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
141  // both values can be expressed in 8 bits.
142  if (valueBitWidth > targetBitWidth)
143  return LLVM::TruncOp::create(rewriter, loc, llvmType, value);
144  return value;
145 }
146 
147 /// Broadcasts the value to vector with `numElements` number of elements.
148 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
149  const TypeConverter &typeConverter,
150  ConversionPatternRewriter &rewriter) {
151  auto vectorType = VectorType::get(numElements, toBroadcast.getType());
152  auto llvmVectorType = typeConverter.convertType(vectorType);
153  auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
154  Value broadcasted = LLVM::PoisonOp::create(rewriter, loc, llvmVectorType);
155  for (unsigned i = 0; i < numElements; ++i) {
156  auto index = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type,
157  rewriter.getI32IntegerAttr(i));
158  broadcasted = LLVM::InsertElementOp::create(
159  rewriter, loc, llvmVectorType, broadcasted, toBroadcast, index);
160  }
161  return broadcasted;
162 }
163 
164 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
165 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
166  const TypeConverter &typeConverter,
167  ConversionPatternRewriter &rewriter) {
168  if (auto vectorType = dyn_cast<VectorType>(srcType)) {
169  unsigned numElements = vectorType.getNumElements();
170  return broadcast(loc, value, numElements, typeConverter, rewriter);
171  }
172  return value;
173 }
174 
175 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
176 /// `BitFieldUExtract`.
177 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
178 /// a vector type, construct a vector that has:
179 /// - same number of elements as `Base`
180 /// - each element has the type that is the same as the type of `Offset` or
181 /// `Count`
182 /// - each element has the same value as `Offset` or `Count`
183 /// Then cast `Offset` and `Count` if their bit width is different
184 /// from `Base` bit width.
185 static Value processCountOrOffset(Location loc, Value value, Type srcType,
186  Type dstType, const TypeConverter &converter,
187  ConversionPatternRewriter &rewriter) {
188  Value broadcasted =
189  optionallyBroadcast(loc, value, srcType, converter, rewriter);
190  return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
191 }
192 
193 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
194 /// offset to LLVM struct. Otherwise, the conversion is not supported.
196  const TypeConverter &converter) {
197  if (type != VulkanLayoutUtils::decorateType(type))
198  return nullptr;
199 
200  SmallVector<Type> elementsVector;
201  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
202  return nullptr;
203  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
204  /*isPacked=*/false);
205 }
206 
207 /// Converts SPIR-V struct with no offset to packed LLVM struct.
209  const TypeConverter &converter) {
210  SmallVector<Type> elementsVector;
211  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
212  return nullptr;
213  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
214  /*isPacked=*/true);
215 }
216 
217 /// Creates LLVM dialect constant with the given value.
219  unsigned value) {
220  return LLVM::ConstantOp::create(
221  rewriter, loc, IntegerType::get(rewriter.getContext(), 32),
222  rewriter.getIntegerAttr(rewriter.getI32Type(), value));
223 }
224 
225 /// Utility for `spirv.Load` and `spirv.Store` conversion.
226 static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
227  ConversionPatternRewriter &rewriter,
228  const TypeConverter &typeConverter,
229  unsigned alignment, bool isVolatile,
230  bool isNonTemporal) {
231  if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
232  auto dstType = typeConverter.convertType(loadOp.getType());
233  if (!dstType)
234  return rewriter.notifyMatchFailure(op, "type conversion failed");
235  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
236  loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
237  isVolatile, isNonTemporal);
238  return success();
239  }
240  auto storeOp = cast<spirv::StoreOp>(op);
241  spirv::StoreOpAdaptor adaptor(operands);
242  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
243  adaptor.getPtr(), alignment,
244  isVolatile, isNonTemporal);
245  return success();
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // Type conversion
250 //===----------------------------------------------------------------------===//
251 
252 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
253 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
254 /// when converting ops that manipulate array types.
255 static std::optional<Type> convertArrayType(spirv::ArrayType type,
256  TypeConverter &converter) {
257  unsigned stride = type.getArrayStride();
258  Type elementType = type.getElementType();
259  auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
260  if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
261  return std::nullopt;
262 
263  auto llvmElementType = converter.convertType(elementType);
264  unsigned numElements = type.getNumElements();
265  return LLVM::LLVMArrayType::get(llvmElementType, numElements);
266 }
267 
268 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
269 /// modelled at the moment.
271  const TypeConverter &converter,
272  spirv::ClientAPI clientAPI) {
273  unsigned addressSpace =
274  storageClassToAddressSpace(clientAPI, type.getStorageClass());
275  return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
276 }
277 
278 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
279 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
280 /// no modelling of array stride at the moment.
281 static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
282  TypeConverter &converter) {
283  if (type.getArrayStride() != 0)
284  return std::nullopt;
285  auto elementType = converter.convertType(type.getElementType());
286  return LLVM::LLVMArrayType::get(elementType, 0);
287 }
288 
289 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
290 /// member decorations. Also, only natural offset is supported.
292  const TypeConverter &converter) {
294  type.getMemberDecorations(memberDecorations);
295  if (!memberDecorations.empty())
296  return nullptr;
297  if (type.hasOffset())
298  return convertStructTypeWithOffset(type, converter);
299  return convertStructTypePacked(type, converter);
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // Operation conversion
304 //===----------------------------------------------------------------------===//
305 
306 namespace {
307 
308 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
309 public:
311 
312  LogicalResult
313  matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
314  ConversionPatternRewriter &rewriter) const override {
315  auto dstType =
316  getTypeConverter()->convertType(op.getComponentPtr().getType());
317  if (!dstType)
318  return rewriter.notifyMatchFailure(op, "type conversion failed");
319  // To use GEP we need to add a first 0 index to go through the pointer.
320  auto indices = llvm::to_vector<4>(adaptor.getIndices());
321  Type indexType = op.getIndices().front().getType();
322  auto llvmIndexType = getTypeConverter()->convertType(indexType);
323  if (!llvmIndexType)
324  return rewriter.notifyMatchFailure(op, "type conversion failed");
325  Value zero =
326  LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmIndexType,
327  rewriter.getIntegerAttr(indexType, 0));
328  indices.insert(indices.begin(), zero);
329 
330  auto elementType = getTypeConverter()->convertType(
331  cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
332  if (!elementType)
333  return rewriter.notifyMatchFailure(op, "type conversion failed");
334  rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
335  adaptor.getBasePtr(), indices);
336  return success();
337  }
338 };
339 
340 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
341 public:
343 
344  LogicalResult
345  matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
346  ConversionPatternRewriter &rewriter) const override {
347  auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
348  if (!dstType)
349  return rewriter.notifyMatchFailure(op, "type conversion failed");
350  rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
351  op.getVariable());
352  return success();
353  }
354 };
355 
356 class BitFieldInsertPattern
357  : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
358 public:
360 
361  LogicalResult
362  matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
363  ConversionPatternRewriter &rewriter) const override {
364  auto srcType = op.getType();
365  auto dstType = getTypeConverter()->convertType(srcType);
366  if (!dstType)
367  return rewriter.notifyMatchFailure(op, "type conversion failed");
368  Location loc = op.getLoc();
369 
370  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
371  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
372  *getTypeConverter(), rewriter);
373  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
374  *getTypeConverter(), rewriter);
375 
376  // Create a mask with bits set outside [Offset, Offset + Count - 1].
377  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
378  Value maskShiftedByCount =
379  LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count);
380  Value negated = LLVM::XOrOp::create(rewriter, loc, dstType,
381  maskShiftedByCount, minusOne);
382  Value maskShiftedByCountAndOffset =
383  LLVM::ShlOp::create(rewriter, loc, dstType, negated, offset);
384  Value mask = LLVM::XOrOp::create(rewriter, loc, dstType,
385  maskShiftedByCountAndOffset, minusOne);
386 
387  // Extract unchanged bits from the `Base` that are outside of
388  // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
389  Value baseAndMask =
390  LLVM::AndOp::create(rewriter, loc, dstType, op.getBase(), mask);
391  Value insertShiftedByOffset =
392  LLVM::ShlOp::create(rewriter, loc, dstType, op.getInsert(), offset);
393  rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
394  insertShiftedByOffset);
395  return success();
396  }
397 };
398 
399 /// Converts SPIR-V ConstantOp with scalar or vector type.
400 class ConstantScalarAndVectorPattern
401  : public SPIRVToLLVMConversion<spirv::ConstantOp> {
402 public:
404 
405  LogicalResult
406  matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
407  ConversionPatternRewriter &rewriter) const override {
408  auto srcType = constOp.getType();
409  if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
410  return failure();
411 
412  auto dstType = getTypeConverter()->convertType(srcType);
413  if (!dstType)
414  return rewriter.notifyMatchFailure(constOp, "type conversion failed");
415 
416  // SPIR-V constant can be a signed/unsigned integer, which has to be
417  // casted to signless integer when converting to LLVM dialect. Removing the
418  // sign bit may have unexpected behaviour. However, it is better to handle
419  // it case-by-case, given that the purpose of the conversion is not to
420  // cover all possible corner cases.
421  if (isSignedIntegerOrVector(srcType) ||
422  isUnsignedIntegerOrVector(srcType)) {
423  auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
424 
425  if (isa<VectorType>(srcType)) {
426  auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
427  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
428  constOp, dstType,
429  dstElementsAttr.mapValues(
430  signlessType, [&](const APInt &value) { return value; }));
431  return success();
432  }
433  auto srcAttr = cast<IntegerAttr>(constOp.getValue());
434  auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
435  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
436  return success();
437  }
438  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
439  constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
440  return success();
441  }
442 };
443 
444 class BitFieldSExtractPattern
445  : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
446 public:
448 
449  LogicalResult
450  matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
451  ConversionPatternRewriter &rewriter) const override {
452  auto srcType = op.getType();
453  auto dstType = getTypeConverter()->convertType(srcType);
454  if (!dstType)
455  return rewriter.notifyMatchFailure(op, "type conversion failed");
456  Location loc = op.getLoc();
457 
458  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
459  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
460  *getTypeConverter(), rewriter);
461  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
462  *getTypeConverter(), rewriter);
463 
464  // Create a constant that holds the size of the `Base`.
465  IntegerType integerType;
466  if (auto vecType = dyn_cast<VectorType>(srcType))
467  integerType = cast<IntegerType>(vecType.getElementType());
468  else
469  integerType = cast<IntegerType>(srcType);
470 
471  auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
472  Value size =
473  isa<VectorType>(srcType)
474  ? LLVM::ConstantOp::create(
475  rewriter, loc, dstType,
476  SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
477  : LLVM::ConstantOp::create(rewriter, loc, dstType, baseSize);
478 
479  // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
480  // at Offset + Count - 1 is the most significant bit now.
481  Value countPlusOffset =
482  LLVM::AddOp::create(rewriter, loc, dstType, count, offset);
483  Value amountToShiftLeft =
484  LLVM::SubOp::create(rewriter, loc, dstType, size, countPlusOffset);
485  Value baseShiftedLeft = LLVM::ShlOp::create(
486  rewriter, loc, dstType, op.getBase(), amountToShiftLeft);
487 
488  // Shift the result right, filling the bits with the sign bit.
489  Value amountToShiftRight =
490  LLVM::AddOp::create(rewriter, loc, dstType, offset, amountToShiftLeft);
491  rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
492  amountToShiftRight);
493  return success();
494  }
495 };
496 
497 class BitFieldUExtractPattern
498  : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
499 public:
501 
502  LogicalResult
503  matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
504  ConversionPatternRewriter &rewriter) const override {
505  auto srcType = op.getType();
506  auto dstType = getTypeConverter()->convertType(srcType);
507  if (!dstType)
508  return rewriter.notifyMatchFailure(op, "type conversion failed");
509  Location loc = op.getLoc();
510 
511  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
512  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
513  *getTypeConverter(), rewriter);
514  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
515  *getTypeConverter(), rewriter);
516 
517  // Create a mask with bits set at [0, Count - 1].
518  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
519  Value maskShiftedByCount =
520  LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count);
521  Value mask = LLVM::XOrOp::create(rewriter, loc, dstType, maskShiftedByCount,
522  minusOne);
523 
524  // Shift `Base` by `Offset` and apply the mask on it.
525  Value shiftedBase =
526  LLVM::LShrOp::create(rewriter, loc, dstType, op.getBase(), offset);
527  rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
528  return success();
529  }
530 };
531 
532 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
533 public:
535 
536  LogicalResult
537  matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
538  ConversionPatternRewriter &rewriter) const override {
539  rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
540  branchOp.getTarget());
541  return success();
542  }
543 };
544 
545 class BranchConditionalConversionPattern
546  : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
547 public:
548  using SPIRVToLLVMConversion<
549  spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
550 
551  LogicalResult
552  matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
553  ConversionPatternRewriter &rewriter) const override {
554  // If branch weights exist, map them to 32-bit integer vector.
555  DenseI32ArrayAttr branchWeights = nullptr;
556  if (auto weights = op.getBranchWeights()) {
557  SmallVector<int32_t> weightValues;
558  for (auto weight : weights->getAsRange<IntegerAttr>())
559  weightValues.push_back(weight.getInt());
560  branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
561  }
562 
563  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
564  op, op.getCondition(), op.getTrueBlockArguments(),
565  op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
566  op.getFalseBlock());
567  return success();
568  }
569 };
570 
571 /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
572 /// type is an aggregate type (struct or array). Otherwise, converts to
573 /// `llvm.extractelement` that operates on vectors.
574 class CompositeExtractPattern
575  : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
576 public:
578 
579  LogicalResult
580  matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
581  ConversionPatternRewriter &rewriter) const override {
582  auto dstType = this->getTypeConverter()->convertType(op.getType());
583  if (!dstType)
584  return rewriter.notifyMatchFailure(op, "type conversion failed");
585 
586  Type containerType = op.getComposite().getType();
587  if (isa<VectorType>(containerType)) {
588  Location loc = op.getLoc();
589  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
590  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
591  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
592  op, dstType, adaptor.getComposite(), index);
593  return success();
594  }
595 
596  rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
597  op, adaptor.getComposite(),
598  LLVM::convertArrayToIndices(op.getIndices()));
599  return success();
600  }
601 };
602 
603 /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
604 /// type is an aggregate type (struct or array). Otherwise, converts to
605 /// `llvm.insertelement` that operates on vectors.
606 class CompositeInsertPattern
607  : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
608 public:
610 
611  LogicalResult
612  matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
613  ConversionPatternRewriter &rewriter) const override {
614  auto dstType = this->getTypeConverter()->convertType(op.getType());
615  if (!dstType)
616  return rewriter.notifyMatchFailure(op, "type conversion failed");
617 
618  Type containerType = op.getComposite().getType();
619  if (isa<VectorType>(containerType)) {
620  Location loc = op.getLoc();
621  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
622  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
623  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
624  op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
625  return success();
626  }
627 
628  rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
629  op, adaptor.getComposite(), adaptor.getObject(),
630  LLVM::convertArrayToIndices(op.getIndices()));
631  return success();
632  }
633 };
634 
635 /// Converts SPIR-V operations that have straightforward LLVM equivalent
636 /// into LLVM dialect operations.
637 template <typename SPIRVOp, typename LLVMOp>
638 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
639 public:
641 
642  LogicalResult
643  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
644  ConversionPatternRewriter &rewriter) const override {
645  auto dstType = this->getTypeConverter()->convertType(op.getType());
646  if (!dstType)
647  return rewriter.notifyMatchFailure(op, "type conversion failed");
648  rewriter.template replaceOpWithNewOp<LLVMOp>(
649  op, dstType, adaptor.getOperands(), op->getAttrs());
650  return success();
651  }
652 };
653 
654 /// Converts `spirv.ExecutionMode` into a global struct constant that holds
655 /// execution mode information.
656 class ExecutionModePattern
657  : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
658 public:
660 
661  LogicalResult
662  matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
663  ConversionPatternRewriter &rewriter) const override {
664  // First, create the global struct's name that would be associated with
665  // this entry point's execution mode. We set it to be:
666  // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
667  ModuleOp module = op->getParentOfType<ModuleOp>();
668  spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
669  std::string moduleName;
670  if (module.getName().has_value())
671  moduleName = "_" + module.getName()->str();
672  else
673  moduleName = "";
674  std::string executionModeInfoName = llvm::formatv(
675  "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
676  static_cast<uint32_t>(executionModeAttr.getValue()));
677 
678  MLIRContext *context = rewriter.getContext();
679  OpBuilder::InsertionGuard guard(rewriter);
680  rewriter.setInsertionPointToStart(module.getBody());
681 
682  // Create a struct type, corresponding to the C struct below.
683  // struct {
684  // int32_t executionMode;
685  // int32_t values[]; // optional values
686  // };
687  auto llvmI32Type = IntegerType::get(context, 32);
688  SmallVector<Type, 2> fields;
689  fields.push_back(llvmI32Type);
690  ArrayAttr values = op.getValues();
691  if (!values.empty()) {
692  auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
693  fields.push_back(arrayType);
694  }
695  auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
696 
697  // Create `llvm.mlir.global` with initializer region containing one block.
698  auto global = LLVM::GlobalOp::create(
699  rewriter, UnknownLoc::get(context), structType, /*isConstant=*/true,
700  LLVM::Linkage::External, executionModeInfoName, Attribute(),
701  /*alignment=*/0);
702  Location loc = global.getLoc();
703  Region &region = global.getInitializerRegion();
704  Block *block = rewriter.createBlock(&region);
705 
706  // Initialize the struct and set the execution mode value.
707  rewriter.setInsertionPointToStart(block);
708  Value structValue = LLVM::PoisonOp::create(rewriter, loc, structType);
709  Value executionMode = LLVM::ConstantOp::create(
710  rewriter, loc, llvmI32Type,
711  rewriter.getI32IntegerAttr(
712  static_cast<uint32_t>(executionModeAttr.getValue())));
713  SmallVector<int64_t> position{0};
714  structValue = LLVM::InsertValueOp::create(rewriter, loc, structValue,
715  executionMode, position);
716 
717  // Insert extra operands if they exist into execution mode info struct.
718  for (unsigned i = 0, e = values.size(); i < e; ++i) {
719  auto attr = values.getValue()[i];
720  Value entry = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, attr);
721  structValue = LLVM::InsertValueOp::create(
722  rewriter, loc, structValue, entry, ArrayRef<int64_t>({1, i}));
723  }
724  LLVM::ReturnOp::create(rewriter, loc, ArrayRef<Value>({structValue}));
725  rewriter.eraseOp(op);
726  return success();
727  }
728 };
729 
730 /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
731 /// global returns a pointer, whereas in LLVM dialect the global holds an actual
732 /// value. This difference is handled by `spirv.mlir.addressof` and
733 /// `llvm.mlir.addressof`ops that both return a pointer.
734 class GlobalVariablePattern
735  : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
736 public:
737  template <typename... Args>
738  GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
739  : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
740  std::forward<Args>(args)...),
741  clientAPI(clientAPI) {}
742 
743  LogicalResult
744  matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
745  ConversionPatternRewriter &rewriter) const override {
746  // Currently, there is no support of initialization with a constant value in
747  // SPIR-V dialect. Specialization constants are not considered as well.
748  if (op.getInitializer())
749  return failure();
750 
751  auto srcType = cast<spirv::PointerType>(op.getType());
752  auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
753  if (!dstType)
754  return rewriter.notifyMatchFailure(op, "type conversion failed");
755 
756  // Limit conversion to the current invocation only or `StorageBuffer`
757  // required by SPIR-V runner.
758  // This is okay because multiple invocations are not supported yet.
759  auto storageClass = srcType.getStorageClass();
760  switch (storageClass) {
761  case spirv::StorageClass::Input:
762  case spirv::StorageClass::Private:
763  case spirv::StorageClass::Output:
764  case spirv::StorageClass::StorageBuffer:
765  case spirv::StorageClass::UniformConstant:
766  break;
767  default:
768  return failure();
769  }
770 
771  // LLVM dialect spec: "If the global value is a constant, storing into it is
772  // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
773  // storage class that is read-only.
774  bool isConstant = (storageClass == spirv::StorageClass::Input) ||
775  (storageClass == spirv::StorageClass::UniformConstant);
776  // SPIR-V spec: "By default, functions and global variables are private to a
777  // module and cannot be accessed by other modules. However, a module may be
778  // written to export or import functions and global (module scope)
779  // variables.". Therefore, map 'Private' storage class to private linkage,
780  // 'Input' and 'Output' to external linkage.
781  auto linkage = storageClass == spirv::StorageClass::Private
782  ? LLVM::Linkage::Private
783  : LLVM::Linkage::External;
784  StringAttr locationAttrName = op.getLocationAttrName();
785  IntegerAttr locationAttr = op.getLocationAttr();
786  auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
787  op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
788  /*alignment=*/0, storageClassToAddressSpace(clientAPI, storageClass));
789 
790  // Attach location attribute if applicable
791  if (locationAttr)
792  newGlobalOp->setAttr(locationAttrName, locationAttr);
793 
794  return success();
795  }
796 
797 private:
798  spirv::ClientAPI clientAPI;
799 };
800 
801 /// Converts SPIR-V cast ops that do not have straightforward LLVM
802 /// equivalent in LLVM dialect.
803 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
804 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
805 public:
807 
808  LogicalResult
809  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
810  ConversionPatternRewriter &rewriter) const override {
811 
812  Type fromType = op.getOperand().getType();
813  Type toType = op.getType();
814 
815  auto dstType = this->getTypeConverter()->convertType(toType);
816  if (!dstType)
817  return rewriter.notifyMatchFailure(op, "type conversion failed");
818 
819  if (getBitWidth(fromType) < getBitWidth(toType)) {
820  rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
821  adaptor.getOperands());
822  return success();
823  }
824  if (getBitWidth(fromType) > getBitWidth(toType)) {
825  rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
826  adaptor.getOperands());
827  return success();
828  }
829  return failure();
830  }
831 };
832 
833 class FunctionCallPattern
834  : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
835 public:
837 
838  LogicalResult
839  matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
840  ConversionPatternRewriter &rewriter) const override {
841  if (callOp.getNumResults() == 0) {
842  auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
843  callOp, TypeRange(), adaptor.getOperands(), callOp->getAttrs());
844  newOp.getProperties().operandSegmentSizes = {
845  static_cast<int32_t>(adaptor.getOperands().size()), 0};
846  newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
847  return success();
848  }
849 
850  // Function returns a single result.
851  auto dstType = getTypeConverter()->convertType(callOp.getType(0));
852  if (!dstType)
853  return rewriter.notifyMatchFailure(callOp, "type conversion failed");
854  auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
855  callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
856  newOp.getProperties().operandSegmentSizes = {
857  static_cast<int32_t>(adaptor.getOperands().size()), 0};
858  newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
859  return success();
860  }
861 };
862 
863 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
864 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
865 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
866 public:
868 
869  LogicalResult
870  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
871  ConversionPatternRewriter &rewriter) const override {
872 
873  auto dstType = this->getTypeConverter()->convertType(op.getType());
874  if (!dstType)
875  return rewriter.notifyMatchFailure(op, "type conversion failed");
876 
877  rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
878  op, dstType, predicate, op.getOperand1(), op.getOperand2());
879  return success();
880  }
881 };
882 
883 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
884 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
885 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
886 public:
888 
889  LogicalResult
890  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
891  ConversionPatternRewriter &rewriter) const override {
892 
893  auto dstType = this->getTypeConverter()->convertType(op.getType());
894  if (!dstType)
895  return rewriter.notifyMatchFailure(op, "type conversion failed");
896 
897  rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
898  op, dstType, predicate, op.getOperand1(), op.getOperand2());
899  return success();
900  }
901 };
902 
903 class InverseSqrtPattern
904  : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
905 public:
907 
908  LogicalResult
909  matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
910  ConversionPatternRewriter &rewriter) const override {
911  auto srcType = op.getType();
912  auto dstType = getTypeConverter()->convertType(srcType);
913  if (!dstType)
914  return rewriter.notifyMatchFailure(op, "type conversion failed");
915 
916  Location loc = op.getLoc();
917  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
918  Value sqrt = LLVM::SqrtOp::create(rewriter, loc, dstType, op.getOperand());
919  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
920  return success();
921  }
922 };
923 
924 /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
925 template <typename SPIRVOp>
926 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
927 public:
929 
930  LogicalResult
931  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
932  ConversionPatternRewriter &rewriter) const override {
933  if (!op.getMemoryAccess()) {
934  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
935  *this->getTypeConverter(), /*alignment=*/0,
936  /*isVolatile=*/false,
937  /*isNonTemporal=*/false);
938  }
939  auto memoryAccess = *op.getMemoryAccess();
940  switch (memoryAccess) {
941  case spirv::MemoryAccess::Aligned:
943  case spirv::MemoryAccess::Nontemporal:
944  case spirv::MemoryAccess::Volatile: {
945  unsigned alignment =
946  memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
947  bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
948  bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
949  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
950  *this->getTypeConverter(), alignment,
951  isVolatile, isNonTemporal);
952  }
953  default:
954  // There is no support of other memory access attributes.
955  return failure();
956  }
957  }
958 };
959 
960 /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
961 template <typename SPIRVOp>
962 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
963 public:
965 
966  LogicalResult
967  matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
968  ConversionPatternRewriter &rewriter) const override {
969  auto srcType = notOp.getType();
970  auto dstType = this->getTypeConverter()->convertType(srcType);
971  if (!dstType)
972  return rewriter.notifyMatchFailure(notOp, "type conversion failed");
973 
974  Location loc = notOp.getLoc();
975  IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
976  auto mask =
977  isa<VectorType>(srcType)
978  ? LLVM::ConstantOp::create(
979  rewriter, loc, dstType,
980  SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
981  : LLVM::ConstantOp::create(rewriter, loc, dstType, minusOne);
982  rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
983  notOp.getOperand(), mask);
984  return success();
985  }
986 };
987 
988 /// A template pattern that erases the given `SPIRVOp`.
989 template <typename SPIRVOp>
990 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
991 public:
993 
994  LogicalResult
995  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
996  ConversionPatternRewriter &rewriter) const override {
997  rewriter.eraseOp(op);
998  return success();
999  }
1000 };
1001 
1002 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
1003 public:
1005 
1006  LogicalResult
1007  matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1008  ConversionPatternRewriter &rewriter) const override {
1009  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
1010  ArrayRef<Value>());
1011  return success();
1012  }
1013 };
1014 
1015 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1016 public:
1018 
1019  LogicalResult
1020  matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1021  ConversionPatternRewriter &rewriter) const override {
1022  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
1023  adaptor.getOperands());
1024  return success();
1025  }
1026 };
1027 
1028 static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
1029  StringRef name,
1030  ArrayRef<Type> paramTypes,
1031  Type resultType,
1032  bool convergent = true) {
1033  auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1034  SymbolTable::lookupSymbolIn(symbolTable, name));
1035  if (func)
1036  return func;
1037 
1038  OpBuilder b(symbolTable->getRegion(0));
1039  func = LLVM::LLVMFuncOp::create(
1040  b, symbolTable->getLoc(), name,
1041  LLVM::LLVMFunctionType::get(resultType, paramTypes));
1042  func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1043  func.setConvergent(convergent);
1044  func.setNoUnwind(true);
1045  func.setWillReturn(true);
1046  return func;
1047 }
1048 
1049 static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
1050  LLVM::LLVMFuncOp func,
1051  ValueRange args) {
1052  auto call = LLVM::CallOp::create(builder, loc, func, args);
1053  call.setCConv(func.getCConv());
1054  call.setConvergentAttr(func.getConvergentAttr());
1055  call.setNoUnwindAttr(func.getNoUnwindAttr());
1056  call.setWillReturnAttr(func.getWillReturnAttr());
1057  return call;
1058 }
1059 
1060 template <typename BarrierOpTy>
1061 class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> {
1062 public:
1063  using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor;
1064 
1066 
1067  static constexpr StringRef getFuncName();
1068 
1069  LogicalResult
1070  matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1071  ConversionPatternRewriter &rewriter) const override {
1072  constexpr StringRef funcName = getFuncName();
1073  Operation *symbolTable =
1074  controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1075 
1076  Type i32 = rewriter.getI32Type();
1077 
1078  Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
1079  LLVM::LLVMFuncOp func =
1080  lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
1081 
1082  Location loc = controlBarrierOp->getLoc();
1083  Value execution = LLVM::ConstantOp::create(
1084  rewriter, loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
1085  Value memory = LLVM::ConstantOp::create(
1086  rewriter, loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
1087  Value semantics = LLVM::ConstantOp::create(
1088  rewriter, loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
1089 
1090  auto call = createSPIRVBuiltinCall(loc, rewriter, func,
1091  {execution, memory, semantics});
1092 
1093  rewriter.replaceOp(controlBarrierOp, call);
1094  return success();
1095  }
1096 };
1097 
1098 namespace {
1099 
1100 StringRef getTypeMangling(Type type, bool isSigned) {
1102  .Case<Float16Type>([](auto) { return "Dh"; })
1103  .Case<Float32Type>([](auto) { return "f"; })
1104  .Case<Float64Type>([](auto) { return "d"; })
1105  .Case<IntegerType>([isSigned](IntegerType intTy) {
1106  switch (intTy.getWidth()) {
1107  case 1:
1108  return "b";
1109  case 8:
1110  return (isSigned) ? "a" : "c";
1111  case 16:
1112  return (isSigned) ? "s" : "t";
1113  case 32:
1114  return (isSigned) ? "i" : "j";
1115  case 64:
1116  return (isSigned) ? "l" : "m";
1117  default:
1118  llvm_unreachable("Unsupported integer width");
1119  }
1120  })
1121  .DefaultUnreachable("No mangling defined");
1122 }
1123 
1124 template <typename ReduceOp>
1125 constexpr StringLiteral getGroupFuncName();
1126 
1127 template <>
1128 constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1129  return "_Z17__spirv_GroupIAddii";
1130 }
1131 template <>
1132 constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1133  return "_Z17__spirv_GroupFAddii";
1134 }
1135 template <>
1136 constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1137  return "_Z17__spirv_GroupSMinii";
1138 }
1139 template <>
1140 constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1141  return "_Z17__spirv_GroupUMinii";
1142 }
1143 template <>
1144 constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1145  return "_Z17__spirv_GroupFMinii";
1146 }
1147 template <>
1148 constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1149  return "_Z17__spirv_GroupSMaxii";
1150 }
1151 template <>
1152 constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1153  return "_Z17__spirv_GroupUMaxii";
1154 }
1155 template <>
1156 constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1157  return "_Z17__spirv_GroupFMaxii";
1158 }
1159 template <>
1160 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1161  return "_Z27__spirv_GroupNonUniformIAddii";
1162 }
1163 template <>
1164 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1165  return "_Z27__spirv_GroupNonUniformFAddii";
1166 }
1167 template <>
1168 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1169  return "_Z27__spirv_GroupNonUniformIMulii";
1170 }
1171 template <>
1172 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1173  return "_Z27__spirv_GroupNonUniformFMulii";
1174 }
1175 template <>
1176 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1177  return "_Z27__spirv_GroupNonUniformSMinii";
1178 }
1179 template <>
1180 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1181  return "_Z27__spirv_GroupNonUniformUMinii";
1182 }
1183 template <>
1184 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1185  return "_Z27__spirv_GroupNonUniformFMinii";
1186 }
1187 template <>
1188 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1189  return "_Z27__spirv_GroupNonUniformSMaxii";
1190 }
1191 template <>
1192 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1193  return "_Z27__spirv_GroupNonUniformUMaxii";
1194 }
1195 template <>
1196 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1197  return "_Z27__spirv_GroupNonUniformFMaxii";
1198 }
1199 template <>
1200 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1201  return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1202 }
1203 template <>
1204 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1205  return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1206 }
1207 template <>
1208 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1209  return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1210 }
1211 template <>
1212 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1213  return "_Z33__spirv_GroupNonUniformLogicalAndii";
1214 }
1215 template <>
1216 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1217  return "_Z32__spirv_GroupNonUniformLogicalOrii";
1218 }
1219 template <>
1220 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1221  return "_Z33__spirv_GroupNonUniformLogicalXorii";
1222 }
1223 } // namespace
1224 
1225 template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
1226 class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
1227 public:
1229 
1230  LogicalResult
1231  matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
1232  ConversionPatternRewriter &rewriter) const override {
1233 
1234  Type retTy = op.getResult().getType();
1235  if (!retTy.isIntOrFloat()) {
1236  return failure();
1237  }
1238  SmallString<36> funcName = getGroupFuncName<ReduceOp>();
1239  funcName += getTypeMangling(retTy, false);
1240 
1241  Type i32Ty = rewriter.getI32Type();
1242  SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
1243  if constexpr (NonUniform) {
1244  if (adaptor.getClusterSize()) {
1245  funcName += "j";
1246  paramTypes.push_back(i32Ty);
1247  }
1248  }
1249 
1250  Operation *symbolTable =
1251  op->template getParentWithTrait<OpTrait::SymbolTable>();
1252 
1253  LLVM::LLVMFuncOp func =
1254  lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy);
1255 
1256  Location loc = op.getLoc();
1257  Value scope = LLVM::ConstantOp::create(
1258  rewriter, loc, i32Ty,
1259  static_cast<int32_t>(adaptor.getExecutionScope()));
1260  Value groupOp = LLVM::ConstantOp::create(
1261  rewriter, loc, i32Ty,
1262  static_cast<int32_t>(adaptor.getGroupOperation()));
1263  SmallVector<Value> operands{scope, groupOp};
1264  operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1265 
1266  auto call = createSPIRVBuiltinCall(loc, rewriter, func, operands);
1267  rewriter.replaceOp(op, call);
1268  return success();
1269  }
1270 };
1271 
1272 template <>
1273 constexpr StringRef
1274 ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1275  return "_Z22__spirv_ControlBarrieriii";
1276 }
1277 
1278 template <>
1279 constexpr StringRef
1280 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1281  return "_Z33__spirv_ControlBarrierArriveINTELiii";
1282 }
1283 
1284 template <>
1285 constexpr StringRef
1286 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1287  return "_Z31__spirv_ControlBarrierWaitINTELiii";
1288 }
1289 
1290 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1291 /// should be reachable for conversion to succeed. The structure of the loop in
1292 /// LLVM dialect will be the following:
1293 ///
1294 /// +------------------------------------+
1295 /// | <code before spirv.mlir.loop> |
1296 /// | llvm.br ^header |
1297 /// +------------------------------------+
1298 /// |
1299 /// +----------------+ |
1300 /// | | |
1301 /// | V V
1302 /// | +------------------------------------+
1303 /// | | ^header: |
1304 /// | | <header code> |
1305 /// | | llvm.cond_br %cond, ^body, ^exit |
1306 /// | +------------------------------------+
1307 /// | |
1308 /// | |----------------------+
1309 /// | | |
1310 /// | V |
1311 /// | +------------------------------------+ |
1312 /// | | ^body: | |
1313 /// | | <body code> | |
1314 /// | | llvm.br ^continue | |
1315 /// | +------------------------------------+ |
1316 /// | | |
1317 /// | V |
1318 /// | +------------------------------------+ |
1319 /// | | ^continue: | |
1320 /// | | <continue code> | |
1321 /// | | llvm.br ^header | |
1322 /// | +------------------------------------+ |
1323 /// | | |
1324 /// +---------------+ +----------------------+
1325 /// |
1326 /// V
1327 /// +------------------------------------+
1328 /// | ^exit: |
1329 /// | llvm.br ^remaining |
1330 /// +------------------------------------+
1331 /// |
1332 /// V
1333 /// +------------------------------------+
1334 /// | ^remaining: |
1335 /// | <code after spirv.mlir.loop> |
1336 /// +------------------------------------+
1337 ///
1338 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1339 public:
1341 
1342  LogicalResult
1343  matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1344  ConversionPatternRewriter &rewriter) const override {
1345  // There is no support of loop control at the moment.
1346  if (loopOp.getLoopControl() != spirv::LoopControl::None)
1347  return failure();
1348 
1349  // `spirv.mlir.loop` with empty region is redundant and should be erased.
1350  if (loopOp.getBody().empty()) {
1351  rewriter.eraseOp(loopOp);
1352  return success();
1353  }
1354 
1355  Location loc = loopOp.getLoc();
1356 
1357  // Split the current block after `spirv.mlir.loop`. The remaining ops will
1358  // be used in `endBlock`.
1359  Block *currentBlock = rewriter.getBlock();
1360  auto position = Block::iterator(loopOp);
1361  Block *endBlock = rewriter.splitBlock(currentBlock, position);
1362 
1363  // Remove entry block and create a branch in the current block going to the
1364  // header block.
1365  Block *entryBlock = loopOp.getEntryBlock();
1366  assert(entryBlock->getOperations().size() == 1);
1367  auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1368  if (!brOp)
1369  return failure();
1370  Block *headerBlock = loopOp.getHeaderBlock();
1371  rewriter.setInsertionPointToEnd(currentBlock);
1372  LLVM::BrOp::create(rewriter, loc, brOp.getBlockArguments(), headerBlock);
1373  rewriter.eraseBlock(entryBlock);
1374 
1375  // Branch from merge block to end block.
1376  Block *mergeBlock = loopOp.getMergeBlock();
1377  Operation *terminator = mergeBlock->getTerminator();
1378  ValueRange terminatorOperands = terminator->getOperands();
1379  rewriter.setInsertionPointToEnd(mergeBlock);
1380  LLVM::BrOp::create(rewriter, loc, terminatorOperands, endBlock);
1381 
1382  rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1383  rewriter.replaceOp(loopOp, endBlock->getArguments());
1384  return success();
1385  }
1386 };
1387 
1388 /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1389 /// block. All blocks within selection should be reachable for conversion to
1390 /// succeed.
1391 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1392 public:
1394 
1395  LogicalResult
1396  matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1397  ConversionPatternRewriter &rewriter) const override {
1398  // There is no support for `Flatten` or `DontFlatten` selection control at
1399  // the moment. This are just compiler hints and can be performed during the
1400  // optimization passes.
1401  if (op.getSelectionControl() != spirv::SelectionControl::None)
1402  return failure();
1403 
1404  // `spirv.mlir.selection` should have at least two blocks: one selection
1405  // header block and one merge block. If no blocks are present, or control
1406  // flow branches straight to merge block (two blocks are present), the op is
1407  // redundant and it is erased.
1408  if (op.getBody().getBlocks().size() <= 2) {
1409  rewriter.eraseOp(op);
1410  return success();
1411  }
1412 
1413  Location loc = op.getLoc();
1414 
1415  // Split the current block after `spirv.mlir.selection`. The remaining ops
1416  // will be used in `continueBlock`.
1417  auto *currentBlock = rewriter.getInsertionBlock();
1418  rewriter.setInsertionPointAfter(op);
1419  auto position = rewriter.getInsertionPoint();
1420  auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1421 
1422  // Extract conditional branch information from the header block. By SPIR-V
1423  // dialect spec, it should contain `spirv.BranchConditional` or
1424  // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1425  // moment in the SPIR-V dialect. Remove this block when finished.
1426  auto *headerBlock = op.getHeaderBlock();
1427  assert(headerBlock->getOperations().size() == 1);
1428  auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1429  headerBlock->getOperations().front());
1430  if (!condBrOp)
1431  return failure();
1432 
1433  // Branch from merge block to continue block.
1434  auto *mergeBlock = op.getMergeBlock();
1435  Operation *terminator = mergeBlock->getTerminator();
1436  ValueRange terminatorOperands = terminator->getOperands();
1437  rewriter.setInsertionPointToEnd(mergeBlock);
1438  LLVM::BrOp::create(rewriter, loc, terminatorOperands, continueBlock);
1439 
1440  // Link current block to `true` and `false` blocks within the selection.
1441  Block *trueBlock = condBrOp.getTrueBlock();
1442  Block *falseBlock = condBrOp.getFalseBlock();
1443  rewriter.setInsertionPointToEnd(currentBlock);
1444  LLVM::CondBrOp::create(rewriter, loc, condBrOp.getCondition(), trueBlock,
1445  condBrOp.getTrueTargetOperands(), falseBlock,
1446  condBrOp.getFalseTargetOperands());
1447 
1448  rewriter.eraseBlock(headerBlock);
1449  rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1450  rewriter.replaceOp(op, continueBlock->getArguments());
1451  return success();
1452  }
1453 };
1454 
1455 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1456 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1457 /// `Shift` is zero or sign extended to match this specification. Cases when
1458 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1459 template <typename SPIRVOp, typename LLVMOp>
1460 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1461 public:
1463 
1464  LogicalResult
1465  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1466  ConversionPatternRewriter &rewriter) const override {
1467 
1468  auto dstType = this->getTypeConverter()->convertType(op.getType());
1469  if (!dstType)
1470  return rewriter.notifyMatchFailure(op, "type conversion failed");
1471 
1472  Type op1Type = op.getOperand1().getType();
1473  Type op2Type = op.getOperand2().getType();
1474 
1475  if (op1Type == op2Type) {
1476  rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1477  adaptor.getOperands());
1478  return success();
1479  }
1480 
1481  std::optional<uint64_t> dstTypeWidth =
1483  std::optional<uint64_t> op2TypeWidth =
1485 
1486  if (!dstTypeWidth || !op2TypeWidth)
1487  return failure();
1488 
1489  Location loc = op.getLoc();
1490  Value extended;
1491  if (op2TypeWidth < dstTypeWidth) {
1492  if (isUnsignedIntegerOrVector(op2Type)) {
1493  extended =
1494  LLVM::ZExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
1495  } else {
1496  extended =
1497  LLVM::SExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
1498  }
1499  } else if (op2TypeWidth == dstTypeWidth) {
1500  extended = adaptor.getOperand2();
1501  } else {
1502  return failure();
1503  }
1504 
1505  Value result =
1506  LLVMOp::create(rewriter, loc, dstType, adaptor.getOperand1(), extended);
1507  rewriter.replaceOp(op, result);
1508  return success();
1509  }
1510 };
1511 
1512 class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1513 public:
1515 
1516  LogicalResult
1517  matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1518  ConversionPatternRewriter &rewriter) const override {
1519  auto dstType = getTypeConverter()->convertType(tanOp.getType());
1520  if (!dstType)
1521  return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
1522 
1523  Location loc = tanOp.getLoc();
1524  Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand());
1525  Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand());
1526  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1527  return success();
1528  }
1529 };
1530 
1531 /// Convert `spirv.Tanh` to
1532 ///
1533 /// exp(2x) - 1
1534 /// -----------
1535 /// exp(2x) + 1
1536 ///
1537 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1538 public:
1540 
1541  LogicalResult
1542  matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1543  ConversionPatternRewriter &rewriter) const override {
1544  auto srcType = tanhOp.getType();
1545  auto dstType = getTypeConverter()->convertType(srcType);
1546  if (!dstType)
1547  return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
1548 
1549  Location loc = tanhOp.getLoc();
1550  Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1551  Value multiplied =
1552  LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand());
1553  Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied);
1554  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1555  Value numerator =
1556  LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one);
1557  Value denominator =
1558  LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one);
1559  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1560  denominator);
1561  return success();
1562  }
1563 };
1564 
1565 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1566 public:
1568 
1569  LogicalResult
1570  matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1571  ConversionPatternRewriter &rewriter) const override {
1572  auto srcType = varOp.getType();
1573  // Initialization is supported for scalars and vectors only.
1574  auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1575  auto init = varOp.getInitializer();
1576  if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1577  return failure();
1578 
1579  auto dstType = getTypeConverter()->convertType(srcType);
1580  if (!dstType)
1581  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1582 
1583  Location loc = varOp.getLoc();
1584  Value size = createI32ConstantOf(loc, rewriter, 1);
1585  if (!init) {
1586  auto elementType = getTypeConverter()->convertType(pointerTo);
1587  if (!elementType)
1588  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1589  rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
1590  size);
1591  return success();
1592  }
1593  auto elementType = getTypeConverter()->convertType(pointerTo);
1594  if (!elementType)
1595  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1596  Value allocated =
1597  LLVM::AllocaOp::create(rewriter, loc, dstType, elementType, size);
1598  LLVM::StoreOp::create(rewriter, loc, adaptor.getInitializer(), allocated);
1599  rewriter.replaceOp(varOp, allocated);
1600  return success();
1601  }
1602 };
1603 
1604 //===----------------------------------------------------------------------===//
1605 // BitcastOp conversion
1606 //===----------------------------------------------------------------------===//
1607 
1608 class BitcastConversionPattern
1609  : public SPIRVToLLVMConversion<spirv::BitcastOp> {
1610 public:
1612 
1613  LogicalResult
1614  matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1615  ConversionPatternRewriter &rewriter) const override {
1616  auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1617  if (!dstType)
1618  return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
1619 
1620  // LLVM's opaque pointers do not require bitcasts.
1621  if (isa<LLVM::LLVMPointerType>(dstType)) {
1622  rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1623  return success();
1624  }
1625 
1626  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1627  bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1628  return success();
1629  }
1630 };
1631 
1632 //===----------------------------------------------------------------------===//
1633 // FuncOp conversion
1634 //===----------------------------------------------------------------------===//
1635 
1636 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1637 public:
1639 
1640  LogicalResult
1641  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1642  ConversionPatternRewriter &rewriter) const override {
1643 
1644  // Convert function signature. At the moment LLVMType converter is enough
1645  // for currently supported types.
1646  auto funcType = funcOp.getFunctionType();
1647  TypeConverter::SignatureConversion signatureConverter(
1648  funcType.getNumInputs());
1649  auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter())
1650  ->convertFunctionSignature(
1651  funcType, /*isVariadic=*/false,
1652  /*useBarePtrCallConv=*/false, signatureConverter);
1653  if (!llvmType)
1654  return failure();
1655 
1656  // Create a new `LLVMFuncOp`
1657  Location loc = funcOp.getLoc();
1658  StringRef name = funcOp.getName();
1659  auto newFuncOp = LLVM::LLVMFuncOp::create(rewriter, loc, name, llvmType);
1660 
1661  // Convert SPIR-V Function Control to equivalent LLVM function attribute
1662  MLIRContext *context = funcOp.getContext();
1663  switch (funcOp.getFunctionControl()) {
1664  case spirv::FunctionControl::Inline:
1665  newFuncOp.setAlwaysInline(true);
1666  break;
1667  case spirv::FunctionControl::DontInline:
1668  newFuncOp.setNoInline(true);
1669  break;
1670 
1671 #define DISPATCH(functionControl, llvmAttr) \
1672  case functionControl: \
1673  newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1674  break;
1675 
1676  DISPATCH(spirv::FunctionControl::Pure,
1677  StringAttr::get(context, "readonly"));
1678  DISPATCH(spirv::FunctionControl::Const,
1679  StringAttr::get(context, "readnone"));
1680 
1681 #undef DISPATCH
1682 
1683  // Default: if `spirv::FunctionControl::None`, then no attributes are
1684  // needed.
1685  default:
1686  break;
1687  }
1688 
1689  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1690  newFuncOp.end());
1691  if (failed(rewriter.convertRegionTypes(
1692  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1693  return failure();
1694  }
1695  rewriter.eraseOp(funcOp);
1696  return success();
1697  }
1698 };
1699 
1700 //===----------------------------------------------------------------------===//
1701 // ModuleOp conversion
1702 //===----------------------------------------------------------------------===//
1703 
1704 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1705 public:
1707 
1708  LogicalResult
1709  matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1710  ConversionPatternRewriter &rewriter) const override {
1711 
1712  auto newModuleOp =
1713  ModuleOp::create(rewriter, spvModuleOp.getLoc(), spvModuleOp.getName());
1714  rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1715 
1716  // Remove the terminator block that was automatically added by builder
1717  rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1718  rewriter.eraseOp(spvModuleOp);
1719  return success();
1720  }
1721 };
1722 
1723 //===----------------------------------------------------------------------===//
1724 // VectorShuffleOp conversion
1725 //===----------------------------------------------------------------------===//
1726 
1727 class VectorShufflePattern
1728  : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1729 public:
1731  LogicalResult
1732  matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1733  ConversionPatternRewriter &rewriter) const override {
1734  Location loc = op.getLoc();
1735  auto components = adaptor.getComponents();
1736  auto vector1 = adaptor.getVector1();
1737  auto vector2 = adaptor.getVector2();
1738  int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1739  int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1740  if (vector1Size == vector2Size) {
1741  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1742  op, vector1, vector2,
1743  LLVM::convertArrayToIndices<int32_t>(components));
1744  return success();
1745  }
1746 
1747  auto dstType = getTypeConverter()->convertType(op.getType());
1748  if (!dstType)
1749  return rewriter.notifyMatchFailure(op, "type conversion failed");
1750  auto scalarType = cast<VectorType>(dstType).getElementType();
1751  auto componentsArray = components.getValue();
1752  auto *context = rewriter.getContext();
1753  auto llvmI32Type = IntegerType::get(context, 32);
1754  Value targetOp = LLVM::PoisonOp::create(rewriter, loc, dstType);
1755  for (unsigned i = 0; i < componentsArray.size(); i++) {
1756  if (!isa<IntegerAttr>(componentsArray[i]))
1757  return op.emitError("unable to support non-constant component");
1758 
1759  int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1760  if (indexVal == -1)
1761  continue;
1762 
1763  int offsetVal = 0;
1764  Value baseVector = vector1;
1765  if (indexVal >= vector1Size) {
1766  offsetVal = vector1Size;
1767  baseVector = vector2;
1768  }
1769 
1770  Value dstIndex = LLVM::ConstantOp::create(
1771  rewriter, loc, llvmI32Type,
1772  rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1773  Value index = LLVM::ConstantOp::create(
1774  rewriter, loc, llvmI32Type,
1775  rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1776 
1777  auto extractOp = LLVM::ExtractElementOp::create(rewriter, loc, scalarType,
1778  baseVector, index);
1779  targetOp = LLVM::InsertElementOp::create(rewriter, loc, dstType, targetOp,
1780  extractOp, dstIndex);
1781  }
1782  rewriter.replaceOp(op, targetOp);
1783  return success();
1784  }
1785 };
1786 } // namespace
1787 
1788 //===----------------------------------------------------------------------===//
1789 // Pattern population
1790 //===----------------------------------------------------------------------===//
1791 
1793  spirv::ClientAPI clientAPI) {
1794  typeConverter.addConversion([&](spirv::ArrayType type) {
1795  return convertArrayType(type, typeConverter);
1796  });
1797  typeConverter.addConversion([&, clientAPI](spirv::PointerType type) {
1798  return convertPointerType(type, typeConverter, clientAPI);
1799  });
1800  typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1801  return convertRuntimeArrayType(type, typeConverter);
1802  });
1803  typeConverter.addConversion([&](spirv::StructType type) {
1804  return convertStructType(type, typeConverter);
1805  });
1806 }
1807 
1809  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
1810  spirv::ClientAPI clientAPI) {
1811  patterns.add<
1812  // Arithmetic ops
1813  DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1814  DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1815  DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1816  DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1817  DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1818  DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1819  DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1820  DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1821  DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1822  DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1823  DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1824  DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1825  DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1826 
1827  // Bitwise ops
1828  BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1829  DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1830  DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1831  DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1832  DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1833  DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1834  NotPattern<spirv::NotOp>,
1835 
1836  // Cast ops
1837  BitcastConversionPattern,
1838  DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1839  DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1840  DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1841  DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1842  IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1843  IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1844  IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1845 
1846  // Comparison ops
1847  IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1848  IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1849  FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1850  FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1851  FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1852  FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1853  FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1854  FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1855  FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1856  FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1857  FComparePattern<spirv::FUnordGreaterThanEqualOp,
1858  LLVM::FCmpPredicate::uge>,
1859  FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1860  FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1861  FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1862  IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1863  IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1864  IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1865  IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1866  IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1867  IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1868  IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1869  IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1870 
1871  // Constant op
1872  ConstantScalarAndVectorPattern,
1873 
1874  // Control Flow ops
1875  BranchConversionPattern, BranchConditionalConversionPattern,
1876  FunctionCallPattern, LoopPattern, SelectionPattern,
1877  ErasePattern<spirv::MergeOp>,
1878 
1879  // Entry points and execution mode are handled separately.
1880  ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1881 
1882  // GLSL extended instruction set ops
1883  DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1884  DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1885  DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1886  DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1887  DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1888  DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1889  DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1890  DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1891  DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1892  DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1893  DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1894  DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1895  InverseSqrtPattern, TanPattern, TanhPattern,
1896 
1897  // Logical ops
1898  DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1899  DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1900  IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1901  IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1902  NotPattern<spirv::LogicalNotOp>,
1903 
1904  // Memory ops
1905  AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1906  LoadStorePattern<spirv::StoreOp>, VariablePattern,
1907 
1908  // Miscellaneous ops
1909  CompositeExtractPattern, CompositeInsertPattern,
1910  DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1911  DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1912  VectorShufflePattern,
1913 
1914  // Shift ops
1915  ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1916  ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1917  ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1918 
1919  // Return ops
1920  ReturnPattern, ReturnValuePattern,
1921 
1922  // Barrier ops
1923  ControlBarrierPattern<spirv::ControlBarrierOp>,
1924  ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1925  ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1926 
1927  // Group reduction operations
1928  GroupReducePattern<spirv::GroupIAddOp>,
1929  GroupReducePattern<spirv::GroupFAddOp>,
1930  GroupReducePattern<spirv::GroupFMinOp>,
1931  GroupReducePattern<spirv::GroupUMinOp>,
1932  GroupReducePattern<spirv::GroupSMinOp, /*Signed=*/true>,
1933  GroupReducePattern<spirv::GroupFMaxOp>,
1934  GroupReducePattern<spirv::GroupUMaxOp>,
1935  GroupReducePattern<spirv::GroupSMaxOp, /*Signed=*/true>,
1936  GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed=*/false,
1937  /*NonUniform=*/true>,
1938  GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed=*/false,
1939  /*NonUniform=*/true>,
1940  GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed=*/false,
1941  /*NonUniform=*/true>,
1942  GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed=*/false,
1943  /*NonUniform=*/true>,
1944  GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed=*/true,
1945  /*NonUniform=*/true>,
1946  GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed=*/false,
1947  /*NonUniform=*/true>,
1948  GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed=*/false,
1949  /*NonUniform=*/true>,
1950  GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed=*/true,
1951  /*NonUniform=*/true>,
1952  GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed=*/false,
1953  /*NonUniform=*/true>,
1954  GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed=*/false,
1955  /*NonUniform=*/true>,
1956  GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed=*/false,
1957  /*NonUniform=*/true>,
1958  GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed=*/false,
1959  /*NonUniform=*/true>,
1960  GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed=*/false,
1961  /*NonUniform=*/true>,
1962  GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed=*/false,
1963  /*NonUniform=*/true>,
1964  GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed=*/false,
1965  /*NonUniform=*/true>,
1966  GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed=*/false,
1967  /*NonUniform=*/true>>(patterns.getContext(),
1968  typeConverter);
1969 
1970  patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
1971  typeConverter);
1972 }
1973 
1975  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1976  patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1977 }
1978 
1980  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1981  patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1982 }
1983 
1984 //===----------------------------------------------------------------------===//
1985 // Pre-conversion hooks
1986 //===----------------------------------------------------------------------===//
1987 
1988 /// Hook for descriptor set and binding number encoding.
1989 static constexpr StringRef kBinding = "binding";
1990 static constexpr StringRef kDescriptorSet = "descriptor_set";
1991 void mlir::encodeBindAttribute(ModuleOp module) {
1992  auto spvModules = module.getOps<spirv::ModuleOp>();
1993  for (auto spvModule : spvModules) {
1994  spvModule.walk([&](spirv::GlobalVariableOp op) {
1995  IntegerAttr descriptorSet =
1996  op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1997  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1998  // For every global variable in the module, get the ones with descriptor
1999  // set and binding numbers.
2000  if (descriptorSet && binding) {
2001  // Encode these numbers into the variable's symbolic name. If the
2002  // SPIR-V module has a name, add it at the beginning.
2003  auto moduleAndName =
2004  spvModule.getName().has_value()
2005  ? spvModule.getName()->str() + "_" + op.getSymName().str()
2006  : op.getSymName().str();
2007  std::string name =
2008  llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
2009  std::to_string(descriptorSet.getInt()),
2010  std::to_string(binding.getInt()));
2011  auto nameAttr = StringAttr::get(op->getContext(), name);
2012 
2013  // Replace all symbol uses and set the new symbol name. Finally, remove
2014  // descriptor set and binding attributes.
2015  if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
2016  op.emitError("unable to replace all symbol uses for ") << name;
2017  SymbolTable::setSymbolName(op, nameAttr);
2018  op->removeAttr(kDescriptorSet);
2019  op->removeAttr(kBinding);
2020  }
2021  });
2022  }
2023 }
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static MLIRContext * getContext(OpFoldResult val)
@ None
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
Definition: SPIRVToLLVM.cpp:64
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
Definition: SPIRVToLLVM.cpp:35
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
Definition: SPIRVToLLVM.cpp:44
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
Definition: SPIRVToLLVM.cpp:84
static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
Definition: SPIRVToLLVM.cpp:94
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
Definition: SPIRVToLLVM.cpp:77
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
Definition: SPIRVToLLVM.cpp:54
static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Type convertStructType(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:200
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:254
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:445
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:448
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:442
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:20
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
Definition: SPIRVTypes.cpp:172
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:174
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:170
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:457
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:519
SPIR-V struct type.
Definition: SPIRVTypes.h:251
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:809
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Definition: LLVMDialect.h:227
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
const FrozenRewritePatternSet & patterns
void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.