MLIR  20.0.0git
SPIRVToLLVM.cpp
Go to the documentation of this file.
1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/FormatVariadic.h"
28 
29 #define DEBUG_TYPE "spirv-to-llvm-pattern"
30 
31 using namespace mlir;
32 
33 //===----------------------------------------------------------------------===//
34 // Utility functions
35 //===----------------------------------------------------------------------===//
36 
37 /// Returns true if the given type is a signed integer or vector type.
38 static bool isSignedIntegerOrVector(Type type) {
39  if (type.isSignedInteger())
40  return true;
41  if (auto vecType = dyn_cast<VectorType>(type))
42  return vecType.getElementType().isSignedInteger();
43  return false;
44 }
45 
46 /// Returns true if the given type is an unsigned integer or vector type
47 static bool isUnsignedIntegerOrVector(Type type) {
48  if (type.isUnsignedInteger())
49  return true;
50  if (auto vecType = dyn_cast<VectorType>(type))
51  return vecType.getElementType().isUnsignedInteger();
52  return false;
53 }
54 
55 /// Returns the width of an integer or of the element type of an integer vector,
56 /// if applicable.
57 static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
58  if (auto intType = dyn_cast<IntegerType>(type))
59  return intType.getWidth();
60  if (auto vecType = dyn_cast<VectorType>(type))
61  if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
62  return intType.getWidth();
63  return std::nullopt;
64 }
65 
66 /// Returns the bit width of integer, float or vector of float or integer values
67 static unsigned getBitWidth(Type type) {
68  assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
69  "bitwidth is not supported for this type");
70  if (type.isIntOrFloat())
71  return type.getIntOrFloatBitWidth();
72  auto vecType = dyn_cast<VectorType>(type);
73  auto elementType = vecType.getElementType();
74  assert(elementType.isIntOrFloat() &&
75  "only integers and floats have a bitwidth");
76  return elementType.getIntOrFloatBitWidth();
77 }
78 
79 /// Returns the bit width of LLVMType integer or vector.
80 static unsigned getLLVMTypeBitWidth(Type type) {
81  return cast<IntegerType>((LLVM::isCompatibleVectorType(type)
83  : type))
84  .getWidth();
85 }
86 
87 /// Creates `IntegerAttribute` with all bits set for given type
88 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
89  if (auto vecType = dyn_cast<VectorType>(type)) {
90  auto integerType = cast<IntegerType>(vecType.getElementType());
91  return builder.getIntegerAttr(integerType, -1);
92  }
93  auto integerType = cast<IntegerType>(type);
94  return builder.getIntegerAttr(integerType, -1);
95 }
96 
97 /// Creates `llvm.mlir.constant` with all bits set for the given type.
98 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
99  PatternRewriter &rewriter) {
100  if (isa<VectorType>(srcType)) {
101  return rewriter.create<LLVM::ConstantOp>(
102  loc, dstType,
103  SplatElementsAttr::get(cast<ShapedType>(srcType),
104  minusOneIntegerAttribute(srcType, rewriter)));
105  }
106  return rewriter.create<LLVM::ConstantOp>(
107  loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
108 }
109 
110 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
111 static Value createFPConstant(Location loc, Type srcType, Type dstType,
112  PatternRewriter &rewriter, double value) {
113  if (auto vecType = dyn_cast<VectorType>(srcType)) {
114  auto floatType = cast<FloatType>(vecType.getElementType());
115  return rewriter.create<LLVM::ConstantOp>(
116  loc, dstType,
117  SplatElementsAttr::get(vecType,
118  rewriter.getFloatAttr(floatType, value)));
119  }
120  auto floatType = cast<FloatType>(srcType);
121  return rewriter.create<LLVM::ConstantOp>(
122  loc, dstType, rewriter.getFloatAttr(floatType, value));
123 }
124 
125 /// Utility function for bitfield ops:
126 /// - `BitFieldInsert`
127 /// - `BitFieldSExtract`
128 /// - `BitFieldUExtract`
129 /// Truncates or extends the value. If the bitwidth of the value is the same as
130 /// `llvmType` bitwidth, the value remains unchanged.
132  Type llvmType,
133  PatternRewriter &rewriter) {
134  auto srcType = value.getType();
135  unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
136  unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
137  ? getLLVMTypeBitWidth(srcType)
138  : getBitWidth(srcType);
139 
140  if (valueBitWidth < targetBitWidth)
141  return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
142  // If the bit widths of `Count` and `Offset` are greater than the bit width
143  // of the target type, they are truncated. Truncation is safe since `Count`
144  // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
145  // both values can be expressed in 8 bits.
146  if (valueBitWidth > targetBitWidth)
147  return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
148  return value;
149 }
150 
151 /// Broadcasts the value to vector with `numElements` number of elements.
152 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
153  const TypeConverter &typeConverter,
154  ConversionPatternRewriter &rewriter) {
155  auto vectorType = VectorType::get(numElements, toBroadcast.getType());
156  auto llvmVectorType = typeConverter.convertType(vectorType);
157  auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
158  Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
159  for (unsigned i = 0; i < numElements; ++i) {
160  auto index = rewriter.create<LLVM::ConstantOp>(
161  loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
162  broadcasted = rewriter.create<LLVM::InsertElementOp>(
163  loc, llvmVectorType, broadcasted, toBroadcast, index);
164  }
165  return broadcasted;
166 }
167 
168 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
169 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
170  const TypeConverter &typeConverter,
171  ConversionPatternRewriter &rewriter) {
172  if (auto vectorType = dyn_cast<VectorType>(srcType)) {
173  unsigned numElements = vectorType.getNumElements();
174  return broadcast(loc, value, numElements, typeConverter, rewriter);
175  }
176  return value;
177 }
178 
179 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
180 /// `BitFieldUExtract`.
181 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
182 /// a vector type, construct a vector that has:
183 /// - same number of elements as `Base`
184 /// - each element has the type that is the same as the type of `Offset` or
185 /// `Count`
186 /// - each element has the same value as `Offset` or `Count`
187 /// Then cast `Offset` and `Count` if their bit width is different
188 /// from `Base` bit width.
189 static Value processCountOrOffset(Location loc, Value value, Type srcType,
190  Type dstType, const TypeConverter &converter,
191  ConversionPatternRewriter &rewriter) {
192  Value broadcasted =
193  optionallyBroadcast(loc, value, srcType, converter, rewriter);
194  return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
195 }
196 
197 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
198 /// offset to LLVM struct. Otherwise, the conversion is not supported.
200  const TypeConverter &converter) {
201  if (type != VulkanLayoutUtils::decorateType(type))
202  return nullptr;
203 
204  SmallVector<Type> elementsVector;
205  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
206  return nullptr;
207  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
208  /*isPacked=*/false);
209 }
210 
211 /// Converts SPIR-V struct with no offset to packed LLVM struct.
213  const TypeConverter &converter) {
214  SmallVector<Type> elementsVector;
215  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
216  return nullptr;
217  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
218  /*isPacked=*/true);
219 }
220 
221 /// Creates LLVM dialect constant with the given value.
223  unsigned value) {
224  return rewriter.create<LLVM::ConstantOp>(
225  loc, IntegerType::get(rewriter.getContext(), 32),
226  rewriter.getIntegerAttr(rewriter.getI32Type(), value));
227 }
228 
229 /// Utility for `spirv.Load` and `spirv.Store` conversion.
230 static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
231  ConversionPatternRewriter &rewriter,
232  const TypeConverter &typeConverter,
233  unsigned alignment, bool isVolatile,
234  bool isNonTemporal) {
235  if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
236  auto dstType = typeConverter.convertType(loadOp.getType());
237  if (!dstType)
238  return rewriter.notifyMatchFailure(op, "type conversion failed");
239  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
240  loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
241  isVolatile, isNonTemporal);
242  return success();
243  }
244  auto storeOp = cast<spirv::StoreOp>(op);
245  spirv::StoreOpAdaptor adaptor(operands);
246  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
247  adaptor.getPtr(), alignment,
248  isVolatile, isNonTemporal);
249  return success();
250 }
251 
252 //===----------------------------------------------------------------------===//
253 // Type conversion
254 //===----------------------------------------------------------------------===//
255 
256 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
257 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
258 /// when converting ops that manipulate array types.
259 static std::optional<Type> convertArrayType(spirv::ArrayType type,
260  TypeConverter &converter) {
261  unsigned stride = type.getArrayStride();
262  Type elementType = type.getElementType();
263  auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
264  if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
265  return std::nullopt;
266 
267  auto llvmElementType = converter.convertType(elementType);
268  unsigned numElements = type.getNumElements();
269  return LLVM::LLVMArrayType::get(llvmElementType, numElements);
270 }
271 
272 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
273 /// modelled at the moment.
275  const TypeConverter &converter,
276  spirv::ClientAPI clientAPI) {
277  unsigned addressSpace =
278  storageClassToAddressSpace(clientAPI, type.getStorageClass());
279  return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
280 }
281 
282 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
283 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
284 /// no modelling of array stride at the moment.
285 static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
286  TypeConverter &converter) {
287  if (type.getArrayStride() != 0)
288  return std::nullopt;
289  auto elementType = converter.convertType(type.getElementType());
290  return LLVM::LLVMArrayType::get(elementType, 0);
291 }
292 
293 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
294 /// member decorations. Also, only natural offset is supported.
296  const TypeConverter &converter) {
298  type.getMemberDecorations(memberDecorations);
299  if (!memberDecorations.empty())
300  return nullptr;
301  if (type.hasOffset())
302  return convertStructTypeWithOffset(type, converter);
303  return convertStructTypePacked(type, converter);
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // Operation conversion
308 //===----------------------------------------------------------------------===//
309 
310 namespace {
311 
312 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
313 public:
315 
316  LogicalResult
317  matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
318  ConversionPatternRewriter &rewriter) const override {
319  auto dstType =
320  getTypeConverter()->convertType(op.getComponentPtr().getType());
321  if (!dstType)
322  return rewriter.notifyMatchFailure(op, "type conversion failed");
323  // To use GEP we need to add a first 0 index to go through the pointer.
324  auto indices = llvm::to_vector<4>(adaptor.getIndices());
325  Type indexType = op.getIndices().front().getType();
326  auto llvmIndexType = getTypeConverter()->convertType(indexType);
327  if (!llvmIndexType)
328  return rewriter.notifyMatchFailure(op, "type conversion failed");
329  Value zero = rewriter.create<LLVM::ConstantOp>(
330  op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
331  indices.insert(indices.begin(), zero);
332 
333  auto elementType = getTypeConverter()->convertType(
334  cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
335  if (!elementType)
336  return rewriter.notifyMatchFailure(op, "type conversion failed");
337  rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
338  adaptor.getBasePtr(), indices);
339  return success();
340  }
341 };
342 
343 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
344 public:
346 
347  LogicalResult
348  matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
349  ConversionPatternRewriter &rewriter) const override {
350  auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
351  if (!dstType)
352  return rewriter.notifyMatchFailure(op, "type conversion failed");
353  rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
354  op.getVariable());
355  return success();
356  }
357 };
358 
359 class BitFieldInsertPattern
360  : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
361 public:
363 
364  LogicalResult
365  matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
366  ConversionPatternRewriter &rewriter) const override {
367  auto srcType = op.getType();
368  auto dstType = getTypeConverter()->convertType(srcType);
369  if (!dstType)
370  return rewriter.notifyMatchFailure(op, "type conversion failed");
371  Location loc = op.getLoc();
372 
373  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
374  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
375  *getTypeConverter(), rewriter);
376  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
377  *getTypeConverter(), rewriter);
378 
379  // Create a mask with bits set outside [Offset, Offset + Count - 1].
380  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
381  Value maskShiftedByCount =
382  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
383  Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
384  maskShiftedByCount, minusOne);
385  Value maskShiftedByCountAndOffset =
386  rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
387  Value mask = rewriter.create<LLVM::XOrOp>(
388  loc, dstType, maskShiftedByCountAndOffset, minusOne);
389 
390  // Extract unchanged bits from the `Base` that are outside of
391  // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
392  Value baseAndMask =
393  rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
394  Value insertShiftedByOffset =
395  rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
396  rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
397  insertShiftedByOffset);
398  return success();
399  }
400 };
401 
402 /// Converts SPIR-V ConstantOp with scalar or vector type.
403 class ConstantScalarAndVectorPattern
404  : public SPIRVToLLVMConversion<spirv::ConstantOp> {
405 public:
407 
408  LogicalResult
409  matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
410  ConversionPatternRewriter &rewriter) const override {
411  auto srcType = constOp.getType();
412  if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
413  return failure();
414 
415  auto dstType = getTypeConverter()->convertType(srcType);
416  if (!dstType)
417  return rewriter.notifyMatchFailure(constOp, "type conversion failed");
418 
419  // SPIR-V constant can be a signed/unsigned integer, which has to be
420  // casted to signless integer when converting to LLVM dialect. Removing the
421  // sign bit may have unexpected behaviour. However, it is better to handle
422  // it case-by-case, given that the purpose of the conversion is not to
423  // cover all possible corner cases.
424  if (isSignedIntegerOrVector(srcType) ||
425  isUnsignedIntegerOrVector(srcType)) {
426  auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
427 
428  if (isa<VectorType>(srcType)) {
429  auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
430  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
431  constOp, dstType,
432  dstElementsAttr.mapValues(
433  signlessType, [&](const APInt &value) { return value; }));
434  return success();
435  }
436  auto srcAttr = cast<IntegerAttr>(constOp.getValue());
437  auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
438  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
439  return success();
440  }
441  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
442  constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
443  return success();
444  }
445 };
446 
447 class BitFieldSExtractPattern
448  : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
449 public:
451 
452  LogicalResult
453  matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
454  ConversionPatternRewriter &rewriter) const override {
455  auto srcType = op.getType();
456  auto dstType = getTypeConverter()->convertType(srcType);
457  if (!dstType)
458  return rewriter.notifyMatchFailure(op, "type conversion failed");
459  Location loc = op.getLoc();
460 
461  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
462  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
463  *getTypeConverter(), rewriter);
464  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
465  *getTypeConverter(), rewriter);
466 
467  // Create a constant that holds the size of the `Base`.
468  IntegerType integerType;
469  if (auto vecType = dyn_cast<VectorType>(srcType))
470  integerType = cast<IntegerType>(vecType.getElementType());
471  else
472  integerType = cast<IntegerType>(srcType);
473 
474  auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
475  Value size =
476  isa<VectorType>(srcType)
477  ? rewriter.create<LLVM::ConstantOp>(
478  loc, dstType,
479  SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
480  : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
481 
482  // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
483  // at Offset + Count - 1 is the most significant bit now.
484  Value countPlusOffset =
485  rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
486  Value amountToShiftLeft =
487  rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
488  Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
489  loc, dstType, op.getBase(), amountToShiftLeft);
490 
491  // Shift the result right, filling the bits with the sign bit.
492  Value amountToShiftRight =
493  rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
494  rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
495  amountToShiftRight);
496  return success();
497  }
498 };
499 
500 class BitFieldUExtractPattern
501  : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
502 public:
504 
505  LogicalResult
506  matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
507  ConversionPatternRewriter &rewriter) const override {
508  auto srcType = op.getType();
509  auto dstType = getTypeConverter()->convertType(srcType);
510  if (!dstType)
511  return rewriter.notifyMatchFailure(op, "type conversion failed");
512  Location loc = op.getLoc();
513 
514  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
515  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
516  *getTypeConverter(), rewriter);
517  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
518  *getTypeConverter(), rewriter);
519 
520  // Create a mask with bits set at [0, Count - 1].
521  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
522  Value maskShiftedByCount =
523  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
524  Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
525  minusOne);
526 
527  // Shift `Base` by `Offset` and apply the mask on it.
528  Value shiftedBase =
529  rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
530  rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
531  return success();
532  }
533 };
534 
535 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
536 public:
538 
539  LogicalResult
540  matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
541  ConversionPatternRewriter &rewriter) const override {
542  rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
543  branchOp.getTarget());
544  return success();
545  }
546 };
547 
548 class BranchConditionalConversionPattern
549  : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
550 public:
551  using SPIRVToLLVMConversion<
552  spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
553 
554  LogicalResult
555  matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
556  ConversionPatternRewriter &rewriter) const override {
557  // If branch weights exist, map them to 32-bit integer vector.
558  DenseI32ArrayAttr branchWeights = nullptr;
559  if (auto weights = op.getBranchWeights()) {
560  SmallVector<int32_t> weightValues;
561  for (auto weight : weights->getAsRange<IntegerAttr>())
562  weightValues.push_back(weight.getInt());
563  branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
564  }
565 
566  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
567  op, op.getCondition(), op.getTrueBlockArguments(),
568  op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
569  op.getFalseBlock());
570  return success();
571  }
572 };
573 
574 /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
575 /// type is an aggregate type (struct or array). Otherwise, converts to
576 /// `llvm.extractelement` that operates on vectors.
577 class CompositeExtractPattern
578  : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
579 public:
581 
582  LogicalResult
583  matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
584  ConversionPatternRewriter &rewriter) const override {
585  auto dstType = this->getTypeConverter()->convertType(op.getType());
586  if (!dstType)
587  return rewriter.notifyMatchFailure(op, "type conversion failed");
588 
589  Type containerType = op.getComposite().getType();
590  if (isa<VectorType>(containerType)) {
591  Location loc = op.getLoc();
592  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
593  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
594  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
595  op, dstType, adaptor.getComposite(), index);
596  return success();
597  }
598 
599  rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
600  op, adaptor.getComposite(),
601  LLVM::convertArrayToIndices(op.getIndices()));
602  return success();
603  }
604 };
605 
606 /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
607 /// type is an aggregate type (struct or array). Otherwise, converts to
608 /// `llvm.insertelement` that operates on vectors.
609 class CompositeInsertPattern
610  : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
611 public:
613 
614  LogicalResult
615  matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
616  ConversionPatternRewriter &rewriter) const override {
617  auto dstType = this->getTypeConverter()->convertType(op.getType());
618  if (!dstType)
619  return rewriter.notifyMatchFailure(op, "type conversion failed");
620 
621  Type containerType = op.getComposite().getType();
622  if (isa<VectorType>(containerType)) {
623  Location loc = op.getLoc();
624  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
625  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
626  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
627  op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
628  return success();
629  }
630 
631  rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
632  op, adaptor.getComposite(), adaptor.getObject(),
633  LLVM::convertArrayToIndices(op.getIndices()));
634  return success();
635  }
636 };
637 
638 /// Converts SPIR-V operations that have straightforward LLVM equivalent
639 /// into LLVM dialect operations.
640 template <typename SPIRVOp, typename LLVMOp>
641 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
642 public:
644 
645  LogicalResult
646  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
647  ConversionPatternRewriter &rewriter) const override {
648  auto dstType = this->getTypeConverter()->convertType(op.getType());
649  if (!dstType)
650  return rewriter.notifyMatchFailure(op, "type conversion failed");
651  rewriter.template replaceOpWithNewOp<LLVMOp>(
652  op, dstType, adaptor.getOperands(), op->getAttrs());
653  return success();
654  }
655 };
656 
657 /// Converts `spirv.ExecutionMode` into a global struct constant that holds
658 /// execution mode information.
659 class ExecutionModePattern
660  : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
661 public:
663 
664  LogicalResult
665  matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
666  ConversionPatternRewriter &rewriter) const override {
667  // First, create the global struct's name that would be associated with
668  // this entry point's execution mode. We set it to be:
669  // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
670  ModuleOp module = op->getParentOfType<ModuleOp>();
671  spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
672  std::string moduleName;
673  if (module.getName().has_value())
674  moduleName = "_" + module.getName()->str();
675  else
676  moduleName = "";
677  std::string executionModeInfoName = llvm::formatv(
678  "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
679  static_cast<uint32_t>(executionModeAttr.getValue()));
680 
681  MLIRContext *context = rewriter.getContext();
682  OpBuilder::InsertionGuard guard(rewriter);
683  rewriter.setInsertionPointToStart(module.getBody());
684 
685  // Create a struct type, corresponding to the C struct below.
686  // struct {
687  // int32_t executionMode;
688  // int32_t values[]; // optional values
689  // };
690  auto llvmI32Type = IntegerType::get(context, 32);
691  SmallVector<Type, 2> fields;
692  fields.push_back(llvmI32Type);
693  ArrayAttr values = op.getValues();
694  if (!values.empty()) {
695  auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
696  fields.push_back(arrayType);
697  }
698  auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
699 
700  // Create `llvm.mlir.global` with initializer region containing one block.
701  auto global = rewriter.create<LLVM::GlobalOp>(
702  UnknownLoc::get(context), structType, /*isConstant=*/true,
703  LLVM::Linkage::External, executionModeInfoName, Attribute(),
704  /*alignment=*/0);
705  Location loc = global.getLoc();
706  Region &region = global.getInitializerRegion();
707  Block *block = rewriter.createBlock(&region);
708 
709  // Initialize the struct and set the execution mode value.
710  rewriter.setInsertionPointToStart(block);
711  Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
712  Value executionMode = rewriter.create<LLVM::ConstantOp>(
713  loc, llvmI32Type,
714  rewriter.getI32IntegerAttr(
715  static_cast<uint32_t>(executionModeAttr.getValue())));
716  structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
717  executionMode, 0);
718 
719  // Insert extra operands if they exist into execution mode info struct.
720  for (unsigned i = 0, e = values.size(); i < e; ++i) {
721  auto attr = values.getValue()[i];
722  Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
723  structValue = rewriter.create<LLVM::InsertValueOp>(
724  loc, structValue, entry, ArrayRef<int64_t>({1, i}));
725  }
726  rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
727  rewriter.eraseOp(op);
728  return success();
729  }
730 };
731 
732 /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
733 /// global returns a pointer, whereas in LLVM dialect the global holds an actual
734 /// value. This difference is handled by `spirv.mlir.addressof` and
735 /// `llvm.mlir.addressof`ops that both return a pointer.
736 class GlobalVariablePattern
737  : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
738 public:
739  template <typename... Args>
740  GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
741  : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
742  std::forward<Args>(args)...),
743  clientAPI(clientAPI) {}
744 
745  LogicalResult
746  matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
747  ConversionPatternRewriter &rewriter) const override {
748  // Currently, there is no support of initialization with a constant value in
749  // SPIR-V dialect. Specialization constants are not considered as well.
750  if (op.getInitializer())
751  return failure();
752 
753  auto srcType = cast<spirv::PointerType>(op.getType());
754  auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
755  if (!dstType)
756  return rewriter.notifyMatchFailure(op, "type conversion failed");
757 
758  // Limit conversion to the current invocation only or `StorageBuffer`
759  // required by SPIR-V runner.
760  // This is okay because multiple invocations are not supported yet.
761  auto storageClass = srcType.getStorageClass();
762  switch (storageClass) {
763  case spirv::StorageClass::Input:
764  case spirv::StorageClass::Private:
765  case spirv::StorageClass::Output:
766  case spirv::StorageClass::StorageBuffer:
767  case spirv::StorageClass::UniformConstant:
768  break;
769  default:
770  return failure();
771  }
772 
773  // LLVM dialect spec: "If the global value is a constant, storing into it is
774  // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
775  // storage class that is read-only.
776  bool isConstant = (storageClass == spirv::StorageClass::Input) ||
777  (storageClass == spirv::StorageClass::UniformConstant);
778  // SPIR-V spec: "By default, functions and global variables are private to a
779  // module and cannot be accessed by other modules. However, a module may be
780  // written to export or import functions and global (module scope)
781  // variables.". Therefore, map 'Private' storage class to private linkage,
782  // 'Input' and 'Output' to external linkage.
783  auto linkage = storageClass == spirv::StorageClass::Private
784  ? LLVM::Linkage::Private
785  : LLVM::Linkage::External;
786  auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
787  op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
788  /*alignment=*/0, storageClassToAddressSpace(clientAPI, storageClass));
789 
790  // Attach location attribute if applicable
791  if (op.getLocationAttr())
792  newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
793 
794  return success();
795  }
796 
797 private:
798  spirv::ClientAPI clientAPI;
799 };
800 
801 /// Converts SPIR-V cast ops that do not have straightforward LLVM
802 /// equivalent in LLVM dialect.
803 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
804 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
805 public:
807 
808  LogicalResult
809  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
810  ConversionPatternRewriter &rewriter) const override {
811 
812  Type fromType = op.getOperand().getType();
813  Type toType = op.getType();
814 
815  auto dstType = this->getTypeConverter()->convertType(toType);
816  if (!dstType)
817  return rewriter.notifyMatchFailure(op, "type conversion failed");
818 
819  if (getBitWidth(fromType) < getBitWidth(toType)) {
820  rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
821  adaptor.getOperands());
822  return success();
823  }
824  if (getBitWidth(fromType) > getBitWidth(toType)) {
825  rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
826  adaptor.getOperands());
827  return success();
828  }
829  return failure();
830  }
831 };
832 
833 class FunctionCallPattern
834  : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
835 public:
837 
838  LogicalResult
839  matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
840  ConversionPatternRewriter &rewriter) const override {
841  if (callOp.getNumResults() == 0) {
842  auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
843  callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
844  newOp.getProperties().operandSegmentSizes = {
845  static_cast<int32_t>(adaptor.getOperands().size()), 0};
846  newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
847  return success();
848  }
849 
850  // Function returns a single result.
851  auto dstType = getTypeConverter()->convertType(callOp.getType(0));
852  if (!dstType)
853  return rewriter.notifyMatchFailure(callOp, "type conversion failed");
854  auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
855  callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
856  newOp.getProperties().operandSegmentSizes = {
857  static_cast<int32_t>(adaptor.getOperands().size()), 0};
858  newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
859  return success();
860  }
861 };
862 
863 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
864 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
865 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
866 public:
868 
869  LogicalResult
870  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
871  ConversionPatternRewriter &rewriter) const override {
872 
873  auto dstType = this->getTypeConverter()->convertType(op.getType());
874  if (!dstType)
875  return rewriter.notifyMatchFailure(op, "type conversion failed");
876 
877  rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
878  op, dstType, predicate, op.getOperand1(), op.getOperand2());
879  return success();
880  }
881 };
882 
883 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
884 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
885 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
886 public:
888 
889  LogicalResult
890  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
891  ConversionPatternRewriter &rewriter) const override {
892 
893  auto dstType = this->getTypeConverter()->convertType(op.getType());
894  if (!dstType)
895  return rewriter.notifyMatchFailure(op, "type conversion failed");
896 
897  rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
898  op, dstType, predicate, op.getOperand1(), op.getOperand2());
899  return success();
900  }
901 };
902 
903 class InverseSqrtPattern
904  : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
905 public:
907 
908  LogicalResult
909  matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
910  ConversionPatternRewriter &rewriter) const override {
911  auto srcType = op.getType();
912  auto dstType = getTypeConverter()->convertType(srcType);
913  if (!dstType)
914  return rewriter.notifyMatchFailure(op, "type conversion failed");
915 
916  Location loc = op.getLoc();
917  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
918  Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
919  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
920  return success();
921  }
922 };
923 
924 /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
925 template <typename SPIRVOp>
926 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
927 public:
929 
930  LogicalResult
931  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
932  ConversionPatternRewriter &rewriter) const override {
933  if (!op.getMemoryAccess()) {
934  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
935  *this->getTypeConverter(), /*alignment=*/0,
936  /*isVolatile=*/false,
937  /*isNonTemporal=*/false);
938  }
939  auto memoryAccess = *op.getMemoryAccess();
940  switch (memoryAccess) {
941  case spirv::MemoryAccess::Aligned:
943  case spirv::MemoryAccess::Nontemporal:
944  case spirv::MemoryAccess::Volatile: {
945  unsigned alignment =
946  memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
947  bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
948  bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
949  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
950  *this->getTypeConverter(), alignment,
951  isVolatile, isNonTemporal);
952  }
953  default:
954  // There is no support of other memory access attributes.
955  return failure();
956  }
957  }
958 };
959 
960 /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
961 template <typename SPIRVOp>
962 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
963 public:
965 
966  LogicalResult
967  matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
968  ConversionPatternRewriter &rewriter) const override {
969  auto srcType = notOp.getType();
970  auto dstType = this->getTypeConverter()->convertType(srcType);
971  if (!dstType)
972  return rewriter.notifyMatchFailure(notOp, "type conversion failed");
973 
974  Location loc = notOp.getLoc();
975  IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
976  auto mask =
977  isa<VectorType>(srcType)
978  ? rewriter.create<LLVM::ConstantOp>(
979  loc, dstType,
980  SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
981  : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
982  rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
983  notOp.getOperand(), mask);
984  return success();
985  }
986 };
987 
988 /// A template pattern that erases the given `SPIRVOp`.
989 template <typename SPIRVOp>
990 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
991 public:
993 
994  LogicalResult
995  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
996  ConversionPatternRewriter &rewriter) const override {
997  rewriter.eraseOp(op);
998  return success();
999  }
1000 };
1001 
1002 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
1003 public:
1005 
1006  LogicalResult
1007  matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1008  ConversionPatternRewriter &rewriter) const override {
1009  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
1010  ArrayRef<Value>());
1011  return success();
1012  }
1013 };
1014 
1015 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1016 public:
1018 
1019  LogicalResult
1020  matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1021  ConversionPatternRewriter &rewriter) const override {
1022  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
1023  adaptor.getOperands());
1024  return success();
1025  }
1026 };
1027 
1028 static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
1029  StringRef name,
1030  ArrayRef<Type> paramTypes,
1031  Type resultType,
1032  bool convergent = true) {
1033  auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1034  SymbolTable::lookupSymbolIn(symbolTable, name));
1035  if (func)
1036  return func;
1037 
1038  OpBuilder b(symbolTable->getRegion(0));
1039  func = b.create<LLVM::LLVMFuncOp>(
1040  symbolTable->getLoc(), name,
1041  LLVM::LLVMFunctionType::get(resultType, paramTypes));
1042  func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1043  func.setConvergent(convergent);
1044  func.setNoUnwind(true);
1045  func.setWillReturn(true);
1046  return func;
1047 }
1048 
1049 static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
1050  LLVM::LLVMFuncOp func,
1051  ValueRange args) {
1052  auto call = builder.create<LLVM::CallOp>(loc, func, args);
1053  call.setCConv(func.getCConv());
1054  call.setConvergentAttr(func.getConvergentAttr());
1055  call.setNoUnwindAttr(func.getNoUnwindAttr());
1056  call.setWillReturnAttr(func.getWillReturnAttr());
1057  return call;
1058 }
1059 
1060 template <typename BarrierOpTy>
1061 class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> {
1062 public:
1063  using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor;
1064 
1066 
1067  static constexpr StringRef getFuncName();
1068 
1069  LogicalResult
1070  matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1071  ConversionPatternRewriter &rewriter) const override {
1072  constexpr StringRef funcName = getFuncName();
1073  Operation *symbolTable =
1074  controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1075 
1076  Type i32 = rewriter.getI32Type();
1077 
1078  Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
1079  LLVM::LLVMFuncOp func =
1080  lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
1081 
1082  Location loc = controlBarrierOp->getLoc();
1083  Value execution = rewriter.create<LLVM::ConstantOp>(
1084  loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
1085  Value memory = rewriter.create<LLVM::ConstantOp>(
1086  loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
1087  Value semantics = rewriter.create<LLVM::ConstantOp>(
1088  loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
1089 
1090  auto call = createSPIRVBuiltinCall(loc, rewriter, func,
1091  {execution, memory, semantics});
1092 
1093  rewriter.replaceOp(controlBarrierOp, call);
1094  return success();
1095  }
1096 };
1097 
1098 namespace {
1099 
1100 StringRef getTypeMangling(Type type, bool isSigned) {
1102  .Case<Float16Type>([](auto) { return "Dh"; })
1103  .Case<Float32Type>([](auto) { return "f"; })
1104  .Case<Float64Type>([](auto) { return "d"; })
1105  .Case<IntegerType>([isSigned](IntegerType intTy) {
1106  switch (intTy.getWidth()) {
1107  case 1:
1108  return "b";
1109  case 8:
1110  return (isSigned) ? "a" : "c";
1111  case 16:
1112  return (isSigned) ? "s" : "t";
1113  case 32:
1114  return (isSigned) ? "i" : "j";
1115  case 64:
1116  return (isSigned) ? "l" : "m";
1117  default:
1118  llvm_unreachable("Unsupported integer width");
1119  }
1120  })
1121  .Default([](auto) {
1122  llvm_unreachable("No mangling defined");
1123  return "";
1124  });
1125 }
1126 
1127 template <typename ReduceOp>
1128 constexpr StringLiteral getGroupFuncName();
1129 
1130 template <>
1131 constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1132  return "_Z17__spirv_GroupIAddii";
1133 }
1134 template <>
1135 constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1136  return "_Z17__spirv_GroupFAddii";
1137 }
1138 template <>
1139 constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1140  return "_Z17__spirv_GroupSMinii";
1141 }
1142 template <>
1143 constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1144  return "_Z17__spirv_GroupUMinii";
1145 }
1146 template <>
1147 constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1148  return "_Z17__spirv_GroupFMinii";
1149 }
1150 template <>
1151 constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1152  return "_Z17__spirv_GroupSMaxii";
1153 }
1154 template <>
1155 constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1156  return "_Z17__spirv_GroupUMaxii";
1157 }
1158 template <>
1159 constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1160  return "_Z17__spirv_GroupFMaxii";
1161 }
1162 template <>
1163 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1164  return "_Z27__spirv_GroupNonUniformIAddii";
1165 }
1166 template <>
1167 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1168  return "_Z27__spirv_GroupNonUniformFAddii";
1169 }
1170 template <>
1171 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1172  return "_Z27__spirv_GroupNonUniformIMulii";
1173 }
1174 template <>
1175 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1176  return "_Z27__spirv_GroupNonUniformFMulii";
1177 }
1178 template <>
1179 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1180  return "_Z27__spirv_GroupNonUniformSMinii";
1181 }
1182 template <>
1183 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1184  return "_Z27__spirv_GroupNonUniformUMinii";
1185 }
1186 template <>
1187 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1188  return "_Z27__spirv_GroupNonUniformFMinii";
1189 }
1190 template <>
1191 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1192  return "_Z27__spirv_GroupNonUniformSMaxii";
1193 }
1194 template <>
1195 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1196  return "_Z27__spirv_GroupNonUniformUMaxii";
1197 }
1198 template <>
1199 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1200  return "_Z27__spirv_GroupNonUniformFMaxii";
1201 }
1202 template <>
1203 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1204  return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1205 }
1206 template <>
1207 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1208  return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1209 }
1210 template <>
1211 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1212  return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1213 }
1214 template <>
1215 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1216  return "_Z33__spirv_GroupNonUniformLogicalAndii";
1217 }
1218 template <>
1219 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1220  return "_Z32__spirv_GroupNonUniformLogicalOrii";
1221 }
1222 template <>
1223 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1224  return "_Z33__spirv_GroupNonUniformLogicalXorii";
1225 }
1226 } // namespace
1227 
1228 template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
1229 class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
1230 public:
1232 
1233  LogicalResult
1234  matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
1235  ConversionPatternRewriter &rewriter) const override {
1236 
1237  Type retTy = op.getResult().getType();
1238  if (!retTy.isIntOrFloat()) {
1239  return failure();
1240  }
1241  SmallString<36> funcName = getGroupFuncName<ReduceOp>();
1242  funcName += getTypeMangling(retTy, false);
1243 
1244  Type i32Ty = rewriter.getI32Type();
1245  SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
1246  if constexpr (NonUniform) {
1247  if (adaptor.getClusterSize()) {
1248  funcName += "j";
1249  paramTypes.push_back(i32Ty);
1250  }
1251  }
1252 
1253  Operation *symbolTable =
1254  op->template getParentWithTrait<OpTrait::SymbolTable>();
1255 
1256  LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
1257  symbolTable, funcName, paramTypes, retTy, !NonUniform);
1258 
1259  Location loc = op.getLoc();
1260  Value scope = rewriter.create<LLVM::ConstantOp>(
1261  loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope()));
1262  Value groupOp = rewriter.create<LLVM::ConstantOp>(
1263  loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation()));
1264  SmallVector<Value> operands{scope, groupOp};
1265  operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1266 
1267  auto call = createSPIRVBuiltinCall(loc, rewriter, func, operands);
1268  rewriter.replaceOp(op, call);
1269  return success();
1270  }
1271 };
1272 
1273 template <>
1274 constexpr StringRef
1275 ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1276  return "_Z22__spirv_ControlBarrieriii";
1277 }
1278 
1279 template <>
1280 constexpr StringRef
1281 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1282  return "_Z33__spirv_ControlBarrierArriveINTELiii";
1283 }
1284 
1285 template <>
1286 constexpr StringRef
1287 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1288  return "_Z31__spirv_ControlBarrierWaitINTELiii";
1289 }
1290 
1291 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1292 /// should be reachable for conversion to succeed. The structure of the loop in
1293 /// LLVM dialect will be the following:
1294 ///
1295 /// +------------------------------------+
1296 /// | <code before spirv.mlir.loop> |
1297 /// | llvm.br ^header |
1298 /// +------------------------------------+
1299 /// |
1300 /// +----------------+ |
1301 /// | | |
1302 /// | V V
1303 /// | +------------------------------------+
1304 /// | | ^header: |
1305 /// | | <header code> |
1306 /// | | llvm.cond_br %cond, ^body, ^exit |
1307 /// | +------------------------------------+
1308 /// | |
1309 /// | |----------------------+
1310 /// | | |
1311 /// | V |
1312 /// | +------------------------------------+ |
1313 /// | | ^body: | |
1314 /// | | <body code> | |
1315 /// | | llvm.br ^continue | |
1316 /// | +------------------------------------+ |
1317 /// | | |
1318 /// | V |
1319 /// | +------------------------------------+ |
1320 /// | | ^continue: | |
1321 /// | | <continue code> | |
1322 /// | | llvm.br ^header | |
1323 /// | +------------------------------------+ |
1324 /// | | |
1325 /// +---------------+ +----------------------+
1326 /// |
1327 /// V
1328 /// +------------------------------------+
1329 /// | ^exit: |
1330 /// | llvm.br ^remaining |
1331 /// +------------------------------------+
1332 /// |
1333 /// V
1334 /// +------------------------------------+
1335 /// | ^remaining: |
1336 /// | <code after spirv.mlir.loop> |
1337 /// +------------------------------------+
1338 ///
1339 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1340 public:
1342 
1343  LogicalResult
1344  matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1345  ConversionPatternRewriter &rewriter) const override {
1346  // There is no support of loop control at the moment.
1347  if (loopOp.getLoopControl() != spirv::LoopControl::None)
1348  return failure();
1349 
1350  // `spirv.mlir.loop` with empty region is redundant and should be erased.
1351  if (loopOp.getBody().empty()) {
1352  rewriter.eraseOp(loopOp);
1353  return success();
1354  }
1355 
1356  Location loc = loopOp.getLoc();
1357 
1358  // Split the current block after `spirv.mlir.loop`. The remaining ops will
1359  // be used in `endBlock`.
1360  Block *currentBlock = rewriter.getBlock();
1361  auto position = Block::iterator(loopOp);
1362  Block *endBlock = rewriter.splitBlock(currentBlock, position);
1363 
1364  // Remove entry block and create a branch in the current block going to the
1365  // header block.
1366  Block *entryBlock = loopOp.getEntryBlock();
1367  assert(entryBlock->getOperations().size() == 1);
1368  auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1369  if (!brOp)
1370  return failure();
1371  Block *headerBlock = loopOp.getHeaderBlock();
1372  rewriter.setInsertionPointToEnd(currentBlock);
1373  rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1374  rewriter.eraseBlock(entryBlock);
1375 
1376  // Branch from merge block to end block.
1377  Block *mergeBlock = loopOp.getMergeBlock();
1378  Operation *terminator = mergeBlock->getTerminator();
1379  ValueRange terminatorOperands = terminator->getOperands();
1380  rewriter.setInsertionPointToEnd(mergeBlock);
1381  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1382 
1383  rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1384  rewriter.replaceOp(loopOp, endBlock->getArguments());
1385  return success();
1386  }
1387 };
1388 
1389 /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1390 /// block. All blocks within selection should be reachable for conversion to
1391 /// succeed.
1392 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1393 public:
1395 
1396  LogicalResult
1397  matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1398  ConversionPatternRewriter &rewriter) const override {
1399  // There is no support for `Flatten` or `DontFlatten` selection control at
1400  // the moment. This are just compiler hints and can be performed during the
1401  // optimization passes.
1402  if (op.getSelectionControl() != spirv::SelectionControl::None)
1403  return failure();
1404 
1405  // `spirv.mlir.selection` should have at least two blocks: one selection
1406  // header block and one merge block. If no blocks are present, or control
1407  // flow branches straight to merge block (two blocks are present), the op is
1408  // redundant and it is erased.
1409  if (op.getBody().getBlocks().size() <= 2) {
1410  rewriter.eraseOp(op);
1411  return success();
1412  }
1413 
1414  Location loc = op.getLoc();
1415 
1416  // Split the current block after `spirv.mlir.selection`. The remaining ops
1417  // will be used in `continueBlock`.
1418  auto *currentBlock = rewriter.getInsertionBlock();
1419  rewriter.setInsertionPointAfter(op);
1420  auto position = rewriter.getInsertionPoint();
1421  auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1422 
1423  // Extract conditional branch information from the header block. By SPIR-V
1424  // dialect spec, it should contain `spirv.BranchConditional` or
1425  // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1426  // moment in the SPIR-V dialect. Remove this block when finished.
1427  auto *headerBlock = op.getHeaderBlock();
1428  assert(headerBlock->getOperations().size() == 1);
1429  auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1430  headerBlock->getOperations().front());
1431  if (!condBrOp)
1432  return failure();
1433  rewriter.eraseBlock(headerBlock);
1434 
1435  // Branch from merge block to continue block.
1436  auto *mergeBlock = op.getMergeBlock();
1437  Operation *terminator = mergeBlock->getTerminator();
1438  ValueRange terminatorOperands = terminator->getOperands();
1439  rewriter.setInsertionPointToEnd(mergeBlock);
1440  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1441 
1442  // Link current block to `true` and `false` blocks within the selection.
1443  Block *trueBlock = condBrOp.getTrueBlock();
1444  Block *falseBlock = condBrOp.getFalseBlock();
1445  rewriter.setInsertionPointToEnd(currentBlock);
1446  rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1447  condBrOp.getTrueTargetOperands(),
1448  falseBlock,
1449  condBrOp.getFalseTargetOperands());
1450 
1451  rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1452  rewriter.replaceOp(op, continueBlock->getArguments());
1453  return success();
1454  }
1455 };
1456 
1457 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1458 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1459 /// `Shift` is zero or sign extended to match this specification. Cases when
1460 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1461 template <typename SPIRVOp, typename LLVMOp>
1462 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1463 public:
1465 
1466  LogicalResult
1467  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1468  ConversionPatternRewriter &rewriter) const override {
1469 
1470  auto dstType = this->getTypeConverter()->convertType(op.getType());
1471  if (!dstType)
1472  return rewriter.notifyMatchFailure(op, "type conversion failed");
1473 
1474  Type op1Type = op.getOperand1().getType();
1475  Type op2Type = op.getOperand2().getType();
1476 
1477  if (op1Type == op2Type) {
1478  rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1479  adaptor.getOperands());
1480  return success();
1481  }
1482 
1483  std::optional<uint64_t> dstTypeWidth =
1485  std::optional<uint64_t> op2TypeWidth =
1487 
1488  if (!dstTypeWidth || !op2TypeWidth)
1489  return failure();
1490 
1491  Location loc = op.getLoc();
1492  Value extended;
1493  if (op2TypeWidth < dstTypeWidth) {
1494  if (isUnsignedIntegerOrVector(op2Type)) {
1495  extended = rewriter.template create<LLVM::ZExtOp>(
1496  loc, dstType, adaptor.getOperand2());
1497  } else {
1498  extended = rewriter.template create<LLVM::SExtOp>(
1499  loc, dstType, adaptor.getOperand2());
1500  }
1501  } else if (op2TypeWidth == dstTypeWidth) {
1502  extended = adaptor.getOperand2();
1503  } else {
1504  return failure();
1505  }
1506 
1507  Value result = rewriter.template create<LLVMOp>(
1508  loc, dstType, adaptor.getOperand1(), extended);
1509  rewriter.replaceOp(op, result);
1510  return success();
1511  }
1512 };
1513 
1514 class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1515 public:
1517 
1518  LogicalResult
1519  matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1520  ConversionPatternRewriter &rewriter) const override {
1521  auto dstType = getTypeConverter()->convertType(tanOp.getType());
1522  if (!dstType)
1523  return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
1524 
1525  Location loc = tanOp.getLoc();
1526  Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1527  Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1528  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1529  return success();
1530  }
1531 };
1532 
1533 /// Convert `spirv.Tanh` to
1534 ///
1535 /// exp(2x) - 1
1536 /// -----------
1537 /// exp(2x) + 1
1538 ///
1539 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1540 public:
1542 
1543  LogicalResult
1544  matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1545  ConversionPatternRewriter &rewriter) const override {
1546  auto srcType = tanhOp.getType();
1547  auto dstType = getTypeConverter()->convertType(srcType);
1548  if (!dstType)
1549  return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
1550 
1551  Location loc = tanhOp.getLoc();
1552  Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1553  Value multiplied =
1554  rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1555  Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1556  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1557  Value numerator =
1558  rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1559  Value denominator =
1560  rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1561  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1562  denominator);
1563  return success();
1564  }
1565 };
1566 
1567 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1568 public:
1570 
1571  LogicalResult
1572  matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1573  ConversionPatternRewriter &rewriter) const override {
1574  auto srcType = varOp.getType();
1575  // Initialization is supported for scalars and vectors only.
1576  auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1577  auto init = varOp.getInitializer();
1578  if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1579  return failure();
1580 
1581  auto dstType = getTypeConverter()->convertType(srcType);
1582  if (!dstType)
1583  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1584 
1585  Location loc = varOp.getLoc();
1586  Value size = createI32ConstantOf(loc, rewriter, 1);
1587  if (!init) {
1588  auto elementType = getTypeConverter()->convertType(pointerTo);
1589  if (!elementType)
1590  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1591  rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
1592  size);
1593  return success();
1594  }
1595  auto elementType = getTypeConverter()->convertType(pointerTo);
1596  if (!elementType)
1597  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1598  Value allocated =
1599  rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
1600  rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1601  rewriter.replaceOp(varOp, allocated);
1602  return success();
1603  }
1604 };
1605 
1606 //===----------------------------------------------------------------------===//
1607 // BitcastOp conversion
1608 //===----------------------------------------------------------------------===//
1609 
1610 class BitcastConversionPattern
1611  : public SPIRVToLLVMConversion<spirv::BitcastOp> {
1612 public:
1614 
1615  LogicalResult
1616  matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1617  ConversionPatternRewriter &rewriter) const override {
1618  auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1619  if (!dstType)
1620  return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
1621 
1622  // LLVM's opaque pointers do not require bitcasts.
1623  if (isa<LLVM::LLVMPointerType>(dstType)) {
1624  rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1625  return success();
1626  }
1627 
1628  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1629  bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1630  return success();
1631  }
1632 };
1633 
1634 //===----------------------------------------------------------------------===//
1635 // FuncOp conversion
1636 //===----------------------------------------------------------------------===//
1637 
1638 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1639 public:
1641 
1642  LogicalResult
1643  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1644  ConversionPatternRewriter &rewriter) const override {
1645 
1646  // Convert function signature. At the moment LLVMType converter is enough
1647  // for currently supported types.
1648  auto funcType = funcOp.getFunctionType();
1649  TypeConverter::SignatureConversion signatureConverter(
1650  funcType.getNumInputs());
1651  auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter())
1652  ->convertFunctionSignature(
1653  funcType, /*isVariadic=*/false,
1654  /*useBarePtrCallConv=*/false, signatureConverter);
1655  if (!llvmType)
1656  return failure();
1657 
1658  // Create a new `LLVMFuncOp`
1659  Location loc = funcOp.getLoc();
1660  StringRef name = funcOp.getName();
1661  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1662 
1663  // Convert SPIR-V Function Control to equivalent LLVM function attribute
1664  MLIRContext *context = funcOp.getContext();
1665  switch (funcOp.getFunctionControl()) {
1666  case spirv::FunctionControl::Inline:
1667  newFuncOp.setAlwaysInline(true);
1668  break;
1669  case spirv::FunctionControl::DontInline:
1670  newFuncOp.setNoInline(true);
1671  break;
1672 
1673 #define DISPATCH(functionControl, llvmAttr) \
1674  case functionControl: \
1675  newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1676  break;
1677 
1678  DISPATCH(spirv::FunctionControl::Pure,
1679  StringAttr::get(context, "readonly"));
1680  DISPATCH(spirv::FunctionControl::Const,
1681  StringAttr::get(context, "readnone"));
1682 
1683 #undef DISPATCH
1684 
1685  // Default: if `spirv::FunctionControl::None`, then no attributes are
1686  // needed.
1687  default:
1688  break;
1689  }
1690 
1691  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1692  newFuncOp.end());
1693  if (failed(rewriter.convertRegionTypes(
1694  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1695  return failure();
1696  }
1697  rewriter.eraseOp(funcOp);
1698  return success();
1699  }
1700 };
1701 
1702 //===----------------------------------------------------------------------===//
1703 // ModuleOp conversion
1704 //===----------------------------------------------------------------------===//
1705 
1706 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1707 public:
1709 
1710  LogicalResult
1711  matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1712  ConversionPatternRewriter &rewriter) const override {
1713 
1714  auto newModuleOp =
1715  rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1716  rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1717 
1718  // Remove the terminator block that was automatically added by builder
1719  rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1720  rewriter.eraseOp(spvModuleOp);
1721  return success();
1722  }
1723 };
1724 
1725 //===----------------------------------------------------------------------===//
1726 // VectorShuffleOp conversion
1727 //===----------------------------------------------------------------------===//
1728 
1729 class VectorShufflePattern
1730  : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1731 public:
1733  LogicalResult
1734  matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1735  ConversionPatternRewriter &rewriter) const override {
1736  Location loc = op.getLoc();
1737  auto components = adaptor.getComponents();
1738  auto vector1 = adaptor.getVector1();
1739  auto vector2 = adaptor.getVector2();
1740  int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1741  int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1742  if (vector1Size == vector2Size) {
1743  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1744  op, vector1, vector2,
1745  LLVM::convertArrayToIndices<int32_t>(components));
1746  return success();
1747  }
1748 
1749  auto dstType = getTypeConverter()->convertType(op.getType());
1750  if (!dstType)
1751  return rewriter.notifyMatchFailure(op, "type conversion failed");
1752  auto scalarType = cast<VectorType>(dstType).getElementType();
1753  auto componentsArray = components.getValue();
1754  auto *context = rewriter.getContext();
1755  auto llvmI32Type = IntegerType::get(context, 32);
1756  Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
1757  for (unsigned i = 0; i < componentsArray.size(); i++) {
1758  if (!isa<IntegerAttr>(componentsArray[i]))
1759  return op.emitError("unable to support non-constant component");
1760 
1761  int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1762  if (indexVal == -1)
1763  continue;
1764 
1765  int offsetVal = 0;
1766  Value baseVector = vector1;
1767  if (indexVal >= vector1Size) {
1768  offsetVal = vector1Size;
1769  baseVector = vector2;
1770  }
1771 
1772  Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1773  loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1774  Value index = rewriter.create<LLVM::ConstantOp>(
1775  loc, llvmI32Type,
1776  rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1777 
1778  auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1779  loc, scalarType, baseVector, index);
1780  targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1781  extractOp, dstIndex);
1782  }
1783  rewriter.replaceOp(op, targetOp);
1784  return success();
1785  }
1786 };
1787 } // namespace
1788 
1789 //===----------------------------------------------------------------------===//
1790 // Pattern population
1791 //===----------------------------------------------------------------------===//
1792 
1794  spirv::ClientAPI clientAPI) {
1795  typeConverter.addConversion([&](spirv::ArrayType type) {
1796  return convertArrayType(type, typeConverter);
1797  });
1798  typeConverter.addConversion([&, clientAPI](spirv::PointerType type) {
1799  return convertPointerType(type, typeConverter, clientAPI);
1800  });
1801  typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1802  return convertRuntimeArrayType(type, typeConverter);
1803  });
1804  typeConverter.addConversion([&](spirv::StructType type) {
1805  return convertStructType(type, typeConverter);
1806  });
1807 }
1808 
1810  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
1811  spirv::ClientAPI clientAPI) {
1812  patterns.add<
1813  // Arithmetic ops
1814  DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1815  DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1816  DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1817  DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1818  DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1819  DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1820  DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1821  DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1822  DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1823  DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1824  DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1825  DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1826  DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1827 
1828  // Bitwise ops
1829  BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1830  DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1831  DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1832  DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1833  DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1834  DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1835  NotPattern<spirv::NotOp>,
1836 
1837  // Cast ops
1838  BitcastConversionPattern,
1839  DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1840  DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1841  DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1842  DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1843  IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1844  IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1845  IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1846 
1847  // Comparison ops
1848  IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1849  IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1850  FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1851  FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1852  FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1853  FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1854  FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1855  FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1856  FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1857  FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1858  FComparePattern<spirv::FUnordGreaterThanEqualOp,
1859  LLVM::FCmpPredicate::uge>,
1860  FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1861  FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1862  FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1863  IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1864  IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1865  IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1866  IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1867  IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1868  IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1869  IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1870  IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1871 
1872  // Constant op
1873  ConstantScalarAndVectorPattern,
1874 
1875  // Control Flow ops
1876  BranchConversionPattern, BranchConditionalConversionPattern,
1877  FunctionCallPattern, LoopPattern, SelectionPattern,
1878  ErasePattern<spirv::MergeOp>,
1879 
1880  // Entry points and execution mode are handled separately.
1881  ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1882 
1883  // GLSL extended instruction set ops
1884  DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1885  DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1886  DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1887  DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1888  DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1889  DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1890  DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1891  DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1892  DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1893  DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1894  DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1895  DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1896  InverseSqrtPattern, TanPattern, TanhPattern,
1897 
1898  // Logical ops
1899  DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1900  DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1901  IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1902  IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1903  NotPattern<spirv::LogicalNotOp>,
1904 
1905  // Memory ops
1906  AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1907  LoadStorePattern<spirv::StoreOp>, VariablePattern,
1908 
1909  // Miscellaneous ops
1910  CompositeExtractPattern, CompositeInsertPattern,
1911  DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1912  DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1913  VectorShufflePattern,
1914 
1915  // Shift ops
1916  ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1917  ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1918  ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1919 
1920  // Return ops
1921  ReturnPattern, ReturnValuePattern,
1922 
1923  // Barrier ops
1924  ControlBarrierPattern<spirv::ControlBarrierOp>,
1925  ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1926  ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1927 
1928  // Group reduction operations
1929  GroupReducePattern<spirv::GroupIAddOp>,
1930  GroupReducePattern<spirv::GroupFAddOp>,
1931  GroupReducePattern<spirv::GroupFMinOp>,
1932  GroupReducePattern<spirv::GroupUMinOp>,
1933  GroupReducePattern<spirv::GroupSMinOp, /*Signed=*/true>,
1934  GroupReducePattern<spirv::GroupFMaxOp>,
1935  GroupReducePattern<spirv::GroupUMaxOp>,
1936  GroupReducePattern<spirv::GroupSMaxOp, /*Signed=*/true>,
1937  GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed=*/false,
1938  /*NonUniform=*/true>,
1939  GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed=*/false,
1940  /*NonUniform=*/true>,
1941  GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed=*/false,
1942  /*NonUniform=*/true>,
1943  GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed=*/false,
1944  /*NonUniform=*/true>,
1945  GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed=*/true,
1946  /*NonUniform=*/true>,
1947  GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed=*/false,
1948  /*NonUniform=*/true>,
1949  GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed=*/false,
1950  /*NonUniform=*/true>,
1951  GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed=*/true,
1952  /*NonUniform=*/true>,
1953  GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed=*/false,
1954  /*NonUniform=*/true>,
1955  GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed=*/false,
1956  /*NonUniform=*/true>,
1957  GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed=*/false,
1958  /*NonUniform=*/true>,
1959  GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed=*/false,
1960  /*NonUniform=*/true>,
1961  GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed=*/false,
1962  /*NonUniform=*/true>,
1963  GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed=*/false,
1964  /*NonUniform=*/true>,
1965  GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed=*/false,
1966  /*NonUniform=*/true>,
1967  GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed=*/false,
1968  /*NonUniform=*/true>>(patterns.getContext(),
1969  typeConverter);
1970 
1971  patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
1972  typeConverter);
1973 }
1974 
1976  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1977  patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1978 }
1979 
1981  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1982  patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1983 }
1984 
1985 //===----------------------------------------------------------------------===//
1986 // Pre-conversion hooks
1987 //===----------------------------------------------------------------------===//
1988 
1989 /// Hook for descriptor set and binding number encoding.
1990 static constexpr StringRef kBinding = "binding";
1991 static constexpr StringRef kDescriptorSet = "descriptor_set";
1992 void mlir::encodeBindAttribute(ModuleOp module) {
1993  auto spvModules = module.getOps<spirv::ModuleOp>();
1994  for (auto spvModule : spvModules) {
1995  spvModule.walk([&](spirv::GlobalVariableOp op) {
1996  IntegerAttr descriptorSet =
1997  op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1998  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1999  // For every global variable in the module, get the ones with descriptor
2000  // set and binding numbers.
2001  if (descriptorSet && binding) {
2002  // Encode these numbers into the variable's symbolic name. If the
2003  // SPIR-V module has a name, add it at the beginning.
2004  auto moduleAndName =
2005  spvModule.getName().has_value()
2006  ? spvModule.getName()->str() + "_" + op.getSymName().str()
2007  : op.getSymName().str();
2008  std::string name =
2009  llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
2010  std::to_string(descriptorSet.getInt()),
2011  std::to_string(binding.getInt()));
2012  auto nameAttr = StringAttr::get(op->getContext(), name);
2013 
2014  // Replace all symbol uses and set the new symbol name. Finally, remove
2015  // descriptor set and binding attributes.
2016  if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
2017  op.emitError("unable to replace all symbol uses for ") << name;
2018  SymbolTable::setSymbolName(op, nameAttr);
2019  op->removeAttr(kDescriptorSet);
2020  op->removeAttr(kBinding);
2021  }
2022  });
2023  }
2024 }
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static MLIRContext * getContext(OpFoldResult val)
@ None
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
Definition: SPIRVToLLVM.cpp:67
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
Definition: SPIRVToLLVM.cpp:38
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
Definition: SPIRVToLLVM.cpp:47
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
Definition: SPIRVToLLVM.cpp:88
static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
Definition: SPIRVToLLVM.cpp:98
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
Definition: SPIRVToLLVM.cpp:80
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
Definition: SPIRVToLLVM.cpp:57
static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Type convertStructType(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:240
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:203
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:294
IntegerType getI32Type()
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:99
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:452
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:453
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:444
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:456
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:450
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:87
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:99
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:127
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:21
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
Definition: SPIRVTypes.cpp:66
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:64
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:412
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:473
SPIR-V struct type.
Definition: SPIRVTypes.h:293
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:876
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:858
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Definition: LLVMDialect.h:229
Type getVectorElementType(Type type)
Returns the element type of any vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:892
Include the generated interface declarations.
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.