MLIR  18.0.0git
SPIRVToLLVM.cpp
Go to the documentation of this file.
1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 #define DEBUG_TYPE "spirv-to-llvm-pattern"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Constants
34 //===----------------------------------------------------------------------===//
35 
36 constexpr unsigned defaultAddressSpace = 0;
37 
38 //===----------------------------------------------------------------------===//
39 // Utility functions
40 //===----------------------------------------------------------------------===//
41 
42 /// Returns true if the given type is a signed integer or vector type.
43 static bool isSignedIntegerOrVector(Type type) {
44  if (type.isSignedInteger())
45  return true;
46  if (auto vecType = dyn_cast<VectorType>(type))
47  return vecType.getElementType().isSignedInteger();
48  return false;
49 }
50 
51 /// Returns true if the given type is an unsigned integer or vector type
52 static bool isUnsignedIntegerOrVector(Type type) {
53  if (type.isUnsignedInteger())
54  return true;
55  if (auto vecType = dyn_cast<VectorType>(type))
56  return vecType.getElementType().isUnsignedInteger();
57  return false;
58 }
59 
60 /// Returns the width of an integer or of the element type of an integer vector,
61 /// if applicable.
62 static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
63  if (auto intType = dyn_cast<IntegerType>(type))
64  return intType.getWidth();
65  if (auto vecType = dyn_cast<VectorType>(type))
66  if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
67  return intType.getWidth();
68  return std::nullopt;
69 }
70 
71 /// Returns the bit width of integer, float or vector of float or integer values
72 static unsigned getBitWidth(Type type) {
73  assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
74  "bitwidth is not supported for this type");
75  if (type.isIntOrFloat())
76  return type.getIntOrFloatBitWidth();
77  auto vecType = dyn_cast<VectorType>(type);
78  auto elementType = vecType.getElementType();
79  assert(elementType.isIntOrFloat() &&
80  "only integers and floats have a bitwidth");
81  return elementType.getIntOrFloatBitWidth();
82 }
83 
84 /// Returns the bit width of LLVMType integer or vector.
85 static unsigned getLLVMTypeBitWidth(Type type) {
86  return cast<IntegerType>((LLVM::isCompatibleVectorType(type)
88  : type))
89  .getWidth();
90 }
91 
92 /// Creates `IntegerAttribute` with all bits set for given type
93 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
94  if (auto vecType = dyn_cast<VectorType>(type)) {
95  auto integerType = cast<IntegerType>(vecType.getElementType());
96  return builder.getIntegerAttr(integerType, -1);
97  }
98  auto integerType = cast<IntegerType>(type);
99  return builder.getIntegerAttr(integerType, -1);
100 }
101 
102 /// Creates `llvm.mlir.constant` with all bits set for the given type.
103 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
104  PatternRewriter &rewriter) {
105  if (isa<VectorType>(srcType)) {
106  return rewriter.create<LLVM::ConstantOp>(
107  loc, dstType,
108  SplatElementsAttr::get(cast<ShapedType>(srcType),
109  minusOneIntegerAttribute(srcType, rewriter)));
110  }
111  return rewriter.create<LLVM::ConstantOp>(
112  loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
113 }
114 
115 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
116 static Value createFPConstant(Location loc, Type srcType, Type dstType,
117  PatternRewriter &rewriter, double value) {
118  if (auto vecType = dyn_cast<VectorType>(srcType)) {
119  auto floatType = cast<FloatType>(vecType.getElementType());
120  return rewriter.create<LLVM::ConstantOp>(
121  loc, dstType,
122  SplatElementsAttr::get(vecType,
123  rewriter.getFloatAttr(floatType, value)));
124  }
125  auto floatType = cast<FloatType>(srcType);
126  return rewriter.create<LLVM::ConstantOp>(
127  loc, dstType, rewriter.getFloatAttr(floatType, value));
128 }
129 
130 /// Utility function for bitfield ops:
131 /// - `BitFieldInsert`
132 /// - `BitFieldSExtract`
133 /// - `BitFieldUExtract`
134 /// Truncates or extends the value. If the bitwidth of the value is the same as
135 /// `llvmType` bitwidth, the value remains unchanged.
137  Type llvmType,
138  PatternRewriter &rewriter) {
139  auto srcType = value.getType();
140  unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
141  unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
142  ? getLLVMTypeBitWidth(srcType)
143  : getBitWidth(srcType);
144 
145  if (valueBitWidth < targetBitWidth)
146  return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
147  // If the bit widths of `Count` and `Offset` are greater than the bit width
148  // of the target type, they are truncated. Truncation is safe since `Count`
149  // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
150  // both values can be expressed in 8 bits.
151  if (valueBitWidth > targetBitWidth)
152  return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
153  return value;
154 }
155 
156 /// Broadcasts the value to vector with `numElements` number of elements.
157 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
158  LLVMTypeConverter &typeConverter,
159  ConversionPatternRewriter &rewriter) {
160  auto vectorType = VectorType::get(numElements, toBroadcast.getType());
161  auto llvmVectorType = typeConverter.convertType(vectorType);
162  auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
163  Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
164  for (unsigned i = 0; i < numElements; ++i) {
165  auto index = rewriter.create<LLVM::ConstantOp>(
166  loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
167  broadcasted = rewriter.create<LLVM::InsertElementOp>(
168  loc, llvmVectorType, broadcasted, toBroadcast, index);
169  }
170  return broadcasted;
171 }
172 
173 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
174 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
175  LLVMTypeConverter &typeConverter,
176  ConversionPatternRewriter &rewriter) {
177  if (auto vectorType = dyn_cast<VectorType>(srcType)) {
178  unsigned numElements = vectorType.getNumElements();
179  return broadcast(loc, value, numElements, typeConverter, rewriter);
180  }
181  return value;
182 }
183 
184 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
185 /// `BitFieldUExtract`.
186 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
187 /// a vector type, construct a vector that has:
188 /// - same number of elements as `Base`
189 /// - each element has the type that is the same as the type of `Offset` or
190 /// `Count`
191 /// - each element has the same value as `Offset` or `Count`
192 /// Then cast `Offset` and `Count` if their bit width is different
193 /// from `Base` bit width.
194 static Value processCountOrOffset(Location loc, Value value, Type srcType,
195  Type dstType, LLVMTypeConverter &converter,
196  ConversionPatternRewriter &rewriter) {
197  Value broadcasted =
198  optionallyBroadcast(loc, value, srcType, converter, rewriter);
199  return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
200 }
201 
202 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
203 /// offset to LLVM struct. Otherwise, the conversion is not supported.
205  LLVMTypeConverter &converter) {
206  if (type != VulkanLayoutUtils::decorateType(type))
207  return nullptr;
208 
209  SmallVector<Type> elementsVector;
210  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
211  return nullptr;
212  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
213  /*isPacked=*/false);
214 }
215 
216 /// Converts SPIR-V struct with no offset to packed LLVM struct.
218  LLVMTypeConverter &converter) {
219  SmallVector<Type> elementsVector;
220  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
221  return nullptr;
222  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
223  /*isPacked=*/true);
224 }
225 
226 /// Creates LLVM dialect constant with the given value.
228  unsigned value) {
229  return rewriter.create<LLVM::ConstantOp>(
230  loc, IntegerType::get(rewriter.getContext(), 32),
231  rewriter.getIntegerAttr(rewriter.getI32Type(), value));
232 }
233 
234 /// Utility for `spirv.Load` and `spirv.Store` conversion.
236  ConversionPatternRewriter &rewriter,
237  LLVMTypeConverter &typeConverter,
238  unsigned alignment, bool isVolatile,
239  bool isNonTemporal) {
240  if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
241  auto dstType = typeConverter.convertType(loadOp.getType());
242  if (!dstType)
243  return failure();
244  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
245  loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
246  isVolatile, isNonTemporal);
247  return success();
248  }
249  auto storeOp = cast<spirv::StoreOp>(op);
250  spirv::StoreOpAdaptor adaptor(operands);
251  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
252  adaptor.getPtr(), alignment,
253  isVolatile, isNonTemporal);
254  return success();
255 }
256 
257 //===----------------------------------------------------------------------===//
258 // Type conversion
259 //===----------------------------------------------------------------------===//
260 
261 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
262 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
263 /// when converting ops that manipulate array types.
264 static std::optional<Type> convertArrayType(spirv::ArrayType type,
265  TypeConverter &converter) {
266  unsigned stride = type.getArrayStride();
267  Type elementType = type.getElementType();
268  auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
269  if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
270  return std::nullopt;
271 
272  auto llvmElementType = converter.convertType(elementType);
273  unsigned numElements = type.getNumElements();
274  return LLVM::LLVMArrayType::get(llvmElementType, numElements);
275 }
276 
277 static unsigned mapToOpenCLAddressSpace(spirv::StorageClass storageClass) {
278  // Based on
279  // https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_binary_form
280  // and clang/lib/Basic/Targets/SPIR.h.
281  switch (storageClass) {
282 #define STORAGE_SPACE_MAP(storage, space) \
283  case spirv::StorageClass::storage: \
284  return space;
285  STORAGE_SPACE_MAP(Function, 0)
286  STORAGE_SPACE_MAP(CrossWorkgroup, 1)
287  STORAGE_SPACE_MAP(Input, 1)
288  STORAGE_SPACE_MAP(UniformConstant, 2)
289  STORAGE_SPACE_MAP(Workgroup, 3)
290  STORAGE_SPACE_MAP(Generic, 4)
291  STORAGE_SPACE_MAP(DeviceOnlyINTEL, 5)
292  STORAGE_SPACE_MAP(HostOnlyINTEL, 6)
293 #undef STORAGE_SPACE_MAP
294  default:
295  return defaultAddressSpace;
296  }
297 }
298 
299 static unsigned mapToAddressSpace(spirv::ClientAPI clientAPI,
300  spirv::StorageClass storageClass) {
301  switch (clientAPI) {
302 #define CLIENT_MAP(client, storage) \
303  case spirv::ClientAPI::client: \
304  return mapTo##client##AddressSpace(storage);
305  CLIENT_MAP(OpenCL, storageClass)
306 #undef CLIENT_MAP
307  default:
308  return defaultAddressSpace;
309  }
310 }
311 
312 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
313 /// modelled at the moment.
315  LLVMTypeConverter &converter,
316  spirv::ClientAPI clientAPI) {
317  unsigned addressSpace = mapToAddressSpace(clientAPI, type.getStorageClass());
318  return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
319 }
320 
321 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
322 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
323 /// no modelling of array stride at the moment.
324 static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
325  TypeConverter &converter) {
326  if (type.getArrayStride() != 0)
327  return std::nullopt;
328  auto elementType = converter.convertType(type.getElementType());
329  return LLVM::LLVMArrayType::get(elementType, 0);
330 }
331 
332 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
333 /// member decorations. Also, only natural offset is supported.
335  LLVMTypeConverter &converter) {
337  type.getMemberDecorations(memberDecorations);
338  if (!memberDecorations.empty())
339  return nullptr;
340  if (type.hasOffset())
341  return convertStructTypeWithOffset(type, converter);
342  return convertStructTypePacked(type, converter);
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // Operation conversion
347 //===----------------------------------------------------------------------===//
348 
349 namespace {
350 
351 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
352 public:
354 
356  matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
357  ConversionPatternRewriter &rewriter) const override {
358  auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
359  if (!dstType)
360  return failure();
361  // To use GEP we need to add a first 0 index to go through the pointer.
362  auto indices = llvm::to_vector<4>(adaptor.getIndices());
363  Type indexType = op.getIndices().front().getType();
364  auto llvmIndexType = typeConverter.convertType(indexType);
365  if (!llvmIndexType)
366  return failure();
367  Value zero = rewriter.create<LLVM::ConstantOp>(
368  op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
369  indices.insert(indices.begin(), zero);
370  rewriter.replaceOpWithNewOp<LLVM::GEPOp>(
371  op, dstType,
372  typeConverter.convertType(
373  cast<spirv::PointerType>(op.getBasePtr().getType())
374  .getPointeeType()),
375  adaptor.getBasePtr(), indices);
376  return success();
377  }
378 };
379 
380 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
381 public:
383 
385  matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
386  ConversionPatternRewriter &rewriter) const override {
387  auto dstType = typeConverter.convertType(op.getPointer().getType());
388  if (!dstType)
389  return failure();
390  rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
391  op.getVariable());
392  return success();
393  }
394 };
395 
396 class BitFieldInsertPattern
397  : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
398 public:
400 
402  matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
403  ConversionPatternRewriter &rewriter) const override {
404  auto srcType = op.getType();
405  auto dstType = typeConverter.convertType(srcType);
406  if (!dstType)
407  return failure();
408  Location loc = op.getLoc();
409 
410  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
411  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
412  typeConverter, rewriter);
413  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
414  typeConverter, rewriter);
415 
416  // Create a mask with bits set outside [Offset, Offset + Count - 1].
417  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
418  Value maskShiftedByCount =
419  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
420  Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
421  maskShiftedByCount, minusOne);
422  Value maskShiftedByCountAndOffset =
423  rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
424  Value mask = rewriter.create<LLVM::XOrOp>(
425  loc, dstType, maskShiftedByCountAndOffset, minusOne);
426 
427  // Extract unchanged bits from the `Base` that are outside of
428  // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
429  Value baseAndMask =
430  rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
431  Value insertShiftedByOffset =
432  rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
433  rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
434  insertShiftedByOffset);
435  return success();
436  }
437 };
438 
439 /// Converts SPIR-V ConstantOp with scalar or vector type.
440 class ConstantScalarAndVectorPattern
441  : public SPIRVToLLVMConversion<spirv::ConstantOp> {
442 public:
444 
446  matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
447  ConversionPatternRewriter &rewriter) const override {
448  auto srcType = constOp.getType();
449  if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
450  return failure();
451 
452  auto dstType = typeConverter.convertType(srcType);
453  if (!dstType)
454  return failure();
455 
456  // SPIR-V constant can be a signed/unsigned integer, which has to be
457  // casted to signless integer when converting to LLVM dialect. Removing the
458  // sign bit may have unexpected behaviour. However, it is better to handle
459  // it case-by-case, given that the purpose of the conversion is not to
460  // cover all possible corner cases.
461  if (isSignedIntegerOrVector(srcType) ||
462  isUnsignedIntegerOrVector(srcType)) {
463  auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
464 
465  if (isa<VectorType>(srcType)) {
466  auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
467  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
468  constOp, dstType,
469  dstElementsAttr.mapValues(
470  signlessType, [&](const APInt &value) { return value; }));
471  return success();
472  }
473  auto srcAttr = cast<IntegerAttr>(constOp.getValue());
474  auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
475  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
476  return success();
477  }
478  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
479  constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
480  return success();
481  }
482 };
483 
484 class BitFieldSExtractPattern
485  : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
486 public:
488 
490  matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
491  ConversionPatternRewriter &rewriter) const override {
492  auto srcType = op.getType();
493  auto dstType = typeConverter.convertType(srcType);
494  if (!dstType)
495  return failure();
496  Location loc = op.getLoc();
497 
498  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
499  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
500  typeConverter, rewriter);
501  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
502  typeConverter, rewriter);
503 
504  // Create a constant that holds the size of the `Base`.
505  IntegerType integerType;
506  if (auto vecType = dyn_cast<VectorType>(srcType))
507  integerType = cast<IntegerType>(vecType.getElementType());
508  else
509  integerType = cast<IntegerType>(srcType);
510 
511  auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
512  Value size =
513  isa<VectorType>(srcType)
514  ? rewriter.create<LLVM::ConstantOp>(
515  loc, dstType,
516  SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
517  : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
518 
519  // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
520  // at Offset + Count - 1 is the most significant bit now.
521  Value countPlusOffset =
522  rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
523  Value amountToShiftLeft =
524  rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
525  Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
526  loc, dstType, op.getBase(), amountToShiftLeft);
527 
528  // Shift the result right, filling the bits with the sign bit.
529  Value amountToShiftRight =
530  rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
531  rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
532  amountToShiftRight);
533  return success();
534  }
535 };
536 
537 class BitFieldUExtractPattern
538  : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
539 public:
541 
543  matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
544  ConversionPatternRewriter &rewriter) const override {
545  auto srcType = op.getType();
546  auto dstType = typeConverter.convertType(srcType);
547  if (!dstType)
548  return failure();
549  Location loc = op.getLoc();
550 
551  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
552  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
553  typeConverter, rewriter);
554  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
555  typeConverter, rewriter);
556 
557  // Create a mask with bits set at [0, Count - 1].
558  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
559  Value maskShiftedByCount =
560  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
561  Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
562  minusOne);
563 
564  // Shift `Base` by `Offset` and apply the mask on it.
565  Value shiftedBase =
566  rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
567  rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
568  return success();
569  }
570 };
571 
572 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
573 public:
575 
577  matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
578  ConversionPatternRewriter &rewriter) const override {
579  rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
580  branchOp.getTarget());
581  return success();
582  }
583 };
584 
585 class BranchConditionalConversionPattern
586  : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
587 public:
588  using SPIRVToLLVMConversion<
589  spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
590 
592  matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
593  ConversionPatternRewriter &rewriter) const override {
594  // If branch weights exist, map them to 32-bit integer vector.
595  DenseI32ArrayAttr branchWeights = nullptr;
596  if (auto weights = op.getBranchWeights()) {
597  SmallVector<int32_t> weightValues;
598  for (auto weight : weights->getAsRange<IntegerAttr>())
599  weightValues.push_back(weight.getInt());
600  branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
601  }
602 
603  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
604  op, op.getCondition(), op.getTrueBlockArguments(),
605  op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
606  op.getFalseBlock());
607  return success();
608  }
609 };
610 
611 /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
612 /// type is an aggregate type (struct or array). Otherwise, converts to
613 /// `llvm.extractelement` that operates on vectors.
614 class CompositeExtractPattern
615  : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
616 public:
618 
620  matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
621  ConversionPatternRewriter &rewriter) const override {
622  auto dstType = this->typeConverter.convertType(op.getType());
623  if (!dstType)
624  return failure();
625 
626  Type containerType = op.getComposite().getType();
627  if (isa<VectorType>(containerType)) {
628  Location loc = op.getLoc();
629  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
630  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
631  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
632  op, dstType, adaptor.getComposite(), index);
633  return success();
634  }
635 
636  rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
637  op, adaptor.getComposite(),
638  LLVM::convertArrayToIndices(op.getIndices()));
639  return success();
640  }
641 };
642 
643 /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
644 /// type is an aggregate type (struct or array). Otherwise, converts to
645 /// `llvm.insertelement` that operates on vectors.
646 class CompositeInsertPattern
647  : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
648 public:
650 
652  matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
653  ConversionPatternRewriter &rewriter) const override {
654  auto dstType = this->typeConverter.convertType(op.getType());
655  if (!dstType)
656  return failure();
657 
658  Type containerType = op.getComposite().getType();
659  if (isa<VectorType>(containerType)) {
660  Location loc = op.getLoc();
661  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
662  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
663  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
664  op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
665  return success();
666  }
667 
668  rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
669  op, adaptor.getComposite(), adaptor.getObject(),
670  LLVM::convertArrayToIndices(op.getIndices()));
671  return success();
672  }
673 };
674 
675 /// Converts SPIR-V operations that have straightforward LLVM equivalent
676 /// into LLVM dialect operations.
677 template <typename SPIRVOp, typename LLVMOp>
678 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
679 public:
681 
683  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
684  ConversionPatternRewriter &rewriter) const override {
685  auto dstType = this->typeConverter.convertType(operation.getType());
686  if (!dstType)
687  return failure();
688  rewriter.template replaceOpWithNewOp<LLVMOp>(
689  operation, dstType, adaptor.getOperands(), operation->getAttrs());
690  return success();
691  }
692 };
693 
694 /// Converts `spirv.ExecutionMode` into a global struct constant that holds
695 /// execution mode information.
696 class ExecutionModePattern
697  : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
698 public:
700 
702  matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
703  ConversionPatternRewriter &rewriter) const override {
704  // First, create the global struct's name that would be associated with
705  // this entry point's execution mode. We set it to be:
706  // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
707  ModuleOp module = op->getParentOfType<ModuleOp>();
708  spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
709  std::string moduleName;
710  if (module.getName().has_value())
711  moduleName = "_" + module.getName()->str();
712  else
713  moduleName = "";
714  std::string executionModeInfoName = llvm::formatv(
715  "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
716  static_cast<uint32_t>(executionModeAttr.getValue()));
717 
718  MLIRContext *context = rewriter.getContext();
719  OpBuilder::InsertionGuard guard(rewriter);
720  rewriter.setInsertionPointToStart(module.getBody());
721 
722  // Create a struct type, corresponding to the C struct below.
723  // struct {
724  // int32_t executionMode;
725  // int32_t values[]; // optional values
726  // };
727  auto llvmI32Type = IntegerType::get(context, 32);
728  SmallVector<Type, 2> fields;
729  fields.push_back(llvmI32Type);
730  ArrayAttr values = op.getValues();
731  if (!values.empty()) {
732  auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
733  fields.push_back(arrayType);
734  }
735  auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
736 
737  // Create `llvm.mlir.global` with initializer region containing one block.
738  auto global = rewriter.create<LLVM::GlobalOp>(
739  UnknownLoc::get(context), structType, /*isConstant=*/true,
740  LLVM::Linkage::External, executionModeInfoName, Attribute(),
741  /*alignment=*/0);
742  Location loc = global.getLoc();
743  Region &region = global.getInitializerRegion();
744  Block *block = rewriter.createBlock(&region);
745 
746  // Initialize the struct and set the execution mode value.
747  rewriter.setInsertionPoint(block, block->begin());
748  Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
749  Value executionMode = rewriter.create<LLVM::ConstantOp>(
750  loc, llvmI32Type,
751  rewriter.getI32IntegerAttr(
752  static_cast<uint32_t>(executionModeAttr.getValue())));
753  structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
754  executionMode, 0);
755 
756  // Insert extra operands if they exist into execution mode info struct.
757  for (unsigned i = 0, e = values.size(); i < e; ++i) {
758  auto attr = values.getValue()[i];
759  Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
760  structValue = rewriter.create<LLVM::InsertValueOp>(
761  loc, structValue, entry, ArrayRef<int64_t>({1, i}));
762  }
763  rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
764  rewriter.eraseOp(op);
765  return success();
766  }
767 };
768 
769 /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
770 /// global returns a pointer, whereas in LLVM dialect the global holds an actual
771 /// value. This difference is handled by `spirv.mlir.addressof` and
772 /// `llvm.mlir.addressof`ops that both return a pointer.
773 class GlobalVariablePattern
774  : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
775 public:
776  template <typename... Args>
777  GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
778  : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
779  std::forward<Args>(args)...),
780  clientAPI(clientAPI) {}
781 
783  matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
784  ConversionPatternRewriter &rewriter) const override {
785  // Currently, there is no support of initialization with a constant value in
786  // SPIR-V dialect. Specialization constants are not considered as well.
787  if (op.getInitializer())
788  return failure();
789 
790  auto srcType = cast<spirv::PointerType>(op.getType());
791  auto dstType = typeConverter.convertType(srcType.getPointeeType());
792  if (!dstType)
793  return failure();
794 
795  // Limit conversion to the current invocation only or `StorageBuffer`
796  // required by SPIR-V runner.
797  // This is okay because multiple invocations are not supported yet.
798  auto storageClass = srcType.getStorageClass();
799  switch (storageClass) {
800  case spirv::StorageClass::Input:
801  case spirv::StorageClass::Private:
802  case spirv::StorageClass::Output:
803  case spirv::StorageClass::StorageBuffer:
804  case spirv::StorageClass::UniformConstant:
805  break;
806  default:
807  return failure();
808  }
809 
810  // LLVM dialect spec: "If the global value is a constant, storing into it is
811  // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
812  // storage class that is read-only.
813  bool isConstant = (storageClass == spirv::StorageClass::Input) ||
814  (storageClass == spirv::StorageClass::UniformConstant);
815  // SPIR-V spec: "By default, functions and global variables are private to a
816  // module and cannot be accessed by other modules. However, a module may be
817  // written to export or import functions and global (module scope)
818  // variables.". Therefore, map 'Private' storage class to private linkage,
819  // 'Input' and 'Output' to external linkage.
820  auto linkage = storageClass == spirv::StorageClass::Private
821  ? LLVM::Linkage::Private
822  : LLVM::Linkage::External;
823  auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
824  op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
825  /*alignment=*/0, mapToAddressSpace(clientAPI, storageClass));
826 
827  // Attach location attribute if applicable
828  if (op.getLocationAttr())
829  newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
830 
831  return success();
832  }
833 
834 private:
835  spirv::ClientAPI clientAPI;
836 };
837 
838 /// Converts SPIR-V cast ops that do not have straightforward LLVM
839 /// equivalent in LLVM dialect.
840 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
841 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
842 public:
844 
846  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
847  ConversionPatternRewriter &rewriter) const override {
848 
849  Type fromType = operation.getOperand().getType();
850  Type toType = operation.getType();
851 
852  auto dstType = this->typeConverter.convertType(toType);
853  if (!dstType)
854  return failure();
855 
856  if (getBitWidth(fromType) < getBitWidth(toType)) {
857  rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
858  adaptor.getOperands());
859  return success();
860  }
861  if (getBitWidth(fromType) > getBitWidth(toType)) {
862  rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
863  adaptor.getOperands());
864  return success();
865  }
866  return failure();
867  }
868 };
869 
870 class FunctionCallPattern
871  : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
872 public:
874 
876  matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
877  ConversionPatternRewriter &rewriter) const override {
878  if (callOp.getNumResults() == 0) {
879  rewriter.replaceOpWithNewOp<LLVM::CallOp>(
880  callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
881  return success();
882  }
883 
884  // Function returns a single result.
885  auto dstType = typeConverter.convertType(callOp.getType(0));
886  rewriter.replaceOpWithNewOp<LLVM::CallOp>(
887  callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
888  return success();
889  }
890 };
891 
892 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
893 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
894 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
895 public:
897 
899  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
900  ConversionPatternRewriter &rewriter) const override {
901 
902  auto dstType = this->typeConverter.convertType(operation.getType());
903  if (!dstType)
904  return failure();
905 
906  rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
907  operation, dstType, predicate, operation.getOperand1(),
908  operation.getOperand2());
909  return success();
910  }
911 };
912 
913 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
914 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
915 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
916 public:
918 
920  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
921  ConversionPatternRewriter &rewriter) const override {
922 
923  auto dstType = this->typeConverter.convertType(operation.getType());
924  if (!dstType)
925  return failure();
926 
927  rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
928  operation, dstType, predicate, operation.getOperand1(),
929  operation.getOperand2());
930  return success();
931  }
932 };
933 
934 class InverseSqrtPattern
935  : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
936 public:
938 
940  matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
941  ConversionPatternRewriter &rewriter) const override {
942  auto srcType = op.getType();
943  auto dstType = typeConverter.convertType(srcType);
944  if (!dstType)
945  return failure();
946 
947  Location loc = op.getLoc();
948  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
949  Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
950  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
951  return success();
952  }
953 };
954 
955 /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
956 template <typename SPIRVOp>
957 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
958 public:
960 
962  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
963  ConversionPatternRewriter &rewriter) const override {
964  if (!op.getMemoryAccess()) {
965  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
966  this->typeConverter, /*alignment=*/0,
967  /*isVolatile=*/false,
968  /*isNonTemporal=*/false);
969  }
970  auto memoryAccess = *op.getMemoryAccess();
971  switch (memoryAccess) {
972  case spirv::MemoryAccess::Aligned:
974  case spirv::MemoryAccess::Nontemporal:
975  case spirv::MemoryAccess::Volatile: {
976  unsigned alignment =
977  memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
978  bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
979  bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
980  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
981  this->typeConverter, alignment, isVolatile,
982  isNonTemporal);
983  }
984  default:
985  // There is no support of other memory access attributes.
986  return failure();
987  }
988  }
989 };
990 
991 /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
992 template <typename SPIRVOp>
993 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
994 public:
996 
998  matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
999  ConversionPatternRewriter &rewriter) const override {
1000  auto srcType = notOp.getType();
1001  auto dstType = this->typeConverter.convertType(srcType);
1002  if (!dstType)
1003  return failure();
1004 
1005  Location loc = notOp.getLoc();
1006  IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
1007  auto mask =
1008  isa<VectorType>(srcType)
1009  ? rewriter.create<LLVM::ConstantOp>(
1010  loc, dstType,
1011  SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
1012  : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
1013  rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
1014  notOp.getOperand(), mask);
1015  return success();
1016  }
1017 };
1018 
1019 /// A template pattern that erases the given `SPIRVOp`.
1020 template <typename SPIRVOp>
1021 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
1022 public:
1024 
1026  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1027  ConversionPatternRewriter &rewriter) const override {
1028  rewriter.eraseOp(op);
1029  return success();
1030  }
1031 };
1032 
1033 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
1034 public:
1036 
1038  matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1039  ConversionPatternRewriter &rewriter) const override {
1040  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
1041  ArrayRef<Value>());
1042  return success();
1043  }
1044 };
1045 
1046 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1047 public:
1049 
1051  matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1052  ConversionPatternRewriter &rewriter) const override {
1053  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
1054  adaptor.getOperands());
1055  return success();
1056  }
1057 };
1058 
1059 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1060 /// should be reachable for conversion to succeed. The structure of the loop in
1061 /// LLVM dialect will be the following:
1062 ///
1063 /// +------------------------------------+
1064 /// | <code before spirv.mlir.loop> |
1065 /// | llvm.br ^header |
1066 /// +------------------------------------+
1067 /// |
1068 /// +----------------+ |
1069 /// | | |
1070 /// | V V
1071 /// | +------------------------------------+
1072 /// | | ^header: |
1073 /// | | <header code> |
1074 /// | | llvm.cond_br %cond, ^body, ^exit |
1075 /// | +------------------------------------+
1076 /// | |
1077 /// | |----------------------+
1078 /// | | |
1079 /// | V |
1080 /// | +------------------------------------+ |
1081 /// | | ^body: | |
1082 /// | | <body code> | |
1083 /// | | llvm.br ^continue | |
1084 /// | +------------------------------------+ |
1085 /// | | |
1086 /// | V |
1087 /// | +------------------------------------+ |
1088 /// | | ^continue: | |
1089 /// | | <continue code> | |
1090 /// | | llvm.br ^header | |
1091 /// | +------------------------------------+ |
1092 /// | | |
1093 /// +---------------+ +----------------------+
1094 /// |
1095 /// V
1096 /// +------------------------------------+
1097 /// | ^exit: |
1098 /// | llvm.br ^remaining |
1099 /// +------------------------------------+
1100 /// |
1101 /// V
1102 /// +------------------------------------+
1103 /// | ^remaining: |
1104 /// | <code after spirv.mlir.loop> |
1105 /// +------------------------------------+
1106 ///
1107 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1108 public:
1110 
1112  matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1113  ConversionPatternRewriter &rewriter) const override {
1114  // There is no support of loop control at the moment.
1115  if (loopOp.getLoopControl() != spirv::LoopControl::None)
1116  return failure();
1117 
1118  Location loc = loopOp.getLoc();
1119 
1120  // Split the current block after `spirv.mlir.loop`. The remaining ops will
1121  // be used in `endBlock`.
1122  Block *currentBlock = rewriter.getBlock();
1123  auto position = Block::iterator(loopOp);
1124  Block *endBlock = rewriter.splitBlock(currentBlock, position);
1125 
1126  // Remove entry block and create a branch in the current block going to the
1127  // header block.
1128  Block *entryBlock = loopOp.getEntryBlock();
1129  assert(entryBlock->getOperations().size() == 1);
1130  auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1131  if (!brOp)
1132  return failure();
1133  Block *headerBlock = loopOp.getHeaderBlock();
1134  rewriter.setInsertionPointToEnd(currentBlock);
1135  rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1136  rewriter.eraseBlock(entryBlock);
1137 
1138  // Branch from merge block to end block.
1139  Block *mergeBlock = loopOp.getMergeBlock();
1140  Operation *terminator = mergeBlock->getTerminator();
1141  ValueRange terminatorOperands = terminator->getOperands();
1142  rewriter.setInsertionPointToEnd(mergeBlock);
1143  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1144 
1145  rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1146  rewriter.replaceOp(loopOp, endBlock->getArguments());
1147  return success();
1148  }
1149 };
1150 
1151 /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1152 /// block. All blocks within selection should be reachable for conversion to
1153 /// succeed.
1154 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1155 public:
1157 
1159  matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1160  ConversionPatternRewriter &rewriter) const override {
1161  // There is no support for `Flatten` or `DontFlatten` selection control at
1162  // the moment. This are just compiler hints and can be performed during the
1163  // optimization passes.
1164  if (op.getSelectionControl() != spirv::SelectionControl::None)
1165  return failure();
1166 
1167  // `spirv.mlir.selection` should have at least two blocks: one selection
1168  // header block and one merge block. If no blocks are present, or control
1169  // flow branches straight to merge block (two blocks are present), the op is
1170  // redundant and it is erased.
1171  if (op.getBody().getBlocks().size() <= 2) {
1172  rewriter.eraseOp(op);
1173  return success();
1174  }
1175 
1176  Location loc = op.getLoc();
1177 
1178  // Split the current block after `spirv.mlir.selection`. The remaining ops
1179  // will be used in `continueBlock`.
1180  auto *currentBlock = rewriter.getInsertionBlock();
1181  rewriter.setInsertionPointAfter(op);
1182  auto position = rewriter.getInsertionPoint();
1183  auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1184 
1185  // Extract conditional branch information from the header block. By SPIR-V
1186  // dialect spec, it should contain `spirv.BranchConditional` or
1187  // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1188  // moment in the SPIR-V dialect. Remove this block when finished.
1189  auto *headerBlock = op.getHeaderBlock();
1190  assert(headerBlock->getOperations().size() == 1);
1191  auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1192  headerBlock->getOperations().front());
1193  if (!condBrOp)
1194  return failure();
1195  rewriter.eraseBlock(headerBlock);
1196 
1197  // Branch from merge block to continue block.
1198  auto *mergeBlock = op.getMergeBlock();
1199  Operation *terminator = mergeBlock->getTerminator();
1200  ValueRange terminatorOperands = terminator->getOperands();
1201  rewriter.setInsertionPointToEnd(mergeBlock);
1202  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1203 
1204  // Link current block to `true` and `false` blocks within the selection.
1205  Block *trueBlock = condBrOp.getTrueBlock();
1206  Block *falseBlock = condBrOp.getFalseBlock();
1207  rewriter.setInsertionPointToEnd(currentBlock);
1208  rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1209  condBrOp.getTrueTargetOperands(),
1210  falseBlock,
1211  condBrOp.getFalseTargetOperands());
1212 
1213  rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1214  rewriter.replaceOp(op, continueBlock->getArguments());
1215  return success();
1216  }
1217 };
1218 
1219 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1220 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1221 /// `Shift` is zero or sign extended to match this specification. Cases when
1222 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1223 template <typename SPIRVOp, typename LLVMOp>
1224 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1225 public:
1227 
1229  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
1230  ConversionPatternRewriter &rewriter) const override {
1231 
1232  auto dstType = this->typeConverter.convertType(operation.getType());
1233  if (!dstType)
1234  return failure();
1235 
1236  Type op1Type = operation.getOperand1().getType();
1237  Type op2Type = operation.getOperand2().getType();
1238 
1239  if (op1Type == op2Type) {
1240  rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
1241  adaptor.getOperands());
1242  return success();
1243  }
1244 
1245  std::optional<uint64_t> dstTypeWidth =
1247  std::optional<uint64_t> op2TypeWidth =
1249 
1250  if (!dstTypeWidth || !op2TypeWidth)
1251  return failure();
1252 
1253  Location loc = operation.getLoc();
1254  Value extended;
1255  if (op2TypeWidth < dstTypeWidth) {
1256  if (isUnsignedIntegerOrVector(op2Type)) {
1257  extended = rewriter.template create<LLVM::ZExtOp>(
1258  loc, dstType, adaptor.getOperand2());
1259  } else {
1260  extended = rewriter.template create<LLVM::SExtOp>(
1261  loc, dstType, adaptor.getOperand2());
1262  }
1263  } else if (op2TypeWidth == dstTypeWidth) {
1264  extended = adaptor.getOperand2();
1265  } else {
1266  return failure();
1267  }
1268 
1269  Value result = rewriter.template create<LLVMOp>(
1270  loc, dstType, adaptor.getOperand1(), extended);
1271  rewriter.replaceOp(operation, result);
1272  return success();
1273  }
1274 };
1275 
1276 class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1277 public:
1279 
1281  matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1282  ConversionPatternRewriter &rewriter) const override {
1283  auto dstType = typeConverter.convertType(tanOp.getType());
1284  if (!dstType)
1285  return failure();
1286 
1287  Location loc = tanOp.getLoc();
1288  Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1289  Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1290  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1291  return success();
1292  }
1293 };
1294 
1295 /// Convert `spirv.Tanh` to
1296 ///
1297 /// exp(2x) - 1
1298 /// -----------
1299 /// exp(2x) + 1
1300 ///
1301 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1302 public:
1304 
1306  matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1307  ConversionPatternRewriter &rewriter) const override {
1308  auto srcType = tanhOp.getType();
1309  auto dstType = typeConverter.convertType(srcType);
1310  if (!dstType)
1311  return failure();
1312 
1313  Location loc = tanhOp.getLoc();
1314  Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1315  Value multiplied =
1316  rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1317  Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1318  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1319  Value numerator =
1320  rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1321  Value denominator =
1322  rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1323  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1324  denominator);
1325  return success();
1326  }
1327 };
1328 
1329 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1330 public:
1332 
1334  matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1335  ConversionPatternRewriter &rewriter) const override {
1336  auto srcType = varOp.getType();
1337  // Initialization is supported for scalars and vectors only.
1338  auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1339  auto init = varOp.getInitializer();
1340  if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1341  return failure();
1342 
1343  auto dstType = typeConverter.convertType(srcType);
1344  if (!dstType)
1345  return failure();
1346 
1347  Location loc = varOp.getLoc();
1348  Value size = createI32ConstantOf(loc, rewriter, 1);
1349  if (!init) {
1350  rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(
1351  varOp, dstType, typeConverter.convertType(pointerTo), size);
1352  return success();
1353  }
1354  Value allocated = rewriter.create<LLVM::AllocaOp>(
1355  loc, dstType, typeConverter.convertType(pointerTo), size);
1356  rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1357  rewriter.replaceOp(varOp, allocated);
1358  return success();
1359  }
1360 };
1361 
1362 //===----------------------------------------------------------------------===//
1363 // BitcastOp conversion
1364 //===----------------------------------------------------------------------===//
1365 
1366 class BitcastConversionPattern
1367  : public SPIRVToLLVMConversion<spirv::BitcastOp> {
1368 public:
1370 
1372  matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1373  ConversionPatternRewriter &rewriter) const override {
1374  auto dstType = typeConverter.convertType(bitcastOp.getType());
1375  if (!dstType)
1376  return failure();
1377 
1378  // LLVM's opaque pointers do not require bitcasts.
1379  if (isa<LLVM::LLVMPointerType>(dstType)) {
1380  rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1381  return success();
1382  }
1383 
1384  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1385  bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1386  return success();
1387  }
1388 };
1389 
1390 //===----------------------------------------------------------------------===//
1391 // FuncOp conversion
1392 //===----------------------------------------------------------------------===//
1393 
1394 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1395 public:
1397 
1399  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1400  ConversionPatternRewriter &rewriter) const override {
1401 
1402  // Convert function signature. At the moment LLVMType converter is enough
1403  // for currently supported types.
1404  auto funcType = funcOp.getFunctionType();
1405  TypeConverter::SignatureConversion signatureConverter(
1406  funcType.getNumInputs());
1407  auto llvmType = typeConverter.convertFunctionSignature(
1408  funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false,
1409  signatureConverter);
1410  if (!llvmType)
1411  return failure();
1412 
1413  // Create a new `LLVMFuncOp`
1414  Location loc = funcOp.getLoc();
1415  StringRef name = funcOp.getName();
1416  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1417 
1418  // Convert SPIR-V Function Control to equivalent LLVM function attribute
1419  MLIRContext *context = funcOp.getContext();
1420  switch (funcOp.getFunctionControl()) {
1421 #define DISPATCH(functionControl, llvmAttr) \
1422  case functionControl: \
1423  newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1424  break;
1425 
1426  DISPATCH(spirv::FunctionControl::Inline,
1427  StringAttr::get(context, "alwaysinline"));
1428  DISPATCH(spirv::FunctionControl::DontInline,
1429  StringAttr::get(context, "noinline"));
1430  DISPATCH(spirv::FunctionControl::Pure,
1431  StringAttr::get(context, "readonly"));
1432  DISPATCH(spirv::FunctionControl::Const,
1433  StringAttr::get(context, "readnone"));
1434 
1435 #undef DISPATCH
1436 
1437  // Default: if `spirv::FunctionControl::None`, then no attributes are
1438  // needed.
1439  default:
1440  break;
1441  }
1442 
1443  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1444  newFuncOp.end());
1445  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
1446  &signatureConverter))) {
1447  return failure();
1448  }
1449  rewriter.eraseOp(funcOp);
1450  return success();
1451  }
1452 };
1453 
1454 //===----------------------------------------------------------------------===//
1455 // ModuleOp conversion
1456 //===----------------------------------------------------------------------===//
1457 
1458 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1459 public:
1461 
1463  matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1464  ConversionPatternRewriter &rewriter) const override {
1465 
1466  auto newModuleOp =
1467  rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1468  rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1469 
1470  // Remove the terminator block that was automatically added by builder
1471  rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1472  rewriter.eraseOp(spvModuleOp);
1473  return success();
1474  }
1475 };
1476 
1477 //===----------------------------------------------------------------------===//
1478 // VectorShuffleOp conversion
1479 //===----------------------------------------------------------------------===//
1480 
1481 class VectorShufflePattern
1482  : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1483 public:
1486  matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1487  ConversionPatternRewriter &rewriter) const override {
1488  Location loc = op.getLoc();
1489  auto components = adaptor.getComponents();
1490  auto vector1 = adaptor.getVector1();
1491  auto vector2 = adaptor.getVector2();
1492  int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1493  int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1494  if (vector1Size == vector2Size) {
1495  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1496  op, vector1, vector2,
1497  LLVM::convertArrayToIndices<int32_t>(components));
1498  return success();
1499  }
1500 
1501  auto dstType = typeConverter.convertType(op.getType());
1502  auto scalarType = cast<VectorType>(dstType).getElementType();
1503  auto componentsArray = components.getValue();
1504  auto *context = rewriter.getContext();
1505  auto llvmI32Type = IntegerType::get(context, 32);
1506  Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
1507  for (unsigned i = 0; i < componentsArray.size(); i++) {
1508  if (!isa<IntegerAttr>(componentsArray[i]))
1509  return op.emitError("unable to support non-constant component");
1510 
1511  int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1512  if (indexVal == -1)
1513  continue;
1514 
1515  int offsetVal = 0;
1516  Value baseVector = vector1;
1517  if (indexVal >= vector1Size) {
1518  offsetVal = vector1Size;
1519  baseVector = vector2;
1520  }
1521 
1522  Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1523  loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1524  Value index = rewriter.create<LLVM::ConstantOp>(
1525  loc, llvmI32Type,
1526  rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1527 
1528  auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1529  loc, scalarType, baseVector, index);
1530  targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1531  extractOp, dstIndex);
1532  }
1533  rewriter.replaceOp(op, targetOp);
1534  return success();
1535  }
1536 };
1537 } // namespace
1538 
1539 //===----------------------------------------------------------------------===//
1540 // Pattern population
1541 //===----------------------------------------------------------------------===//
1542 
1544  spirv::ClientAPI clientAPI) {
1545  typeConverter.addConversion([&](spirv::ArrayType type) {
1546  return convertArrayType(type, typeConverter);
1547  });
1548  typeConverter.addConversion([&, clientAPI](spirv::PointerType type) {
1549  return convertPointerType(type, typeConverter, clientAPI);
1550  });
1551  typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1552  return convertRuntimeArrayType(type, typeConverter);
1553  });
1554  typeConverter.addConversion([&](spirv::StructType type) {
1555  return convertStructType(type, typeConverter);
1556  });
1557 }
1558 
1560  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
1561  spirv::ClientAPI clientAPI) {
1562  patterns.add<
1563  // Arithmetic ops
1564  DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1565  DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1566  DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1567  DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1568  DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1569  DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1570  DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1571  DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1572  DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1573  DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1574  DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1575  DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1576  DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1577 
1578  // Bitwise ops
1579  BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1580  DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1581  DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1582  DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1583  DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1584  DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1585  NotPattern<spirv::NotOp>,
1586 
1587  // Cast ops
1588  BitcastConversionPattern,
1589  DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1590  DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1591  DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1592  DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1593  IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1594  IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1595  IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1596 
1597  // Comparison ops
1598  IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1599  IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1600  FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1601  FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1602  FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1603  FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1604  FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1605  FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1606  FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1607  FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1608  FComparePattern<spirv::FUnordGreaterThanEqualOp,
1609  LLVM::FCmpPredicate::uge>,
1610  FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1611  FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1612  FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1613  IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1614  IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1615  IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1616  IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1617  IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1618  IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1619  IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1620  IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1621 
1622  // Constant op
1623  ConstantScalarAndVectorPattern,
1624 
1625  // Control Flow ops
1626  BranchConversionPattern, BranchConditionalConversionPattern,
1627  FunctionCallPattern, LoopPattern, SelectionPattern,
1628  ErasePattern<spirv::MergeOp>,
1629 
1630  // Entry points and execution mode are handled separately.
1631  ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1632 
1633  // GLSL extended instruction set ops
1634  DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1635  DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1636  DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1637  DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1638  DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1639  DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1640  DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1641  DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1642  DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1643  DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1644  DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1645  DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1646  InverseSqrtPattern, TanPattern, TanhPattern,
1647 
1648  // Logical ops
1649  DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1650  DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1651  IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1652  IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1653  NotPattern<spirv::LogicalNotOp>,
1654 
1655  // Memory ops
1656  AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1657  LoadStorePattern<spirv::StoreOp>, VariablePattern,
1658 
1659  // Miscellaneous ops
1660  CompositeExtractPattern, CompositeInsertPattern,
1661  DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1662  DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1663  VectorShufflePattern,
1664 
1665  // Shift ops
1666  ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1667  ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1668  ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1669 
1670  // Return ops
1671  ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
1672 
1673  patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
1674  typeConverter);
1675 }
1676 
1678  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1679  patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1680 }
1681 
1683  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1684  patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1685 }
1686 
1687 //===----------------------------------------------------------------------===//
1688 // Pre-conversion hooks
1689 //===----------------------------------------------------------------------===//
1690 
1691 /// Hook for descriptor set and binding number encoding.
1692 static constexpr StringRef kBinding = "binding";
1693 static constexpr StringRef kDescriptorSet = "descriptor_set";
1694 void mlir::encodeBindAttribute(ModuleOp module) {
1695  auto spvModules = module.getOps<spirv::ModuleOp>();
1696  for (auto spvModule : spvModules) {
1697  spvModule.walk([&](spirv::GlobalVariableOp op) {
1698  IntegerAttr descriptorSet =
1699  op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1700  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1701  // For every global variable in the module, get the ones with descriptor
1702  // set and binding numbers.
1703  if (descriptorSet && binding) {
1704  // Encode these numbers into the variable's symbolic name. If the
1705  // SPIR-V module has a name, add it at the beginning.
1706  auto moduleAndName =
1707  spvModule.getName().has_value()
1708  ? spvModule.getName()->str() + "_" + op.getSymName().str()
1709  : op.getSymName().str();
1710  std::string name =
1711  llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1712  std::to_string(descriptorSet.getInt()),
1713  std::to_string(binding.getInt()));
1714  auto nameAttr = StringAttr::get(op->getContext(), name);
1715 
1716  // Replace all symbol uses and set the new symbol name. Finally, remove
1717  // descriptor set and binding attributes.
1718  if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
1719  op.emitError("unable to replace all symbol uses for ") << name;
1720  SymbolTable::setSymbolName(op, nameAttr);
1722  op->removeAttr(kBinding);
1723  }
1724  });
1725  }
1726 }
static MLIRContext * getContext(OpFoldResult val)
@ None
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertStructTypeWithOffset(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, LLVMTypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned mapToOpenCLAddressSpace(spirv::StorageClass storageClass)
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
Definition: SPIRVToLLVM.cpp:72
constexpr unsigned defaultAddressSpace
Definition: SPIRVToLLVM.cpp:36
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
#define STORAGE_SPACE_MAP(storage, space)
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
Definition: SPIRVToLLVM.cpp:43
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
Definition: SPIRVToLLVM.cpp:52
static Type convertStructTypePacked(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static unsigned mapToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
static Type convertStructType(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
Definition: SPIRVToLLVM.cpp:93
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
Definition: SPIRVToLLVM.cpp:85
static Value optionallyBroadcast(Location loc, Value value, Type srcType, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
Definition: SPIRVToLLVM.cpp:62
#define CLIENT_MAP(client, storage)
static Type convertPointerType(spirv::PointerType type, LLVMTypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:133
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
OpListType & getOperations()
Definition: Block.h:130
BlockArgListType getArguments()
Definition: Block.h:80
iterator begin()
Definition: Block.h:136
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:33
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:436
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:421
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:419
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:433
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:427
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:528
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition: Operation.h:578
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:77
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:89
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:117
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:21
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
Definition: SPIRVTypes.cpp:66
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:64
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:552
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:613
SPIR-V struct type.
Definition: SPIRVTypes.h:294
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:863
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:845
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Definition: LLVMDialect.h:221
Type getVectorElementType(Type type)
Returns the element type of any vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:879
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateSPIRVToLLVMFunctionConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateSPIRVToLLVMConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
void populateSPIRVToLLVMModuleConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26