MLIR  19.0.0git
SPIRVToLLVM.cpp
Go to the documentation of this file.
1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 #define DEBUG_TYPE "spirv-to-llvm-pattern"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Constants
34 //===----------------------------------------------------------------------===//
35 
36 constexpr unsigned defaultAddressSpace = 0;
37 
38 //===----------------------------------------------------------------------===//
39 // Utility functions
40 //===----------------------------------------------------------------------===//
41 
42 /// Returns true if the given type is a signed integer or vector type.
43 static bool isSignedIntegerOrVector(Type type) {
44  if (type.isSignedInteger())
45  return true;
46  if (auto vecType = dyn_cast<VectorType>(type))
47  return vecType.getElementType().isSignedInteger();
48  return false;
49 }
50 
51 /// Returns true if the given type is an unsigned integer or vector type
52 static bool isUnsignedIntegerOrVector(Type type) {
53  if (type.isUnsignedInteger())
54  return true;
55  if (auto vecType = dyn_cast<VectorType>(type))
56  return vecType.getElementType().isUnsignedInteger();
57  return false;
58 }
59 
60 /// Returns the width of an integer or of the element type of an integer vector,
61 /// if applicable.
62 static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
63  if (auto intType = dyn_cast<IntegerType>(type))
64  return intType.getWidth();
65  if (auto vecType = dyn_cast<VectorType>(type))
66  if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
67  return intType.getWidth();
68  return std::nullopt;
69 }
70 
71 /// Returns the bit width of integer, float or vector of float or integer values
72 static unsigned getBitWidth(Type type) {
73  assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
74  "bitwidth is not supported for this type");
75  if (type.isIntOrFloat())
76  return type.getIntOrFloatBitWidth();
77  auto vecType = dyn_cast<VectorType>(type);
78  auto elementType = vecType.getElementType();
79  assert(elementType.isIntOrFloat() &&
80  "only integers and floats have a bitwidth");
81  return elementType.getIntOrFloatBitWidth();
82 }
83 
84 /// Returns the bit width of LLVMType integer or vector.
85 static unsigned getLLVMTypeBitWidth(Type type) {
86  return cast<IntegerType>((LLVM::isCompatibleVectorType(type)
88  : type))
89  .getWidth();
90 }
91 
92 /// Creates `IntegerAttribute` with all bits set for given type
93 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
94  if (auto vecType = dyn_cast<VectorType>(type)) {
95  auto integerType = cast<IntegerType>(vecType.getElementType());
96  return builder.getIntegerAttr(integerType, -1);
97  }
98  auto integerType = cast<IntegerType>(type);
99  return builder.getIntegerAttr(integerType, -1);
100 }
101 
102 /// Creates `llvm.mlir.constant` with all bits set for the given type.
103 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
104  PatternRewriter &rewriter) {
105  if (isa<VectorType>(srcType)) {
106  return rewriter.create<LLVM::ConstantOp>(
107  loc, dstType,
108  SplatElementsAttr::get(cast<ShapedType>(srcType),
109  minusOneIntegerAttribute(srcType, rewriter)));
110  }
111  return rewriter.create<LLVM::ConstantOp>(
112  loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
113 }
114 
115 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
116 static Value createFPConstant(Location loc, Type srcType, Type dstType,
117  PatternRewriter &rewriter, double value) {
118  if (auto vecType = dyn_cast<VectorType>(srcType)) {
119  auto floatType = cast<FloatType>(vecType.getElementType());
120  return rewriter.create<LLVM::ConstantOp>(
121  loc, dstType,
122  SplatElementsAttr::get(vecType,
123  rewriter.getFloatAttr(floatType, value)));
124  }
125  auto floatType = cast<FloatType>(srcType);
126  return rewriter.create<LLVM::ConstantOp>(
127  loc, dstType, rewriter.getFloatAttr(floatType, value));
128 }
129 
130 /// Utility function for bitfield ops:
131 /// - `BitFieldInsert`
132 /// - `BitFieldSExtract`
133 /// - `BitFieldUExtract`
134 /// Truncates or extends the value. If the bitwidth of the value is the same as
135 /// `llvmType` bitwidth, the value remains unchanged.
137  Type llvmType,
138  PatternRewriter &rewriter) {
139  auto srcType = value.getType();
140  unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
141  unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
142  ? getLLVMTypeBitWidth(srcType)
143  : getBitWidth(srcType);
144 
145  if (valueBitWidth < targetBitWidth)
146  return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
147  // If the bit widths of `Count` and `Offset` are greater than the bit width
148  // of the target type, they are truncated. Truncation is safe since `Count`
149  // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
150  // both values can be expressed in 8 bits.
151  if (valueBitWidth > targetBitWidth)
152  return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
153  return value;
154 }
155 
156 /// Broadcasts the value to vector with `numElements` number of elements.
157 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
158  LLVMTypeConverter &typeConverter,
159  ConversionPatternRewriter &rewriter) {
160  auto vectorType = VectorType::get(numElements, toBroadcast.getType());
161  auto llvmVectorType = typeConverter.convertType(vectorType);
162  auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
163  Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
164  for (unsigned i = 0; i < numElements; ++i) {
165  auto index = rewriter.create<LLVM::ConstantOp>(
166  loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
167  broadcasted = rewriter.create<LLVM::InsertElementOp>(
168  loc, llvmVectorType, broadcasted, toBroadcast, index);
169  }
170  return broadcasted;
171 }
172 
173 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
174 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
175  LLVMTypeConverter &typeConverter,
176  ConversionPatternRewriter &rewriter) {
177  if (auto vectorType = dyn_cast<VectorType>(srcType)) {
178  unsigned numElements = vectorType.getNumElements();
179  return broadcast(loc, value, numElements, typeConverter, rewriter);
180  }
181  return value;
182 }
183 
184 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
185 /// `BitFieldUExtract`.
186 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
187 /// a vector type, construct a vector that has:
188 /// - same number of elements as `Base`
189 /// - each element has the type that is the same as the type of `Offset` or
190 /// `Count`
191 /// - each element has the same value as `Offset` or `Count`
192 /// Then cast `Offset` and `Count` if their bit width is different
193 /// from `Base` bit width.
194 static Value processCountOrOffset(Location loc, Value value, Type srcType,
195  Type dstType, LLVMTypeConverter &converter,
196  ConversionPatternRewriter &rewriter) {
197  Value broadcasted =
198  optionallyBroadcast(loc, value, srcType, converter, rewriter);
199  return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
200 }
201 
202 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
203 /// offset to LLVM struct. Otherwise, the conversion is not supported.
205  LLVMTypeConverter &converter) {
206  if (type != VulkanLayoutUtils::decorateType(type))
207  return nullptr;
208 
209  SmallVector<Type> elementsVector;
210  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
211  return nullptr;
212  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
213  /*isPacked=*/false);
214 }
215 
216 /// Converts SPIR-V struct with no offset to packed LLVM struct.
218  LLVMTypeConverter &converter) {
219  SmallVector<Type> elementsVector;
220  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
221  return nullptr;
222  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
223  /*isPacked=*/true);
224 }
225 
226 /// Creates LLVM dialect constant with the given value.
228  unsigned value) {
229  return rewriter.create<LLVM::ConstantOp>(
230  loc, IntegerType::get(rewriter.getContext(), 32),
231  rewriter.getIntegerAttr(rewriter.getI32Type(), value));
232 }
233 
234 /// Utility for `spirv.Load` and `spirv.Store` conversion.
236  ConversionPatternRewriter &rewriter,
237  LLVMTypeConverter &typeConverter,
238  unsigned alignment, bool isVolatile,
239  bool isNonTemporal) {
240  if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
241  auto dstType = typeConverter.convertType(loadOp.getType());
242  if (!dstType)
243  return rewriter.notifyMatchFailure(op, "type conversion failed");
244  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
245  loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
246  isVolatile, isNonTemporal);
247  return success();
248  }
249  auto storeOp = cast<spirv::StoreOp>(op);
250  spirv::StoreOpAdaptor adaptor(operands);
251  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
252  adaptor.getPtr(), alignment,
253  isVolatile, isNonTemporal);
254  return success();
255 }
256 
257 //===----------------------------------------------------------------------===//
258 // Type conversion
259 //===----------------------------------------------------------------------===//
260 
261 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
262 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
263 /// when converting ops that manipulate array types.
264 static std::optional<Type> convertArrayType(spirv::ArrayType type,
265  TypeConverter &converter) {
266  unsigned stride = type.getArrayStride();
267  Type elementType = type.getElementType();
268  auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
269  if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
270  return std::nullopt;
271 
272  auto llvmElementType = converter.convertType(elementType);
273  unsigned numElements = type.getNumElements();
274  return LLVM::LLVMArrayType::get(llvmElementType, numElements);
275 }
276 
277 static unsigned mapToOpenCLAddressSpace(spirv::StorageClass storageClass) {
278  // Based on
279  // https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_binary_form
280  // and clang/lib/Basic/Targets/SPIR.h.
281  switch (storageClass) {
282 #define STORAGE_SPACE_MAP(storage, space) \
283  case spirv::StorageClass::storage: \
284  return space;
285  STORAGE_SPACE_MAP(Function, 0)
286  STORAGE_SPACE_MAP(CrossWorkgroup, 1)
287  STORAGE_SPACE_MAP(Input, 1)
288  STORAGE_SPACE_MAP(UniformConstant, 2)
289  STORAGE_SPACE_MAP(Workgroup, 3)
290  STORAGE_SPACE_MAP(Generic, 4)
291  STORAGE_SPACE_MAP(DeviceOnlyINTEL, 5)
292  STORAGE_SPACE_MAP(HostOnlyINTEL, 6)
293 #undef STORAGE_SPACE_MAP
294  default:
295  return defaultAddressSpace;
296  }
297 }
298 
299 static unsigned mapToAddressSpace(spirv::ClientAPI clientAPI,
300  spirv::StorageClass storageClass) {
301  switch (clientAPI) {
302 #define CLIENT_MAP(client, storage) \
303  case spirv::ClientAPI::client: \
304  return mapTo##client##AddressSpace(storage);
305  CLIENT_MAP(OpenCL, storageClass)
306 #undef CLIENT_MAP
307  default:
308  return defaultAddressSpace;
309  }
310 }
311 
312 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
313 /// modelled at the moment.
315  LLVMTypeConverter &converter,
316  spirv::ClientAPI clientAPI) {
317  unsigned addressSpace = mapToAddressSpace(clientAPI, type.getStorageClass());
318  return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
319 }
320 
321 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
322 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
323 /// no modelling of array stride at the moment.
324 static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
325  TypeConverter &converter) {
326  if (type.getArrayStride() != 0)
327  return std::nullopt;
328  auto elementType = converter.convertType(type.getElementType());
329  return LLVM::LLVMArrayType::get(elementType, 0);
330 }
331 
332 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
333 /// member decorations. Also, only natural offset is supported.
335  LLVMTypeConverter &converter) {
337  type.getMemberDecorations(memberDecorations);
338  if (!memberDecorations.empty())
339  return nullptr;
340  if (type.hasOffset())
341  return convertStructTypeWithOffset(type, converter);
342  return convertStructTypePacked(type, converter);
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // Operation conversion
347 //===----------------------------------------------------------------------===//
348 
349 namespace {
350 
351 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
352 public:
354 
356  matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
357  ConversionPatternRewriter &rewriter) const override {
358  auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
359  if (!dstType)
360  return rewriter.notifyMatchFailure(op, "type conversion failed");
361  // To use GEP we need to add a first 0 index to go through the pointer.
362  auto indices = llvm::to_vector<4>(adaptor.getIndices());
363  Type indexType = op.getIndices().front().getType();
364  auto llvmIndexType = typeConverter.convertType(indexType);
365  if (!llvmIndexType)
366  return rewriter.notifyMatchFailure(op, "type conversion failed");
367  Value zero = rewriter.create<LLVM::ConstantOp>(
368  op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
369  indices.insert(indices.begin(), zero);
370 
371  auto elementType = typeConverter.convertType(
372  cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
373  if (!elementType)
374  return rewriter.notifyMatchFailure(op, "type conversion failed");
375  rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
376  adaptor.getBasePtr(), indices);
377  return success();
378  }
379 };
380 
381 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
382 public:
384 
386  matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
387  ConversionPatternRewriter &rewriter) const override {
388  auto dstType = typeConverter.convertType(op.getPointer().getType());
389  if (!dstType)
390  return rewriter.notifyMatchFailure(op, "type conversion failed");
391  rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
392  op.getVariable());
393  return success();
394  }
395 };
396 
397 class BitFieldInsertPattern
398  : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
399 public:
401 
403  matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
404  ConversionPatternRewriter &rewriter) const override {
405  auto srcType = op.getType();
406  auto dstType = typeConverter.convertType(srcType);
407  if (!dstType)
408  return rewriter.notifyMatchFailure(op, "type conversion failed");
409  Location loc = op.getLoc();
410 
411  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
412  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
413  typeConverter, rewriter);
414  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
415  typeConverter, rewriter);
416 
417  // Create a mask with bits set outside [Offset, Offset + Count - 1].
418  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
419  Value maskShiftedByCount =
420  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
421  Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
422  maskShiftedByCount, minusOne);
423  Value maskShiftedByCountAndOffset =
424  rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
425  Value mask = rewriter.create<LLVM::XOrOp>(
426  loc, dstType, maskShiftedByCountAndOffset, minusOne);
427 
428  // Extract unchanged bits from the `Base` that are outside of
429  // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
430  Value baseAndMask =
431  rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
432  Value insertShiftedByOffset =
433  rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
434  rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
435  insertShiftedByOffset);
436  return success();
437  }
438 };
439 
440 /// Converts SPIR-V ConstantOp with scalar or vector type.
441 class ConstantScalarAndVectorPattern
442  : public SPIRVToLLVMConversion<spirv::ConstantOp> {
443 public:
445 
447  matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
448  ConversionPatternRewriter &rewriter) const override {
449  auto srcType = constOp.getType();
450  if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
451  return failure();
452 
453  auto dstType = typeConverter.convertType(srcType);
454  if (!dstType)
455  return rewriter.notifyMatchFailure(constOp, "type conversion failed");
456 
457  // SPIR-V constant can be a signed/unsigned integer, which has to be
458  // casted to signless integer when converting to LLVM dialect. Removing the
459  // sign bit may have unexpected behaviour. However, it is better to handle
460  // it case-by-case, given that the purpose of the conversion is not to
461  // cover all possible corner cases.
462  if (isSignedIntegerOrVector(srcType) ||
463  isUnsignedIntegerOrVector(srcType)) {
464  auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
465 
466  if (isa<VectorType>(srcType)) {
467  auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
468  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
469  constOp, dstType,
470  dstElementsAttr.mapValues(
471  signlessType, [&](const APInt &value) { return value; }));
472  return success();
473  }
474  auto srcAttr = cast<IntegerAttr>(constOp.getValue());
475  auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
476  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
477  return success();
478  }
479  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
480  constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
481  return success();
482  }
483 };
484 
485 class BitFieldSExtractPattern
486  : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
487 public:
489 
491  matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
492  ConversionPatternRewriter &rewriter) const override {
493  auto srcType = op.getType();
494  auto dstType = typeConverter.convertType(srcType);
495  if (!dstType)
496  return rewriter.notifyMatchFailure(op, "type conversion failed");
497  Location loc = op.getLoc();
498 
499  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
500  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
501  typeConverter, rewriter);
502  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
503  typeConverter, rewriter);
504 
505  // Create a constant that holds the size of the `Base`.
506  IntegerType integerType;
507  if (auto vecType = dyn_cast<VectorType>(srcType))
508  integerType = cast<IntegerType>(vecType.getElementType());
509  else
510  integerType = cast<IntegerType>(srcType);
511 
512  auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
513  Value size =
514  isa<VectorType>(srcType)
515  ? rewriter.create<LLVM::ConstantOp>(
516  loc, dstType,
517  SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
518  : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
519 
520  // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
521  // at Offset + Count - 1 is the most significant bit now.
522  Value countPlusOffset =
523  rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
524  Value amountToShiftLeft =
525  rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
526  Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
527  loc, dstType, op.getBase(), amountToShiftLeft);
528 
529  // Shift the result right, filling the bits with the sign bit.
530  Value amountToShiftRight =
531  rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
532  rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
533  amountToShiftRight);
534  return success();
535  }
536 };
537 
538 class BitFieldUExtractPattern
539  : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
540 public:
542 
544  matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
545  ConversionPatternRewriter &rewriter) const override {
546  auto srcType = op.getType();
547  auto dstType = typeConverter.convertType(srcType);
548  if (!dstType)
549  return rewriter.notifyMatchFailure(op, "type conversion failed");
550  Location loc = op.getLoc();
551 
552  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
553  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
554  typeConverter, rewriter);
555  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
556  typeConverter, rewriter);
557 
558  // Create a mask with bits set at [0, Count - 1].
559  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
560  Value maskShiftedByCount =
561  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
562  Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
563  minusOne);
564 
565  // Shift `Base` by `Offset` and apply the mask on it.
566  Value shiftedBase =
567  rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
568  rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
569  return success();
570  }
571 };
572 
573 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
574 public:
576 
578  matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
579  ConversionPatternRewriter &rewriter) const override {
580  rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
581  branchOp.getTarget());
582  return success();
583  }
584 };
585 
586 class BranchConditionalConversionPattern
587  : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
588 public:
589  using SPIRVToLLVMConversion<
590  spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
591 
593  matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
594  ConversionPatternRewriter &rewriter) const override {
595  // If branch weights exist, map them to 32-bit integer vector.
596  DenseI32ArrayAttr branchWeights = nullptr;
597  if (auto weights = op.getBranchWeights()) {
598  SmallVector<int32_t> weightValues;
599  for (auto weight : weights->getAsRange<IntegerAttr>())
600  weightValues.push_back(weight.getInt());
601  branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
602  }
603 
604  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
605  op, op.getCondition(), op.getTrueBlockArguments(),
606  op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
607  op.getFalseBlock());
608  return success();
609  }
610 };
611 
612 /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
613 /// type is an aggregate type (struct or array). Otherwise, converts to
614 /// `llvm.extractelement` that operates on vectors.
615 class CompositeExtractPattern
616  : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
617 public:
619 
621  matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
622  ConversionPatternRewriter &rewriter) const override {
623  auto dstType = this->typeConverter.convertType(op.getType());
624  if (!dstType)
625  return rewriter.notifyMatchFailure(op, "type conversion failed");
626 
627  Type containerType = op.getComposite().getType();
628  if (isa<VectorType>(containerType)) {
629  Location loc = op.getLoc();
630  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
631  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
632  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
633  op, dstType, adaptor.getComposite(), index);
634  return success();
635  }
636 
637  rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
638  op, adaptor.getComposite(),
639  LLVM::convertArrayToIndices(op.getIndices()));
640  return success();
641  }
642 };
643 
644 /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
645 /// type is an aggregate type (struct or array). Otherwise, converts to
646 /// `llvm.insertelement` that operates on vectors.
647 class CompositeInsertPattern
648  : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
649 public:
651 
653  matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
654  ConversionPatternRewriter &rewriter) const override {
655  auto dstType = this->typeConverter.convertType(op.getType());
656  if (!dstType)
657  return rewriter.notifyMatchFailure(op, "type conversion failed");
658 
659  Type containerType = op.getComposite().getType();
660  if (isa<VectorType>(containerType)) {
661  Location loc = op.getLoc();
662  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
663  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
664  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
665  op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
666  return success();
667  }
668 
669  rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
670  op, adaptor.getComposite(), adaptor.getObject(),
671  LLVM::convertArrayToIndices(op.getIndices()));
672  return success();
673  }
674 };
675 
676 /// Converts SPIR-V operations that have straightforward LLVM equivalent
677 /// into LLVM dialect operations.
678 template <typename SPIRVOp, typename LLVMOp>
679 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
680 public:
682 
684  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
685  ConversionPatternRewriter &rewriter) const override {
686  auto dstType = this->typeConverter.convertType(op.getType());
687  if (!dstType)
688  return rewriter.notifyMatchFailure(op, "type conversion failed");
689  rewriter.template replaceOpWithNewOp<LLVMOp>(
690  op, dstType, adaptor.getOperands(), op->getAttrs());
691  return success();
692  }
693 };
694 
695 /// Converts `spirv.ExecutionMode` into a global struct constant that holds
696 /// execution mode information.
697 class ExecutionModePattern
698  : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
699 public:
701 
703  matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
704  ConversionPatternRewriter &rewriter) const override {
705  // First, create the global struct's name that would be associated with
706  // this entry point's execution mode. We set it to be:
707  // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
708  ModuleOp module = op->getParentOfType<ModuleOp>();
709  spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
710  std::string moduleName;
711  if (module.getName().has_value())
712  moduleName = "_" + module.getName()->str();
713  else
714  moduleName = "";
715  std::string executionModeInfoName = llvm::formatv(
716  "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
717  static_cast<uint32_t>(executionModeAttr.getValue()));
718 
719  MLIRContext *context = rewriter.getContext();
720  OpBuilder::InsertionGuard guard(rewriter);
721  rewriter.setInsertionPointToStart(module.getBody());
722 
723  // Create a struct type, corresponding to the C struct below.
724  // struct {
725  // int32_t executionMode;
726  // int32_t values[]; // optional values
727  // };
728  auto llvmI32Type = IntegerType::get(context, 32);
729  SmallVector<Type, 2> fields;
730  fields.push_back(llvmI32Type);
731  ArrayAttr values = op.getValues();
732  if (!values.empty()) {
733  auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
734  fields.push_back(arrayType);
735  }
736  auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
737 
738  // Create `llvm.mlir.global` with initializer region containing one block.
739  auto global = rewriter.create<LLVM::GlobalOp>(
740  UnknownLoc::get(context), structType, /*isConstant=*/true,
741  LLVM::Linkage::External, executionModeInfoName, Attribute(),
742  /*alignment=*/0);
743  Location loc = global.getLoc();
744  Region &region = global.getInitializerRegion();
745  Block *block = rewriter.createBlock(&region);
746 
747  // Initialize the struct and set the execution mode value.
748  rewriter.setInsertionPoint(block, block->begin());
749  Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
750  Value executionMode = rewriter.create<LLVM::ConstantOp>(
751  loc, llvmI32Type,
752  rewriter.getI32IntegerAttr(
753  static_cast<uint32_t>(executionModeAttr.getValue())));
754  structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
755  executionMode, 0);
756 
757  // Insert extra operands if they exist into execution mode info struct.
758  for (unsigned i = 0, e = values.size(); i < e; ++i) {
759  auto attr = values.getValue()[i];
760  Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
761  structValue = rewriter.create<LLVM::InsertValueOp>(
762  loc, structValue, entry, ArrayRef<int64_t>({1, i}));
763  }
764  rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
765  rewriter.eraseOp(op);
766  return success();
767  }
768 };
769 
770 /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
771 /// global returns a pointer, whereas in LLVM dialect the global holds an actual
772 /// value. This difference is handled by `spirv.mlir.addressof` and
773 /// `llvm.mlir.addressof`ops that both return a pointer.
774 class GlobalVariablePattern
775  : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
776 public:
777  template <typename... Args>
778  GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
779  : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
780  std::forward<Args>(args)...),
781  clientAPI(clientAPI) {}
782 
784  matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
785  ConversionPatternRewriter &rewriter) const override {
786  // Currently, there is no support of initialization with a constant value in
787  // SPIR-V dialect. Specialization constants are not considered as well.
788  if (op.getInitializer())
789  return failure();
790 
791  auto srcType = cast<spirv::PointerType>(op.getType());
792  auto dstType = typeConverter.convertType(srcType.getPointeeType());
793  if (!dstType)
794  return rewriter.notifyMatchFailure(op, "type conversion failed");
795 
796  // Limit conversion to the current invocation only or `StorageBuffer`
797  // required by SPIR-V runner.
798  // This is okay because multiple invocations are not supported yet.
799  auto storageClass = srcType.getStorageClass();
800  switch (storageClass) {
801  case spirv::StorageClass::Input:
802  case spirv::StorageClass::Private:
803  case spirv::StorageClass::Output:
804  case spirv::StorageClass::StorageBuffer:
805  case spirv::StorageClass::UniformConstant:
806  break;
807  default:
808  return failure();
809  }
810 
811  // LLVM dialect spec: "If the global value is a constant, storing into it is
812  // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
813  // storage class that is read-only.
814  bool isConstant = (storageClass == spirv::StorageClass::Input) ||
815  (storageClass == spirv::StorageClass::UniformConstant);
816  // SPIR-V spec: "By default, functions and global variables are private to a
817  // module and cannot be accessed by other modules. However, a module may be
818  // written to export or import functions and global (module scope)
819  // variables.". Therefore, map 'Private' storage class to private linkage,
820  // 'Input' and 'Output' to external linkage.
821  auto linkage = storageClass == spirv::StorageClass::Private
822  ? LLVM::Linkage::Private
823  : LLVM::Linkage::External;
824  auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
825  op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
826  /*alignment=*/0, mapToAddressSpace(clientAPI, storageClass));
827 
828  // Attach location attribute if applicable
829  if (op.getLocationAttr())
830  newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
831 
832  return success();
833  }
834 
835 private:
836  spirv::ClientAPI clientAPI;
837 };
838 
839 /// Converts SPIR-V cast ops that do not have straightforward LLVM
840 /// equivalent in LLVM dialect.
841 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
842 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
843 public:
845 
847  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
848  ConversionPatternRewriter &rewriter) const override {
849 
850  Type fromType = op.getOperand().getType();
851  Type toType = op.getType();
852 
853  auto dstType = this->typeConverter.convertType(toType);
854  if (!dstType)
855  return rewriter.notifyMatchFailure(op, "type conversion failed");
856 
857  if (getBitWidth(fromType) < getBitWidth(toType)) {
858  rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
859  adaptor.getOperands());
860  return success();
861  }
862  if (getBitWidth(fromType) > getBitWidth(toType)) {
863  rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
864  adaptor.getOperands());
865  return success();
866  }
867  return failure();
868  }
869 };
870 
871 class FunctionCallPattern
872  : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
873 public:
875 
877  matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
878  ConversionPatternRewriter &rewriter) const override {
879  if (callOp.getNumResults() == 0) {
880  rewriter.replaceOpWithNewOp<LLVM::CallOp>(
881  callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
882  return success();
883  }
884 
885  // Function returns a single result.
886  auto dstType = typeConverter.convertType(callOp.getType(0));
887  if (!dstType)
888  return rewriter.notifyMatchFailure(callOp, "type conversion failed");
889  rewriter.replaceOpWithNewOp<LLVM::CallOp>(
890  callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
891  return success();
892  }
893 };
894 
895 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
896 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
897 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
898 public:
900 
902  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
903  ConversionPatternRewriter &rewriter) const override {
904 
905  auto dstType = this->typeConverter.convertType(op.getType());
906  if (!dstType)
907  return rewriter.notifyMatchFailure(op, "type conversion failed");
908 
909  rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
910  op, dstType, predicate, op.getOperand1(), op.getOperand2());
911  return success();
912  }
913 };
914 
915 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
916 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
917 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
918 public:
920 
922  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
923  ConversionPatternRewriter &rewriter) const override {
924 
925  auto dstType = this->typeConverter.convertType(op.getType());
926  if (!dstType)
927  return rewriter.notifyMatchFailure(op, "type conversion failed");
928 
929  rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
930  op, dstType, predicate, op.getOperand1(), op.getOperand2());
931  return success();
932  }
933 };
934 
935 class InverseSqrtPattern
936  : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
937 public:
939 
941  matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
942  ConversionPatternRewriter &rewriter) const override {
943  auto srcType = op.getType();
944  auto dstType = typeConverter.convertType(srcType);
945  if (!dstType)
946  return rewriter.notifyMatchFailure(op, "type conversion failed");
947 
948  Location loc = op.getLoc();
949  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
950  Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
951  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
952  return success();
953  }
954 };
955 
956 /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
957 template <typename SPIRVOp>
958 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
959 public:
961 
963  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
964  ConversionPatternRewriter &rewriter) const override {
965  if (!op.getMemoryAccess()) {
966  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
967  this->typeConverter, /*alignment=*/0,
968  /*isVolatile=*/false,
969  /*isNonTemporal=*/false);
970  }
971  auto memoryAccess = *op.getMemoryAccess();
972  switch (memoryAccess) {
973  case spirv::MemoryAccess::Aligned:
975  case spirv::MemoryAccess::Nontemporal:
976  case spirv::MemoryAccess::Volatile: {
977  unsigned alignment =
978  memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
979  bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
980  bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
981  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
982  this->typeConverter, alignment, isVolatile,
983  isNonTemporal);
984  }
985  default:
986  // There is no support of other memory access attributes.
987  return failure();
988  }
989  }
990 };
991 
992 /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
993 template <typename SPIRVOp>
994 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
995 public:
997 
999  matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
1000  ConversionPatternRewriter &rewriter) const override {
1001  auto srcType = notOp.getType();
1002  auto dstType = this->typeConverter.convertType(srcType);
1003  if (!dstType)
1004  return rewriter.notifyMatchFailure(notOp, "type conversion failed");
1005 
1006  Location loc = notOp.getLoc();
1007  IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
1008  auto mask =
1009  isa<VectorType>(srcType)
1010  ? rewriter.create<LLVM::ConstantOp>(
1011  loc, dstType,
1012  SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
1013  : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
1014  rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
1015  notOp.getOperand(), mask);
1016  return success();
1017  }
1018 };
1019 
1020 /// A template pattern that erases the given `SPIRVOp`.
1021 template <typename SPIRVOp>
1022 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
1023 public:
1025 
1027  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1028  ConversionPatternRewriter &rewriter) const override {
1029  rewriter.eraseOp(op);
1030  return success();
1031  }
1032 };
1033 
1034 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
1035 public:
1037 
1039  matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1040  ConversionPatternRewriter &rewriter) const override {
1041  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
1042  ArrayRef<Value>());
1043  return success();
1044  }
1045 };
1046 
1047 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1048 public:
1050 
1052  matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1053  ConversionPatternRewriter &rewriter) const override {
1054  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
1055  adaptor.getOperands());
1056  return success();
1057  }
1058 };
1059 
1060 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1061 /// should be reachable for conversion to succeed. The structure of the loop in
1062 /// LLVM dialect will be the following:
1063 ///
1064 /// +------------------------------------+
1065 /// | <code before spirv.mlir.loop> |
1066 /// | llvm.br ^header |
1067 /// +------------------------------------+
1068 /// |
1069 /// +----------------+ |
1070 /// | | |
1071 /// | V V
1072 /// | +------------------------------------+
1073 /// | | ^header: |
1074 /// | | <header code> |
1075 /// | | llvm.cond_br %cond, ^body, ^exit |
1076 /// | +------------------------------------+
1077 /// | |
1078 /// | |----------------------+
1079 /// | | |
1080 /// | V |
1081 /// | +------------------------------------+ |
1082 /// | | ^body: | |
1083 /// | | <body code> | |
1084 /// | | llvm.br ^continue | |
1085 /// | +------------------------------------+ |
1086 /// | | |
1087 /// | V |
1088 /// | +------------------------------------+ |
1089 /// | | ^continue: | |
1090 /// | | <continue code> | |
1091 /// | | llvm.br ^header | |
1092 /// | +------------------------------------+ |
1093 /// | | |
1094 /// +---------------+ +----------------------+
1095 /// |
1096 /// V
1097 /// +------------------------------------+
1098 /// | ^exit: |
1099 /// | llvm.br ^remaining |
1100 /// +------------------------------------+
1101 /// |
1102 /// V
1103 /// +------------------------------------+
1104 /// | ^remaining: |
1105 /// | <code after spirv.mlir.loop> |
1106 /// +------------------------------------+
1107 ///
1108 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1109 public:
1111 
1113  matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1114  ConversionPatternRewriter &rewriter) const override {
1115  // There is no support of loop control at the moment.
1116  if (loopOp.getLoopControl() != spirv::LoopControl::None)
1117  return failure();
1118 
1119  Location loc = loopOp.getLoc();
1120 
1121  // Split the current block after `spirv.mlir.loop`. The remaining ops will
1122  // be used in `endBlock`.
1123  Block *currentBlock = rewriter.getBlock();
1124  auto position = Block::iterator(loopOp);
1125  Block *endBlock = rewriter.splitBlock(currentBlock, position);
1126 
1127  // Remove entry block and create a branch in the current block going to the
1128  // header block.
1129  Block *entryBlock = loopOp.getEntryBlock();
1130  assert(entryBlock->getOperations().size() == 1);
1131  auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1132  if (!brOp)
1133  return failure();
1134  Block *headerBlock = loopOp.getHeaderBlock();
1135  rewriter.setInsertionPointToEnd(currentBlock);
1136  rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1137  rewriter.eraseBlock(entryBlock);
1138 
1139  // Branch from merge block to end block.
1140  Block *mergeBlock = loopOp.getMergeBlock();
1141  Operation *terminator = mergeBlock->getTerminator();
1142  ValueRange terminatorOperands = terminator->getOperands();
1143  rewriter.setInsertionPointToEnd(mergeBlock);
1144  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1145 
1146  rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1147  rewriter.replaceOp(loopOp, endBlock->getArguments());
1148  return success();
1149  }
1150 };
1151 
1152 /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1153 /// block. All blocks within selection should be reachable for conversion to
1154 /// succeed.
1155 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1156 public:
1158 
1160  matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1161  ConversionPatternRewriter &rewriter) const override {
1162  // There is no support for `Flatten` or `DontFlatten` selection control at
1163  // the moment. This are just compiler hints and can be performed during the
1164  // optimization passes.
1165  if (op.getSelectionControl() != spirv::SelectionControl::None)
1166  return failure();
1167 
1168  // `spirv.mlir.selection` should have at least two blocks: one selection
1169  // header block and one merge block. If no blocks are present, or control
1170  // flow branches straight to merge block (two blocks are present), the op is
1171  // redundant and it is erased.
1172  if (op.getBody().getBlocks().size() <= 2) {
1173  rewriter.eraseOp(op);
1174  return success();
1175  }
1176 
1177  Location loc = op.getLoc();
1178 
1179  // Split the current block after `spirv.mlir.selection`. The remaining ops
1180  // will be used in `continueBlock`.
1181  auto *currentBlock = rewriter.getInsertionBlock();
1182  rewriter.setInsertionPointAfter(op);
1183  auto position = rewriter.getInsertionPoint();
1184  auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1185 
1186  // Extract conditional branch information from the header block. By SPIR-V
1187  // dialect spec, it should contain `spirv.BranchConditional` or
1188  // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1189  // moment in the SPIR-V dialect. Remove this block when finished.
1190  auto *headerBlock = op.getHeaderBlock();
1191  assert(headerBlock->getOperations().size() == 1);
1192  auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1193  headerBlock->getOperations().front());
1194  if (!condBrOp)
1195  return failure();
1196  rewriter.eraseBlock(headerBlock);
1197 
1198  // Branch from merge block to continue block.
1199  auto *mergeBlock = op.getMergeBlock();
1200  Operation *terminator = mergeBlock->getTerminator();
1201  ValueRange terminatorOperands = terminator->getOperands();
1202  rewriter.setInsertionPointToEnd(mergeBlock);
1203  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1204 
1205  // Link current block to `true` and `false` blocks within the selection.
1206  Block *trueBlock = condBrOp.getTrueBlock();
1207  Block *falseBlock = condBrOp.getFalseBlock();
1208  rewriter.setInsertionPointToEnd(currentBlock);
1209  rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1210  condBrOp.getTrueTargetOperands(),
1211  falseBlock,
1212  condBrOp.getFalseTargetOperands());
1213 
1214  rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1215  rewriter.replaceOp(op, continueBlock->getArguments());
1216  return success();
1217  }
1218 };
1219 
1220 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1221 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1222 /// `Shift` is zero or sign extended to match this specification. Cases when
1223 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1224 template <typename SPIRVOp, typename LLVMOp>
1225 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1226 public:
1228 
1230  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1231  ConversionPatternRewriter &rewriter) const override {
1232 
1233  auto dstType = this->typeConverter.convertType(op.getType());
1234  if (!dstType)
1235  return rewriter.notifyMatchFailure(op, "type conversion failed");
1236 
1237  Type op1Type = op.getOperand1().getType();
1238  Type op2Type = op.getOperand2().getType();
1239 
1240  if (op1Type == op2Type) {
1241  rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1242  adaptor.getOperands());
1243  return success();
1244  }
1245 
1246  std::optional<uint64_t> dstTypeWidth =
1248  std::optional<uint64_t> op2TypeWidth =
1250 
1251  if (!dstTypeWidth || !op2TypeWidth)
1252  return failure();
1253 
1254  Location loc = op.getLoc();
1255  Value extended;
1256  if (op2TypeWidth < dstTypeWidth) {
1257  if (isUnsignedIntegerOrVector(op2Type)) {
1258  extended = rewriter.template create<LLVM::ZExtOp>(
1259  loc, dstType, adaptor.getOperand2());
1260  } else {
1261  extended = rewriter.template create<LLVM::SExtOp>(
1262  loc, dstType, adaptor.getOperand2());
1263  }
1264  } else if (op2TypeWidth == dstTypeWidth) {
1265  extended = adaptor.getOperand2();
1266  } else {
1267  return failure();
1268  }
1269 
1270  Value result = rewriter.template create<LLVMOp>(
1271  loc, dstType, adaptor.getOperand1(), extended);
1272  rewriter.replaceOp(op, result);
1273  return success();
1274  }
1275 };
1276 
1277 class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1278 public:
1280 
1282  matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1283  ConversionPatternRewriter &rewriter) const override {
1284  auto dstType = typeConverter.convertType(tanOp.getType());
1285  if (!dstType)
1286  return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
1287 
1288  Location loc = tanOp.getLoc();
1289  Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1290  Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1291  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1292  return success();
1293  }
1294 };
1295 
1296 /// Convert `spirv.Tanh` to
1297 ///
1298 /// exp(2x) - 1
1299 /// -----------
1300 /// exp(2x) + 1
1301 ///
1302 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1303 public:
1305 
1307  matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1308  ConversionPatternRewriter &rewriter) const override {
1309  auto srcType = tanhOp.getType();
1310  auto dstType = typeConverter.convertType(srcType);
1311  if (!dstType)
1312  return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
1313 
1314  Location loc = tanhOp.getLoc();
1315  Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1316  Value multiplied =
1317  rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1318  Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1319  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1320  Value numerator =
1321  rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1322  Value denominator =
1323  rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1324  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1325  denominator);
1326  return success();
1327  }
1328 };
1329 
1330 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1331 public:
1333 
1335  matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1336  ConversionPatternRewriter &rewriter) const override {
1337  auto srcType = varOp.getType();
1338  // Initialization is supported for scalars and vectors only.
1339  auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1340  auto init = varOp.getInitializer();
1341  if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1342  return failure();
1343 
1344  auto dstType = typeConverter.convertType(srcType);
1345  if (!dstType)
1346  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1347 
1348  Location loc = varOp.getLoc();
1349  Value size = createI32ConstantOf(loc, rewriter, 1);
1350  if (!init) {
1351  auto elementType = typeConverter.convertType(pointerTo);
1352  if (!elementType)
1353  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1354  rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
1355  size);
1356  return success();
1357  }
1358  auto elementType = typeConverter.convertType(pointerTo);
1359  if (!elementType)
1360  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1361  Value allocated =
1362  rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
1363  rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1364  rewriter.replaceOp(varOp, allocated);
1365  return success();
1366  }
1367 };
1368 
1369 //===----------------------------------------------------------------------===//
1370 // BitcastOp conversion
1371 //===----------------------------------------------------------------------===//
1372 
1373 class BitcastConversionPattern
1374  : public SPIRVToLLVMConversion<spirv::BitcastOp> {
1375 public:
1377 
1379  matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1380  ConversionPatternRewriter &rewriter) const override {
1381  auto dstType = typeConverter.convertType(bitcastOp.getType());
1382  if (!dstType)
1383  return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
1384 
1385  // LLVM's opaque pointers do not require bitcasts.
1386  if (isa<LLVM::LLVMPointerType>(dstType)) {
1387  rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1388  return success();
1389  }
1390 
1391  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1392  bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1393  return success();
1394  }
1395 };
1396 
1397 //===----------------------------------------------------------------------===//
1398 // FuncOp conversion
1399 //===----------------------------------------------------------------------===//
1400 
1401 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1402 public:
1404 
1406  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1407  ConversionPatternRewriter &rewriter) const override {
1408 
1409  // Convert function signature. At the moment LLVMType converter is enough
1410  // for currently supported types.
1411  auto funcType = funcOp.getFunctionType();
1412  TypeConverter::SignatureConversion signatureConverter(
1413  funcType.getNumInputs());
1414  auto llvmType = typeConverter.convertFunctionSignature(
1415  funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false,
1416  signatureConverter);
1417  if (!llvmType)
1418  return failure();
1419 
1420  // Create a new `LLVMFuncOp`
1421  Location loc = funcOp.getLoc();
1422  StringRef name = funcOp.getName();
1423  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1424 
1425  // Convert SPIR-V Function Control to equivalent LLVM function attribute
1426  MLIRContext *context = funcOp.getContext();
1427  switch (funcOp.getFunctionControl()) {
1428 #define DISPATCH(functionControl, llvmAttr) \
1429  case functionControl: \
1430  newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1431  break;
1432 
1433  DISPATCH(spirv::FunctionControl::Inline,
1434  StringAttr::get(context, "alwaysinline"));
1435  DISPATCH(spirv::FunctionControl::DontInline,
1436  StringAttr::get(context, "noinline"));
1437  DISPATCH(spirv::FunctionControl::Pure,
1438  StringAttr::get(context, "readonly"));
1439  DISPATCH(spirv::FunctionControl::Const,
1440  StringAttr::get(context, "readnone"));
1441 
1442 #undef DISPATCH
1443 
1444  // Default: if `spirv::FunctionControl::None`, then no attributes are
1445  // needed.
1446  default:
1447  break;
1448  }
1449 
1450  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1451  newFuncOp.end());
1452  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
1453  &signatureConverter))) {
1454  return failure();
1455  }
1456  rewriter.eraseOp(funcOp);
1457  return success();
1458  }
1459 };
1460 
1461 //===----------------------------------------------------------------------===//
1462 // ModuleOp conversion
1463 //===----------------------------------------------------------------------===//
1464 
1465 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1466 public:
1468 
1470  matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1471  ConversionPatternRewriter &rewriter) const override {
1472 
1473  auto newModuleOp =
1474  rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1475  rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1476 
1477  // Remove the terminator block that was automatically added by builder
1478  rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1479  rewriter.eraseOp(spvModuleOp);
1480  return success();
1481  }
1482 };
1483 
1484 //===----------------------------------------------------------------------===//
1485 // VectorShuffleOp conversion
1486 //===----------------------------------------------------------------------===//
1487 
1488 class VectorShufflePattern
1489  : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1490 public:
1493  matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1494  ConversionPatternRewriter &rewriter) const override {
1495  Location loc = op.getLoc();
1496  auto components = adaptor.getComponents();
1497  auto vector1 = adaptor.getVector1();
1498  auto vector2 = adaptor.getVector2();
1499  int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1500  int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1501  if (vector1Size == vector2Size) {
1502  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1503  op, vector1, vector2,
1504  LLVM::convertArrayToIndices<int32_t>(components));
1505  return success();
1506  }
1507 
1508  auto dstType = typeConverter.convertType(op.getType());
1509  if (!dstType)
1510  return rewriter.notifyMatchFailure(op, "type conversion failed");
1511  auto scalarType = cast<VectorType>(dstType).getElementType();
1512  auto componentsArray = components.getValue();
1513  auto *context = rewriter.getContext();
1514  auto llvmI32Type = IntegerType::get(context, 32);
1515  Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
1516  for (unsigned i = 0; i < componentsArray.size(); i++) {
1517  if (!isa<IntegerAttr>(componentsArray[i]))
1518  return op.emitError("unable to support non-constant component");
1519 
1520  int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1521  if (indexVal == -1)
1522  continue;
1523 
1524  int offsetVal = 0;
1525  Value baseVector = vector1;
1526  if (indexVal >= vector1Size) {
1527  offsetVal = vector1Size;
1528  baseVector = vector2;
1529  }
1530 
1531  Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1532  loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1533  Value index = rewriter.create<LLVM::ConstantOp>(
1534  loc, llvmI32Type,
1535  rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1536 
1537  auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1538  loc, scalarType, baseVector, index);
1539  targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1540  extractOp, dstIndex);
1541  }
1542  rewriter.replaceOp(op, targetOp);
1543  return success();
1544  }
1545 };
1546 } // namespace
1547 
1548 //===----------------------------------------------------------------------===//
1549 // Pattern population
1550 //===----------------------------------------------------------------------===//
1551 
1553  spirv::ClientAPI clientAPI) {
1554  typeConverter.addConversion([&](spirv::ArrayType type) {
1555  return convertArrayType(type, typeConverter);
1556  });
1557  typeConverter.addConversion([&, clientAPI](spirv::PointerType type) {
1558  return convertPointerType(type, typeConverter, clientAPI);
1559  });
1560  typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1561  return convertRuntimeArrayType(type, typeConverter);
1562  });
1563  typeConverter.addConversion([&](spirv::StructType type) {
1564  return convertStructType(type, typeConverter);
1565  });
1566 }
1567 
1569  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
1570  spirv::ClientAPI clientAPI) {
1571  patterns.add<
1572  // Arithmetic ops
1573  DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1574  DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1575  DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1576  DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1577  DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1578  DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1579  DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1580  DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1581  DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1582  DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1583  DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1584  DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1585  DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1586 
1587  // Bitwise ops
1588  BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1589  DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1590  DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1591  DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1592  DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1593  DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1594  NotPattern<spirv::NotOp>,
1595 
1596  // Cast ops
1597  BitcastConversionPattern,
1598  DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1599  DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1600  DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1601  DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1602  IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1603  IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1604  IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1605 
1606  // Comparison ops
1607  IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1608  IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1609  FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1610  FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1611  FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1612  FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1613  FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1614  FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1615  FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1616  FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1617  FComparePattern<spirv::FUnordGreaterThanEqualOp,
1618  LLVM::FCmpPredicate::uge>,
1619  FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1620  FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1621  FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1622  IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1623  IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1624  IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1625  IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1626  IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1627  IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1628  IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1629  IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1630 
1631  // Constant op
1632  ConstantScalarAndVectorPattern,
1633 
1634  // Control Flow ops
1635  BranchConversionPattern, BranchConditionalConversionPattern,
1636  FunctionCallPattern, LoopPattern, SelectionPattern,
1637  ErasePattern<spirv::MergeOp>,
1638 
1639  // Entry points and execution mode are handled separately.
1640  ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1641 
1642  // GLSL extended instruction set ops
1643  DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1644  DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1645  DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1646  DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1647  DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1648  DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1649  DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1650  DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1651  DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1652  DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1653  DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1654  DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1655  InverseSqrtPattern, TanPattern, TanhPattern,
1656 
1657  // Logical ops
1658  DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1659  DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1660  IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1661  IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1662  NotPattern<spirv::LogicalNotOp>,
1663 
1664  // Memory ops
1665  AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1666  LoadStorePattern<spirv::StoreOp>, VariablePattern,
1667 
1668  // Miscellaneous ops
1669  CompositeExtractPattern, CompositeInsertPattern,
1670  DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1671  DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1672  VectorShufflePattern,
1673 
1674  // Shift ops
1675  ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1676  ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1677  ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1678 
1679  // Return ops
1680  ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
1681 
1682  patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
1683  typeConverter);
1684 }
1685 
1687  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1688  patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1689 }
1690 
1692  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1693  patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1694 }
1695 
1696 //===----------------------------------------------------------------------===//
1697 // Pre-conversion hooks
1698 //===----------------------------------------------------------------------===//
1699 
1700 /// Hook for descriptor set and binding number encoding.
1701 static constexpr StringRef kBinding = "binding";
1702 static constexpr StringRef kDescriptorSet = "descriptor_set";
1703 void mlir::encodeBindAttribute(ModuleOp module) {
1704  auto spvModules = module.getOps<spirv::ModuleOp>();
1705  for (auto spvModule : spvModules) {
1706  spvModule.walk([&](spirv::GlobalVariableOp op) {
1707  IntegerAttr descriptorSet =
1708  op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1709  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1710  // For every global variable in the module, get the ones with descriptor
1711  // set and binding numbers.
1712  if (descriptorSet && binding) {
1713  // Encode these numbers into the variable's symbolic name. If the
1714  // SPIR-V module has a name, add it at the beginning.
1715  auto moduleAndName =
1716  spvModule.getName().has_value()
1717  ? spvModule.getName()->str() + "_" + op.getSymName().str()
1718  : op.getSymName().str();
1719  std::string name =
1720  llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1721  std::to_string(descriptorSet.getInt()),
1722  std::to_string(binding.getInt()));
1723  auto nameAttr = StringAttr::get(op->getContext(), name);
1724 
1725  // Replace all symbol uses and set the new symbol name. Finally, remove
1726  // descriptor set and binding attributes.
1727  if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
1728  op.emitError("unable to replace all symbol uses for ") << name;
1729  SymbolTable::setSymbolName(op, nameAttr);
1731  op->removeAttr(kBinding);
1732  }
1733  });
1734  }
1735 }
static MLIRContext * getContext(OpFoldResult val)
@ None
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertStructTypeWithOffset(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, LLVMTypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned mapToOpenCLAddressSpace(spirv::StorageClass storageClass)
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
Definition: SPIRVToLLVM.cpp:72
constexpr unsigned defaultAddressSpace
Definition: SPIRVToLLVM.cpp:36
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
#define STORAGE_SPACE_MAP(storage, space)
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
Definition: SPIRVToLLVM.cpp:43
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
Definition: SPIRVToLLVM.cpp:52
static Type convertStructTypePacked(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static unsigned mapToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
static Type convertStructType(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
Definition: SPIRVToLLVM.cpp:93
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
Definition: SPIRVToLLVM.cpp:85
static Value optionallyBroadcast(Location loc, Value value, Type srcType, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
Definition: SPIRVToLLVM.cpp:62
#define CLIENT_MAP(client, storage)
static Type convertPointerType(spirv::PointerType type, LLVMTypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:137
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
OpListType & getOperations()
Definition: Block.h:134
BlockArgListType getArguments()
Definition: Block.h:84
iterator begin()
Definition: Block.h:140
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:453
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:447
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:450
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:444
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition: Operation.h:595
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:79
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:91
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:119
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:21
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
Definition: SPIRVTypes.cpp:66
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:64
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:487
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:548
SPIR-V struct type.
Definition: SPIRVTypes.h:293
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:876
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:858
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Definition: LLVMDialect.h:220
Type getVectorElementType(Type type)
Returns the element type of any vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:892
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateSPIRVToLLVMFunctionConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateSPIRVToLLVMConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
void populateSPIRVToLLVMModuleConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26