MLIR  22.0.0git
SPIRVToLLVM.cpp
Go to the documentation of this file.
1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/PatternMatch.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/FormatVariadic.h"
25 
26 #define DEBUG_TYPE "spirv-to-llvm-pattern"
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Utility functions
32 //===----------------------------------------------------------------------===//
33 
34 /// Returns true if the given type is a signed integer or vector type.
35 static bool isSignedIntegerOrVector(Type type) {
36  if (type.isSignedInteger())
37  return true;
38  if (auto vecType = dyn_cast<VectorType>(type))
39  return vecType.getElementType().isSignedInteger();
40  return false;
41 }
42 
43 /// Returns true if the given type is an unsigned integer or vector type
44 static bool isUnsignedIntegerOrVector(Type type) {
45  if (type.isUnsignedInteger())
46  return true;
47  if (auto vecType = dyn_cast<VectorType>(type))
48  return vecType.getElementType().isUnsignedInteger();
49  return false;
50 }
51 
52 /// Returns the width of an integer or of the element type of an integer vector,
53 /// if applicable.
54 static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
55  if (auto intType = dyn_cast<IntegerType>(type))
56  return intType.getWidth();
57  if (auto vecType = dyn_cast<VectorType>(type))
58  if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
59  return intType.getWidth();
60  return std::nullopt;
61 }
62 
63 /// Returns the bit width of integer, float or vector of float or integer values
64 static unsigned getBitWidth(Type type) {
65  assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
66  "bitwidth is not supported for this type");
67  if (type.isIntOrFloat())
68  return type.getIntOrFloatBitWidth();
69  auto vecType = dyn_cast<VectorType>(type);
70  auto elementType = vecType.getElementType();
71  assert(elementType.isIntOrFloat() &&
72  "only integers and floats have a bitwidth");
73  return elementType.getIntOrFloatBitWidth();
74 }
75 
76 /// Returns the bit width of LLVMType integer or vector.
77 static unsigned getLLVMTypeBitWidth(Type type) {
78  if (auto vecTy = dyn_cast<VectorType>(type))
79  type = vecTy.getElementType();
80  return cast<IntegerType>(type).getWidth();
81 }
82 
83 /// Creates `IntegerAttribute` with all bits set for given type
84 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
85  if (auto vecType = dyn_cast<VectorType>(type)) {
86  auto integerType = cast<IntegerType>(vecType.getElementType());
87  return builder.getIntegerAttr(integerType, -1);
88  }
89  auto integerType = cast<IntegerType>(type);
90  return builder.getIntegerAttr(integerType, -1);
91 }
92 
93 /// Creates `llvm.mlir.constant` with all bits set for the given type.
94 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
95  PatternRewriter &rewriter) {
96  if (isa<VectorType>(srcType)) {
97  return LLVM::ConstantOp::create(
98  rewriter, loc, dstType,
99  SplatElementsAttr::get(cast<ShapedType>(srcType),
100  minusOneIntegerAttribute(srcType, rewriter)));
101  }
102  return LLVM::ConstantOp::create(rewriter, loc, dstType,
103  minusOneIntegerAttribute(srcType, rewriter));
104 }
105 
106 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
107 static Value createFPConstant(Location loc, Type srcType, Type dstType,
108  PatternRewriter &rewriter, double value) {
109  if (auto vecType = dyn_cast<VectorType>(srcType)) {
110  auto floatType = cast<FloatType>(vecType.getElementType());
111  return LLVM::ConstantOp::create(
112  rewriter, loc, dstType,
113  SplatElementsAttr::get(vecType,
114  rewriter.getFloatAttr(floatType, value)));
115  }
116  auto floatType = cast<FloatType>(srcType);
117  return LLVM::ConstantOp::create(rewriter, loc, dstType,
118  rewriter.getFloatAttr(floatType, value));
119 }
120 
121 /// Utility function for bitfield ops:
122 /// - `BitFieldInsert`
123 /// - `BitFieldSExtract`
124 /// - `BitFieldUExtract`
125 /// Truncates or extends the value. If the bitwidth of the value is the same as
126 /// `llvmType` bitwidth, the value remains unchanged.
128  Type llvmType,
129  PatternRewriter &rewriter) {
130  auto srcType = value.getType();
131  unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
132  unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
133  ? getLLVMTypeBitWidth(srcType)
134  : getBitWidth(srcType);
135 
136  if (valueBitWidth < targetBitWidth)
137  return LLVM::ZExtOp::create(rewriter, loc, llvmType, value);
138  // If the bit widths of `Count` and `Offset` are greater than the bit width
139  // of the target type, they are truncated. Truncation is safe since `Count`
140  // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
141  // both values can be expressed in 8 bits.
142  if (valueBitWidth > targetBitWidth)
143  return LLVM::TruncOp::create(rewriter, loc, llvmType, value);
144  return value;
145 }
146 
147 /// Broadcasts the value to vector with `numElements` number of elements.
148 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
149  const TypeConverter &typeConverter,
150  ConversionPatternRewriter &rewriter) {
151  auto vectorType = VectorType::get(numElements, toBroadcast.getType());
152  auto llvmVectorType = typeConverter.convertType(vectorType);
153  auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
154  Value broadcasted = LLVM::PoisonOp::create(rewriter, loc, llvmVectorType);
155  for (unsigned i = 0; i < numElements; ++i) {
156  auto index = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type,
157  rewriter.getI32IntegerAttr(i));
158  broadcasted = LLVM::InsertElementOp::create(
159  rewriter, loc, llvmVectorType, broadcasted, toBroadcast, index);
160  }
161  return broadcasted;
162 }
163 
164 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
165 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
166  const TypeConverter &typeConverter,
167  ConversionPatternRewriter &rewriter) {
168  if (auto vectorType = dyn_cast<VectorType>(srcType)) {
169  unsigned numElements = vectorType.getNumElements();
170  return broadcast(loc, value, numElements, typeConverter, rewriter);
171  }
172  return value;
173 }
174 
175 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
176 /// `BitFieldUExtract`.
177 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
178 /// a vector type, construct a vector that has:
179 /// - same number of elements as `Base`
180 /// - each element has the type that is the same as the type of `Offset` or
181 /// `Count`
182 /// - each element has the same value as `Offset` or `Count`
183 /// Then cast `Offset` and `Count` if their bit width is different
184 /// from `Base` bit width.
185 static Value processCountOrOffset(Location loc, Value value, Type srcType,
186  Type dstType, const TypeConverter &converter,
187  ConversionPatternRewriter &rewriter) {
188  Value broadcasted =
189  optionallyBroadcast(loc, value, srcType, converter, rewriter);
190  return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
191 }
192 
193 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
194 /// offset to LLVM struct. Otherwise, the conversion is not supported.
196  const TypeConverter &converter) {
197  if (type != VulkanLayoutUtils::decorateType(type))
198  return nullptr;
199 
200  SmallVector<Type> elementsVector;
201  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
202  return nullptr;
203  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
204  /*isPacked=*/false);
205 }
206 
207 /// Converts SPIR-V struct with no offset to packed LLVM struct.
209  const TypeConverter &converter) {
210  SmallVector<Type> elementsVector;
211  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
212  return nullptr;
213  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
214  /*isPacked=*/true);
215 }
216 
217 /// Creates LLVM dialect constant with the given value.
219  unsigned value) {
220  return LLVM::ConstantOp::create(
221  rewriter, loc, IntegerType::get(rewriter.getContext(), 32),
222  rewriter.getIntegerAttr(rewriter.getI32Type(), value));
223 }
224 
225 /// Utility for `spirv.Load` and `spirv.Store` conversion.
226 static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
227  ConversionPatternRewriter &rewriter,
228  const TypeConverter &typeConverter,
229  unsigned alignment, bool isVolatile,
230  bool isNonTemporal) {
231  if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
232  auto dstType = typeConverter.convertType(loadOp.getType());
233  if (!dstType)
234  return rewriter.notifyMatchFailure(op, "type conversion failed");
235  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
236  loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
237  isVolatile, isNonTemporal);
238  return success();
239  }
240  auto storeOp = cast<spirv::StoreOp>(op);
241  spirv::StoreOpAdaptor adaptor(operands);
242  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
243  adaptor.getPtr(), alignment,
244  isVolatile, isNonTemporal);
245  return success();
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // Type conversion
250 //===----------------------------------------------------------------------===//
251 
252 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
253 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
254 /// when converting ops that manipulate array types.
255 static std::optional<Type> convertArrayType(spirv::ArrayType type,
256  TypeConverter &converter) {
257  unsigned stride = type.getArrayStride();
258  Type elementType = type.getElementType();
259  auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
260  if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
261  return std::nullopt;
262 
263  auto llvmElementType = converter.convertType(elementType);
264  unsigned numElements = type.getNumElements();
265  return LLVM::LLVMArrayType::get(llvmElementType, numElements);
266 }
267 
268 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
269 /// modelled at the moment.
271  const TypeConverter &converter,
272  spirv::ClientAPI clientAPI) {
273  unsigned addressSpace =
274  storageClassToAddressSpace(clientAPI, type.getStorageClass());
275  return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
276 }
277 
278 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
279 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
280 /// no modelling of array stride at the moment.
281 static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
282  TypeConverter &converter) {
283  if (type.getArrayStride() != 0)
284  return std::nullopt;
285  auto elementType = converter.convertType(type.getElementType());
286  return LLVM::LLVMArrayType::get(elementType, 0);
287 }
288 
289 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
290 /// member decorations. Also, only natural offset is supported.
292  const TypeConverter &converter) {
294  type.getMemberDecorations(memberDecorations);
295  if (!memberDecorations.empty())
296  return nullptr;
297  if (type.hasOffset())
298  return convertStructTypeWithOffset(type, converter);
299  return convertStructTypePacked(type, converter);
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // Operation conversion
304 //===----------------------------------------------------------------------===//
305 
306 namespace {
307 
308 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
309 public:
311 
312  LogicalResult
313  matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
314  ConversionPatternRewriter &rewriter) const override {
315  auto dstType =
316  getTypeConverter()->convertType(op.getComponentPtr().getType());
317  if (!dstType)
318  return rewriter.notifyMatchFailure(op, "type conversion failed");
319  // To use GEP we need to add a first 0 index to go through the pointer.
320  auto indices = llvm::to_vector<4>(adaptor.getIndices());
321  Type indexType = op.getIndices().front().getType();
322  auto llvmIndexType = getTypeConverter()->convertType(indexType);
323  if (!llvmIndexType)
324  return rewriter.notifyMatchFailure(op, "type conversion failed");
325  Value zero =
326  LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmIndexType,
327  rewriter.getIntegerAttr(indexType, 0));
328  indices.insert(indices.begin(), zero);
329 
330  auto elementType = getTypeConverter()->convertType(
331  cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
332  if (!elementType)
333  return rewriter.notifyMatchFailure(op, "type conversion failed");
334  rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
335  adaptor.getBasePtr(), indices);
336  return success();
337  }
338 };
339 
340 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
341 public:
343 
344  LogicalResult
345  matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
346  ConversionPatternRewriter &rewriter) const override {
347  auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
348  if (!dstType)
349  return rewriter.notifyMatchFailure(op, "type conversion failed");
350  rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
351  op.getVariable());
352  return success();
353  }
354 };
355 
356 class BitFieldInsertPattern
357  : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
358 public:
360 
361  LogicalResult
362  matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
363  ConversionPatternRewriter &rewriter) const override {
364  auto srcType = op.getType();
365  auto dstType = getTypeConverter()->convertType(srcType);
366  if (!dstType)
367  return rewriter.notifyMatchFailure(op, "type conversion failed");
368  Location loc = op.getLoc();
369 
370  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
371  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
372  *getTypeConverter(), rewriter);
373  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
374  *getTypeConverter(), rewriter);
375 
376  // Create a mask with bits set outside [Offset, Offset + Count - 1].
377  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
378  Value maskShiftedByCount =
379  LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count);
380  Value negated = LLVM::XOrOp::create(rewriter, loc, dstType,
381  maskShiftedByCount, minusOne);
382  Value maskShiftedByCountAndOffset =
383  LLVM::ShlOp::create(rewriter, loc, dstType, negated, offset);
384  Value mask = LLVM::XOrOp::create(rewriter, loc, dstType,
385  maskShiftedByCountAndOffset, minusOne);
386 
387  // Extract unchanged bits from the `Base` that are outside of
388  // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
389  Value baseAndMask =
390  LLVM::AndOp::create(rewriter, loc, dstType, op.getBase(), mask);
391  Value insertShiftedByOffset =
392  LLVM::ShlOp::create(rewriter, loc, dstType, op.getInsert(), offset);
393  rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
394  insertShiftedByOffset);
395  return success();
396  }
397 };
398 
399 /// Converts SPIR-V ConstantOp with scalar or vector type.
400 class ConstantScalarAndVectorPattern
401  : public SPIRVToLLVMConversion<spirv::ConstantOp> {
402 public:
404 
405  LogicalResult
406  matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
407  ConversionPatternRewriter &rewriter) const override {
408  auto srcType = constOp.getType();
409  if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
410  return failure();
411 
412  auto dstType = getTypeConverter()->convertType(srcType);
413  if (!dstType)
414  return rewriter.notifyMatchFailure(constOp, "type conversion failed");
415 
416  // SPIR-V constant can be a signed/unsigned integer, which has to be
417  // casted to signless integer when converting to LLVM dialect. Removing the
418  // sign bit may have unexpected behaviour. However, it is better to handle
419  // it case-by-case, given that the purpose of the conversion is not to
420  // cover all possible corner cases.
421  if (isSignedIntegerOrVector(srcType) ||
422  isUnsignedIntegerOrVector(srcType)) {
423  auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
424 
425  if (isa<VectorType>(srcType)) {
426  auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
427  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
428  constOp, dstType,
429  dstElementsAttr.mapValues(
430  signlessType, [&](const APInt &value) { return value; }));
431  return success();
432  }
433  auto srcAttr = cast<IntegerAttr>(constOp.getValue());
434  auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
435  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
436  return success();
437  }
438  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
439  constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
440  return success();
441  }
442 };
443 
444 class BitFieldSExtractPattern
445  : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
446 public:
448 
449  LogicalResult
450  matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
451  ConversionPatternRewriter &rewriter) const override {
452  auto srcType = op.getType();
453  auto dstType = getTypeConverter()->convertType(srcType);
454  if (!dstType)
455  return rewriter.notifyMatchFailure(op, "type conversion failed");
456  Location loc = op.getLoc();
457 
458  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
459  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
460  *getTypeConverter(), rewriter);
461  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
462  *getTypeConverter(), rewriter);
463 
464  // Create a constant that holds the size of the `Base`.
465  IntegerType integerType;
466  if (auto vecType = dyn_cast<VectorType>(srcType))
467  integerType = cast<IntegerType>(vecType.getElementType());
468  else
469  integerType = cast<IntegerType>(srcType);
470 
471  auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
472  Value size =
473  isa<VectorType>(srcType)
474  ? LLVM::ConstantOp::create(
475  rewriter, loc, dstType,
476  SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
477  : LLVM::ConstantOp::create(rewriter, loc, dstType, baseSize);
478 
479  // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
480  // at Offset + Count - 1 is the most significant bit now.
481  Value countPlusOffset =
482  LLVM::AddOp::create(rewriter, loc, dstType, count, offset);
483  Value amountToShiftLeft =
484  LLVM::SubOp::create(rewriter, loc, dstType, size, countPlusOffset);
485  Value baseShiftedLeft = LLVM::ShlOp::create(
486  rewriter, loc, dstType, op.getBase(), amountToShiftLeft);
487 
488  // Shift the result right, filling the bits with the sign bit.
489  Value amountToShiftRight =
490  LLVM::AddOp::create(rewriter, loc, dstType, offset, amountToShiftLeft);
491  rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
492  amountToShiftRight);
493  return success();
494  }
495 };
496 
497 class BitFieldUExtractPattern
498  : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
499 public:
501 
502  LogicalResult
503  matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
504  ConversionPatternRewriter &rewriter) const override {
505  auto srcType = op.getType();
506  auto dstType = getTypeConverter()->convertType(srcType);
507  if (!dstType)
508  return rewriter.notifyMatchFailure(op, "type conversion failed");
509  Location loc = op.getLoc();
510 
511  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
512  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
513  *getTypeConverter(), rewriter);
514  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
515  *getTypeConverter(), rewriter);
516 
517  // Create a mask with bits set at [0, Count - 1].
518  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
519  Value maskShiftedByCount =
520  LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count);
521  Value mask = LLVM::XOrOp::create(rewriter, loc, dstType, maskShiftedByCount,
522  minusOne);
523 
524  // Shift `Base` by `Offset` and apply the mask on it.
525  Value shiftedBase =
526  LLVM::LShrOp::create(rewriter, loc, dstType, op.getBase(), offset);
527  rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
528  return success();
529  }
530 };
531 
532 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
533 public:
535 
536  LogicalResult
537  matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
538  ConversionPatternRewriter &rewriter) const override {
539  rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
540  branchOp.getTarget());
541  return success();
542  }
543 };
544 
545 class BranchConditionalConversionPattern
546  : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
547 public:
548  using SPIRVToLLVMConversion<
549  spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
550 
551  LogicalResult
552  matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
553  ConversionPatternRewriter &rewriter) const override {
554  // If branch weights exist, map them to 32-bit integer vector.
555  DenseI32ArrayAttr branchWeights = nullptr;
556  if (auto weights = op.getBranchWeights()) {
557  SmallVector<int32_t> weightValues;
558  for (auto weight : weights->getAsRange<IntegerAttr>())
559  weightValues.push_back(weight.getInt());
560  branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
561  }
562 
563  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
564  op, op.getCondition(), op.getTrueBlockArguments(),
565  op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
566  op.getFalseBlock());
567  return success();
568  }
569 };
570 
571 /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
572 /// type is an aggregate type (struct or array). Otherwise, converts to
573 /// `llvm.extractelement` that operates on vectors.
574 class CompositeExtractPattern
575  : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
576 public:
578 
579  LogicalResult
580  matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
581  ConversionPatternRewriter &rewriter) const override {
582  auto dstType = this->getTypeConverter()->convertType(op.getType());
583  if (!dstType)
584  return rewriter.notifyMatchFailure(op, "type conversion failed");
585 
586  Type containerType = op.getComposite().getType();
587  if (isa<VectorType>(containerType)) {
588  Location loc = op.getLoc();
589  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
590  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
591  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
592  op, dstType, adaptor.getComposite(), index);
593  return success();
594  }
595 
596  rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
597  op, adaptor.getComposite(),
598  LLVM::convertArrayToIndices(op.getIndices()));
599  return success();
600  }
601 };
602 
603 /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
604 /// type is an aggregate type (struct or array). Otherwise, converts to
605 /// `llvm.insertelement` that operates on vectors.
606 class CompositeInsertPattern
607  : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
608 public:
610 
611  LogicalResult
612  matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
613  ConversionPatternRewriter &rewriter) const override {
614  auto dstType = this->getTypeConverter()->convertType(op.getType());
615  if (!dstType)
616  return rewriter.notifyMatchFailure(op, "type conversion failed");
617 
618  Type containerType = op.getComposite().getType();
619  if (isa<VectorType>(containerType)) {
620  Location loc = op.getLoc();
621  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
622  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
623  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
624  op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
625  return success();
626  }
627 
628  rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
629  op, adaptor.getComposite(), adaptor.getObject(),
630  LLVM::convertArrayToIndices(op.getIndices()));
631  return success();
632  }
633 };
634 
635 /// Converts SPIR-V operations that have straightforward LLVM equivalent
636 /// into LLVM dialect operations.
637 template <typename SPIRVOp, typename LLVMOp>
638 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
639 public:
641 
642  LogicalResult
643  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
644  ConversionPatternRewriter &rewriter) const override {
645  auto dstType = this->getTypeConverter()->convertType(op.getType());
646  if (!dstType)
647  return rewriter.notifyMatchFailure(op, "type conversion failed");
648  rewriter.template replaceOpWithNewOp<LLVMOp>(
649  op, dstType, adaptor.getOperands(), op->getAttrs());
650  return success();
651  }
652 };
653 
654 /// Converts `spirv.ExecutionMode` into a global struct constant that holds
655 /// execution mode information.
656 class ExecutionModePattern
657  : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
658 public:
660 
661  LogicalResult
662  matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
663  ConversionPatternRewriter &rewriter) const override {
664  // First, create the global struct's name that would be associated with
665  // this entry point's execution mode. We set it to be:
666  // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
667  ModuleOp module = op->getParentOfType<ModuleOp>();
668  spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
669  std::string moduleName;
670  if (module.getName().has_value())
671  moduleName = "_" + module.getName()->str();
672  else
673  moduleName = "";
674  std::string executionModeInfoName = llvm::formatv(
675  "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
676  static_cast<uint32_t>(executionModeAttr.getValue()));
677 
678  MLIRContext *context = rewriter.getContext();
679  OpBuilder::InsertionGuard guard(rewriter);
680  rewriter.setInsertionPointToStart(module.getBody());
681 
682  // Create a struct type, corresponding to the C struct below.
683  // struct {
684  // int32_t executionMode;
685  // int32_t values[]; // optional values
686  // };
687  auto llvmI32Type = IntegerType::get(context, 32);
688  SmallVector<Type, 2> fields;
689  fields.push_back(llvmI32Type);
690  ArrayAttr values = op.getValues();
691  if (!values.empty()) {
692  auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
693  fields.push_back(arrayType);
694  }
695  auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
696 
697  // Create `llvm.mlir.global` with initializer region containing one block.
698  auto global = LLVM::GlobalOp::create(
699  rewriter, UnknownLoc::get(context), structType, /*isConstant=*/true,
700  LLVM::Linkage::External, executionModeInfoName, Attribute(),
701  /*alignment=*/0);
702  Location loc = global.getLoc();
703  Region &region = global.getInitializerRegion();
704  Block *block = rewriter.createBlock(&region);
705 
706  // Initialize the struct and set the execution mode value.
707  rewriter.setInsertionPointToStart(block);
708  Value structValue = LLVM::PoisonOp::create(rewriter, loc, structType);
709  Value executionMode = LLVM::ConstantOp::create(
710  rewriter, loc, llvmI32Type,
711  rewriter.getI32IntegerAttr(
712  static_cast<uint32_t>(executionModeAttr.getValue())));
713  SmallVector<int64_t> position{0};
714  structValue = LLVM::InsertValueOp::create(rewriter, loc, structValue,
715  executionMode, position);
716 
717  // Insert extra operands if they exist into execution mode info struct.
718  for (unsigned i = 0, e = values.size(); i < e; ++i) {
719  auto attr = values.getValue()[i];
720  Value entry = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, attr);
721  structValue = LLVM::InsertValueOp::create(
722  rewriter, loc, structValue, entry, ArrayRef<int64_t>({1, i}));
723  }
724  LLVM::ReturnOp::create(rewriter, loc, ArrayRef<Value>({structValue}));
725  rewriter.eraseOp(op);
726  return success();
727  }
728 };
729 
730 /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
731 /// global returns a pointer, whereas in LLVM dialect the global holds an actual
732 /// value. This difference is handled by `spirv.mlir.addressof` and
733 /// `llvm.mlir.addressof`ops that both return a pointer.
734 class GlobalVariablePattern
735  : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
736 public:
737  template <typename... Args>
738  GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
739  : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
740  std::forward<Args>(args)...),
741  clientAPI(clientAPI) {}
742 
743  LogicalResult
744  matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
745  ConversionPatternRewriter &rewriter) const override {
746  // Currently, there is no support of initialization with a constant value in
747  // SPIR-V dialect. Specialization constants are not considered as well.
748  if (op.getInitializer())
749  return failure();
750 
751  auto srcType = cast<spirv::PointerType>(op.getType());
752  auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
753  if (!dstType)
754  return rewriter.notifyMatchFailure(op, "type conversion failed");
755 
756  // Limit conversion to the current invocation only or `StorageBuffer`
757  // required by SPIR-V runner.
758  // This is okay because multiple invocations are not supported yet.
759  auto storageClass = srcType.getStorageClass();
760  switch (storageClass) {
761  case spirv::StorageClass::Input:
762  case spirv::StorageClass::Private:
763  case spirv::StorageClass::Output:
764  case spirv::StorageClass::StorageBuffer:
765  case spirv::StorageClass::UniformConstant:
766  break;
767  default:
768  return failure();
769  }
770 
771  // LLVM dialect spec: "If the global value is a constant, storing into it is
772  // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
773  // storage class that is read-only.
774  bool isConstant = (storageClass == spirv::StorageClass::Input) ||
775  (storageClass == spirv::StorageClass::UniformConstant);
776  // SPIR-V spec: "By default, functions and global variables are private to a
777  // module and cannot be accessed by other modules. However, a module may be
778  // written to export or import functions and global (module scope)
779  // variables.". Therefore, map 'Private' storage class to private linkage,
780  // 'Input' and 'Output' to external linkage.
781  auto linkage = storageClass == spirv::StorageClass::Private
782  ? LLVM::Linkage::Private
783  : LLVM::Linkage::External;
784  StringAttr locationAttrName = op.getLocationAttrName();
785  IntegerAttr locationAttr = op.getLocationAttr();
786  auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
787  op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
788  /*alignment=*/0, storageClassToAddressSpace(clientAPI, storageClass));
789 
790  // Attach location attribute if applicable
791  if (locationAttr)
792  newGlobalOp->setAttr(locationAttrName, locationAttr);
793 
794  return success();
795  }
796 
797 private:
798  spirv::ClientAPI clientAPI;
799 };
800 
801 /// Converts SPIR-V cast ops that do not have straightforward LLVM
802 /// equivalent in LLVM dialect.
803 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
804 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
805 public:
807 
808  LogicalResult
809  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
810  ConversionPatternRewriter &rewriter) const override {
811 
812  Type fromType = op.getOperand().getType();
813  Type toType = op.getType();
814 
815  auto dstType = this->getTypeConverter()->convertType(toType);
816  if (!dstType)
817  return rewriter.notifyMatchFailure(op, "type conversion failed");
818 
819  if (getBitWidth(fromType) < getBitWidth(toType)) {
820  rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
821  adaptor.getOperands());
822  return success();
823  }
824  if (getBitWidth(fromType) > getBitWidth(toType)) {
825  rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
826  adaptor.getOperands());
827  return success();
828  }
829  return failure();
830  }
831 };
832 
833 class FunctionCallPattern
834  : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
835 public:
837 
838  LogicalResult
839  matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
840  ConversionPatternRewriter &rewriter) const override {
841  if (callOp.getNumResults() == 0) {
842  auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
843  callOp, TypeRange(), adaptor.getOperands(), callOp->getAttrs());
844  newOp.getProperties().operandSegmentSizes = {
845  static_cast<int32_t>(adaptor.getOperands().size()), 0};
846  newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
847  return success();
848  }
849 
850  // Function returns a single result.
851  auto dstType = getTypeConverter()->convertType(callOp.getType(0));
852  if (!dstType)
853  return rewriter.notifyMatchFailure(callOp, "type conversion failed");
854  auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
855  callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
856  newOp.getProperties().operandSegmentSizes = {
857  static_cast<int32_t>(adaptor.getOperands().size()), 0};
858  newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
859  return success();
860  }
861 };
862 
863 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
864 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
865 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
866 public:
868 
869  LogicalResult
870  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
871  ConversionPatternRewriter &rewriter) const override {
872 
873  auto dstType = this->getTypeConverter()->convertType(op.getType());
874  if (!dstType)
875  return rewriter.notifyMatchFailure(op, "type conversion failed");
876 
877  rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
878  op, dstType, predicate, op.getOperand1(), op.getOperand2());
879  return success();
880  }
881 };
882 
883 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
884 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
885 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
886 public:
888 
889  LogicalResult
890  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
891  ConversionPatternRewriter &rewriter) const override {
892 
893  auto dstType = this->getTypeConverter()->convertType(op.getType());
894  if (!dstType)
895  return rewriter.notifyMatchFailure(op, "type conversion failed");
896 
897  rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
898  op, dstType, predicate, op.getOperand1(), op.getOperand2());
899  return success();
900  }
901 };
902 
903 class InverseSqrtPattern
904  : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
905 public:
907 
908  LogicalResult
909  matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
910  ConversionPatternRewriter &rewriter) const override {
911  auto srcType = op.getType();
912  auto dstType = getTypeConverter()->convertType(srcType);
913  if (!dstType)
914  return rewriter.notifyMatchFailure(op, "type conversion failed");
915 
916  Location loc = op.getLoc();
917  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
918  Value sqrt = LLVM::SqrtOp::create(rewriter, loc, dstType, op.getOperand());
919  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
920  return success();
921  }
922 };
923 
924 /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
925 template <typename SPIRVOp>
926 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
927 public:
929 
930  LogicalResult
931  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
932  ConversionPatternRewriter &rewriter) const override {
933  if (!op.getMemoryAccess()) {
934  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
935  *this->getTypeConverter(), /*alignment=*/0,
936  /*isVolatile=*/false,
937  /*isNonTemporal=*/false);
938  }
939  auto memoryAccess = *op.getMemoryAccess();
940  switch (memoryAccess) {
941  case spirv::MemoryAccess::Aligned:
943  case spirv::MemoryAccess::Nontemporal:
944  case spirv::MemoryAccess::Volatile: {
945  unsigned alignment =
946  memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
947  bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
948  bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
949  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
950  *this->getTypeConverter(), alignment,
951  isVolatile, isNonTemporal);
952  }
953  default:
954  // There is no support of other memory access attributes.
955  return failure();
956  }
957  }
958 };
959 
960 /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
961 template <typename SPIRVOp>
962 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
963 public:
965 
966  LogicalResult
967  matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
968  ConversionPatternRewriter &rewriter) const override {
969  auto srcType = notOp.getType();
970  auto dstType = this->getTypeConverter()->convertType(srcType);
971  if (!dstType)
972  return rewriter.notifyMatchFailure(notOp, "type conversion failed");
973 
974  Location loc = notOp.getLoc();
975  IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
976  auto mask =
977  isa<VectorType>(srcType)
978  ? LLVM::ConstantOp::create(
979  rewriter, loc, dstType,
980  SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
981  : LLVM::ConstantOp::create(rewriter, loc, dstType, minusOne);
982  rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
983  notOp.getOperand(), mask);
984  return success();
985  }
986 };
987 
988 /// A template pattern that erases the given `SPIRVOp`.
989 template <typename SPIRVOp>
990 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
991 public:
993 
994  LogicalResult
995  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
996  ConversionPatternRewriter &rewriter) const override {
997  rewriter.eraseOp(op);
998  return success();
999  }
1000 };
1001 
1002 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
1003 public:
1005 
1006  LogicalResult
1007  matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1008  ConversionPatternRewriter &rewriter) const override {
1009  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
1010  ArrayRef<Value>());
1011  return success();
1012  }
1013 };
1014 
1015 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1016 public:
1018 
1019  LogicalResult
1020  matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1021  ConversionPatternRewriter &rewriter) const override {
1022  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
1023  adaptor.getOperands());
1024  return success();
1025  }
1026 };
1027 
1028 static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
1029  StringRef name,
1030  ArrayRef<Type> paramTypes,
1031  Type resultType,
1032  bool convergent = true) {
1033  auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1034  SymbolTable::lookupSymbolIn(symbolTable, name));
1035  if (func)
1036  return func;
1037 
1038  OpBuilder b(symbolTable->getRegion(0));
1039  func = LLVM::LLVMFuncOp::create(
1040  b, symbolTable->getLoc(), name,
1041  LLVM::LLVMFunctionType::get(resultType, paramTypes));
1042  func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1043  func.setConvergent(convergent);
1044  func.setNoUnwind(true);
1045  func.setWillReturn(true);
1046  return func;
1047 }
1048 
1049 static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
1050  LLVM::LLVMFuncOp func,
1051  ValueRange args) {
1052  auto call = LLVM::CallOp::create(builder, loc, func, args);
1053  call.setCConv(func.getCConv());
1054  call.setConvergentAttr(func.getConvergentAttr());
1055  call.setNoUnwindAttr(func.getNoUnwindAttr());
1056  call.setWillReturnAttr(func.getWillReturnAttr());
1057  return call;
1058 }
1059 
1060 template <typename BarrierOpTy>
1061 class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> {
1062 public:
1063  using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor;
1064 
1066 
1067  static constexpr StringRef getFuncName();
1068 
1069  LogicalResult
1070  matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1071  ConversionPatternRewriter &rewriter) const override {
1072  constexpr StringRef funcName = getFuncName();
1073  Operation *symbolTable =
1074  controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1075 
1076  Type i32 = rewriter.getI32Type();
1077 
1078  Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
1079  LLVM::LLVMFuncOp func =
1080  lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
1081 
1082  Location loc = controlBarrierOp->getLoc();
1083  Value execution = LLVM::ConstantOp::create(
1084  rewriter, loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
1085  Value memory = LLVM::ConstantOp::create(
1086  rewriter, loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
1087  Value semantics = LLVM::ConstantOp::create(
1088  rewriter, loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
1089 
1090  auto call = createSPIRVBuiltinCall(loc, rewriter, func,
1091  {execution, memory, semantics});
1092 
1093  rewriter.replaceOp(controlBarrierOp, call);
1094  return success();
1095  }
1096 };
1097 
1098 namespace {
1099 
1100 StringRef getTypeMangling(Type type, bool isSigned) {
1102  .Case<Float16Type>([](auto) { return "Dh"; })
1103  .Case<Float32Type>([](auto) { return "f"; })
1104  .Case<Float64Type>([](auto) { return "d"; })
1105  .Case<IntegerType>([isSigned](IntegerType intTy) {
1106  switch (intTy.getWidth()) {
1107  case 1:
1108  return "b";
1109  case 8:
1110  return (isSigned) ? "a" : "c";
1111  case 16:
1112  return (isSigned) ? "s" : "t";
1113  case 32:
1114  return (isSigned) ? "i" : "j";
1115  case 64:
1116  return (isSigned) ? "l" : "m";
1117  default:
1118  llvm_unreachable("Unsupported integer width");
1119  }
1120  })
1121  .Default([](auto) {
1122  llvm_unreachable("No mangling defined");
1123  return "";
1124  });
1125 }
1126 
1127 template <typename ReduceOp>
1128 constexpr StringLiteral getGroupFuncName();
1129 
1130 template <>
1131 constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1132  return "_Z17__spirv_GroupIAddii";
1133 }
1134 template <>
1135 constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1136  return "_Z17__spirv_GroupFAddii";
1137 }
1138 template <>
1139 constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1140  return "_Z17__spirv_GroupSMinii";
1141 }
1142 template <>
1143 constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1144  return "_Z17__spirv_GroupUMinii";
1145 }
1146 template <>
1147 constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1148  return "_Z17__spirv_GroupFMinii";
1149 }
1150 template <>
1151 constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1152  return "_Z17__spirv_GroupSMaxii";
1153 }
1154 template <>
1155 constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1156  return "_Z17__spirv_GroupUMaxii";
1157 }
1158 template <>
1159 constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1160  return "_Z17__spirv_GroupFMaxii";
1161 }
1162 template <>
1163 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1164  return "_Z27__spirv_GroupNonUniformIAddii";
1165 }
1166 template <>
1167 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1168  return "_Z27__spirv_GroupNonUniformFAddii";
1169 }
1170 template <>
1171 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1172  return "_Z27__spirv_GroupNonUniformIMulii";
1173 }
1174 template <>
1175 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1176  return "_Z27__spirv_GroupNonUniformFMulii";
1177 }
1178 template <>
1179 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1180  return "_Z27__spirv_GroupNonUniformSMinii";
1181 }
1182 template <>
1183 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1184  return "_Z27__spirv_GroupNonUniformUMinii";
1185 }
1186 template <>
1187 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1188  return "_Z27__spirv_GroupNonUniformFMinii";
1189 }
1190 template <>
1191 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1192  return "_Z27__spirv_GroupNonUniformSMaxii";
1193 }
1194 template <>
1195 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1196  return "_Z27__spirv_GroupNonUniformUMaxii";
1197 }
1198 template <>
1199 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1200  return "_Z27__spirv_GroupNonUniformFMaxii";
1201 }
1202 template <>
1203 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1204  return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1205 }
1206 template <>
1207 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1208  return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1209 }
1210 template <>
1211 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1212  return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1213 }
1214 template <>
1215 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1216  return "_Z33__spirv_GroupNonUniformLogicalAndii";
1217 }
1218 template <>
1219 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1220  return "_Z32__spirv_GroupNonUniformLogicalOrii";
1221 }
1222 template <>
1223 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1224  return "_Z33__spirv_GroupNonUniformLogicalXorii";
1225 }
1226 } // namespace
1227 
1228 template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
1229 class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
1230 public:
1232 
1233  LogicalResult
1234  matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
1235  ConversionPatternRewriter &rewriter) const override {
1236 
1237  Type retTy = op.getResult().getType();
1238  if (!retTy.isIntOrFloat()) {
1239  return failure();
1240  }
1241  SmallString<36> funcName = getGroupFuncName<ReduceOp>();
1242  funcName += getTypeMangling(retTy, false);
1243 
1244  Type i32Ty = rewriter.getI32Type();
1245  SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
1246  if constexpr (NonUniform) {
1247  if (adaptor.getClusterSize()) {
1248  funcName += "j";
1249  paramTypes.push_back(i32Ty);
1250  }
1251  }
1252 
1253  Operation *symbolTable =
1254  op->template getParentWithTrait<OpTrait::SymbolTable>();
1255 
1256  LLVM::LLVMFuncOp func =
1257  lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy);
1258 
1259  Location loc = op.getLoc();
1260  Value scope = LLVM::ConstantOp::create(
1261  rewriter, loc, i32Ty,
1262  static_cast<int32_t>(adaptor.getExecutionScope()));
1263  Value groupOp = LLVM::ConstantOp::create(
1264  rewriter, loc, i32Ty,
1265  static_cast<int32_t>(adaptor.getGroupOperation()));
1266  SmallVector<Value> operands{scope, groupOp};
1267  operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1268 
1269  auto call = createSPIRVBuiltinCall(loc, rewriter, func, operands);
1270  rewriter.replaceOp(op, call);
1271  return success();
1272  }
1273 };
1274 
1275 template <>
1276 constexpr StringRef
1277 ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1278  return "_Z22__spirv_ControlBarrieriii";
1279 }
1280 
1281 template <>
1282 constexpr StringRef
1283 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1284  return "_Z33__spirv_ControlBarrierArriveINTELiii";
1285 }
1286 
1287 template <>
1288 constexpr StringRef
1289 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1290  return "_Z31__spirv_ControlBarrierWaitINTELiii";
1291 }
1292 
1293 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1294 /// should be reachable for conversion to succeed. The structure of the loop in
1295 /// LLVM dialect will be the following:
1296 ///
1297 /// +------------------------------------+
1298 /// | <code before spirv.mlir.loop> |
1299 /// | llvm.br ^header |
1300 /// +------------------------------------+
1301 /// |
1302 /// +----------------+ |
1303 /// | | |
1304 /// | V V
1305 /// | +------------------------------------+
1306 /// | | ^header: |
1307 /// | | <header code> |
1308 /// | | llvm.cond_br %cond, ^body, ^exit |
1309 /// | +------------------------------------+
1310 /// | |
1311 /// | |----------------------+
1312 /// | | |
1313 /// | V |
1314 /// | +------------------------------------+ |
1315 /// | | ^body: | |
1316 /// | | <body code> | |
1317 /// | | llvm.br ^continue | |
1318 /// | +------------------------------------+ |
1319 /// | | |
1320 /// | V |
1321 /// | +------------------------------------+ |
1322 /// | | ^continue: | |
1323 /// | | <continue code> | |
1324 /// | | llvm.br ^header | |
1325 /// | +------------------------------------+ |
1326 /// | | |
1327 /// +---------------+ +----------------------+
1328 /// |
1329 /// V
1330 /// +------------------------------------+
1331 /// | ^exit: |
1332 /// | llvm.br ^remaining |
1333 /// +------------------------------------+
1334 /// |
1335 /// V
1336 /// +------------------------------------+
1337 /// | ^remaining: |
1338 /// | <code after spirv.mlir.loop> |
1339 /// +------------------------------------+
1340 ///
1341 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1342 public:
1344 
1345  LogicalResult
1346  matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1347  ConversionPatternRewriter &rewriter) const override {
1348  // There is no support of loop control at the moment.
1349  if (loopOp.getLoopControl() != spirv::LoopControl::None)
1350  return failure();
1351 
1352  // `spirv.mlir.loop` with empty region is redundant and should be erased.
1353  if (loopOp.getBody().empty()) {
1354  rewriter.eraseOp(loopOp);
1355  return success();
1356  }
1357 
1358  Location loc = loopOp.getLoc();
1359 
1360  // Split the current block after `spirv.mlir.loop`. The remaining ops will
1361  // be used in `endBlock`.
1362  Block *currentBlock = rewriter.getBlock();
1363  auto position = Block::iterator(loopOp);
1364  Block *endBlock = rewriter.splitBlock(currentBlock, position);
1365 
1366  // Remove entry block and create a branch in the current block going to the
1367  // header block.
1368  Block *entryBlock = loopOp.getEntryBlock();
1369  assert(entryBlock->getOperations().size() == 1);
1370  auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1371  if (!brOp)
1372  return failure();
1373  Block *headerBlock = loopOp.getHeaderBlock();
1374  rewriter.setInsertionPointToEnd(currentBlock);
1375  LLVM::BrOp::create(rewriter, loc, brOp.getBlockArguments(), headerBlock);
1376  rewriter.eraseBlock(entryBlock);
1377 
1378  // Branch from merge block to end block.
1379  Block *mergeBlock = loopOp.getMergeBlock();
1380  Operation *terminator = mergeBlock->getTerminator();
1381  ValueRange terminatorOperands = terminator->getOperands();
1382  rewriter.setInsertionPointToEnd(mergeBlock);
1383  LLVM::BrOp::create(rewriter, loc, terminatorOperands, endBlock);
1384 
1385  rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1386  rewriter.replaceOp(loopOp, endBlock->getArguments());
1387  return success();
1388  }
1389 };
1390 
1391 /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1392 /// block. All blocks within selection should be reachable for conversion to
1393 /// succeed.
1394 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1395 public:
1397 
1398  LogicalResult
1399  matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1400  ConversionPatternRewriter &rewriter) const override {
1401  // There is no support for `Flatten` or `DontFlatten` selection control at
1402  // the moment. This are just compiler hints and can be performed during the
1403  // optimization passes.
1404  if (op.getSelectionControl() != spirv::SelectionControl::None)
1405  return failure();
1406 
1407  // `spirv.mlir.selection` should have at least two blocks: one selection
1408  // header block and one merge block. If no blocks are present, or control
1409  // flow branches straight to merge block (two blocks are present), the op is
1410  // redundant and it is erased.
1411  if (op.getBody().getBlocks().size() <= 2) {
1412  rewriter.eraseOp(op);
1413  return success();
1414  }
1415 
1416  Location loc = op.getLoc();
1417 
1418  // Split the current block after `spirv.mlir.selection`. The remaining ops
1419  // will be used in `continueBlock`.
1420  auto *currentBlock = rewriter.getInsertionBlock();
1421  rewriter.setInsertionPointAfter(op);
1422  auto position = rewriter.getInsertionPoint();
1423  auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1424 
1425  // Extract conditional branch information from the header block. By SPIR-V
1426  // dialect spec, it should contain `spirv.BranchConditional` or
1427  // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1428  // moment in the SPIR-V dialect. Remove this block when finished.
1429  auto *headerBlock = op.getHeaderBlock();
1430  assert(headerBlock->getOperations().size() == 1);
1431  auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1432  headerBlock->getOperations().front());
1433  if (!condBrOp)
1434  return failure();
1435 
1436  // Branch from merge block to continue block.
1437  auto *mergeBlock = op.getMergeBlock();
1438  Operation *terminator = mergeBlock->getTerminator();
1439  ValueRange terminatorOperands = terminator->getOperands();
1440  rewriter.setInsertionPointToEnd(mergeBlock);
1441  LLVM::BrOp::create(rewriter, loc, terminatorOperands, continueBlock);
1442 
1443  // Link current block to `true` and `false` blocks within the selection.
1444  Block *trueBlock = condBrOp.getTrueBlock();
1445  Block *falseBlock = condBrOp.getFalseBlock();
1446  rewriter.setInsertionPointToEnd(currentBlock);
1447  LLVM::CondBrOp::create(rewriter, loc, condBrOp.getCondition(), trueBlock,
1448  condBrOp.getTrueTargetOperands(), falseBlock,
1449  condBrOp.getFalseTargetOperands());
1450 
1451  rewriter.eraseBlock(headerBlock);
1452  rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1453  rewriter.replaceOp(op, continueBlock->getArguments());
1454  return success();
1455  }
1456 };
1457 
1458 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1459 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1460 /// `Shift` is zero or sign extended to match this specification. Cases when
1461 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1462 template <typename SPIRVOp, typename LLVMOp>
1463 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1464 public:
1466 
1467  LogicalResult
1468  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1469  ConversionPatternRewriter &rewriter) const override {
1470 
1471  auto dstType = this->getTypeConverter()->convertType(op.getType());
1472  if (!dstType)
1473  return rewriter.notifyMatchFailure(op, "type conversion failed");
1474 
1475  Type op1Type = op.getOperand1().getType();
1476  Type op2Type = op.getOperand2().getType();
1477 
1478  if (op1Type == op2Type) {
1479  rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1480  adaptor.getOperands());
1481  return success();
1482  }
1483 
1484  std::optional<uint64_t> dstTypeWidth =
1486  std::optional<uint64_t> op2TypeWidth =
1488 
1489  if (!dstTypeWidth || !op2TypeWidth)
1490  return failure();
1491 
1492  Location loc = op.getLoc();
1493  Value extended;
1494  if (op2TypeWidth < dstTypeWidth) {
1495  if (isUnsignedIntegerOrVector(op2Type)) {
1496  extended =
1497  LLVM::ZExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
1498  } else {
1499  extended =
1500  LLVM::SExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
1501  }
1502  } else if (op2TypeWidth == dstTypeWidth) {
1503  extended = adaptor.getOperand2();
1504  } else {
1505  return failure();
1506  }
1507 
1508  Value result =
1509  LLVMOp::create(rewriter, loc, dstType, adaptor.getOperand1(), extended);
1510  rewriter.replaceOp(op, result);
1511  return success();
1512  }
1513 };
1514 
1515 class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1516 public:
1518 
1519  LogicalResult
1520  matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1521  ConversionPatternRewriter &rewriter) const override {
1522  auto dstType = getTypeConverter()->convertType(tanOp.getType());
1523  if (!dstType)
1524  return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
1525 
1526  Location loc = tanOp.getLoc();
1527  Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand());
1528  Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand());
1529  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1530  return success();
1531  }
1532 };
1533 
1534 /// Convert `spirv.Tanh` to
1535 ///
1536 /// exp(2x) - 1
1537 /// -----------
1538 /// exp(2x) + 1
1539 ///
1540 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1541 public:
1543 
1544  LogicalResult
1545  matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1546  ConversionPatternRewriter &rewriter) const override {
1547  auto srcType = tanhOp.getType();
1548  auto dstType = getTypeConverter()->convertType(srcType);
1549  if (!dstType)
1550  return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
1551 
1552  Location loc = tanhOp.getLoc();
1553  Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1554  Value multiplied =
1555  LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand());
1556  Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied);
1557  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1558  Value numerator =
1559  LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one);
1560  Value denominator =
1561  LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one);
1562  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1563  denominator);
1564  return success();
1565  }
1566 };
1567 
1568 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1569 public:
1571 
1572  LogicalResult
1573  matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1574  ConversionPatternRewriter &rewriter) const override {
1575  auto srcType = varOp.getType();
1576  // Initialization is supported for scalars and vectors only.
1577  auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1578  auto init = varOp.getInitializer();
1579  if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1580  return failure();
1581 
1582  auto dstType = getTypeConverter()->convertType(srcType);
1583  if (!dstType)
1584  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1585 
1586  Location loc = varOp.getLoc();
1587  Value size = createI32ConstantOf(loc, rewriter, 1);
1588  if (!init) {
1589  auto elementType = getTypeConverter()->convertType(pointerTo);
1590  if (!elementType)
1591  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1592  rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
1593  size);
1594  return success();
1595  }
1596  auto elementType = getTypeConverter()->convertType(pointerTo);
1597  if (!elementType)
1598  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1599  Value allocated =
1600  LLVM::AllocaOp::create(rewriter, loc, dstType, elementType, size);
1601  LLVM::StoreOp::create(rewriter, loc, adaptor.getInitializer(), allocated);
1602  rewriter.replaceOp(varOp, allocated);
1603  return success();
1604  }
1605 };
1606 
1607 //===----------------------------------------------------------------------===//
1608 // BitcastOp conversion
1609 //===----------------------------------------------------------------------===//
1610 
1611 class BitcastConversionPattern
1612  : public SPIRVToLLVMConversion<spirv::BitcastOp> {
1613 public:
1615 
1616  LogicalResult
1617  matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1618  ConversionPatternRewriter &rewriter) const override {
1619  auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1620  if (!dstType)
1621  return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
1622 
1623  // LLVM's opaque pointers do not require bitcasts.
1624  if (isa<LLVM::LLVMPointerType>(dstType)) {
1625  rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1626  return success();
1627  }
1628 
1629  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1630  bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1631  return success();
1632  }
1633 };
1634 
1635 //===----------------------------------------------------------------------===//
1636 // FuncOp conversion
1637 //===----------------------------------------------------------------------===//
1638 
1639 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1640 public:
1642 
1643  LogicalResult
1644  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1645  ConversionPatternRewriter &rewriter) const override {
1646 
1647  // Convert function signature. At the moment LLVMType converter is enough
1648  // for currently supported types.
1649  auto funcType = funcOp.getFunctionType();
1650  TypeConverter::SignatureConversion signatureConverter(
1651  funcType.getNumInputs());
1652  auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter())
1653  ->convertFunctionSignature(
1654  funcType, /*isVariadic=*/false,
1655  /*useBarePtrCallConv=*/false, signatureConverter);
1656  if (!llvmType)
1657  return failure();
1658 
1659  // Create a new `LLVMFuncOp`
1660  Location loc = funcOp.getLoc();
1661  StringRef name = funcOp.getName();
1662  auto newFuncOp = LLVM::LLVMFuncOp::create(rewriter, loc, name, llvmType);
1663 
1664  // Convert SPIR-V Function Control to equivalent LLVM function attribute
1665  MLIRContext *context = funcOp.getContext();
1666  switch (funcOp.getFunctionControl()) {
1667  case spirv::FunctionControl::Inline:
1668  newFuncOp.setAlwaysInline(true);
1669  break;
1670  case spirv::FunctionControl::DontInline:
1671  newFuncOp.setNoInline(true);
1672  break;
1673 
1674 #define DISPATCH(functionControl, llvmAttr) \
1675  case functionControl: \
1676  newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1677  break;
1678 
1679  DISPATCH(spirv::FunctionControl::Pure,
1680  StringAttr::get(context, "readonly"));
1681  DISPATCH(spirv::FunctionControl::Const,
1682  StringAttr::get(context, "readnone"));
1683 
1684 #undef DISPATCH
1685 
1686  // Default: if `spirv::FunctionControl::None`, then no attributes are
1687  // needed.
1688  default:
1689  break;
1690  }
1691 
1692  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1693  newFuncOp.end());
1694  if (failed(rewriter.convertRegionTypes(
1695  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1696  return failure();
1697  }
1698  rewriter.eraseOp(funcOp);
1699  return success();
1700  }
1701 };
1702 
1703 //===----------------------------------------------------------------------===//
1704 // ModuleOp conversion
1705 //===----------------------------------------------------------------------===//
1706 
1707 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1708 public:
1710 
1711  LogicalResult
1712  matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1713  ConversionPatternRewriter &rewriter) const override {
1714 
1715  auto newModuleOp =
1716  ModuleOp::create(rewriter, spvModuleOp.getLoc(), spvModuleOp.getName());
1717  rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1718 
1719  // Remove the terminator block that was automatically added by builder
1720  rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1721  rewriter.eraseOp(spvModuleOp);
1722  return success();
1723  }
1724 };
1725 
1726 //===----------------------------------------------------------------------===//
1727 // VectorShuffleOp conversion
1728 //===----------------------------------------------------------------------===//
1729 
1730 class VectorShufflePattern
1731  : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1732 public:
1734  LogicalResult
1735  matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1736  ConversionPatternRewriter &rewriter) const override {
1737  Location loc = op.getLoc();
1738  auto components = adaptor.getComponents();
1739  auto vector1 = adaptor.getVector1();
1740  auto vector2 = adaptor.getVector2();
1741  int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1742  int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1743  if (vector1Size == vector2Size) {
1744  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1745  op, vector1, vector2,
1746  LLVM::convertArrayToIndices<int32_t>(components));
1747  return success();
1748  }
1749 
1750  auto dstType = getTypeConverter()->convertType(op.getType());
1751  if (!dstType)
1752  return rewriter.notifyMatchFailure(op, "type conversion failed");
1753  auto scalarType = cast<VectorType>(dstType).getElementType();
1754  auto componentsArray = components.getValue();
1755  auto *context = rewriter.getContext();
1756  auto llvmI32Type = IntegerType::get(context, 32);
1757  Value targetOp = LLVM::PoisonOp::create(rewriter, loc, dstType);
1758  for (unsigned i = 0; i < componentsArray.size(); i++) {
1759  if (!isa<IntegerAttr>(componentsArray[i]))
1760  return op.emitError("unable to support non-constant component");
1761 
1762  int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1763  if (indexVal == -1)
1764  continue;
1765 
1766  int offsetVal = 0;
1767  Value baseVector = vector1;
1768  if (indexVal >= vector1Size) {
1769  offsetVal = vector1Size;
1770  baseVector = vector2;
1771  }
1772 
1773  Value dstIndex = LLVM::ConstantOp::create(
1774  rewriter, loc, llvmI32Type,
1775  rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1776  Value index = LLVM::ConstantOp::create(
1777  rewriter, loc, llvmI32Type,
1778  rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1779 
1780  auto extractOp = LLVM::ExtractElementOp::create(rewriter, loc, scalarType,
1781  baseVector, index);
1782  targetOp = LLVM::InsertElementOp::create(rewriter, loc, dstType, targetOp,
1783  extractOp, dstIndex);
1784  }
1785  rewriter.replaceOp(op, targetOp);
1786  return success();
1787  }
1788 };
1789 } // namespace
1790 
1791 //===----------------------------------------------------------------------===//
1792 // Pattern population
1793 //===----------------------------------------------------------------------===//
1794 
1796  spirv::ClientAPI clientAPI) {
1797  typeConverter.addConversion([&](spirv::ArrayType type) {
1798  return convertArrayType(type, typeConverter);
1799  });
1800  typeConverter.addConversion([&, clientAPI](spirv::PointerType type) {
1801  return convertPointerType(type, typeConverter, clientAPI);
1802  });
1803  typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1804  return convertRuntimeArrayType(type, typeConverter);
1805  });
1806  typeConverter.addConversion([&](spirv::StructType type) {
1807  return convertStructType(type, typeConverter);
1808  });
1809 }
1810 
1812  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
1813  spirv::ClientAPI clientAPI) {
1814  patterns.add<
1815  // Arithmetic ops
1816  DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1817  DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1818  DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1819  DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1820  DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1821  DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1822  DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1823  DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1824  DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1825  DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1826  DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1827  DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1828  DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1829 
1830  // Bitwise ops
1831  BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1832  DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1833  DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1834  DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1835  DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1836  DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1837  NotPattern<spirv::NotOp>,
1838 
1839  // Cast ops
1840  BitcastConversionPattern,
1841  DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1842  DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1843  DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1844  DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1845  IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1846  IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1847  IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1848 
1849  // Comparison ops
1850  IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1851  IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1852  FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1853  FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1854  FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1855  FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1856  FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1857  FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1858  FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1859  FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1860  FComparePattern<spirv::FUnordGreaterThanEqualOp,
1861  LLVM::FCmpPredicate::uge>,
1862  FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1863  FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1864  FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1865  IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1866  IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1867  IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1868  IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1869  IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1870  IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1871  IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1872  IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1873 
1874  // Constant op
1875  ConstantScalarAndVectorPattern,
1876 
1877  // Control Flow ops
1878  BranchConversionPattern, BranchConditionalConversionPattern,
1879  FunctionCallPattern, LoopPattern, SelectionPattern,
1880  ErasePattern<spirv::MergeOp>,
1881 
1882  // Entry points and execution mode are handled separately.
1883  ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1884 
1885  // GLSL extended instruction set ops
1886  DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1887  DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1888  DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1889  DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1890  DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1891  DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1892  DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1893  DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1894  DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1895  DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1896  DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1897  DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1898  InverseSqrtPattern, TanPattern, TanhPattern,
1899 
1900  // Logical ops
1901  DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1902  DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1903  IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1904  IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1905  NotPattern<spirv::LogicalNotOp>,
1906 
1907  // Memory ops
1908  AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1909  LoadStorePattern<spirv::StoreOp>, VariablePattern,
1910 
1911  // Miscellaneous ops
1912  CompositeExtractPattern, CompositeInsertPattern,
1913  DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1914  DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1915  VectorShufflePattern,
1916 
1917  // Shift ops
1918  ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1919  ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1920  ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1921 
1922  // Return ops
1923  ReturnPattern, ReturnValuePattern,
1924 
1925  // Barrier ops
1926  ControlBarrierPattern<spirv::ControlBarrierOp>,
1927  ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1928  ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1929 
1930  // Group reduction operations
1931  GroupReducePattern<spirv::GroupIAddOp>,
1932  GroupReducePattern<spirv::GroupFAddOp>,
1933  GroupReducePattern<spirv::GroupFMinOp>,
1934  GroupReducePattern<spirv::GroupUMinOp>,
1935  GroupReducePattern<spirv::GroupSMinOp, /*Signed=*/true>,
1936  GroupReducePattern<spirv::GroupFMaxOp>,
1937  GroupReducePattern<spirv::GroupUMaxOp>,
1938  GroupReducePattern<spirv::GroupSMaxOp, /*Signed=*/true>,
1939  GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed=*/false,
1940  /*NonUniform=*/true>,
1941  GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed=*/false,
1942  /*NonUniform=*/true>,
1943  GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed=*/false,
1944  /*NonUniform=*/true>,
1945  GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed=*/false,
1946  /*NonUniform=*/true>,
1947  GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed=*/true,
1948  /*NonUniform=*/true>,
1949  GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed=*/false,
1950  /*NonUniform=*/true>,
1951  GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed=*/false,
1952  /*NonUniform=*/true>,
1953  GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed=*/true,
1954  /*NonUniform=*/true>,
1955  GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed=*/false,
1956  /*NonUniform=*/true>,
1957  GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed=*/false,
1958  /*NonUniform=*/true>,
1959  GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed=*/false,
1960  /*NonUniform=*/true>,
1961  GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed=*/false,
1962  /*NonUniform=*/true>,
1963  GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed=*/false,
1964  /*NonUniform=*/true>,
1965  GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed=*/false,
1966  /*NonUniform=*/true>,
1967  GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed=*/false,
1968  /*NonUniform=*/true>,
1969  GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed=*/false,
1970  /*NonUniform=*/true>>(patterns.getContext(),
1971  typeConverter);
1972 
1973  patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
1974  typeConverter);
1975 }
1976 
1978  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1979  patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1980 }
1981 
1983  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1984  patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1985 }
1986 
1987 //===----------------------------------------------------------------------===//
1988 // Pre-conversion hooks
1989 //===----------------------------------------------------------------------===//
1990 
1991 /// Hook for descriptor set and binding number encoding.
1992 static constexpr StringRef kBinding = "binding";
1993 static constexpr StringRef kDescriptorSet = "descriptor_set";
1994 void mlir::encodeBindAttribute(ModuleOp module) {
1995  auto spvModules = module.getOps<spirv::ModuleOp>();
1996  for (auto spvModule : spvModules) {
1997  spvModule.walk([&](spirv::GlobalVariableOp op) {
1998  IntegerAttr descriptorSet =
1999  op->getAttrOfType<IntegerAttr>(kDescriptorSet);
2000  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
2001  // For every global variable in the module, get the ones with descriptor
2002  // set and binding numbers.
2003  if (descriptorSet && binding) {
2004  // Encode these numbers into the variable's symbolic name. If the
2005  // SPIR-V module has a name, add it at the beginning.
2006  auto moduleAndName =
2007  spvModule.getName().has_value()
2008  ? spvModule.getName()->str() + "_" + op.getSymName().str()
2009  : op.getSymName().str();
2010  std::string name =
2011  llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
2012  std::to_string(descriptorSet.getInt()),
2013  std::to_string(binding.getInt()));
2014  auto nameAttr = StringAttr::get(op->getContext(), name);
2015 
2016  // Replace all symbol uses and set the new symbol name. Finally, remove
2017  // descriptor set and binding attributes.
2018  if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
2019  op.emitError("unable to replace all symbol uses for ") << name;
2020  SymbolTable::setSymbolName(op, nameAttr);
2021  op->removeAttr(kDescriptorSet);
2022  op->removeAttr(kBinding);
2023  }
2024  });
2025  }
2026 }
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static MLIRContext * getContext(OpFoldResult val)
@ None
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
Definition: SPIRVToLLVM.cpp:64
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
Definition: SPIRVToLLVM.cpp:35
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
Definition: SPIRVToLLVM.cpp:44
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
Definition: SPIRVToLLVM.cpp:84
static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
Definition: SPIRVToLLVM.cpp:94
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
Definition: SPIRVToLLVM.cpp:77
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
Definition: SPIRVToLLVM.cpp:54
static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Type convertStructType(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:195
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:158
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:89
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:20
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
Definition: SPIRVTypes.cpp:64
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:66
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:62
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:453
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:514
SPIR-V struct type.
Definition: SPIRVTypes.h:295
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:795
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Definition: LLVMDialect.h:227
Include the generated interface declarations.
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
const FrozenRewritePatternSet & patterns
void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.