MLIR  21.0.0git
SPIRVToLLVM.cpp
Go to the documentation of this file.
1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/FormatVariadic.h"
28 
29 #define DEBUG_TYPE "spirv-to-llvm-pattern"
30 
31 using namespace mlir;
32 
33 //===----------------------------------------------------------------------===//
34 // Utility functions
35 //===----------------------------------------------------------------------===//
36 
37 /// Returns true if the given type is a signed integer or vector type.
38 static bool isSignedIntegerOrVector(Type type) {
39  if (type.isSignedInteger())
40  return true;
41  if (auto vecType = dyn_cast<VectorType>(type))
42  return vecType.getElementType().isSignedInteger();
43  return false;
44 }
45 
46 /// Returns true if the given type is an unsigned integer or vector type
47 static bool isUnsignedIntegerOrVector(Type type) {
48  if (type.isUnsignedInteger())
49  return true;
50  if (auto vecType = dyn_cast<VectorType>(type))
51  return vecType.getElementType().isUnsignedInteger();
52  return false;
53 }
54 
55 /// Returns the width of an integer or of the element type of an integer vector,
56 /// if applicable.
57 static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
58  if (auto intType = dyn_cast<IntegerType>(type))
59  return intType.getWidth();
60  if (auto vecType = dyn_cast<VectorType>(type))
61  if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
62  return intType.getWidth();
63  return std::nullopt;
64 }
65 
66 /// Returns the bit width of integer, float or vector of float or integer values
67 static unsigned getBitWidth(Type type) {
68  assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
69  "bitwidth is not supported for this type");
70  if (type.isIntOrFloat())
71  return type.getIntOrFloatBitWidth();
72  auto vecType = dyn_cast<VectorType>(type);
73  auto elementType = vecType.getElementType();
74  assert(elementType.isIntOrFloat() &&
75  "only integers and floats have a bitwidth");
76  return elementType.getIntOrFloatBitWidth();
77 }
78 
79 /// Returns the bit width of LLVMType integer or vector.
80 static unsigned getLLVMTypeBitWidth(Type type) {
81  if (auto vecTy = dyn_cast<VectorType>(type))
82  type = vecTy.getElementType();
83  return cast<IntegerType>(type).getWidth();
84 }
85 
86 /// Creates `IntegerAttribute` with all bits set for given type
87 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
88  if (auto vecType = dyn_cast<VectorType>(type)) {
89  auto integerType = cast<IntegerType>(vecType.getElementType());
90  return builder.getIntegerAttr(integerType, -1);
91  }
92  auto integerType = cast<IntegerType>(type);
93  return builder.getIntegerAttr(integerType, -1);
94 }
95 
96 /// Creates `llvm.mlir.constant` with all bits set for the given type.
97 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
98  PatternRewriter &rewriter) {
99  if (isa<VectorType>(srcType)) {
100  return rewriter.create<LLVM::ConstantOp>(
101  loc, dstType,
102  SplatElementsAttr::get(cast<ShapedType>(srcType),
103  minusOneIntegerAttribute(srcType, rewriter)));
104  }
105  return rewriter.create<LLVM::ConstantOp>(
106  loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
107 }
108 
109 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
110 static Value createFPConstant(Location loc, Type srcType, Type dstType,
111  PatternRewriter &rewriter, double value) {
112  if (auto vecType = dyn_cast<VectorType>(srcType)) {
113  auto floatType = cast<FloatType>(vecType.getElementType());
114  return rewriter.create<LLVM::ConstantOp>(
115  loc, dstType,
116  SplatElementsAttr::get(vecType,
117  rewriter.getFloatAttr(floatType, value)));
118  }
119  auto floatType = cast<FloatType>(srcType);
120  return rewriter.create<LLVM::ConstantOp>(
121  loc, dstType, rewriter.getFloatAttr(floatType, value));
122 }
123 
124 /// Utility function for bitfield ops:
125 /// - `BitFieldInsert`
126 /// - `BitFieldSExtract`
127 /// - `BitFieldUExtract`
128 /// Truncates or extends the value. If the bitwidth of the value is the same as
129 /// `llvmType` bitwidth, the value remains unchanged.
131  Type llvmType,
132  PatternRewriter &rewriter) {
133  auto srcType = value.getType();
134  unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
135  unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
136  ? getLLVMTypeBitWidth(srcType)
137  : getBitWidth(srcType);
138 
139  if (valueBitWidth < targetBitWidth)
140  return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
141  // If the bit widths of `Count` and `Offset` are greater than the bit width
142  // of the target type, they are truncated. Truncation is safe since `Count`
143  // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
144  // both values can be expressed in 8 bits.
145  if (valueBitWidth > targetBitWidth)
146  return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
147  return value;
148 }
149 
150 /// Broadcasts the value to vector with `numElements` number of elements.
151 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
152  const TypeConverter &typeConverter,
153  ConversionPatternRewriter &rewriter) {
154  auto vectorType = VectorType::get(numElements, toBroadcast.getType());
155  auto llvmVectorType = typeConverter.convertType(vectorType);
156  auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
157  Value broadcasted = rewriter.create<LLVM::PoisonOp>(loc, llvmVectorType);
158  for (unsigned i = 0; i < numElements; ++i) {
159  auto index = rewriter.create<LLVM::ConstantOp>(
160  loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
161  broadcasted = rewriter.create<LLVM::InsertElementOp>(
162  loc, llvmVectorType, broadcasted, toBroadcast, index);
163  }
164  return broadcasted;
165 }
166 
167 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
168 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
169  const TypeConverter &typeConverter,
170  ConversionPatternRewriter &rewriter) {
171  if (auto vectorType = dyn_cast<VectorType>(srcType)) {
172  unsigned numElements = vectorType.getNumElements();
173  return broadcast(loc, value, numElements, typeConverter, rewriter);
174  }
175  return value;
176 }
177 
178 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
179 /// `BitFieldUExtract`.
180 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
181 /// a vector type, construct a vector that has:
182 /// - same number of elements as `Base`
183 /// - each element has the type that is the same as the type of `Offset` or
184 /// `Count`
185 /// - each element has the same value as `Offset` or `Count`
186 /// Then cast `Offset` and `Count` if their bit width is different
187 /// from `Base` bit width.
188 static Value processCountOrOffset(Location loc, Value value, Type srcType,
189  Type dstType, const TypeConverter &converter,
190  ConversionPatternRewriter &rewriter) {
191  Value broadcasted =
192  optionallyBroadcast(loc, value, srcType, converter, rewriter);
193  return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
194 }
195 
196 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
197 /// offset to LLVM struct. Otherwise, the conversion is not supported.
199  const TypeConverter &converter) {
200  if (type != VulkanLayoutUtils::decorateType(type))
201  return nullptr;
202 
203  SmallVector<Type> elementsVector;
204  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
205  return nullptr;
206  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
207  /*isPacked=*/false);
208 }
209 
210 /// Converts SPIR-V struct with no offset to packed LLVM struct.
212  const TypeConverter &converter) {
213  SmallVector<Type> elementsVector;
214  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
215  return nullptr;
216  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
217  /*isPacked=*/true);
218 }
219 
220 /// Creates LLVM dialect constant with the given value.
222  unsigned value) {
223  return rewriter.create<LLVM::ConstantOp>(
224  loc, IntegerType::get(rewriter.getContext(), 32),
225  rewriter.getIntegerAttr(rewriter.getI32Type(), value));
226 }
227 
228 /// Utility for `spirv.Load` and `spirv.Store` conversion.
229 static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
230  ConversionPatternRewriter &rewriter,
231  const TypeConverter &typeConverter,
232  unsigned alignment, bool isVolatile,
233  bool isNonTemporal) {
234  if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
235  auto dstType = typeConverter.convertType(loadOp.getType());
236  if (!dstType)
237  return rewriter.notifyMatchFailure(op, "type conversion failed");
238  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
239  loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
240  isVolatile, isNonTemporal);
241  return success();
242  }
243  auto storeOp = cast<spirv::StoreOp>(op);
244  spirv::StoreOpAdaptor adaptor(operands);
245  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
246  adaptor.getPtr(), alignment,
247  isVolatile, isNonTemporal);
248  return success();
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // Type conversion
253 //===----------------------------------------------------------------------===//
254 
255 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
256 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
257 /// when converting ops that manipulate array types.
258 static std::optional<Type> convertArrayType(spirv::ArrayType type,
259  TypeConverter &converter) {
260  unsigned stride = type.getArrayStride();
261  Type elementType = type.getElementType();
262  auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
263  if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
264  return std::nullopt;
265 
266  auto llvmElementType = converter.convertType(elementType);
267  unsigned numElements = type.getNumElements();
268  return LLVM::LLVMArrayType::get(llvmElementType, numElements);
269 }
270 
271 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
272 /// modelled at the moment.
274  const TypeConverter &converter,
275  spirv::ClientAPI clientAPI) {
276  unsigned addressSpace =
277  storageClassToAddressSpace(clientAPI, type.getStorageClass());
278  return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
279 }
280 
281 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
282 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
283 /// no modelling of array stride at the moment.
284 static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
285  TypeConverter &converter) {
286  if (type.getArrayStride() != 0)
287  return std::nullopt;
288  auto elementType = converter.convertType(type.getElementType());
289  return LLVM::LLVMArrayType::get(elementType, 0);
290 }
291 
292 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
293 /// member decorations. Also, only natural offset is supported.
295  const TypeConverter &converter) {
297  type.getMemberDecorations(memberDecorations);
298  if (!memberDecorations.empty())
299  return nullptr;
300  if (type.hasOffset())
301  return convertStructTypeWithOffset(type, converter);
302  return convertStructTypePacked(type, converter);
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // Operation conversion
307 //===----------------------------------------------------------------------===//
308 
309 namespace {
310 
311 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
312 public:
314 
315  LogicalResult
316  matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
317  ConversionPatternRewriter &rewriter) const override {
318  auto dstType =
319  getTypeConverter()->convertType(op.getComponentPtr().getType());
320  if (!dstType)
321  return rewriter.notifyMatchFailure(op, "type conversion failed");
322  // To use GEP we need to add a first 0 index to go through the pointer.
323  auto indices = llvm::to_vector<4>(adaptor.getIndices());
324  Type indexType = op.getIndices().front().getType();
325  auto llvmIndexType = getTypeConverter()->convertType(indexType);
326  if (!llvmIndexType)
327  return rewriter.notifyMatchFailure(op, "type conversion failed");
328  Value zero = rewriter.create<LLVM::ConstantOp>(
329  op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
330  indices.insert(indices.begin(), zero);
331 
332  auto elementType = getTypeConverter()->convertType(
333  cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
334  if (!elementType)
335  return rewriter.notifyMatchFailure(op, "type conversion failed");
336  rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
337  adaptor.getBasePtr(), indices);
338  return success();
339  }
340 };
341 
342 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
343 public:
345 
346  LogicalResult
347  matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
348  ConversionPatternRewriter &rewriter) const override {
349  auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
350  if (!dstType)
351  return rewriter.notifyMatchFailure(op, "type conversion failed");
352  rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
353  op.getVariable());
354  return success();
355  }
356 };
357 
358 class BitFieldInsertPattern
359  : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
360 public:
362 
363  LogicalResult
364  matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
365  ConversionPatternRewriter &rewriter) const override {
366  auto srcType = op.getType();
367  auto dstType = getTypeConverter()->convertType(srcType);
368  if (!dstType)
369  return rewriter.notifyMatchFailure(op, "type conversion failed");
370  Location loc = op.getLoc();
371 
372  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
373  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
374  *getTypeConverter(), rewriter);
375  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
376  *getTypeConverter(), rewriter);
377 
378  // Create a mask with bits set outside [Offset, Offset + Count - 1].
379  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
380  Value maskShiftedByCount =
381  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
382  Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
383  maskShiftedByCount, minusOne);
384  Value maskShiftedByCountAndOffset =
385  rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
386  Value mask = rewriter.create<LLVM::XOrOp>(
387  loc, dstType, maskShiftedByCountAndOffset, minusOne);
388 
389  // Extract unchanged bits from the `Base` that are outside of
390  // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
391  Value baseAndMask =
392  rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
393  Value insertShiftedByOffset =
394  rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
395  rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
396  insertShiftedByOffset);
397  return success();
398  }
399 };
400 
401 /// Converts SPIR-V ConstantOp with scalar or vector type.
402 class ConstantScalarAndVectorPattern
403  : public SPIRVToLLVMConversion<spirv::ConstantOp> {
404 public:
406 
407  LogicalResult
408  matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
409  ConversionPatternRewriter &rewriter) const override {
410  auto srcType = constOp.getType();
411  if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
412  return failure();
413 
414  auto dstType = getTypeConverter()->convertType(srcType);
415  if (!dstType)
416  return rewriter.notifyMatchFailure(constOp, "type conversion failed");
417 
418  // SPIR-V constant can be a signed/unsigned integer, which has to be
419  // casted to signless integer when converting to LLVM dialect. Removing the
420  // sign bit may have unexpected behaviour. However, it is better to handle
421  // it case-by-case, given that the purpose of the conversion is not to
422  // cover all possible corner cases.
423  if (isSignedIntegerOrVector(srcType) ||
424  isUnsignedIntegerOrVector(srcType)) {
425  auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
426 
427  if (isa<VectorType>(srcType)) {
428  auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
429  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
430  constOp, dstType,
431  dstElementsAttr.mapValues(
432  signlessType, [&](const APInt &value) { return value; }));
433  return success();
434  }
435  auto srcAttr = cast<IntegerAttr>(constOp.getValue());
436  auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
437  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
438  return success();
439  }
440  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
441  constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
442  return success();
443  }
444 };
445 
446 class BitFieldSExtractPattern
447  : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
448 public:
450 
451  LogicalResult
452  matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
453  ConversionPatternRewriter &rewriter) const override {
454  auto srcType = op.getType();
455  auto dstType = getTypeConverter()->convertType(srcType);
456  if (!dstType)
457  return rewriter.notifyMatchFailure(op, "type conversion failed");
458  Location loc = op.getLoc();
459 
460  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
461  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
462  *getTypeConverter(), rewriter);
463  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
464  *getTypeConverter(), rewriter);
465 
466  // Create a constant that holds the size of the `Base`.
467  IntegerType integerType;
468  if (auto vecType = dyn_cast<VectorType>(srcType))
469  integerType = cast<IntegerType>(vecType.getElementType());
470  else
471  integerType = cast<IntegerType>(srcType);
472 
473  auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
474  Value size =
475  isa<VectorType>(srcType)
476  ? rewriter.create<LLVM::ConstantOp>(
477  loc, dstType,
478  SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
479  : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
480 
481  // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
482  // at Offset + Count - 1 is the most significant bit now.
483  Value countPlusOffset =
484  rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
485  Value amountToShiftLeft =
486  rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
487  Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
488  loc, dstType, op.getBase(), amountToShiftLeft);
489 
490  // Shift the result right, filling the bits with the sign bit.
491  Value amountToShiftRight =
492  rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
493  rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
494  amountToShiftRight);
495  return success();
496  }
497 };
498 
499 class BitFieldUExtractPattern
500  : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
501 public:
503 
504  LogicalResult
505  matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
506  ConversionPatternRewriter &rewriter) const override {
507  auto srcType = op.getType();
508  auto dstType = getTypeConverter()->convertType(srcType);
509  if (!dstType)
510  return rewriter.notifyMatchFailure(op, "type conversion failed");
511  Location loc = op.getLoc();
512 
513  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
514  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
515  *getTypeConverter(), rewriter);
516  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
517  *getTypeConverter(), rewriter);
518 
519  // Create a mask with bits set at [0, Count - 1].
520  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
521  Value maskShiftedByCount =
522  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
523  Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
524  minusOne);
525 
526  // Shift `Base` by `Offset` and apply the mask on it.
527  Value shiftedBase =
528  rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
529  rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
530  return success();
531  }
532 };
533 
534 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
535 public:
537 
538  LogicalResult
539  matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
540  ConversionPatternRewriter &rewriter) const override {
541  rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
542  branchOp.getTarget());
543  return success();
544  }
545 };
546 
547 class BranchConditionalConversionPattern
548  : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
549 public:
550  using SPIRVToLLVMConversion<
551  spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
552 
553  LogicalResult
554  matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
555  ConversionPatternRewriter &rewriter) const override {
556  // If branch weights exist, map them to 32-bit integer vector.
557  DenseI32ArrayAttr branchWeights = nullptr;
558  if (auto weights = op.getBranchWeights()) {
559  SmallVector<int32_t> weightValues;
560  for (auto weight : weights->getAsRange<IntegerAttr>())
561  weightValues.push_back(weight.getInt());
562  branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
563  }
564 
565  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
566  op, op.getCondition(), op.getTrueBlockArguments(),
567  op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
568  op.getFalseBlock());
569  return success();
570  }
571 };
572 
573 /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
574 /// type is an aggregate type (struct or array). Otherwise, converts to
575 /// `llvm.extractelement` that operates on vectors.
576 class CompositeExtractPattern
577  : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
578 public:
580 
581  LogicalResult
582  matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
583  ConversionPatternRewriter &rewriter) const override {
584  auto dstType = this->getTypeConverter()->convertType(op.getType());
585  if (!dstType)
586  return rewriter.notifyMatchFailure(op, "type conversion failed");
587 
588  Type containerType = op.getComposite().getType();
589  if (isa<VectorType>(containerType)) {
590  Location loc = op.getLoc();
591  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
592  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
593  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
594  op, dstType, adaptor.getComposite(), index);
595  return success();
596  }
597 
598  rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
599  op, adaptor.getComposite(),
600  LLVM::convertArrayToIndices(op.getIndices()));
601  return success();
602  }
603 };
604 
605 /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
606 /// type is an aggregate type (struct or array). Otherwise, converts to
607 /// `llvm.insertelement` that operates on vectors.
608 class CompositeInsertPattern
609  : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
610 public:
612 
613  LogicalResult
614  matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
615  ConversionPatternRewriter &rewriter) const override {
616  auto dstType = this->getTypeConverter()->convertType(op.getType());
617  if (!dstType)
618  return rewriter.notifyMatchFailure(op, "type conversion failed");
619 
620  Type containerType = op.getComposite().getType();
621  if (isa<VectorType>(containerType)) {
622  Location loc = op.getLoc();
623  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
624  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
625  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
626  op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
627  return success();
628  }
629 
630  rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
631  op, adaptor.getComposite(), adaptor.getObject(),
632  LLVM::convertArrayToIndices(op.getIndices()));
633  return success();
634  }
635 };
636 
637 /// Converts SPIR-V operations that have straightforward LLVM equivalent
638 /// into LLVM dialect operations.
639 template <typename SPIRVOp, typename LLVMOp>
640 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
641 public:
643 
644  LogicalResult
645  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
646  ConversionPatternRewriter &rewriter) const override {
647  auto dstType = this->getTypeConverter()->convertType(op.getType());
648  if (!dstType)
649  return rewriter.notifyMatchFailure(op, "type conversion failed");
650  rewriter.template replaceOpWithNewOp<LLVMOp>(
651  op, dstType, adaptor.getOperands(), op->getAttrs());
652  return success();
653  }
654 };
655 
656 /// Converts `spirv.ExecutionMode` into a global struct constant that holds
657 /// execution mode information.
658 class ExecutionModePattern
659  : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
660 public:
662 
663  LogicalResult
664  matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
665  ConversionPatternRewriter &rewriter) const override {
666  // First, create the global struct's name that would be associated with
667  // this entry point's execution mode. We set it to be:
668  // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
669  ModuleOp module = op->getParentOfType<ModuleOp>();
670  spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
671  std::string moduleName;
672  if (module.getName().has_value())
673  moduleName = "_" + module.getName()->str();
674  else
675  moduleName = "";
676  std::string executionModeInfoName = llvm::formatv(
677  "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
678  static_cast<uint32_t>(executionModeAttr.getValue()));
679 
680  MLIRContext *context = rewriter.getContext();
681  OpBuilder::InsertionGuard guard(rewriter);
682  rewriter.setInsertionPointToStart(module.getBody());
683 
684  // Create a struct type, corresponding to the C struct below.
685  // struct {
686  // int32_t executionMode;
687  // int32_t values[]; // optional values
688  // };
689  auto llvmI32Type = IntegerType::get(context, 32);
690  SmallVector<Type, 2> fields;
691  fields.push_back(llvmI32Type);
692  ArrayAttr values = op.getValues();
693  if (!values.empty()) {
694  auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
695  fields.push_back(arrayType);
696  }
697  auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
698 
699  // Create `llvm.mlir.global` with initializer region containing one block.
700  auto global = rewriter.create<LLVM::GlobalOp>(
701  UnknownLoc::get(context), structType, /*isConstant=*/true,
702  LLVM::Linkage::External, executionModeInfoName, Attribute(),
703  /*alignment=*/0);
704  Location loc = global.getLoc();
705  Region &region = global.getInitializerRegion();
706  Block *block = rewriter.createBlock(&region);
707 
708  // Initialize the struct and set the execution mode value.
709  rewriter.setInsertionPointToStart(block);
710  Value structValue = rewriter.create<LLVM::PoisonOp>(loc, structType);
711  Value executionMode = rewriter.create<LLVM::ConstantOp>(
712  loc, llvmI32Type,
713  rewriter.getI32IntegerAttr(
714  static_cast<uint32_t>(executionModeAttr.getValue())));
715  structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
716  executionMode, 0);
717 
718  // Insert extra operands if they exist into execution mode info struct.
719  for (unsigned i = 0, e = values.size(); i < e; ++i) {
720  auto attr = values.getValue()[i];
721  Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
722  structValue = rewriter.create<LLVM::InsertValueOp>(
723  loc, structValue, entry, ArrayRef<int64_t>({1, i}));
724  }
725  rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
726  rewriter.eraseOp(op);
727  return success();
728  }
729 };
730 
731 /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
732 /// global returns a pointer, whereas in LLVM dialect the global holds an actual
733 /// value. This difference is handled by `spirv.mlir.addressof` and
734 /// `llvm.mlir.addressof`ops that both return a pointer.
735 class GlobalVariablePattern
736  : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
737 public:
738  template <typename... Args>
739  GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
740  : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
741  std::forward<Args>(args)...),
742  clientAPI(clientAPI) {}
743 
744  LogicalResult
745  matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
746  ConversionPatternRewriter &rewriter) const override {
747  // Currently, there is no support of initialization with a constant value in
748  // SPIR-V dialect. Specialization constants are not considered as well.
749  if (op.getInitializer())
750  return failure();
751 
752  auto srcType = cast<spirv::PointerType>(op.getType());
753  auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
754  if (!dstType)
755  return rewriter.notifyMatchFailure(op, "type conversion failed");
756 
757  // Limit conversion to the current invocation only or `StorageBuffer`
758  // required by SPIR-V runner.
759  // This is okay because multiple invocations are not supported yet.
760  auto storageClass = srcType.getStorageClass();
761  switch (storageClass) {
762  case spirv::StorageClass::Input:
763  case spirv::StorageClass::Private:
764  case spirv::StorageClass::Output:
765  case spirv::StorageClass::StorageBuffer:
766  case spirv::StorageClass::UniformConstant:
767  break;
768  default:
769  return failure();
770  }
771 
772  // LLVM dialect spec: "If the global value is a constant, storing into it is
773  // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
774  // storage class that is read-only.
775  bool isConstant = (storageClass == spirv::StorageClass::Input) ||
776  (storageClass == spirv::StorageClass::UniformConstant);
777  // SPIR-V spec: "By default, functions and global variables are private to a
778  // module and cannot be accessed by other modules. However, a module may be
779  // written to export or import functions and global (module scope)
780  // variables.". Therefore, map 'Private' storage class to private linkage,
781  // 'Input' and 'Output' to external linkage.
782  auto linkage = storageClass == spirv::StorageClass::Private
783  ? LLVM::Linkage::Private
784  : LLVM::Linkage::External;
785  auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
786  op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
787  /*alignment=*/0, storageClassToAddressSpace(clientAPI, storageClass));
788 
789  // Attach location attribute if applicable
790  if (op.getLocationAttr())
791  newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
792 
793  return success();
794  }
795 
796 private:
797  spirv::ClientAPI clientAPI;
798 };
799 
800 /// Converts SPIR-V cast ops that do not have straightforward LLVM
801 /// equivalent in LLVM dialect.
802 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
803 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
804 public:
806 
807  LogicalResult
808  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
809  ConversionPatternRewriter &rewriter) const override {
810 
811  Type fromType = op.getOperand().getType();
812  Type toType = op.getType();
813 
814  auto dstType = this->getTypeConverter()->convertType(toType);
815  if (!dstType)
816  return rewriter.notifyMatchFailure(op, "type conversion failed");
817 
818  if (getBitWidth(fromType) < getBitWidth(toType)) {
819  rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
820  adaptor.getOperands());
821  return success();
822  }
823  if (getBitWidth(fromType) > getBitWidth(toType)) {
824  rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
825  adaptor.getOperands());
826  return success();
827  }
828  return failure();
829  }
830 };
831 
832 class FunctionCallPattern
833  : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
834 public:
836 
837  LogicalResult
838  matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
839  ConversionPatternRewriter &rewriter) const override {
840  if (callOp.getNumResults() == 0) {
841  auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
842  callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
843  newOp.getProperties().operandSegmentSizes = {
844  static_cast<int32_t>(adaptor.getOperands().size()), 0};
845  newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
846  return success();
847  }
848 
849  // Function returns a single result.
850  auto dstType = getTypeConverter()->convertType(callOp.getType(0));
851  if (!dstType)
852  return rewriter.notifyMatchFailure(callOp, "type conversion failed");
853  auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
854  callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
855  newOp.getProperties().operandSegmentSizes = {
856  static_cast<int32_t>(adaptor.getOperands().size()), 0};
857  newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
858  return success();
859  }
860 };
861 
862 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
863 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
864 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
865 public:
867 
868  LogicalResult
869  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
870  ConversionPatternRewriter &rewriter) const override {
871 
872  auto dstType = this->getTypeConverter()->convertType(op.getType());
873  if (!dstType)
874  return rewriter.notifyMatchFailure(op, "type conversion failed");
875 
876  rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
877  op, dstType, predicate, op.getOperand1(), op.getOperand2());
878  return success();
879  }
880 };
881 
882 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
883 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
884 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
885 public:
887 
888  LogicalResult
889  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
890  ConversionPatternRewriter &rewriter) const override {
891 
892  auto dstType = this->getTypeConverter()->convertType(op.getType());
893  if (!dstType)
894  return rewriter.notifyMatchFailure(op, "type conversion failed");
895 
896  rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
897  op, dstType, predicate, op.getOperand1(), op.getOperand2());
898  return success();
899  }
900 };
901 
902 class InverseSqrtPattern
903  : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
904 public:
906 
907  LogicalResult
908  matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
909  ConversionPatternRewriter &rewriter) const override {
910  auto srcType = op.getType();
911  auto dstType = getTypeConverter()->convertType(srcType);
912  if (!dstType)
913  return rewriter.notifyMatchFailure(op, "type conversion failed");
914 
915  Location loc = op.getLoc();
916  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
917  Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
918  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
919  return success();
920  }
921 };
922 
923 /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
924 template <typename SPIRVOp>
925 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
926 public:
928 
929  LogicalResult
930  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
931  ConversionPatternRewriter &rewriter) const override {
932  if (!op.getMemoryAccess()) {
933  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
934  *this->getTypeConverter(), /*alignment=*/0,
935  /*isVolatile=*/false,
936  /*isNonTemporal=*/false);
937  }
938  auto memoryAccess = *op.getMemoryAccess();
939  switch (memoryAccess) {
940  case spirv::MemoryAccess::Aligned:
942  case spirv::MemoryAccess::Nontemporal:
943  case spirv::MemoryAccess::Volatile: {
944  unsigned alignment =
945  memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
946  bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
947  bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
948  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
949  *this->getTypeConverter(), alignment,
950  isVolatile, isNonTemporal);
951  }
952  default:
953  // There is no support of other memory access attributes.
954  return failure();
955  }
956  }
957 };
958 
959 /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
960 template <typename SPIRVOp>
961 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
962 public:
964 
965  LogicalResult
966  matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
967  ConversionPatternRewriter &rewriter) const override {
968  auto srcType = notOp.getType();
969  auto dstType = this->getTypeConverter()->convertType(srcType);
970  if (!dstType)
971  return rewriter.notifyMatchFailure(notOp, "type conversion failed");
972 
973  Location loc = notOp.getLoc();
974  IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
975  auto mask =
976  isa<VectorType>(srcType)
977  ? rewriter.create<LLVM::ConstantOp>(
978  loc, dstType,
979  SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
980  : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
981  rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
982  notOp.getOperand(), mask);
983  return success();
984  }
985 };
986 
987 /// A template pattern that erases the given `SPIRVOp`.
988 template <typename SPIRVOp>
989 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
990 public:
992 
993  LogicalResult
994  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
995  ConversionPatternRewriter &rewriter) const override {
996  rewriter.eraseOp(op);
997  return success();
998  }
999 };
1000 
1001 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
1002 public:
1004 
1005  LogicalResult
1006  matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1007  ConversionPatternRewriter &rewriter) const override {
1008  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
1009  ArrayRef<Value>());
1010  return success();
1011  }
1012 };
1013 
1014 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1015 public:
1017 
1018  LogicalResult
1019  matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1020  ConversionPatternRewriter &rewriter) const override {
1021  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
1022  adaptor.getOperands());
1023  return success();
1024  }
1025 };
1026 
1027 static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
1028  StringRef name,
1029  ArrayRef<Type> paramTypes,
1030  Type resultType,
1031  bool convergent = true) {
1032  auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1033  SymbolTable::lookupSymbolIn(symbolTable, name));
1034  if (func)
1035  return func;
1036 
1037  OpBuilder b(symbolTable->getRegion(0));
1038  func = b.create<LLVM::LLVMFuncOp>(
1039  symbolTable->getLoc(), name,
1040  LLVM::LLVMFunctionType::get(resultType, paramTypes));
1041  func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1042  func.setConvergent(convergent);
1043  func.setNoUnwind(true);
1044  func.setWillReturn(true);
1045  return func;
1046 }
1047 
1048 static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
1049  LLVM::LLVMFuncOp func,
1050  ValueRange args) {
1051  auto call = builder.create<LLVM::CallOp>(loc, func, args);
1052  call.setCConv(func.getCConv());
1053  call.setConvergentAttr(func.getConvergentAttr());
1054  call.setNoUnwindAttr(func.getNoUnwindAttr());
1055  call.setWillReturnAttr(func.getWillReturnAttr());
1056  return call;
1057 }
1058 
1059 template <typename BarrierOpTy>
1060 class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> {
1061 public:
1062  using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor;
1063 
1065 
1066  static constexpr StringRef getFuncName();
1067 
1068  LogicalResult
1069  matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1070  ConversionPatternRewriter &rewriter) const override {
1071  constexpr StringRef funcName = getFuncName();
1072  Operation *symbolTable =
1073  controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1074 
1075  Type i32 = rewriter.getI32Type();
1076 
1077  Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
1078  LLVM::LLVMFuncOp func =
1079  lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
1080 
1081  Location loc = controlBarrierOp->getLoc();
1082  Value execution = rewriter.create<LLVM::ConstantOp>(
1083  loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
1084  Value memory = rewriter.create<LLVM::ConstantOp>(
1085  loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
1086  Value semantics = rewriter.create<LLVM::ConstantOp>(
1087  loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
1088 
1089  auto call = createSPIRVBuiltinCall(loc, rewriter, func,
1090  {execution, memory, semantics});
1091 
1092  rewriter.replaceOp(controlBarrierOp, call);
1093  return success();
1094  }
1095 };
1096 
1097 namespace {
1098 
1099 StringRef getTypeMangling(Type type, bool isSigned) {
1101  .Case<Float16Type>([](auto) { return "Dh"; })
1102  .Case<Float32Type>([](auto) { return "f"; })
1103  .Case<Float64Type>([](auto) { return "d"; })
1104  .Case<IntegerType>([isSigned](IntegerType intTy) {
1105  switch (intTy.getWidth()) {
1106  case 1:
1107  return "b";
1108  case 8:
1109  return (isSigned) ? "a" : "c";
1110  case 16:
1111  return (isSigned) ? "s" : "t";
1112  case 32:
1113  return (isSigned) ? "i" : "j";
1114  case 64:
1115  return (isSigned) ? "l" : "m";
1116  default:
1117  llvm_unreachable("Unsupported integer width");
1118  }
1119  })
1120  .Default([](auto) {
1121  llvm_unreachable("No mangling defined");
1122  return "";
1123  });
1124 }
1125 
1126 template <typename ReduceOp>
1127 constexpr StringLiteral getGroupFuncName();
1128 
1129 template <>
1130 constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1131  return "_Z17__spirv_GroupIAddii";
1132 }
1133 template <>
1134 constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1135  return "_Z17__spirv_GroupFAddii";
1136 }
1137 template <>
1138 constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1139  return "_Z17__spirv_GroupSMinii";
1140 }
1141 template <>
1142 constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1143  return "_Z17__spirv_GroupUMinii";
1144 }
1145 template <>
1146 constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1147  return "_Z17__spirv_GroupFMinii";
1148 }
1149 template <>
1150 constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1151  return "_Z17__spirv_GroupSMaxii";
1152 }
1153 template <>
1154 constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1155  return "_Z17__spirv_GroupUMaxii";
1156 }
1157 template <>
1158 constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1159  return "_Z17__spirv_GroupFMaxii";
1160 }
1161 template <>
1162 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1163  return "_Z27__spirv_GroupNonUniformIAddii";
1164 }
1165 template <>
1166 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1167  return "_Z27__spirv_GroupNonUniformFAddii";
1168 }
1169 template <>
1170 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1171  return "_Z27__spirv_GroupNonUniformIMulii";
1172 }
1173 template <>
1174 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1175  return "_Z27__spirv_GroupNonUniformFMulii";
1176 }
1177 template <>
1178 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1179  return "_Z27__spirv_GroupNonUniformSMinii";
1180 }
1181 template <>
1182 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1183  return "_Z27__spirv_GroupNonUniformUMinii";
1184 }
1185 template <>
1186 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1187  return "_Z27__spirv_GroupNonUniformFMinii";
1188 }
1189 template <>
1190 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1191  return "_Z27__spirv_GroupNonUniformSMaxii";
1192 }
1193 template <>
1194 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1195  return "_Z27__spirv_GroupNonUniformUMaxii";
1196 }
1197 template <>
1198 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1199  return "_Z27__spirv_GroupNonUniformFMaxii";
1200 }
1201 template <>
1202 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1203  return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1204 }
1205 template <>
1206 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1207  return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1208 }
1209 template <>
1210 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1211  return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1212 }
1213 template <>
1214 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1215  return "_Z33__spirv_GroupNonUniformLogicalAndii";
1216 }
1217 template <>
1218 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1219  return "_Z32__spirv_GroupNonUniformLogicalOrii";
1220 }
1221 template <>
1222 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1223  return "_Z33__spirv_GroupNonUniformLogicalXorii";
1224 }
1225 } // namespace
1226 
1227 template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
1228 class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
1229 public:
1231 
1232  LogicalResult
1233  matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
1234  ConversionPatternRewriter &rewriter) const override {
1235 
1236  Type retTy = op.getResult().getType();
1237  if (!retTy.isIntOrFloat()) {
1238  return failure();
1239  }
1240  SmallString<36> funcName = getGroupFuncName<ReduceOp>();
1241  funcName += getTypeMangling(retTy, false);
1242 
1243  Type i32Ty = rewriter.getI32Type();
1244  SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
1245  if constexpr (NonUniform) {
1246  if (adaptor.getClusterSize()) {
1247  funcName += "j";
1248  paramTypes.push_back(i32Ty);
1249  }
1250  }
1251 
1252  Operation *symbolTable =
1253  op->template getParentWithTrait<OpTrait::SymbolTable>();
1254 
1255  LLVM::LLVMFuncOp func =
1256  lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy);
1257 
1258  Location loc = op.getLoc();
1259  Value scope = rewriter.create<LLVM::ConstantOp>(
1260  loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope()));
1261  Value groupOp = rewriter.create<LLVM::ConstantOp>(
1262  loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation()));
1263  SmallVector<Value> operands{scope, groupOp};
1264  operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1265 
1266  auto call = createSPIRVBuiltinCall(loc, rewriter, func, operands);
1267  rewriter.replaceOp(op, call);
1268  return success();
1269  }
1270 };
1271 
1272 template <>
1273 constexpr StringRef
1274 ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1275  return "_Z22__spirv_ControlBarrieriii";
1276 }
1277 
1278 template <>
1279 constexpr StringRef
1280 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1281  return "_Z33__spirv_ControlBarrierArriveINTELiii";
1282 }
1283 
1284 template <>
1285 constexpr StringRef
1286 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1287  return "_Z31__spirv_ControlBarrierWaitINTELiii";
1288 }
1289 
1290 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1291 /// should be reachable for conversion to succeed. The structure of the loop in
1292 /// LLVM dialect will be the following:
1293 ///
1294 /// +------------------------------------+
1295 /// | <code before spirv.mlir.loop> |
1296 /// | llvm.br ^header |
1297 /// +------------------------------------+
1298 /// |
1299 /// +----------------+ |
1300 /// | | |
1301 /// | V V
1302 /// | +------------------------------------+
1303 /// | | ^header: |
1304 /// | | <header code> |
1305 /// | | llvm.cond_br %cond, ^body, ^exit |
1306 /// | +------------------------------------+
1307 /// | |
1308 /// | |----------------------+
1309 /// | | |
1310 /// | V |
1311 /// | +------------------------------------+ |
1312 /// | | ^body: | |
1313 /// | | <body code> | |
1314 /// | | llvm.br ^continue | |
1315 /// | +------------------------------------+ |
1316 /// | | |
1317 /// | V |
1318 /// | +------------------------------------+ |
1319 /// | | ^continue: | |
1320 /// | | <continue code> | |
1321 /// | | llvm.br ^header | |
1322 /// | +------------------------------------+ |
1323 /// | | |
1324 /// +---------------+ +----------------------+
1325 /// |
1326 /// V
1327 /// +------------------------------------+
1328 /// | ^exit: |
1329 /// | llvm.br ^remaining |
1330 /// +------------------------------------+
1331 /// |
1332 /// V
1333 /// +------------------------------------+
1334 /// | ^remaining: |
1335 /// | <code after spirv.mlir.loop> |
1336 /// +------------------------------------+
1337 ///
1338 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1339 public:
1341 
1342  LogicalResult
1343  matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1344  ConversionPatternRewriter &rewriter) const override {
1345  // There is no support of loop control at the moment.
1346  if (loopOp.getLoopControl() != spirv::LoopControl::None)
1347  return failure();
1348 
1349  // `spirv.mlir.loop` with empty region is redundant and should be erased.
1350  if (loopOp.getBody().empty()) {
1351  rewriter.eraseOp(loopOp);
1352  return success();
1353  }
1354 
1355  Location loc = loopOp.getLoc();
1356 
1357  // Split the current block after `spirv.mlir.loop`. The remaining ops will
1358  // be used in `endBlock`.
1359  Block *currentBlock = rewriter.getBlock();
1360  auto position = Block::iterator(loopOp);
1361  Block *endBlock = rewriter.splitBlock(currentBlock, position);
1362 
1363  // Remove entry block and create a branch in the current block going to the
1364  // header block.
1365  Block *entryBlock = loopOp.getEntryBlock();
1366  assert(entryBlock->getOperations().size() == 1);
1367  auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1368  if (!brOp)
1369  return failure();
1370  Block *headerBlock = loopOp.getHeaderBlock();
1371  rewriter.setInsertionPointToEnd(currentBlock);
1372  rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1373  rewriter.eraseBlock(entryBlock);
1374 
1375  // Branch from merge block to end block.
1376  Block *mergeBlock = loopOp.getMergeBlock();
1377  Operation *terminator = mergeBlock->getTerminator();
1378  ValueRange terminatorOperands = terminator->getOperands();
1379  rewriter.setInsertionPointToEnd(mergeBlock);
1380  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1381 
1382  rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1383  rewriter.replaceOp(loopOp, endBlock->getArguments());
1384  return success();
1385  }
1386 };
1387 
1388 /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1389 /// block. All blocks within selection should be reachable for conversion to
1390 /// succeed.
1391 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1392 public:
1394 
1395  LogicalResult
1396  matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1397  ConversionPatternRewriter &rewriter) const override {
1398  // There is no support for `Flatten` or `DontFlatten` selection control at
1399  // the moment. This are just compiler hints and can be performed during the
1400  // optimization passes.
1401  if (op.getSelectionControl() != spirv::SelectionControl::None)
1402  return failure();
1403 
1404  // `spirv.mlir.selection` should have at least two blocks: one selection
1405  // header block and one merge block. If no blocks are present, or control
1406  // flow branches straight to merge block (two blocks are present), the op is
1407  // redundant and it is erased.
1408  if (op.getBody().getBlocks().size() <= 2) {
1409  rewriter.eraseOp(op);
1410  return success();
1411  }
1412 
1413  Location loc = op.getLoc();
1414 
1415  // Split the current block after `spirv.mlir.selection`. The remaining ops
1416  // will be used in `continueBlock`.
1417  auto *currentBlock = rewriter.getInsertionBlock();
1418  rewriter.setInsertionPointAfter(op);
1419  auto position = rewriter.getInsertionPoint();
1420  auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1421 
1422  // Extract conditional branch information from the header block. By SPIR-V
1423  // dialect spec, it should contain `spirv.BranchConditional` or
1424  // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1425  // moment in the SPIR-V dialect. Remove this block when finished.
1426  auto *headerBlock = op.getHeaderBlock();
1427  assert(headerBlock->getOperations().size() == 1);
1428  auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1429  headerBlock->getOperations().front());
1430  if (!condBrOp)
1431  return failure();
1432  rewriter.eraseBlock(headerBlock);
1433 
1434  // Branch from merge block to continue block.
1435  auto *mergeBlock = op.getMergeBlock();
1436  Operation *terminator = mergeBlock->getTerminator();
1437  ValueRange terminatorOperands = terminator->getOperands();
1438  rewriter.setInsertionPointToEnd(mergeBlock);
1439  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1440 
1441  // Link current block to `true` and `false` blocks within the selection.
1442  Block *trueBlock = condBrOp.getTrueBlock();
1443  Block *falseBlock = condBrOp.getFalseBlock();
1444  rewriter.setInsertionPointToEnd(currentBlock);
1445  rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1446  condBrOp.getTrueTargetOperands(),
1447  falseBlock,
1448  condBrOp.getFalseTargetOperands());
1449 
1450  rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1451  rewriter.replaceOp(op, continueBlock->getArguments());
1452  return success();
1453  }
1454 };
1455 
1456 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1457 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1458 /// `Shift` is zero or sign extended to match this specification. Cases when
1459 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1460 template <typename SPIRVOp, typename LLVMOp>
1461 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1462 public:
1464 
1465  LogicalResult
1466  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1467  ConversionPatternRewriter &rewriter) const override {
1468 
1469  auto dstType = this->getTypeConverter()->convertType(op.getType());
1470  if (!dstType)
1471  return rewriter.notifyMatchFailure(op, "type conversion failed");
1472 
1473  Type op1Type = op.getOperand1().getType();
1474  Type op2Type = op.getOperand2().getType();
1475 
1476  if (op1Type == op2Type) {
1477  rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1478  adaptor.getOperands());
1479  return success();
1480  }
1481 
1482  std::optional<uint64_t> dstTypeWidth =
1484  std::optional<uint64_t> op2TypeWidth =
1486 
1487  if (!dstTypeWidth || !op2TypeWidth)
1488  return failure();
1489 
1490  Location loc = op.getLoc();
1491  Value extended;
1492  if (op2TypeWidth < dstTypeWidth) {
1493  if (isUnsignedIntegerOrVector(op2Type)) {
1494  extended = rewriter.template create<LLVM::ZExtOp>(
1495  loc, dstType, adaptor.getOperand2());
1496  } else {
1497  extended = rewriter.template create<LLVM::SExtOp>(
1498  loc, dstType, adaptor.getOperand2());
1499  }
1500  } else if (op2TypeWidth == dstTypeWidth) {
1501  extended = adaptor.getOperand2();
1502  } else {
1503  return failure();
1504  }
1505 
1506  Value result = rewriter.template create<LLVMOp>(
1507  loc, dstType, adaptor.getOperand1(), extended);
1508  rewriter.replaceOp(op, result);
1509  return success();
1510  }
1511 };
1512 
1513 class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1514 public:
1516 
1517  LogicalResult
1518  matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1519  ConversionPatternRewriter &rewriter) const override {
1520  auto dstType = getTypeConverter()->convertType(tanOp.getType());
1521  if (!dstType)
1522  return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
1523 
1524  Location loc = tanOp.getLoc();
1525  Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1526  Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1527  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1528  return success();
1529  }
1530 };
1531 
1532 /// Convert `spirv.Tanh` to
1533 ///
1534 /// exp(2x) - 1
1535 /// -----------
1536 /// exp(2x) + 1
1537 ///
1538 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1539 public:
1541 
1542  LogicalResult
1543  matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1544  ConversionPatternRewriter &rewriter) const override {
1545  auto srcType = tanhOp.getType();
1546  auto dstType = getTypeConverter()->convertType(srcType);
1547  if (!dstType)
1548  return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
1549 
1550  Location loc = tanhOp.getLoc();
1551  Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1552  Value multiplied =
1553  rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1554  Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1555  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1556  Value numerator =
1557  rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1558  Value denominator =
1559  rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1560  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1561  denominator);
1562  return success();
1563  }
1564 };
1565 
1566 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1567 public:
1569 
1570  LogicalResult
1571  matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1572  ConversionPatternRewriter &rewriter) const override {
1573  auto srcType = varOp.getType();
1574  // Initialization is supported for scalars and vectors only.
1575  auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1576  auto init = varOp.getInitializer();
1577  if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1578  return failure();
1579 
1580  auto dstType = getTypeConverter()->convertType(srcType);
1581  if (!dstType)
1582  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1583 
1584  Location loc = varOp.getLoc();
1585  Value size = createI32ConstantOf(loc, rewriter, 1);
1586  if (!init) {
1587  auto elementType = getTypeConverter()->convertType(pointerTo);
1588  if (!elementType)
1589  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1590  rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
1591  size);
1592  return success();
1593  }
1594  auto elementType = getTypeConverter()->convertType(pointerTo);
1595  if (!elementType)
1596  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1597  Value allocated =
1598  rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
1599  rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1600  rewriter.replaceOp(varOp, allocated);
1601  return success();
1602  }
1603 };
1604 
1605 //===----------------------------------------------------------------------===//
1606 // BitcastOp conversion
1607 //===----------------------------------------------------------------------===//
1608 
1609 class BitcastConversionPattern
1610  : public SPIRVToLLVMConversion<spirv::BitcastOp> {
1611 public:
1613 
1614  LogicalResult
1615  matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1616  ConversionPatternRewriter &rewriter) const override {
1617  auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1618  if (!dstType)
1619  return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
1620 
1621  // LLVM's opaque pointers do not require bitcasts.
1622  if (isa<LLVM::LLVMPointerType>(dstType)) {
1623  rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1624  return success();
1625  }
1626 
1627  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1628  bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1629  return success();
1630  }
1631 };
1632 
1633 //===----------------------------------------------------------------------===//
1634 // FuncOp conversion
1635 //===----------------------------------------------------------------------===//
1636 
1637 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1638 public:
1640 
1641  LogicalResult
1642  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1643  ConversionPatternRewriter &rewriter) const override {
1644 
1645  // Convert function signature. At the moment LLVMType converter is enough
1646  // for currently supported types.
1647  auto funcType = funcOp.getFunctionType();
1648  TypeConverter::SignatureConversion signatureConverter(
1649  funcType.getNumInputs());
1650  auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter())
1651  ->convertFunctionSignature(
1652  funcType, /*isVariadic=*/false,
1653  /*useBarePtrCallConv=*/false, signatureConverter);
1654  if (!llvmType)
1655  return failure();
1656 
1657  // Create a new `LLVMFuncOp`
1658  Location loc = funcOp.getLoc();
1659  StringRef name = funcOp.getName();
1660  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1661 
1662  // Convert SPIR-V Function Control to equivalent LLVM function attribute
1663  MLIRContext *context = funcOp.getContext();
1664  switch (funcOp.getFunctionControl()) {
1665  case spirv::FunctionControl::Inline:
1666  newFuncOp.setAlwaysInline(true);
1667  break;
1668  case spirv::FunctionControl::DontInline:
1669  newFuncOp.setNoInline(true);
1670  break;
1671 
1672 #define DISPATCH(functionControl, llvmAttr) \
1673  case functionControl: \
1674  newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1675  break;
1676 
1677  DISPATCH(spirv::FunctionControl::Pure,
1678  StringAttr::get(context, "readonly"));
1679  DISPATCH(spirv::FunctionControl::Const,
1680  StringAttr::get(context, "readnone"));
1681 
1682 #undef DISPATCH
1683 
1684  // Default: if `spirv::FunctionControl::None`, then no attributes are
1685  // needed.
1686  default:
1687  break;
1688  }
1689 
1690  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1691  newFuncOp.end());
1692  if (failed(rewriter.convertRegionTypes(
1693  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1694  return failure();
1695  }
1696  rewriter.eraseOp(funcOp);
1697  return success();
1698  }
1699 };
1700 
1701 //===----------------------------------------------------------------------===//
1702 // ModuleOp conversion
1703 //===----------------------------------------------------------------------===//
1704 
1705 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1706 public:
1708 
1709  LogicalResult
1710  matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1711  ConversionPatternRewriter &rewriter) const override {
1712 
1713  auto newModuleOp =
1714  rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1715  rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1716 
1717  // Remove the terminator block that was automatically added by builder
1718  rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1719  rewriter.eraseOp(spvModuleOp);
1720  return success();
1721  }
1722 };
1723 
1724 //===----------------------------------------------------------------------===//
1725 // VectorShuffleOp conversion
1726 //===----------------------------------------------------------------------===//
1727 
1728 class VectorShufflePattern
1729  : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1730 public:
1732  LogicalResult
1733  matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1734  ConversionPatternRewriter &rewriter) const override {
1735  Location loc = op.getLoc();
1736  auto components = adaptor.getComponents();
1737  auto vector1 = adaptor.getVector1();
1738  auto vector2 = adaptor.getVector2();
1739  int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1740  int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1741  if (vector1Size == vector2Size) {
1742  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1743  op, vector1, vector2,
1744  LLVM::convertArrayToIndices<int32_t>(components));
1745  return success();
1746  }
1747 
1748  auto dstType = getTypeConverter()->convertType(op.getType());
1749  if (!dstType)
1750  return rewriter.notifyMatchFailure(op, "type conversion failed");
1751  auto scalarType = cast<VectorType>(dstType).getElementType();
1752  auto componentsArray = components.getValue();
1753  auto *context = rewriter.getContext();
1754  auto llvmI32Type = IntegerType::get(context, 32);
1755  Value targetOp = rewriter.create<LLVM::PoisonOp>(loc, dstType);
1756  for (unsigned i = 0; i < componentsArray.size(); i++) {
1757  if (!isa<IntegerAttr>(componentsArray[i]))
1758  return op.emitError("unable to support non-constant component");
1759 
1760  int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1761  if (indexVal == -1)
1762  continue;
1763 
1764  int offsetVal = 0;
1765  Value baseVector = vector1;
1766  if (indexVal >= vector1Size) {
1767  offsetVal = vector1Size;
1768  baseVector = vector2;
1769  }
1770 
1771  Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1772  loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1773  Value index = rewriter.create<LLVM::ConstantOp>(
1774  loc, llvmI32Type,
1775  rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1776 
1777  auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1778  loc, scalarType, baseVector, index);
1779  targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1780  extractOp, dstIndex);
1781  }
1782  rewriter.replaceOp(op, targetOp);
1783  return success();
1784  }
1785 };
1786 } // namespace
1787 
1788 //===----------------------------------------------------------------------===//
1789 // Pattern population
1790 //===----------------------------------------------------------------------===//
1791 
1793  spirv::ClientAPI clientAPI) {
1794  typeConverter.addConversion([&](spirv::ArrayType type) {
1795  return convertArrayType(type, typeConverter);
1796  });
1797  typeConverter.addConversion([&, clientAPI](spirv::PointerType type) {
1798  return convertPointerType(type, typeConverter, clientAPI);
1799  });
1800  typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1801  return convertRuntimeArrayType(type, typeConverter);
1802  });
1803  typeConverter.addConversion([&](spirv::StructType type) {
1804  return convertStructType(type, typeConverter);
1805  });
1806 }
1807 
1809  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
1810  spirv::ClientAPI clientAPI) {
1811  patterns.add<
1812  // Arithmetic ops
1813  DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1814  DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1815  DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1816  DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1817  DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1818  DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1819  DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1820  DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1821  DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1822  DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1823  DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1824  DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1825  DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1826 
1827  // Bitwise ops
1828  BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1829  DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1830  DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1831  DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1832  DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1833  DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1834  NotPattern<spirv::NotOp>,
1835 
1836  // Cast ops
1837  BitcastConversionPattern,
1838  DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1839  DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1840  DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1841  DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1842  IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1843  IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1844  IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1845 
1846  // Comparison ops
1847  IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1848  IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1849  FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1850  FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1851  FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1852  FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1853  FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1854  FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1855  FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1856  FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1857  FComparePattern<spirv::FUnordGreaterThanEqualOp,
1858  LLVM::FCmpPredicate::uge>,
1859  FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1860  FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1861  FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1862  IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1863  IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1864  IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1865  IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1866  IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1867  IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1868  IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1869  IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1870 
1871  // Constant op
1872  ConstantScalarAndVectorPattern,
1873 
1874  // Control Flow ops
1875  BranchConversionPattern, BranchConditionalConversionPattern,
1876  FunctionCallPattern, LoopPattern, SelectionPattern,
1877  ErasePattern<spirv::MergeOp>,
1878 
1879  // Entry points and execution mode are handled separately.
1880  ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1881 
1882  // GLSL extended instruction set ops
1883  DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1884  DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1885  DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1886  DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1887  DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1888  DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1889  DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1890  DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1891  DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1892  DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1893  DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1894  DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1895  InverseSqrtPattern, TanPattern, TanhPattern,
1896 
1897  // Logical ops
1898  DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1899  DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1900  IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1901  IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1902  NotPattern<spirv::LogicalNotOp>,
1903 
1904  // Memory ops
1905  AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1906  LoadStorePattern<spirv::StoreOp>, VariablePattern,
1907 
1908  // Miscellaneous ops
1909  CompositeExtractPattern, CompositeInsertPattern,
1910  DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1911  DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1912  VectorShufflePattern,
1913 
1914  // Shift ops
1915  ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1916  ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1917  ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1918 
1919  // Return ops
1920  ReturnPattern, ReturnValuePattern,
1921 
1922  // Barrier ops
1923  ControlBarrierPattern<spirv::ControlBarrierOp>,
1924  ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1925  ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1926 
1927  // Group reduction operations
1928  GroupReducePattern<spirv::GroupIAddOp>,
1929  GroupReducePattern<spirv::GroupFAddOp>,
1930  GroupReducePattern<spirv::GroupFMinOp>,
1931  GroupReducePattern<spirv::GroupUMinOp>,
1932  GroupReducePattern<spirv::GroupSMinOp, /*Signed=*/true>,
1933  GroupReducePattern<spirv::GroupFMaxOp>,
1934  GroupReducePattern<spirv::GroupUMaxOp>,
1935  GroupReducePattern<spirv::GroupSMaxOp, /*Signed=*/true>,
1936  GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed=*/false,
1937  /*NonUniform=*/true>,
1938  GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed=*/false,
1939  /*NonUniform=*/true>,
1940  GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed=*/false,
1941  /*NonUniform=*/true>,
1942  GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed=*/false,
1943  /*NonUniform=*/true>,
1944  GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed=*/true,
1945  /*NonUniform=*/true>,
1946  GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed=*/false,
1947  /*NonUniform=*/true>,
1948  GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed=*/false,
1949  /*NonUniform=*/true>,
1950  GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed=*/true,
1951  /*NonUniform=*/true>,
1952  GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed=*/false,
1953  /*NonUniform=*/true>,
1954  GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed=*/false,
1955  /*NonUniform=*/true>,
1956  GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed=*/false,
1957  /*NonUniform=*/true>,
1958  GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed=*/false,
1959  /*NonUniform=*/true>,
1960  GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed=*/false,
1961  /*NonUniform=*/true>,
1962  GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed=*/false,
1963  /*NonUniform=*/true>,
1964  GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed=*/false,
1965  /*NonUniform=*/true>,
1966  GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed=*/false,
1967  /*NonUniform=*/true>>(patterns.getContext(),
1968  typeConverter);
1969 
1970  patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
1971  typeConverter);
1972 }
1973 
1975  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1976  patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1977 }
1978 
1980  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1981  patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1982 }
1983 
1984 //===----------------------------------------------------------------------===//
1985 // Pre-conversion hooks
1986 //===----------------------------------------------------------------------===//
1987 
1988 /// Hook for descriptor set and binding number encoding.
1989 static constexpr StringRef kBinding = "binding";
1990 static constexpr StringRef kDescriptorSet = "descriptor_set";
1991 void mlir::encodeBindAttribute(ModuleOp module) {
1992  auto spvModules = module.getOps<spirv::ModuleOp>();
1993  for (auto spvModule : spvModules) {
1994  spvModule.walk([&](spirv::GlobalVariableOp op) {
1995  IntegerAttr descriptorSet =
1996  op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1997  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1998  // For every global variable in the module, get the ones with descriptor
1999  // set and binding numbers.
2000  if (descriptorSet && binding) {
2001  // Encode these numbers into the variable's symbolic name. If the
2002  // SPIR-V module has a name, add it at the beginning.
2003  auto moduleAndName =
2004  spvModule.getName().has_value()
2005  ? spvModule.getName()->str() + "_" + op.getSymName().str()
2006  : op.getSymName().str();
2007  std::string name =
2008  llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
2009  std::to_string(descriptorSet.getInt()),
2010  std::to_string(binding.getInt()));
2011  auto nameAttr = StringAttr::get(op->getContext(), name);
2012 
2013  // Replace all symbol uses and set the new symbol name. Finally, remove
2014  // descriptor set and binding attributes.
2015  if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
2016  op.emitError("unable to replace all symbol uses for ") << name;
2017  SymbolTable::setSymbolName(op, nameAttr);
2018  op->removeAttr(kDescriptorSet);
2019  op->removeAttr(kBinding);
2020  }
2021  });
2022  }
2023 }
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static MLIRContext * getContext(OpFoldResult val)
@ None
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
Definition: SPIRVToLLVM.cpp:67
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
Definition: SPIRVToLLVM.cpp:38
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
Definition: SPIRVToLLVM.cpp:47
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
Definition: SPIRVToLLVM.cpp:87
static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
Definition: SPIRVToLLVM.cpp:97
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
Definition: SPIRVToLLVM.cpp:80
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
Definition: SPIRVToLLVM.cpp:57
static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Type convertStructType(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:89
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:446
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:686
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:504
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:21
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
Definition: SPIRVTypes.cpp:66
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:64
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:412
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:473
SPIR-V struct type.
Definition: SPIRVTypes.h:293
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:793
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Definition: LLVMDialect.h:229
Include the generated interface declarations.
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
const FrozenRewritePatternSet & patterns
void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.